Przeglądaj źródła

fix concurrent access to tls config

Darien Raymond 7 lat temu
rodzic
commit
9a9b6f9077
1 zmienionych plików z 15 dodań i 1 usunięć
  1. 15 1
      transport/internet/tls/config.go

+ 15 - 1
transport/internet/tls/config.go

@@ -4,6 +4,7 @@ import (
 	"context"
 	"crypto/tls"
 	"crypto/x509"
+	"sync"
 	"time"
 
 	"v2ray.com/core/common/net"
@@ -77,10 +78,17 @@ func (c *Config) getCustomCA() []*Certificate {
 }
 
 func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
+	var access sync.RWMutex
+
 	return func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
 		domain := hello.ServerName
 		certExpired := false
-		if certificate, found := c.NameToCertificate[domain]; found {
+
+		access.RLock()
+		certificate, found := c.NameToCertificate[domain]
+		access.RUnlock()
+
+		if found {
 			if !isCertificateExpired(certificate) {
 				return certificate, nil
 			}
@@ -90,6 +98,7 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
 		if certExpired {
 			newCerts := make([]tls.Certificate, 0, len(c.Certificates))
 
+			access.Lock()
 			for _, certificate := range c.Certificates {
 				if !isCertificateExpired(&certificate) {
 					newCerts = append(newCerts, certificate)
@@ -97,6 +106,7 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
 			}
 
 			c.Certificates = newCerts
+			access.Unlock()
 		}
 
 		var issuedCertificate *tls.Certificate
@@ -110,8 +120,10 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
 					continue
 				}
 
+				access.Lock()
 				c.Certificates = append(c.Certificates, *newCert)
 				issuedCertificate = &c.Certificates[len(c.Certificates)-1]
+				access.Unlock()
 				break
 			}
 		}
@@ -120,7 +132,9 @@ func getGetCertificateFunc(c *tls.Config, ca []*Certificate) func(hello *tls.Cli
 			return nil, newError("failed to create a new certificate for ", domain)
 		}
 
+		access.Lock()
 		c.BuildNameToCertificate()
+		access.Unlock()
 
 		return issuedCertificate, nil
 	}