Selaa lähdekoodia

don't copy leftoever if at head

Darien Raymond 8 vuotta sitten
vanhempi
commit
2897df5a7a
2 muutettua tiedostoa jossa 10 lisäystä ja 1 poistoa
  1. 3 1
      common/crypto/auth.go
  2. 7 0
      common/crypto/auth_test.go

+ 3 - 1
common/crypto/auth.go

@@ -128,8 +128,10 @@ func (v *AuthenticationReader) CopyChunk(b []byte) int {
 }
 
 func (v *AuthenticationReader) EnsureChunk() error {
+	atHead := false
 	if v.buffer.IsEmpty() {
 		v.buffer.Clear()
+		atHead = true
 	}
 
 	for {
@@ -139,7 +141,7 @@ func (v *AuthenticationReader) EnsureChunk() error {
 		}
 
 		leftover := v.buffer.Bytes()
-		if len(leftover) > 0 {
+		if !atHead && len(leftover) > 0 {
 			common.Must(v.buffer.Reset(func(b []byte) (int, error) {
 				return copy(b, leftover), nil
 			}))

+ 7 - 0
common/crypto/auth_test.go

@@ -91,6 +91,8 @@ func TestAuthenticationReaderWriterPartial(t *testing.T) {
 		AdditionalDataGenerator: &NoOpBytesGenerator{},
 	}, cache)
 
+	writer.Write([]byte{'a', 'b', 'c', 'd'})
+
 	nBytes, err := writer.Write(payload)
 	assert.Error(err).IsNil()
 	assert.Int(nBytes).Equals(len(payload))
@@ -122,6 +124,11 @@ func TestAuthenticationReaderWriterPartial(t *testing.T) {
 	actualPayload := make([]byte, 7*1024)
 	nBytes, err = reader.Read(actualPayload)
 	assert.Error(err).IsNil()
+	assert.Int(nBytes).Equals(4)
+	assert.Bytes(actualPayload[:nBytes]).Equals([]byte{'a', 'b', 'c', 'd'})
+
+	nBytes, err = reader.Read(actualPayload)
+	assert.Error(err).IsNil()
 	assert.Int(nBytes).Equals(len(actualPayload))
 	assert.Bytes(actualPayload[:nBytes]).Equals(payload[:nBytes])