浏览代码

don't copy leftoever if at head

Darien Raymond 8 年之前
父节点
当前提交
2897df5a7a
共有 2 个文件被更改,包括 10 次插入1 次删除
  1. 3 1
      common/crypto/auth.go
  2. 7 0
      common/crypto/auth_test.go

+ 3 - 1
common/crypto/auth.go

@@ -128,8 +128,10 @@ func (v *AuthenticationReader) CopyChunk(b []byte) int {
 }
 }
 
 
 func (v *AuthenticationReader) EnsureChunk() error {
 func (v *AuthenticationReader) EnsureChunk() error {
+	atHead := false
 	if v.buffer.IsEmpty() {
 	if v.buffer.IsEmpty() {
 		v.buffer.Clear()
 		v.buffer.Clear()
+		atHead = true
 	}
 	}
 
 
 	for {
 	for {
@@ -139,7 +141,7 @@ func (v *AuthenticationReader) EnsureChunk() error {
 		}
 		}
 
 
 		leftover := v.buffer.Bytes()
 		leftover := v.buffer.Bytes()
-		if len(leftover) > 0 {
+		if !atHead && len(leftover) > 0 {
 			common.Must(v.buffer.Reset(func(b []byte) (int, error) {
 			common.Must(v.buffer.Reset(func(b []byte) (int, error) {
 				return copy(b, leftover), nil
 				return copy(b, leftover), nil
 			}))
 			}))

+ 7 - 0
common/crypto/auth_test.go

@@ -91,6 +91,8 @@ func TestAuthenticationReaderWriterPartial(t *testing.T) {
 		AdditionalDataGenerator: &NoOpBytesGenerator{},
 		AdditionalDataGenerator: &NoOpBytesGenerator{},
 	}, cache)
 	}, cache)
 
 
+	writer.Write([]byte{'a', 'b', 'c', 'd'})
+
 	nBytes, err := writer.Write(payload)
 	nBytes, err := writer.Write(payload)
 	assert.Error(err).IsNil()
 	assert.Error(err).IsNil()
 	assert.Int(nBytes).Equals(len(payload))
 	assert.Int(nBytes).Equals(len(payload))
@@ -122,6 +124,11 @@ func TestAuthenticationReaderWriterPartial(t *testing.T) {
 	actualPayload := make([]byte, 7*1024)
 	actualPayload := make([]byte, 7*1024)
 	nBytes, err = reader.Read(actualPayload)
 	nBytes, err = reader.Read(actualPayload)
 	assert.Error(err).IsNil()
 	assert.Error(err).IsNil()
+	assert.Int(nBytes).Equals(4)
+	assert.Bytes(actualPayload[:nBytes]).Equals([]byte{'a', 'b', 'c', 'd'})
+
+	nBytes, err = reader.Read(actualPayload)
+	assert.Error(err).IsNil()
 	assert.Int(nBytes).Equals(len(actualPayload))
 	assert.Int(nBytes).Equals(len(actualPayload))
 	assert.Bytes(actualPayload[:nBytes]).Equals(payload[:nBytes])
 	assert.Bytes(actualPayload[:nBytes]).Equals(payload[:nBytes])