Browse Source

refactor dialer

Darien Raymond 8 năm trước cách đây
mục cha
commit
9cbc9b7170

+ 2 - 2
transport/internet/dialer.go

@@ -53,6 +53,6 @@ func Dial(ctx context.Context, dest v2net.Destination) (Connection, error) {
 }
 
 // DialSystem calls system dialer to create a network connection.
-func DialSystem(src v2net.Address, dest v2net.Destination) (net.Conn, error) {
-	return effectiveSystemDialer.Dial(src, dest)
+func DialSystem(ctx context.Context, src v2net.Address, dest v2net.Destination) (net.Conn, error) {
+	return effectiveSystemDialer.Dial(ctx, src, dest)
 }

+ 2 - 1
transport/internet/dialer_test.go

@@ -7,6 +7,7 @@ import (
 	"v2ray.com/core/testing/assert"
 	"v2ray.com/core/testing/servers/tcp"
 	. "v2ray.com/core/transport/internet"
+	"context"
 )
 
 func TestDialWithLocalAddr(t *testing.T) {
@@ -17,7 +18,7 @@ func TestDialWithLocalAddr(t *testing.T) {
 	assert.Error(err).IsNil()
 	defer server.Close()
 
-	conn, err := DialSystem(net.LocalHostIP, net.TCPDestination(net.LocalHostIP, dest.Port))
+	conn, err := DialSystem(context.Background(), net.LocalHostIP, net.TCPDestination(net.LocalHostIP, dest.Port))
 	assert.Error(err).IsNil()
 	assert.String(conn.RemoteAddr().String()).Equals("127.0.0.1:" + dest.Port.String())
 	conn.Close()

+ 1 - 1
transport/internet/kcp/dialer.go

@@ -117,7 +117,7 @@ func DialKCP(ctx context.Context, dest v2net.Destination) (internet.Connection,
 	id := internal.NewConnectionID(src, dest)
 	conn := globalPool.Get(id)
 	if conn == nil {
-		rawConn, err := internet.DialSystem(src, dest)
+		rawConn, err := internet.DialSystem(ctx, src, dest)
 		if err != nil {
 			log.Error("KCP|Dialer: Failed to dial to dest: ", err)
 			return nil, err

+ 6 - 4
transport/internet/system_dialer.go

@@ -4,6 +4,8 @@ import (
 	"net"
 	"time"
 
+	"context"
+
 	v2net "v2ray.com/core/common/net"
 )
 
@@ -12,13 +14,13 @@ var (
 )
 
 type SystemDialer interface {
-	Dial(source v2net.Address, destination v2net.Destination) (net.Conn, error)
+	Dial(ctx context.Context, source v2net.Address, destination v2net.Destination) (net.Conn, error)
 }
 
 type DefaultSystemDialer struct {
 }
 
-func (v *DefaultSystemDialer) Dial(src v2net.Address, dest v2net.Destination) (net.Conn, error) {
+func (v *DefaultSystemDialer) Dial(ctx context.Context, src v2net.Address, dest v2net.Destination) (net.Conn, error) {
 	dialer := &net.Dialer{
 		Timeout:   time.Second * 60,
 		DualStack: true,
@@ -38,7 +40,7 @@ func (v *DefaultSystemDialer) Dial(src v2net.Address, dest v2net.Destination) (n
 		}
 		dialer.LocalAddr = addr
 	}
-	return dialer.Dial(dest.Network.SystemString(), dest.NetAddr())
+	return dialer.DialContext(ctx, dest.Network.SystemString(), dest.NetAddr())
 }
 
 type SystemDialerAdapter interface {
@@ -55,7 +57,7 @@ func WithAdapter(dialer SystemDialerAdapter) SystemDialer {
 	}
 }
 
-func (v *SimpleSystemDialer) Dial(src v2net.Address, dest v2net.Destination) (net.Conn, error) {
+func (v *SimpleSystemDialer) Dial(ctx context.Context, src v2net.Address, dest v2net.Destination) (net.Conn, error) {
 	return v.adapter.Dial(dest.Network.SystemString(), dest.NetAddr())
 }
 

+ 1 - 1
transport/internet/tcp/dialer.go

@@ -31,7 +31,7 @@ func Dial(ctx context.Context, dest v2net.Destination) (internet.Connection, err
 	}
 	if conn == nil {
 		var err error
-		conn, err = internet.DialSystem(src, dest)
+		conn, err = internet.DialSystem(ctx, src, dest)
 		if err != nil {
 			return nil, err
 		}

+ 1 - 1
transport/internet/udp/dialer.go

@@ -13,7 +13,7 @@ func init() {
 	common.Must(internet.RegisterTransportDialer(internet.TransportProtocol_UDP,
 		func(ctx context.Context, dest v2net.Destination) (internet.Connection, error) {
 			src := internet.DialerSourceFromContext(ctx)
-			conn, err := internet.DialSystem(src, dest)
+			conn, err := internet.DialSystem(ctx, src, dest)
 			if err != nil {
 				return nil, err
 			}

+ 1 - 1
transport/internet/websocket/dialer.go

@@ -47,7 +47,7 @@ func dialWebsocket(ctx context.Context, dest v2net.Destination) (net.Conn, error
 	wsSettings := internet.TransportSettingsFromContext(ctx).(*Config)
 
 	commonDial := func(network, addr string) (net.Conn, error) {
-		return internet.DialSystem(src, dest)
+		return internet.DialSystem(ctx, src, dest)
 	}
 
 	dialer := websocket.Dialer{