Przeglądaj źródła

adopt context in listeners

Darien Raymond 8 lat temu
rodzic
commit
7792237b50

+ 1 - 12
transport/internet/kcp/kcp_test.go

@@ -10,7 +10,6 @@ import (
 	"time"
 
 	v2net "v2ray.com/core/common/net"
-	"v2ray.com/core/common/serial"
 	"v2ray.com/core/testing/assert"
 	"v2ray.com/core/transport/internet"
 	. "v2ray.com/core/transport/internet/kcp"
@@ -19,17 +18,7 @@ import (
 func TestDialAndListen(t *testing.T) {
 	assert := assert.On(t)
 
-	listerner, err := NewListener(v2net.LocalHostIP, v2net.Port(0), internet.ListenOptions{
-		Stream: &internet.StreamConfig{
-			Protocol: internet.TransportProtocol_MKCP,
-			TransportSettings: []*internet.TransportConfig{
-				{
-					Protocol: internet.TransportProtocol_MKCP,
-					Settings: serial.ToTypedMessage(&Config{}),
-				},
-			},
-		},
-	})
+	listerner, err := NewListener(internet.ContextWithTransportSettings(context.Background(), &Config{}), v2net.LocalHostIP, v2net.Port(0))
 	assert.Error(err).IsNil()
 	port := v2net.Port(listerner.Addr().(*net.UDPAddr).Port)
 

+ 7 - 14
transport/internet/kcp/listener.go

@@ -1,6 +1,7 @@
 package kcp
 
 import (
+	"context"
 	"crypto/cipher"
 	"crypto/tls"
 	"io"
@@ -90,12 +91,8 @@ type Listener struct {
 	security      cipher.AEAD
 }
 
-func NewListener(address v2net.Address, port v2net.Port, options internet.ListenOptions) (*Listener, error) {
-	networkSettings, err := options.Stream.GetEffectiveTransportSettings()
-	if err != nil {
-		log.Error("KCP|Listener: Failed to get KCP settings: ", err)
-		return nil, err
-	}
+func NewListener(ctx context.Context, address v2net.Address, port v2net.Port) (*Listener, error) {
+	networkSettings := internet.TransportSettingsFromContext(ctx)
 	kcpSettings := networkSettings.(*Config)
 	kcpSettings.ConnectionReuse = &ConnectionReuse{Enable: false}
 
@@ -119,12 +116,8 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen
 		closed:        make(chan bool),
 		config:        kcpSettings,
 	}
-	if options.Stream != nil && options.Stream.HasSecuritySettings() {
-		securitySettings, err := options.Stream.GetEffectiveSecuritySettings()
-		if err != nil {
-			log.Error("KCP|Listener: Failed to get security settings: ", err)
-			return nil, err
-		}
+	securitySettings := internet.SecuritySettingsFromContext(ctx)
+	if securitySettings != nil {
 		switch securitySettings := securitySettings.(type) {
 		case *v2tls.Config:
 			l.tlsConfig = securitySettings.GetTLSConfig()
@@ -295,8 +288,8 @@ func (v *Writer) Close() error {
 	return nil
 }
 
-func ListenKCP(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) {
-	return NewListener(address, port, options)
+func ListenKCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
+	return NewListener(ctx, address, port)
 }
 
 func init() {

+ 4 - 11
transport/internet/tcp/hub.go

@@ -1,6 +1,7 @@
 package tcp
 
 import (
+	"context"
 	"crypto/tls"
 	"net"
 	"sync"
@@ -34,7 +35,7 @@ type TCPListener struct {
 	config        *Config
 }
 
-func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) {
+func ListenTCP(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
 	listener, err := net.ListenTCP("tcp", &net.TCPAddr{
 		IP:   address.IP(),
 		Port: int(port),
@@ -43,10 +44,7 @@ func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOp
 		return nil, err
 	}
 	log.Info("TCP|Listener: Listening on ", address, ":", port)
-	networkSettings, err := options.Stream.GetEffectiveTransportSettings()
-	if err != nil {
-		return nil, err
-	}
+	networkSettings := internet.TransportSettingsFromContext(ctx)
 	tcpSettings := networkSettings.(*Config)
 
 	l := &TCPListener{
@@ -55,12 +53,7 @@ func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOp
 		awaitingConns: make(chan *ConnectionWithError, 32),
 		config:        tcpSettings,
 	}
-	if options.Stream != nil && options.Stream.HasSecuritySettings() {
-		securitySettings, err := options.Stream.GetEffectiveSecuritySettings()
-		if err != nil {
-			log.Error("TCP: Failed to get security config: ", err)
-			return nil, err
-		}
+	if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
 		tlsConfig, ok := securitySettings.(*v2tls.Config)
 		if ok {
 			l.tlsConfig = tlsConfig.GetTLSConfig()

+ 17 - 9
transport/internet/tcp_hub.go

@@ -3,6 +3,8 @@ package internet
 import (
 	"net"
 
+	"context"
+
 	"v2ray.com/core/app/log"
 	"v2ray.com/core/common/errors"
 	v2net "v2ray.com/core/common/net"
@@ -21,11 +23,7 @@ func RegisterTransportListener(protocol TransportProtocol, listener ListenFunc)
 	return nil
 }
 
-type ListenFunc func(address v2net.Address, port v2net.Port, options ListenOptions) (Listener, error)
-type ListenOptions struct {
-	Stream       *StreamConfig
-	RecvOrigDest bool
-}
+type ListenFunc func(ctx context.Context, address v2net.Address, port v2net.Port) (Listener, error)
 
 type Listener interface {
 	Accept() (Connection, error)
@@ -40,15 +38,25 @@ type TCPHub struct {
 }
 
 func ListenTCP(address v2net.Address, port v2net.Port, callback ConnectionHandler, settings *StreamConfig) (*TCPHub, error) {
-	options := ListenOptions{
-		Stream: settings,
-	}
+	ctx := context.Background()
 	protocol := settings.GetEffectiveProtocol()
+	transportSettings, err := settings.GetEffectiveTransportSettings()
+	if err != nil {
+		return nil, err
+	}
+	ctx = ContextWithTransportSettings(ctx, transportSettings)
+	if settings != nil && settings.HasSecuritySettings() {
+		securitySettings, err := settings.GetEffectiveSecuritySettings()
+		if err != nil {
+			return nil, err
+		}
+		ctx = ContextWithSecuritySettings(ctx, securitySettings)
+	}
 	listenFunc := transportListenerCache[protocol]
 	if listenFunc == nil {
 		return nil, errors.New("Internet|TCPHub: ", protocol, " listener not registered.")
 	}
-	listener, err := listenFunc(address, port, options)
+	listener, err := listenFunc(ctx, address, port)
 	if err != nil {
 		return nil, errors.Base(err).Message("Internet|TCPHub: Failed to listen on address: ", address, ":", port)
 	}

+ 5 - 11
transport/internet/websocket/hub.go

@@ -1,6 +1,7 @@
 package websocket
 
 import (
+	"context"
 	"crypto/tls"
 	"net"
 	"net/http"
@@ -59,11 +60,8 @@ type Listener struct {
 	config        *Config
 }
 
-func ListenWS(address v2net.Address, port v2net.Port, options internet.ListenOptions) (internet.Listener, error) {
-	networkSettings, err := options.Stream.GetEffectiveTransportSettings()
-	if err != nil {
-		return nil, err
-	}
+func ListenWS(ctx context.Context, address v2net.Address, port v2net.Port) (internet.Listener, error) {
+	networkSettings := internet.TransportSettingsFromContext(ctx)
 	wsSettings := networkSettings.(*Config)
 
 	l := &Listener{
@@ -71,18 +69,14 @@ func ListenWS(address v2net.Address, port v2net.Port, options internet.ListenOpt
 		awaitingConns: make(chan *ConnectionWithError, 32),
 		config:        wsSettings,
 	}
-	if options.Stream != nil && options.Stream.HasSecuritySettings() {
-		securitySettings, err := options.Stream.GetEffectiveSecuritySettings()
-		if err != nil {
-			return nil, errors.Base(err).Message("WebSocket: Failed to create apply TLS config.")
-		}
+	if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
 		tlsConfig, ok := securitySettings.(*v2tls.Config)
 		if ok {
 			l.tlsConfig = tlsConfig.GetTLSConfig()
 		}
 	}
 
-	err = l.listenws(address, port)
+	err := l.listenws(address, port)
 
 	return l, err
 }

+ 15 - 45
transport/internet/websocket/ws_test.go

@@ -1,15 +1,12 @@
 package websocket_test
 
 import (
-	"testing"
-	"time"
-
 	"bytes"
-
 	"context"
+	"testing"
+	"time"
 
 	v2net "v2ray.com/core/common/net"
-	"v2ray.com/core/common/serial"
 	"v2ray.com/core/testing/assert"
 	tlsgen "v2ray.com/core/testing/tls"
 	"v2ray.com/core/transport/internet"
@@ -19,19 +16,9 @@ import (
 
 func Test_listenWSAndDial(t *testing.T) {
 	assert := assert.On(t)
-	listen, err := ListenWS(v2net.DomainAddress("localhost"), 13146, internet.ListenOptions{
-		Stream: &internet.StreamConfig{
-			Protocol: internet.TransportProtocol_WebSocket,
-			TransportSettings: []*internet.TransportConfig{
-				{
-					Protocol: internet.TransportProtocol_WebSocket,
-					Settings: serial.ToTypedMessage(&Config{
-						Path: "ws",
-					}),
-				},
-			},
-		},
-	})
+	listen, err := ListenWS(internet.ContextWithTransportSettings(context.Background(), &Config{
+		Path: "ws",
+	}), v2net.DomainAddress("localhost"), 13146)
 	assert.Error(err).IsNil()
 	go func() {
 		for {
@@ -99,33 +86,6 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
 		assert.Fail("Too slow")
 	}()
 
-	listen, err := ListenWS(v2net.DomainAddress("localhost"), 13143, internet.ListenOptions{
-		Stream: &internet.StreamConfig{
-			SecurityType: serial.GetMessageType(new(v2tls.Config)),
-			SecuritySettings: []*serial.TypedMessage{serial.ToTypedMessage(&v2tls.Config{
-				Certificate: []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()},
-			})},
-			Protocol: internet.TransportProtocol_WebSocket,
-			TransportSettings: []*internet.TransportConfig{
-				{
-					Protocol: internet.TransportProtocol_WebSocket,
-					Settings: serial.ToTypedMessage(&Config{
-						Path: "wss",
-						ConnectionReuse: &ConnectionReuse{
-							Enable: true,
-						},
-					}),
-				},
-			},
-		},
-	})
-	assert.Error(err).IsNil()
-	go func() {
-		conn, err := listen.Accept()
-		assert.Error(err).IsNil()
-		conn.Close()
-		listen.Close()
-	}()
 	ctx := internet.ContextWithTransportSettings(context.Background(), &Config{
 		Path: "wss",
 		ConnectionReuse: &ConnectionReuse{
@@ -134,7 +94,17 @@ func Test_listenWSAndDial_TLS(t *testing.T) {
 	})
 	ctx = internet.ContextWithSecuritySettings(ctx, &v2tls.Config{
 		AllowInsecure: true,
+		Certificate:   []*v2tls.Certificate{tlsgen.GenerateCertificateForTest()},
 	})
+	listen, err := ListenWS(ctx, v2net.DomainAddress("localhost"), 13143)
+	assert.Error(err).IsNil()
+	go func() {
+		conn, err := listen.Accept()
+		assert.Error(err).IsNil()
+		conn.Close()
+		listen.Close()
+	}()
+
 	conn, err := Dial(ctx, v2net.TCPDestination(v2net.DomainAddress("localhost"), 13143))
 	assert.Error(err).IsNil()
 	conn.Close()