浏览代码

refactor auth reader

Darien Raymond 9 年之前
父节点
当前提交
a92df58017
共有 2 个文件被更改,包括 108 次插入41 次删除
  1. 102 40
      common/crypto/auth.go
  2. 6 1
      proxy/vmess/inbound/inbound.go

+ 102 - 40
common/crypto/auth.go

@@ -4,34 +4,86 @@ import (
 	"crypto/cipher"
 	"errors"
 	"io"
+
 	"v2ray.com/core/common/alloc"
 	"v2ray.com/core/common/serial"
 )
 
 var (
 	ErrAuthenticationFailed = errors.New("Authentication failed.")
-	errInsufficientBuffer   = errors.New("Insufficient buffer.")
+
+	errInsufficientBuffer = errors.New("Insufficient buffer.")
+	errInvalidNonce       = errors.New("Invalid nonce.")
 )
 
-type BytesGenerator func() []byte
+type BytesGenerator interface {
+	Next() []byte
+}
+
+type NoOpBytesGenerator struct {
+	buffer [1]byte
+}
+
+func (v NoOpBytesGenerator) Next() []byte {
+	return v.buffer[:0]
+}
+
+type StaticBytesGenerator struct {
+	Content []byte
+}
+
+func (v StaticBytesGenerator) Next() []byte {
+	return v.Content
+}
+
+type Authenticator interface {
+	NonceSize() int
+	Overhead() int
+	Open(dst, cipherText []byte) ([]byte, error)
+	Seal(dst, plainText []byte) ([]byte, error)
+}
+
+type AEADAuthenticator struct {
+	cipher.AEAD
+	NonceGenerator          BytesGenerator
+	AdditionalDataGenerator BytesGenerator
+}
+
+func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
+	iv := v.NonceGenerator.Next()
+	if len(iv) != v.AEAD.NonceSize() {
+		return nil, errInvalidNonce
+	}
+
+	additionalData := v.AdditionalDataGenerator.Next()
+	return v.AEAD.Open(dst, iv, cipherText, additionalData)
+}
+
+func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
+	iv := v.NonceGenerator.Next()
+	if len(iv) != v.AEAD.NonceSize() {
+		return nil, errInvalidNonce
+	}
+
+	additionalData := v.AdditionalDataGenerator.Next()
+	return v.AEAD.Seal(dst, iv, plainText, additionalData), nil
+}
 
 type AuthenticationReader struct {
-	aead     cipher.AEAD
-	buffer   *alloc.Buffer
-	reader   io.Reader
-	ivGen    BytesGenerator
-	extraGen BytesGenerator
+	auth   Authenticator
+	buffer *alloc.Buffer
+	reader io.Reader
 
-	chunk []byte
+	chunk      []byte
+	aggressive bool
 }
 
-func NewAuthenticationReader(aead cipher.AEAD, reader io.Reader, ivGen BytesGenerator, extraGen BytesGenerator) *AuthenticationReader {
+func NewAuthenticationReader(auth Authenticator, reader io.Reader, aggressive bool) *AuthenticationReader {
 	return &AuthenticationReader{
-		aead:     aead,
-		buffer:   alloc.NewLocalBuffer(32 * 1024),
-		reader:   reader,
-		ivGen:    ivGen,
-		extraGen: extraGen,
+		auth:       auth,
+		buffer:     alloc.NewLocalBuffer(32 * 1024),
+		reader:     reader,
+		aggressive: aggressive,
 	}
 }
 
@@ -43,11 +95,11 @@ func (v *AuthenticationReader) NextChunk() error {
 	if size > v.buffer.Len()-2 {
 		return errInsufficientBuffer
 	}
-	if size == v.aead.Overhead() {
+	if size == v.auth.Overhead() {
 		return io.EOF
 	}
 	cipherChunk := v.buffer.BytesRange(2, size+2)
-	plainChunk, err := v.aead.Open(cipherChunk, v.ivGen(), cipherChunk, v.extraGen())
+	plainChunk, err := v.auth.Open(cipherChunk, cipherChunk)
 	if err != nil {
 		return err
 	}
@@ -57,6 +109,9 @@ func (v *AuthenticationReader) NextChunk() error {
 }
 
 func (v *AuthenticationReader) CopyChunk(b []byte) int {
+	if len(v.chunk) == 0 {
+		return 0
+	}
 	nBytes := copy(b, v.chunk)
 	if nBytes == len(v.chunk) {
 		v.chunk = nil
@@ -72,49 +127,56 @@ func (v *AuthenticationReader) Read(b []byte) (int, error) {
 		return nBytes, nil
 	}
 
-	err := v.NextChunk()
-	if err == errInsufficientBuffer {
-		_, err = v.buffer.FillFrom(v.reader)
-	}
-
-	if err != nil {
-		return 0, err
-	}
-
 	totalBytes := 0
 	for {
-		totalBytes += v.CopyChunk(b)
-		if len(b) == 0 {
-			break
+		err := v.NextChunk()
+		if err == errInsufficientBuffer {
+			if totalBytes > 0 {
+				return totalBytes, nil
+			}
+			leftover := v.buffer.Bytes()
+			v.buffer.SetBytesFunc(func(b []byte) int {
+				return copy(b, leftover)
+			})
+			_, err = v.buffer.FillFrom(v.reader)
 		}
-		if err := v.NextChunk(); err != nil {
-			break
+
+		if err != nil {
+			return 0, err
+		}
+
+		nBytes := v.CopyChunk(b)
+		b = b[nBytes:]
+		totalBytes += nBytes
+
+		if !v.aggressive {
+			return totalBytes, nil
 		}
 	}
-	return totalBytes, nil
 }
 
 type AuthenticationWriter struct {
-	aead     cipher.AEAD
+	auth     Authenticator
 	buffer   []byte
 	writer   io.Writer
 	ivGen    BytesGenerator
 	extraGen BytesGenerator
 }
 
-func NewAuthenticationWriter(aead cipher.AEAD, writer io.Writer, ivGen BytesGenerator, extraGen BytesGenerator) *AuthenticationWriter {
+func NewAuthenticationWriter(auth Authenticator, writer io.Writer) *AuthenticationWriter {
 	return &AuthenticationWriter{
-		aead:     aead,
-		buffer:   make([]byte, 32*1024),
-		writer:   writer,
-		ivGen:    ivGen,
-		extraGen: extraGen,
+		auth:   auth,
+		buffer: make([]byte, 32*1024),
+		writer: writer,
 	}
 }
 
 func (v *AuthenticationWriter) Write(b []byte) (int, error) {
-	cipherChunk := v.aead.Seal(v.buffer[2:], v.ivGen(), b, v.extraGen())
+	cipherChunk, err := v.auth.Seal(v.buffer[2:], b)
+	if err != nil {
+		return 0, err
+	}
 	serial.Uint16ToBytes(uint16(len(cipherChunk)), b[:0])
-	_, err := v.writer.Write(v.buffer[:2+len(cipherChunk)])
+	_, err = v.writer.Write(v.buffer[:2+len(cipherChunk)])
 	return len(b), err
 }

+ 6 - 1
proxy/vmess/inbound/inbound.go

@@ -190,7 +190,12 @@ func (v *VMessInboundHandler) HandleConnection(connection internet.Connection) {
 		bodyReader := session.DecodeRequestBody(reader)
 		var requestReader v2io.Reader
 		if request.Option.Has(protocol.RequestOptionChunkStream) {
-			authReader := crypto.NewAuthenticationReader(new(encoding.FnvAuthenticator), bodyReader, func() []byte { return nil }, func() []byte { return nil })
+			auth := &crypto.AEADAuthenticator{
+				AEAD:                    new(encoding.FnvAuthenticator),
+				NonceGenerator:          crypto.NoOpBytesGenerator{},
+				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
+			}
+			authReader := crypto.NewAuthenticationReader(auth, bodyReader, request.Command == protocol.RequestCommandTCP)
 			requestReader = v2io.NewAdaptiveReader(authReader)
 		} else {
 			requestReader = v2io.NewAdaptiveReader(bodyReader)