Преглед на файлове

system listener for both TCP and UDP

Darien Raymond преди 7 години
родител
ревизия
20251bf499

+ 3 - 0
common/net/system.go

@@ -8,6 +8,8 @@ var DialUDP = net.DialUDP
 var DialUnix = net.DialUnix
 var Dial = net.Dial
 
+type ListenConfig = net.ListenConfig
+
 var Listen = net.Listen
 var ListenTCP = net.ListenTCP
 var ListenUDP = net.ListenUDP
@@ -25,6 +27,7 @@ var CIDRMask = net.CIDRMask
 
 type Addr = net.Addr
 type Conn = net.Conn
+type PacketConn = net.PacketConn
 
 type TCPAddr = net.TCPAddr
 type TCPConn = net.TCPConn

+ 7 - 5
testing/servers/tcp/tcp.go

@@ -18,7 +18,7 @@ type Server struct {
 	ShouldClose  bool
 	SendFirst    []byte
 	Listen       net.Address
-	listener     *net.TCPListener
+	listener     net.Listener
 }
 
 func (server *Server) Start() (net.Destination, error) {
@@ -30,17 +30,19 @@ func (server *Server) StartContext(ctx context.Context) (net.Destination, error)
 	if listenerAddr == nil {
 		listenerAddr = net.LocalHostIP
 	}
-	listener, err := internet.ListenSystemTCP(ctx, &net.TCPAddr{
+	listener, err := internet.ListenSystem(ctx, &net.TCPAddr{
 		IP:   listenerAddr.IP(),
 		Port: int(server.Port),
 	})
 	if err != nil {
 		return net.Destination{}, err
 	}
-	server.Port = net.Port(listener.Addr().(*net.TCPAddr).Port)
-	server.listener = listener
-	go server.acceptConnections(listener)
+
 	localAddr := listener.Addr().(*net.TCPAddr)
+	server.Port = net.Port(localAddr.Port)
+	server.listener = listener
+	go server.acceptConnections(listener.(*net.TCPListener))
+
 	return net.TCPDestination(net.IPAddress(localAddr.IP), net.Port(localAddr.Port)), nil
 }
 

+ 0 - 2
transport/internet/context.go

@@ -11,8 +11,6 @@ type key int
 const (
 	streamSettingsKey key = iota
 	dialerSrcKey
-	transportSettingsKey
-	securitySettingsKey
 )
 
 func ContextWithStreamSettings(ctx context.Context, streamSettings *MemoryStreamConfig) context.Context {

+ 8 - 4
transport/internet/dialer.go

@@ -41,11 +41,15 @@ func Dial(ctx context.Context, dest net.Destination) (Connection, error) {
 		return dialer(ctx, dest)
 	}
 
-	udpDialer := transportDialerCache["udp"]
-	if udpDialer == nil {
-		return nil, newError("UDP dialer not registered").AtError()
+	if dest.Network == net.Network_UDP {
+		udpDialer := transportDialerCache["udp"]
+		if udpDialer == nil {
+			return nil, newError("UDP dialer not registered").AtError()
+		}
+		return udpDialer(ctx, dest)
 	}
-	return udpDialer(ctx, dest)
+
+	return nil, newError("unknown network ", dest.Network)
 }
 
 // DialSystem calls system dialer to create a network connection.

+ 1 - 1
transport/internet/http/hub.go

@@ -117,7 +117,7 @@ func Listen(ctx context.Context, address net.Address, port net.Port, handler int
 
 	listener.server = server
 	go func() {
-		tcpListener, err := internet.ListenSystemTCP(ctx, &net.TCPAddr{
+		tcpListener, err := internet.ListenSystem(ctx, &net.TCPAddr{
 			IP:   address.IP(),
 			Port: int(port),
 		})

+ 9 - 0
transport/internet/sockopt.go

@@ -8,3 +8,12 @@ func isTCPSocket(network string) bool {
 		return false
 	}
 }
+
+func isUDPSocket(network string) bool {
+	switch network {
+	case "udp", "udp4", "udp6":
+		return true
+	default:
+		return false
+	}
+}

+ 12 - 4
transport/internet/system_dialer.go

@@ -20,18 +20,26 @@ type SystemDialer interface {
 type DefaultSystemDialer struct {
 }
 
+func getSocketSettings(ctx context.Context) *SocketConfig {
+	streamSettings := StreamSettingsFromContext(ctx)
+	if streamSettings != nil && streamSettings.SocketSettings != nil {
+		return streamSettings.SocketSettings
+	}
+
+	return nil
+}
+
 func (DefaultSystemDialer) Dial(ctx context.Context, src net.Address, dest net.Destination) (net.Conn, error) {
 	dialer := &net.Dialer{
 		Timeout:   time.Second * 60,
 		DualStack: true,
 	}
 
-	streamSettings := StreamSettingsFromContext(ctx)
-	if streamSettings != nil && streamSettings.SocketSettings != nil {
-		config := streamSettings.SocketSettings
+	sockopts := getSocketSettings(ctx)
+	if sockopts != nil {
 		dialer.Control = func(network, address string, c syscall.RawConn) error {
 			return c.Control(func(fd uintptr) {
-				if err := applyOutboundSocketOptions(network, address, fd, config); err != nil {
+				if err := applyOutboundSocketOptions(network, address, fd, sockopts); err != nil {
 					newError("failed to apply socket options").Base(err).WriteToLog(session.ExportIDToError(ctx))
 				}
 			})

+ 30 - 20
transport/internet/system_listener.go

@@ -2,38 +2,48 @@ package internet
 
 import (
 	"context"
+	"syscall"
 
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/session"
 )
 
 var (
-	effectiveTCPListener = DefaultTCPListener{}
+	effectiveListener = DefaultListener{}
 )
 
-type DefaultTCPListener struct{}
+type DefaultListener struct{}
 
-func (tl *DefaultTCPListener) Listen(ctx context.Context, addr *net.TCPAddr) (*net.TCPListener, error) {
-	l, err := net.ListenTCP("tcp", addr)
-	if err != nil {
-		return nil, err
-	}
+func (*DefaultListener) Listen(ctx context.Context, addr net.Addr) (net.Listener, error) {
+	var lc net.ListenConfig
 
-	streamSettings := StreamSettingsFromContext(ctx)
-	if streamSettings != nil && streamSettings.SocketSettings != nil {
-		config := streamSettings.SocketSettings
-		rawConn, err := l.SyscallConn()
-		if err != nil {
-			return nil, err
+	sockopt := getSocketSettings(ctx)
+	if sockopt != nil {
+		lc.Control = func(network, address string, c syscall.RawConn) error {
+			return c.Control(func(fd uintptr) {
+				if err := applyInboundSocketOptions(network, fd, sockopt); err != nil {
+					newError("failed to apply socket options to incoming connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
+				}
+			})
 		}
-		if err := rawConn.Control(func(fd uintptr) {
-			if err := applyInboundSocketOptions("tcp", fd, config); err != nil {
-				newError("failed to apply socket options to incoming connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
-			}
-		}); err != nil {
-			return nil, err
+	}
+
+	return lc.Listen(ctx, addr.Network(), addr.String())
+}
+
+func (*DefaultListener) ListenPacket(ctx context.Context, addr net.Addr) (net.PacketConn, error) {
+	var lc net.ListenConfig
+
+	sockopt := getSocketSettings(ctx)
+	if sockopt != nil {
+		lc.Control = func(network, address string, c syscall.RawConn) error {
+			return c.Control(func(fd uintptr) {
+				if err := applyInboundSocketOptions(network, fd, sockopt); err != nil {
+					newError("failed to apply socket options to incoming connection").Base(err).WriteToLog(session.ExportIDToError(ctx))
+				}
+			})
 		}
 	}
 
-	return l, nil
+	return lc.ListenPacket(ctx, addr.Network(), addr.String())
 }

+ 2 - 2
transport/internet/tcp/hub.go

@@ -14,7 +14,7 @@ import (
 
 // Listener is an internet.Listener that listens for TCP connections.
 type Listener struct {
-	listener   *net.TCPListener
+	listener   net.Listener
 	tlsConfig  *gotls.Config
 	authConfig internet.ConnectionAuthenticator
 	config     *Config
@@ -23,7 +23,7 @@ type Listener struct {
 
 // ListenTCP creates a new Listener based on configurations.
 func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler internet.ConnHandler) (internet.Listener, error) {
-	listener, err := internet.ListenSystemTCP(ctx, &net.TCPAddr{
+	listener, err := internet.ListenSystem(ctx, &net.TCPAddr{
 		IP:   address.IP(),
 		Port: int(port),
 	})

+ 6 - 2
transport/internet/tcp_hub.go

@@ -54,6 +54,10 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler
 	return listener, nil
 }
 
-func ListenSystemTCP(ctx context.Context, addr *net.TCPAddr) (*net.TCPListener, error) {
-	return effectiveTCPListener.Listen(ctx, addr)
+func ListenSystem(ctx context.Context, addr net.Addr) (net.Listener, error) {
+	return effectiveListener.Listen(ctx, addr)
+}
+
+func ListenSystemPacket(ctx context.Context, addr net.Addr) (net.PacketConn, error) {
+	return effectiveListener.ListenPacket(ctx, addr)
 }

+ 1 - 1
transport/internet/websocket/hub.go

@@ -85,7 +85,7 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, addConn i
 }
 
 func listenTCP(ctx context.Context, address net.Address, port net.Port, tlsConfig *tls.Config) (net.Listener, error) {
-	listener, err := internet.ListenSystemTCP(ctx, &net.TCPAddr{
+	listener, err := internet.ListenSystem(ctx, &net.TCPAddr{
 		IP:   address.IP(),
 		Port: int(port),
 	})