Selaa lähdekoodia

only try issuing new certificate when user provide custom CA

Darien Raymond 7 vuotta sitten
vanhempi
commit
abee8bddf3
1 muutettua tiedostoa jossa 45 lisäystä ja 33 poistoa
  1. 45 33
      transport/internet/tls/config.go

+ 45 - 33
transport/internet/tls/config.go

@@ -58,6 +58,15 @@ func issueCertificate(rawCA *Certificate, domain string) (*tls.Certificate, erro
 	return &cert, err
 }
 
+func (c *Config) hasCustomCA() bool {
+	for _, certificate := range c.Certificate {
+		if certificate.Usage == Certificate_AUTHORITY_ISSUE {
+			return true
+		}
+	}
+	return false
+}
+
 func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
 	config := &tls.Config{
 		ClientSessionCache: globalSessionCache,
@@ -74,53 +83,56 @@ func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
 	config.InsecureSkipVerify = c.AllowInsecure
 	config.Certificates = c.BuildCertificates()
 	config.BuildNameToCertificate()
-	config.GetCertificate = func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
-		domain := hello.ServerName
-		certExpired := false
-		if certificate, found := config.NameToCertificate[domain]; found {
-			if !isCertificateExpired(certificate) {
-				return certificate, nil
+	if c.hasCustomCA() {
+		config.GetCertificate = func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+			domain := hello.ServerName
+			certExpired := false
+			if certificate, found := config.NameToCertificate[domain]; found {
+				if !isCertificateExpired(certificate) {
+					return certificate, nil
+				}
+				certExpired = true
 			}
-			certExpired = true
-		}
 
-		if certExpired {
-			newCerts := make([]tls.Certificate, 0, len(config.Certificates))
+			if certExpired {
+				newCerts := make([]tls.Certificate, 0, len(config.Certificates))
 
-			for _, certificate := range config.Certificates {
-				if !isCertificateExpired(&certificate) {
-					newCerts = append(newCerts, certificate)
+				for _, certificate := range config.Certificates {
+					if !isCertificateExpired(&certificate) {
+						newCerts = append(newCerts, certificate)
+					}
 				}
+
+				config.Certificates = newCerts
 			}
 
-			config.Certificates = newCerts
-		}
+			var issuedCertificate *tls.Certificate
 
-		var issuedCertificate *tls.Certificate
+			// Create a new certificate from existing CA if possible
+			for _, rawCert := range c.Certificate {
+				if rawCert.Usage == Certificate_AUTHORITY_ISSUE {
+					newCert, err := issueCertificate(rawCert, domain)
+					if err != nil {
+						newError("failed to issue new certificate for ", domain).Base(err).WriteToLog()
+						continue
+					}
 
-		// Create a new certificate from existing CA if possible
-		for _, rawCert := range c.Certificate {
-			if rawCert.Usage == Certificate_AUTHORITY_ISSUE {
-				newCert, err := issueCertificate(rawCert, domain)
-				if err != nil {
-					newError("failed to issue new certificate for ", domain).Base(err).WriteToLog()
-					continue
+					config.Certificates = append(config.Certificates, *newCert)
+					issuedCertificate = &config.Certificates[len(config.Certificates)-1]
+					break
 				}
-
-				config.Certificates = append(config.Certificates, *newCert)
-				issuedCertificate = &config.Certificates[len(config.Certificates)-1]
-				break
 			}
-		}
 
-		if issuedCertificate == nil {
-			return nil, newError("failed to create a new certificate for ", domain)
-		}
+			if issuedCertificate == nil {
+				return nil, newError("failed to create a new certificate for ", domain)
+			}
 
-		config.BuildNameToCertificate()
+			config.BuildNameToCertificate()
 
-		return issuedCertificate, nil
+			return issuedCertificate, nil
+		}
 	}
+
 	if len(c.ServerName) > 0 {
 		config.ServerName = c.ServerName
 	}