Darien Raymond 8 år sedan
förälder
incheckning
af88016320

+ 2 - 2
transport/internet/kcp/dialer.go

@@ -135,10 +135,10 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er
 	if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
 		switch securitySettings := securitySettings.(type) {
 		case *v2tls.Config:
-			config := securitySettings.GetTLSConfig()
 			if dest.Address.Family().IsDomain() {
-				config.ServerName = dest.Address.Domain()
+				securitySettings.OverrideServerNameIfEmpty(dest.Address.Domain())
 			}
+			config := securitySettings.GetTLSConfig()
 			tlsConn := tls.Client(iConn, config)
 			iConn = tlsConn
 		}

+ 2 - 2
transport/internet/tcp/dialer.go

@@ -29,10 +29,10 @@ func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error
 	if securitySettings := internet.SecuritySettingsFromContext(ctx); securitySettings != nil {
 		tlsConfig, ok := securitySettings.(*tls.Config)
 		if ok {
-			config := tlsConfig.GetTLSConfig()
 			if dest.Address.Family().IsDomain() {
-				config.ServerName = dest.Address.Domain()
+				tlsConfig.OverrideServerNameIfEmpty(dest.Address.Domain())
 			}
+			config := tlsConfig.GetTLSConfig()
 			conn = tls.Client(conn, config)
 		}
 	}

+ 15 - 9
transport/internet/tls/config.go

@@ -10,9 +10,9 @@ var (
 	globalSessionCache = tls.NewLRUClientSessionCache(128)
 )
 
-func (v *Config) BuildCertificates() []tls.Certificate {
-	certs := make([]tls.Certificate, 0, len(v.Certificate))
-	for _, entry := range v.Certificate {
+func (c *Config) BuildCertificates() []tls.Certificate {
+	certs := make([]tls.Certificate, 0, len(c.Certificate))
+	for _, entry := range c.Certificate {
 		keyPair, err := tls.X509KeyPair(entry.Certificate, entry.Key)
 		if err != nil {
 			log.Trace(newError("ignoring invalid X509 key pair").Base(err).AtWarning())
@@ -23,21 +23,27 @@ func (v *Config) BuildCertificates() []tls.Certificate {
 	return certs
 }
 
-func (v *Config) GetTLSConfig() *tls.Config {
+func (c *Config) GetTLSConfig() *tls.Config {
 	config := &tls.Config{
 		ClientSessionCache: globalSessionCache,
 		NextProtos:         []string{"http/1.1"},
 	}
-	if v == nil {
+	if c == nil {
 		return config
 	}
 
-	config.InsecureSkipVerify = v.AllowInsecure
-	config.Certificates = v.BuildCertificates()
+	config.InsecureSkipVerify = c.AllowInsecure
+	config.Certificates = c.BuildCertificates()
 	config.BuildNameToCertificate()
-	if len(v.ServerName) > 0 {
-		config.ServerName = v.ServerName
+	if len(c.ServerName) > 0 {
+		config.ServerName = c.ServerName
 	}
 
 	return config
 }
+
+func (c *Config) OverrideServerNameIfEmpty(serverName string) {
+	if len(c.ServerName) == 0 {
+		c.ServerName = serverName
+	}
+}

+ 2 - 2
transport/internet/websocket/dialer.go

@@ -46,10 +46,10 @@ func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error)
 		tlsConfig, ok := securitySettings.(*tls.Config)
 		if ok {
 			protocol = "wss"
-			dialer.TLSClientConfig = tlsConfig.GetTLSConfig()
 			if dest.Address.Family().IsDomain() {
-				dialer.TLSClientConfig.ServerName = dest.Address.Domain()
+				tlsConfig.OverrideServerNameIfEmpty(dest.Address.Domain())
 			}
+			dialer.TLSClientConfig = tlsConfig.GetTLSConfig()
 		}
 	}