Darien Raymond 7 лет назад
Родитель
Сommit
048ffbc7dc

+ 3 - 10
transport/internet/kcp/dialer.go

@@ -77,16 +77,9 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er
 
 	var iConn internet.Connection = session
 
-	if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
-		switch securitySettings := securitySettings.(type) {
-		case *v2tls.Config:
-			if dest.Address.Family().IsDomain() {
-				securitySettings.OverrideServerNameIfEmpty(dest.Address.Domain())
-			}
-			config := securitySettings.GetTLSConfig()
-			tlsConn := tls.Client(iConn, config)
-			iConn = tlsConn
-		}
+	if config := v2tls.ConfigFromContext(ctx, v2tls.WithDestination(dest)); config != nil {
+		tlsConn := tls.Client(iConn, config.GetTLSConfig())
+		iConn = tlsConn
 	}
 
 	return iConn, nil

+ 4 - 6
transport/internet/kcp/listener.go

@@ -59,13 +59,11 @@ func NewListener(ctx context.Context, address net.Address, port net.Port, addCon
 		config:   kcpSettings,
 		addConn:  addConn,
 	}
-	securitySettings := internet.SecuritySettingsFromContext(ctx)
-	if securitySettings != nil {
-		switch securitySettings := securitySettings.(type) {
-		case *v2tls.Config:
-			l.tlsConfig = securitySettings.GetTLSConfig()
-		}
+
+	if config := v2tls.ConfigFromContext(ctx); config != nil {
+		l.tlsConfig = config.GetTLSConfig()
 	}
+
 	hub, err := udp.ListenUDP(address, port, udp.ListenOption{Callback: l.OnReceive, Concurrency: 2})
 	if err != nil {
 		return nil, err

+ 4 - 10
transport/internet/tcp/dialer.go

@@ -19,22 +19,16 @@ func getTCPSettingsFromContext(ctx context.Context) *Config {
 }
 
 func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error) {
-	log.Trace(newError("dailing TCP to ", dest))
+	log.Trace(newError("dialing TCP to ", dest))
 	src := internet.DialerSourceFromContext(ctx)
 
 	conn, err := internet.DialSystem(ctx, src, dest)
 	if err != nil {
 		return nil, err
 	}
-	if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
-		tlsConfig, ok := securitySettings.(*tls.Config)
-		if ok {
-			if dest.Address.Family().IsDomain() {
-				tlsConfig.OverrideServerNameIfEmpty(dest.Address.Domain())
-			}
-			config := tlsConfig.GetTLSConfig()
-			conn = tls.Client(conn, config)
-		}
+
+	if config := tls.ConfigFromContext(ctx, tls.WithDestination(dest)); config != nil {
+		conn = tls.Client(conn, config.GetTLSConfig())
 	}
 
 	tcpSettings := getTCPSettingsFromContext(ctx)

+ 4 - 5
transport/internet/tcp/hub.go

@@ -37,12 +37,11 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, addConn
 		config:   tcpSettings,
 		addConn:  addConn,
 	}
-	if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
-		tlsConfig, ok := securitySettings.(*tls.Config)
-		if ok {
-			l.tlsConfig = tlsConfig.GetTLSConfig()
-		}
+
+	if config := tls.ConfigFromContext(ctx); config != nil {
+		l.tlsConfig = config.GetTLSConfig()
 	}
+
 	if tcpSettings.HeaderSettings != nil {
 		headerConfig, err := tcpSettings.HeaderSettings.GetInstance()
 		if err != nil {

+ 24 - 3
transport/internet/tls/config.go

@@ -1,9 +1,12 @@
 package tls
 
 import (
+	"context"
 	"crypto/tls"
 
 	"v2ray.com/core/app/log"
+	"v2ray.com/core/common/net"
+	"v2ray.com/core/transport/internet"
 )
 
 var (
@@ -42,8 +45,26 @@ func (c *Config) GetTLSConfig() *tls.Config {
 	return config
 }
 
-func (c *Config) OverrideServerNameIfEmpty(serverName string) {
-	if len(c.ServerName) == 0 {
-		c.ServerName = serverName
+type Option func(*Config)
+
+func WithDestination(dest net.Destination) Option {
+	return func(config *Config) {
+		if dest.Address.Family().IsDomain() && len(config.ServerName) == 0 {
+			config.ServerName = dest.Address.Domain()
+		}
+	}
+}
+
+func ConfigFromContext(ctx context.Context, opts ...Option) *Config {
+	securitySettings := internet.SecuritySettingsFromContext(ctx)
+	if securitySettings == nil {
+		return nil
+	}
+	if config, ok := securitySettings.(*Config); ok {
+		for _, opt := range opts {
+			opt(config)
+		}
+		return config
 	}
+	return nil
 }

+ 3 - 9
transport/internet/websocket/dialer.go

@@ -42,15 +42,9 @@ func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error)
 
 	protocol := "ws"
 
-	if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
-		tlsConfig, ok := securitySettings.(*tls.Config)
-		if ok {
-			protocol = "wss"
-			if dest.Address.Family().IsDomain() {
-				tlsConfig.OverrideServerNameIfEmpty(dest.Address.Domain())
-			}
-			dialer.TLSClientConfig = tlsConfig.GetTLSConfig()
-		}
+	if config := tls.ConfigFromContext(ctx, tls.WithDestination(dest)); config != nil {
+		protocol = "wss"
+		dialer.TLSClientConfig = config.GetTLSConfig()
 	}
 
 	host := dest.NetAddr()

+ 2 - 5
transport/internet/websocket/hub.go

@@ -59,11 +59,8 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, addConn i
 		config:  wsSettings,
 		addConn: addConn,
 	}
-	if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
-		tlsConfig, ok := securitySettings.(*v2tls.Config)
-		if ok {
-			l.tlsConfig = tlsConfig.GetTLSConfig()
-		}
+	if config := v2tls.ConfigFromContext(ctx); config != nil {
+		l.tlsConfig = config.GetTLSConfig()
 	}
 
 	err := l.listenws(address, port)