瀏覽代碼

refine authentication reader

Darien Raymond 8 年之前
父節點
當前提交
47c3646162
共有 2 個文件被更改,包括 34 次插入27 次删除
  1. 21 25
      common/crypto/auth.go
  2. 13 2
      common/crypto/auth_test.go

+ 21 - 25
common/crypto/auth.go

@@ -2,20 +2,16 @@ package crypto
 
 import (
 	"crypto/cipher"
-	"errors"
 	"io"
 
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/serial"
 )
 
 var (
-	ErrAuthenticationFailed = errors.New("Authentication failed.")
-
 	errInsufficientBuffer = errors.New("Insufficient buffer.")
-	errInvalidNonce       = errors.New("Invalid nonce.")
-	errInvalidLength      = errors.New("Invalid buffer size.")
 )
 
 type BytesGenerator interface {
@@ -54,7 +50,7 @@ type AEADAuthenticator struct {
 func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
 	iv := v.NonceGenerator.Next()
 	if len(iv) != v.AEAD.NonceSize() {
-		return nil, errInvalidNonce
+		return nil, errors.New("Crypto:AEADAuthenticator: Invalid nonce size: ", len(iv))
 	}
 
 	additionalData := v.AdditionalDataGenerator.Next()
@@ -64,7 +60,7 @@ func (v *AEADAuthenticator) Open(dst, cipherText []byte) ([]byte, error) {
 func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
 	iv := v.NonceGenerator.Next()
 	if len(iv) != v.AEAD.NonceSize() {
-		return nil, errInvalidNonce
+		return nil, errors.New("Crypto:AEADAuthenticator: Invalid nonce size: ", len(iv))
 	}
 
 	additionalData := v.AdditionalDataGenerator.Next()
@@ -100,13 +96,13 @@ func (v *AuthenticationReader) NextChunk() error {
 		return errInsufficientBuffer
 	}
 	if size > readerBufferSize-2 {
-		return errInvalidLength
+		return errors.New("Crypto:AuthenticationReader: Size too large: ", size)
 	}
 	if size == v.auth.Overhead() {
 		return io.EOF
 	}
 	if size < v.auth.Overhead() {
-		return errors.New("AuthenticationReader: invalid packet size.")
+		return errors.New("AuthenticationReader: invalid packet size:", size)
 	}
 	cipherChunk := v.buffer.BytesRange(2, size+2)
 	plainChunk, err := v.auth.Open(cipherChunk[:0], cipherChunk)
@@ -132,26 +128,26 @@ func (v *AuthenticationReader) CopyChunk(b []byte) int {
 }
 
 func (v *AuthenticationReader) EnsureChunk() error {
+	if v.buffer.IsEmpty() {
+		v.buffer.Clear()
+	}
+
 	for {
 		err := v.NextChunk()
-		if err == nil {
-			return nil
+		if err != errInsufficientBuffer {
+			return err
 		}
-		if err == errInsufficientBuffer {
-			if v.buffer.IsEmpty() {
-				v.buffer.Clear()
-			} else {
-				leftover := v.buffer.Bytes()
-				common.Must(v.buffer.Reset(func(b []byte) (int, error) {
-					return copy(b, leftover), nil
-				}))
-			}
-			err = v.buffer.AppendSupplier(buf.ReadFrom(v.reader))
-			if err == nil {
-				continue
-			}
+
+		leftover := v.buffer.Bytes()
+		if len(leftover) > 0 {
+			common.Must(v.buffer.Reset(func(b []byte) (int, error) {
+				return copy(b, leftover), nil
+			}))
+		}
+
+		if err := v.buffer.AppendSupplier(buf.ReadFrom(v.reader)); err != nil {
+			return err
 		}
-		return err
 	}
 }
 

+ 13 - 2
common/crypto/auth_test.go

@@ -7,6 +7,8 @@ import (
 	"io"
 	"testing"
 
+	"time"
+
 	"v2ray.com/core/common/buf"
 	. "v2ray.com/core/common/crypto"
 	"v2ray.com/core/testing/assert"
@@ -77,10 +79,10 @@ func TestAuthenticationReaderWriterPartial(t *testing.T) {
 	payload := make([]byte, 8*1024)
 	rand.Read(payload)
 
-	cache := buf.NewLocal(16 * 1024)
 	iv := make([]byte, 12)
 	rand.Read(iv)
 
+	cache := buf.NewLocal(16 * 1024)
 	writer := NewAuthenticationWriter(&AEADAuthenticator{
 		AEAD: aead,
 		NonceGenerator: &StaticBytesGenerator{
@@ -96,13 +98,22 @@ func TestAuthenticationReaderWriterPartial(t *testing.T) {
 	_, err = writer.Write([]byte{})
 	assert.Error(err).IsNil()
 
+	pr, pw := io.Pipe()
+	go func() {
+		pw.Write(cache.BytesTo(1024))
+		time.Sleep(time.Second * 2)
+		pw.Write(cache.BytesFrom(1024))
+		time.Sleep(time.Second * 2)
+		pw.Close()
+	}()
+
 	reader := NewAuthenticationReader(&AEADAuthenticator{
 		AEAD: aead,
 		NonceGenerator: &StaticBytesGenerator{
 			Content: iv,
 		},
 		AdditionalDataGenerator: &NoOpBytesGenerator{},
-	}, cache)
+	}, pr)
 
 	actualPayload := make([]byte, 7*1024)
 	nBytes, err = reader.Read(actualPayload)