Browse Source

add more checks

Misaki Kasumi 1 year ago
parent
commit
9bee648078

+ 1 - 1
transport/internet/socket_activation_other.go

@@ -8,6 +8,6 @@ import (
 	"github.com/v2fly/v2ray-core/v5/common/net"
 )
 
-func activate_socket(address string) (net.Listener, error) {
+func activateSocket(address string) (net.Listener, error) {
 	return nil, fmt.Errorf("socket activation is not supported on this platform")
 }

+ 31 - 5
transport/internet/socket_activation_unix.go

@@ -4,19 +4,45 @@
 package internet
 
 import (
+	"fmt"
 	"os"
+	"path"
 	"strconv"
 	"syscall"
 
 	"github.com/v2fly/v2ray-core/v5/common/net"
 )
 
-func activate_socket(address string) (net.Listener, error) {
-	fd, err := strconv.Atoi(address[8:])
+func activateSocket(address string) (net.Listener, error) {
+	fd, err := strconv.Atoi(path.Base(address))
 	if err != nil {
 		return nil, err
 	}
-	// Ignore the fail of SetNonblock: it's merely an optimization so that Go can poll this fd.
-	_ = syscall.SetNonblock(fd, true)
-	return net.FileListener(os.NewFile(uintptr(fd), address))
+
+	err = syscall.SetNonblock(fd, true)
+	if err != nil {
+		return nil, err
+	}
+
+	acceptConn, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_ACCEPTCONN)
+	if err != nil {
+		return nil, err
+	}
+	if acceptConn == 0 {
+		return nil, fmt.Errorf("socket '%s' has not been marked to accept connections", address)
+	}
+
+	sockType, err := syscall.GetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_TYPE)
+	if err != nil {
+		return nil, err
+	}
+	if sockType != syscall.SOCK_STREAM {
+		// XXX: currently only stream socks are supported
+		return nil, fmt.Errorf("socket '%s' is not a stream socket", address)
+	}
+
+	file := os.NewFile(uintptr(fd), address)
+	defer file.Close()
+
+	return net.FileListener(file)
 }

+ 1 - 1
transport/internet/system_listener.go

@@ -97,7 +97,7 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
 			}
 		} else if strings.HasPrefix(address, "/dev/fd/") {
 			// socket activation
-			l, err = activate_socket(address)
+			l, err = activateSocket(address)
 			if err != nil {
 				return nil, err
 			}