Parcourir la source

registerable dialer and listener

Darien Raymond il y a 8 ans
Parent
commit
21a15bbf74

+ 21 - 22
transport/internet/dialer.go

@@ -20,42 +20,41 @@ type DialerOptions struct {
 type Dialer func(src v2net.Address, dest v2net.Destination, options DialerOptions) (Connection, error)
 
 var (
-	TCPDialer   Dialer
-	KCPDialer   Dialer
-	UDPDialer   Dialer
-	WSDialer    Dialer
+	networkDialerCache = make(map[v2net.Network]Dialer)
+
 	ProxyDialer Dialer
 )
 
+func RegisterNetworkDialer(network v2net.Network, dialer Dialer) error {
+	if _, found := networkDialerCache[network]; found {
+		return errors.New("Internet|Dialer: ", network, " dialer already registered.")
+	}
+	networkDialerCache[network] = dialer
+	return nil
+}
+
 func Dial(src v2net.Address, dest v2net.Destination, options DialerOptions) (Connection, error) {
 	if options.Proxy.HasTag() && ProxyDialer != nil {
 		log.Info("Internet: Proxying outbound connection through: ", options.Proxy.Tag)
 		return ProxyDialer(src, dest, options)
 	}
 
-	var connection Connection
-	var err error
 	if dest.Network == v2net.Network_TCP {
-		switch options.Stream.Network {
-		case v2net.Network_TCP:
-			connection, err = TCPDialer(src, dest, options)
-		case v2net.Network_KCP:
-			connection, err = KCPDialer(src, dest, options)
-		case v2net.Network_WebSocket:
-			connection, err = WSDialer(src, dest, options)
-		default:
-			return nil, ErrUnsupportedStreamType
+		dialer := networkDialerCache[options.Stream.Network]
+		if dialer == nil {
+			return nil, errors.New("Internet|Dialer: ", options.Stream.Network, " dialer not registered.")
 		}
-		if err != nil {
-			return nil, err
-		}
-
-		return connection, nil
+		return dialer(src, dest, options)
 	}
 
-	return UDPDialer(src, dest, options)
+	udpDialer := networkDialerCache[v2net.Network_UDP]
+	if udpDialer == nil {
+		return nil, errors.New("Internet|Dialer: UDP dialer not registered.")
+	}
+	return udpDialer(src, dest, options)
 }
 
-func DialToDest(src v2net.Address, dest v2net.Destination) (net.Conn, error) {
+// DialSystem calls system dialer to create a network connection.
+func DialSystem(src v2net.Address, dest v2net.Destination) (net.Conn, error) {
 	return effectiveSystemDialer.Dial(src, dest)
 }

+ 1 - 1
transport/internet/dialer_test.go

@@ -17,7 +17,7 @@ func TestDialWithLocalAddr(t *testing.T) {
 	assert.Error(err).IsNil()
 	defer server.Close()
 
-	conn, err := DialToDest(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, dest.Port))
+	conn, err := DialSystem(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, dest.Port))
 	assert.Error(err).IsNil()
 	assert.String(conn.RemoteAddr().String()).Equals("127.0.0.1:" + dest.Port.String())
 	conn.Close()

+ 3 - 2
transport/internet/kcp/dialer.go

@@ -8,6 +8,7 @@ import (
 
 	"crypto/cipher"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/dice"
 	"v2ray.com/core/common/errors"
@@ -115,7 +116,7 @@ func DialKCP(src v2net.Address, dest v2net.Destination, options internet.DialerO
 	id := internal.NewConnectionID(src, dest)
 	conn := globalPool.Get(id)
 	if conn == nil {
-		rawConn, err := internet.DialToDest(src, dest)
+		rawConn, err := internet.DialSystem(src, dest)
 		if err != nil {
 			log.Error("KCP|Dialer: Failed to dial to dest: ", err)
 			return nil, err
@@ -172,5 +173,5 @@ func DialKCP(src v2net.Address, dest v2net.Destination, options internet.DialerO
 }
 
 func init() {
-	internet.KCPDialer = DialKCP
+	common.Must(internet.RegisterNetworkDialer(v2net.Network_KCP, DialKCP))
 }

+ 2 - 1
transport/internet/kcp/listener.go

@@ -9,6 +9,7 @@ import (
 
 	"crypto/cipher"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/log"
@@ -297,5 +298,5 @@ func ListenKCP(address v2net.Address, port v2net.Port, options internet.ListenOp
 }
 
 func init() {
-	internet.KCPListenFunc = ListenKCP
+	common.Must(internet.RegisterNetworkListener(v2net.Network_KCP, ListenKCP))
 }

+ 3 - 2
transport/internet/tcp/dialer.go

@@ -4,6 +4,7 @@ import (
 	"crypto/tls"
 	"net"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/log"
 	v2net "v2ray.com/core/common/net"
@@ -34,7 +35,7 @@ func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOpti
 	}
 	if conn == nil {
 		var err error
-		conn, err = internet.DialToDest(src, dest)
+		conn, err = internet.DialSystem(src, dest)
 		if err != nil {
 			return nil, err
 		}
@@ -69,5 +70,5 @@ func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOpti
 }
 
 func init() {
-	internet.TCPDialer = Dial
+	common.Must(internet.RegisterNetworkDialer(v2net.Network_TCP, Dial))
 }

+ 2 - 1
transport/internet/tcp/hub.go

@@ -6,6 +6,7 @@ import (
 	"sync"
 	"time"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/log"
 	v2net "v2ray.com/core/common/net"
@@ -158,5 +159,5 @@ func (v *TCPListener) Close() error {
 }
 
 func init() {
-	internet.TCPListenFunc = ListenTCP
+	common.Must(internet.RegisterNetworkListener(v2net.Network_TCP, ListenTCP))
 }

+ 14 - 18
transport/internet/tcp_hub.go

@@ -13,11 +13,17 @@ import (
 var (
 	ErrClosedConnection = errors.New("Connection already closed.")
 
-	KCPListenFunc ListenFunc
-	TCPListenFunc ListenFunc
-	WSListenFunc  ListenFunc
+	networkListenerCache = make(map[v2net.Network]ListenFunc)
 )
 
+func RegisterNetworkListener(network v2net.Network, listener ListenFunc) error {
+	if _, found := networkListenerCache[network]; found {
+		return errors.New("Internet|TCPHub: ", network, " listener already registered.")
+	}
+	networkListenerCache[network] = listener
+	return nil
+}
+
 type ListenFunc func(address v2net.Address, port v2net.Port, options ListenOptions) (Listener, error)
 type ListenOptions struct {
 	Stream *StreamConfig
@@ -37,26 +43,16 @@ type TCPHub struct {
 }
 
 func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandler, settings *StreamConfig) (*TCPHub, error) {
-	var listener Listener
-	var err error
 	options := ListenOptions{
 		Stream: settings,
 	}
-	switch settings.Network {
-	case v2net.Network_TCP:
-		listener, err = TCPListenFunc(address, port, options)
-	case v2net.Network_KCP:
-		listener, err = KCPListenFunc(address, port, options)
-	case v2net.Network_WebSocket:
-		listener, err = WSListenFunc(address, port, options)
-	default:
-		log.Error("Internet|Listener: Unknown stream type: ", settings.Network)
-		err = ErrUnsupportedStreamType
+	listenFunc := networkListenerCache[settings.Network]
+	if listenFunc == nil {
+		return nil, errors.New("Internet|TCPHub: ", settings.Network, " listener not registered.")
 	}
-
+	listener, err := listenFunc(address, port, options)
 	if err != nil {
-		log.Warning("Internet|Listener: Failed to listen on ", address, ":", port)
-		return nil, err
+		return nil, errors.Base(err).Message("Interent|TCPHub: Failed to listen: ")
 	}
 
 	hub := &TCPHub{

+ 12 - 10
transport/internet/udp/connection.go

@@ -3,6 +3,7 @@ package udp
 import (
 	"net"
 
+	"v2ray.com/core/common"
 	v2net "v2ray.com/core/common/net"
 	"v2ray.com/core/transport/internet"
 )
@@ -18,14 +19,15 @@ func (v *Connection) Reusable() bool {
 func (v *Connection) SetReusable(b bool) {}
 
 func init() {
-	internet.UDPDialer = func(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (internet.Connection, error) {
-		conn, err := internet.DialToDest(src, dest)
-		if err != nil {
-			return nil, err
-		}
-		// TODO: handle dialer options
-		return &Connection{
-			UDPConn: *(conn.(*net.UDPConn)),
-		}, nil
-	}
+	common.Must(internet.RegisterNetworkDialer(v2net.Network_UDP,
+		func(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (internet.Connection, error) {
+			conn, err := internet.DialSystem(src, dest)
+			if err != nil {
+				return nil, err
+			}
+			// TODO: handle dialer options
+			return &Connection{
+				UDPConn: *(conn.(*net.UDPConn)),
+			}, nil
+		}))
 }

+ 3 - 2
transport/internet/websocket/dialer.go

@@ -5,6 +5,7 @@ import (
 	"net"
 
 	"github.com/gorilla/websocket"
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/log"
 	v2net "v2ray.com/core/common/net"
 	"v2ray.com/core/transport/internet"
@@ -46,7 +47,7 @@ func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOpti
 }
 
 func init() {
-	internet.WSDialer = Dial
+	common.Must(internet.RegisterNetworkDialer(v2net.Network_WebSocket, Dial))
 }
 
 func wsDial(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (*wsconn, error) {
@@ -57,7 +58,7 @@ func wsDial(src v2net.Address, dest v2net.Destination, options internet.DialerOp
 	wsSettings := networkSettings.(*Config)
 
 	commonDial := func(network, addr string) (net.Conn, error) {
-		return internet.DialToDest(src, dest)
+		return internet.DialSystem(src, dest)
 	}
 
 	dialer := websocket.Dialer{

+ 2 - 1
transport/internet/websocket/hub.go

@@ -8,6 +8,7 @@ import (
 	"sync"
 	"time"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/log"
 	v2net "v2ray.com/core/common/net"
@@ -197,5 +198,5 @@ func (v *WSListener) Close() error {
 }
 
 func init() {
-	internet.WSListenFunc = ListenWS
+	common.Must(internet.RegisterNetworkListener(v2net.Network_WebSocket, ListenWS))
 }