Browse Source

update tls config generation

Darien Raymond 7 years ago
parent
commit
bdab1af29a

+ 2 - 2
transport/internet/kcp/dialer.go

@@ -86,8 +86,8 @@ func DialKCP(ctx context.Context, dest net.Destination) (internet.Connection, er
 
 	var iConn internet.Connection = session
 
-	if config := v2tls.ConfigFromContext(ctx, v2tls.WithDestination(dest)); config != nil {
-		tlsConn := tls.Client(iConn, config.GetTLSConfig())
+	if config := v2tls.ConfigFromContext(ctx); config != nil {
+		tlsConn := tls.Client(iConn, config.GetTLSConfig(v2tls.WithDestination(dest)))
 		iConn = tlsConn
 	}
 

+ 2 - 2
transport/internet/tcp/dialer.go

@@ -27,8 +27,8 @@ func Dial(ctx context.Context, dest net.Destination) (internet.Connection, error
 		return nil, err
 	}
 
-	if config := tls.ConfigFromContext(ctx, tls.WithDestination(dest), tls.WithNextProto("h2")); config != nil {
-		conn = tls.Client(conn, config.GetTLSConfig())
+	if config := tls.ConfigFromContext(ctx); config != nil {
+		conn = tls.Client(conn, config.GetTLSConfig(tls.WithDestination(dest), tls.WithNextProto("h2")))
 	}
 
 	tcpSettings := getTCPSettingsFromContext(ctx)

+ 2 - 2
transport/internet/tcp/hub.go

@@ -39,8 +39,8 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, handler
 		addConn:  handler,
 	}
 
-	if config := tls.ConfigFromContext(ctx, tls.WithNextProto("h2")); config != nil {
-		l.tlsConfig = config.GetTLSConfig()
+	if config := tls.ConfigFromContext(ctx); config != nil {
+		l.tlsConfig = config.GetTLSConfig(tls.WithNextProto("h2"))
 	}
 
 	if tcpSettings.HeaderSettings != nil {

+ 15 - 13
transport/internet/tls/config.go

@@ -25,7 +25,7 @@ func (c *Config) BuildCertificates() []tls.Certificate {
 	return certs
 }
 
-func (c *Config) GetTLSConfig() *tls.Config {
+func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
 	config := &tls.Config{
 		ClientSessionCache: globalSessionCache,
 		NextProtos:         []string{"http/1.1"},
@@ -34,6 +34,10 @@ func (c *Config) GetTLSConfig() *tls.Config {
 		return config
 	}
 
+	for _, opt := range opts {
+		opt(config)
+	}
+
 	config.InsecureSkipVerify = c.AllowInsecure
 	config.Certificates = c.BuildCertificates()
 	config.BuildNameToCertificate()
@@ -47,10 +51,10 @@ func (c *Config) GetTLSConfig() *tls.Config {
 	return config
 }
 
-type Option func(*Config)
+type Option func(*tls.Config)
 
 func WithDestination(dest net.Destination) Option {
-	return func(config *Config) {
+	return func(config *tls.Config) {
 		if dest.Address.Family().IsDomain() && len(config.ServerName) == 0 {
 			config.ServerName = dest.Address.Domain()
 		}
@@ -58,23 +62,21 @@ func WithDestination(dest net.Destination) Option {
 }
 
 func WithNextProto(protocol ...string) Option {
-	return func(config *Config) {
-		if len(config.NextProtocol) == 0 {
-			config.NextProtocol = protocol
+	return func(config *tls.Config) {
+		if len(config.NextProtos) == 0 {
+			config.NextProtos = protocol
 		}
 	}
 }
 
-func ConfigFromContext(ctx context.Context, opts ...Option) *Config {
+func ConfigFromContext(ctx context.Context) *Config {
 	securitySettings := internet.SecuritySettingsFromContext(ctx)
 	if securitySettings == nil {
 		return nil
 	}
-	if config, ok := securitySettings.(*Config); ok {
-		for _, opt := range opts {
-			opt(config)
-		}
-		return config
+	config, ok := securitySettings.(*Config)
+	if !ok {
+		return nil
 	}
-	return nil
+	return config
 }

+ 2 - 2
transport/internet/websocket/dialer.go

@@ -41,9 +41,9 @@ func dialWebsocket(ctx context.Context, dest net.Destination) (net.Conn, error)
 
 	protocol := "ws"
 
-	if config := tls.ConfigFromContext(ctx, tls.WithDestination(dest)); config != nil {
+	if config := tls.ConfigFromContext(ctx); config != nil {
 		protocol = "wss"
-		dialer.TLSClientConfig = config.GetTLSConfig()
+		dialer.TLSClientConfig = config.GetTLSConfig(tls.WithDestination(dest))
 	}
 
 	host := dest.NetAddr()