浏览代码

fix auth reader

Darien Raymond 9 年之前
父节点
当前提交
417284ed99
共有 2 个文件被更改,包括 67 次插入1 次删除
  1. 3 1
      common/crypto/auth.go
  2. 64 0
      common/crypto/auth_test.go

+ 3 - 1
common/crypto/auth.go

@@ -139,7 +139,9 @@ func (v *AuthenticationReader) EnsureChunk() error {
 					return copy(b, leftover), nil
 				})
 			}
-			err = v.buffer.AppendSupplier(buf.ReadFrom(v.reader))
+			if err := v.buffer.AppendSupplier(buf.ReadFrom(v.reader)); err == nil {
+				continue
+			}
 		}
 		return err
 	}

+ 64 - 0
common/crypto/auth_test.go

@@ -0,0 +1,64 @@
+package crypto_test
+
+import (
+	"crypto/aes"
+	"crypto/cipher"
+	"crypto/rand"
+	"io"
+	"testing"
+
+	"v2ray.com/core/common/buf"
+	. "v2ray.com/core/common/crypto"
+	"v2ray.com/core/testing/assert"
+)
+
+func TestAuthenticationReaderWriter(t *testing.T) {
+	assert := assert.On(t)
+
+	key := make([]byte, 16)
+	rand.Read(key)
+	block, err := aes.NewCipher(key)
+	assert.Error(err).IsNil()
+
+	aead, err := cipher.NewGCM(block)
+	assert.Error(err).IsNil()
+
+	payload := make([]byte, 8*1024)
+	rand.Read(payload)
+
+	cache := buf.NewLocal(16 * 1024)
+	iv := make([]byte, 12)
+	rand.Read(iv)
+
+	writer := NewAuthenticationWriter(&AEADAuthenticator{
+		AEAD: aead,
+		NonceGenerator: &StaticBytesGenerator{
+			Content: iv,
+		},
+		AdditionalDataGenerator: &NoOpBytesGenerator{},
+	}, cache)
+
+	nBytes, err := writer.Write(payload)
+	assert.Error(err).IsNil()
+	assert.Int(nBytes).Equals(len(payload))
+	assert.Int(cache.Len()).GreaterThan(0)
+	_, err = writer.Write([]byte{})
+	assert.Error(err).IsNil()
+
+	reader := NewAuthenticationReader(&AEADAuthenticator{
+		AEAD: aead,
+		NonceGenerator: &StaticBytesGenerator{
+			Content: iv,
+		},
+		AdditionalDataGenerator: &NoOpBytesGenerator{},
+	}, cache, false)
+
+	actualPayload := make([]byte, 16*1024)
+	nBytes, err = reader.Read(actualPayload)
+	assert.Error(err).IsNil()
+	assert.Int(nBytes).Equals(len(payload))
+	//assert.Bytes(actualPayload[:nBytes]).Equals(payload)
+
+	_, err = reader.Read(actualPayload)
+	assert.Error(err).Equals(io.EOF)
+}