Browse Source

refactored cert pin

Shelikhoo 4 years ago
parent
commit
ebb720804d
3 changed files with 37 additions and 28 deletions
  1. 1 13
      infra/control/certchainhash.go
  2. 0 15
      transport/internet/tls/config.go
  3. 36 0
      transport/internet/tls/pin.go

+ 1 - 13
infra/control/certchainhash.go

@@ -1,8 +1,6 @@
 package control
 
 import (
-	"encoding/base64"
-	"encoding/pem"
 	"flag"
 	"fmt"
 	"io/ioutil"
@@ -38,17 +36,7 @@ func (c CertificateChainHashCommand) Execute(args []string) error {
 	if err != nil {
 		return err
 	}
-	var certChain [][]byte
-	for {
-		block, remain := pem.Decode(certContent)
-		if block == nil {
-			break
-		}
-		certChain = append(certChain, block.Bytes)
-		certContent = remain
-	}
-	certChainHash := v2tls.GenerateCertChainHash(certChain)
-	certChainHashB64 := base64.StdEncoding.EncodeToString(certChainHash)
+	certChainHashB64 := v2tls.CalculatePEMCertChainSHA256Hash(certContent)
 	fmt.Println(certChainHashB64)
 	return nil
 }

+ 0 - 15
transport/internet/tls/config.go

@@ -4,7 +4,6 @@ package tls
 
 import (
 	"crypto/hmac"
-	"crypto/sha256"
 	"crypto/tls"
 	"crypto/x509"
 	"encoding/base64"
@@ -186,20 +185,6 @@ func (c *Config) verifyPeerCert(rawCerts [][]byte, verifiedChains [][]*x509.Cert
 	return nil
 }
 
-func GenerateCertChainHash(rawCerts [][]byte) []byte {
-	var hashValue []byte
-	for _, certValue := range rawCerts {
-		out := sha256.Sum256(certValue)
-		if hashValue == nil {
-			hashValue = out[:]
-		} else {
-			newHashValue := sha256.Sum256(append(hashValue, out[:]...))
-			hashValue = newHashValue[:]
-		}
-	}
-	return hashValue
-}
-
 // GetTLSConfig converts this Config into tls.Config.
 func (c *Config) GetTLSConfig(opts ...Option) *tls.Config {
 	root, err := c.getCertPool()

+ 36 - 0
transport/internet/tls/pin.go

@@ -0,0 +1,36 @@
+package tls
+
+import (
+	"crypto/sha256"
+	"encoding/base64"
+	"encoding/pem"
+)
+
+func CalculatePEMCertChainSHA256Hash(certContent []byte) string {
+	var certChain [][]byte
+	for {
+		block, remain := pem.Decode(certContent)
+		if block == nil {
+			break
+		}
+		certChain = append(certChain, block.Bytes)
+		certContent = remain
+	}
+	certChainHash := GenerateCertChainHash(certChain)
+	certChainHashB64 := base64.StdEncoding.EncodeToString(certChainHash)
+	return certChainHashB64
+}
+
+func GenerateCertChainHash(rawCerts [][]byte) []byte {
+	var hashValue []byte
+	for _, certValue := range rawCerts {
+		out := sha256.Sum256(certValue)
+		if hashValue == nil {
+			hashValue = out[:]
+		} else {
+			newHashValue := sha256.Sum256(append(hashValue, out[:]...))
+			hashValue = newHashValue[:]
+		}
+	}
+	return hashValue
+}