Bladeren bron

simplify auth reader

Darien Raymond 8 jaren geleden
bovenliggende
commit
bcfcba396b
2 gewijzigde bestanden met toevoegingen van 27 en 98 verwijderingen
  1. 23 98
      common/crypto/auth.go
  2. 4 0
      common/crypto/auth_test.go

+ 23 - 98
common/crypto/auth.go

@@ -90,133 +90,58 @@ func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
 
 type AuthenticationReader struct {
 	auth         Authenticator
-	buffer       *buf.Buffer
-	reader       io.Reader
+	reader       *buf.BufferedReader
 	sizeParser   ChunkSizeDecoder
-	size         int
 	transferType protocol.TransferType
 }
 
-const (
-	readerBufferSize = 32 * 1024
-)
-
 func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, transferType protocol.TransferType) *AuthenticationReader {
 	return &AuthenticationReader{
 		auth:         auth,
-		buffer:       buf.NewLocal(readerBufferSize),
-		reader:       reader,
+		reader:       buf.NewBufferedReader(buf.NewReader(reader)),
 		sizeParser:   sizeParser,
-		size:         -1,
 		transferType: transferType,
 	}
 }
 
-func (r *AuthenticationReader) readSize() error {
-	if r.size >= 0 {
-		return nil
-	}
-
-	sizeBytes := r.sizeParser.SizeBytes()
-	if r.buffer.Len() < sizeBytes {
-		if r.buffer.IsEmpty() {
-			r.buffer.Clear()
-		} else {
-			common.Must(r.buffer.Reset(buf.ReadFrom(r.buffer)))
-		}
-
-		delta := sizeBytes - r.buffer.Len()
-		if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil {
-			return err
-		}
-	}
-	size, err := r.sizeParser.Decode(r.buffer.BytesTo(sizeBytes))
+func (r *AuthenticationReader) readSize() (int, error) {
+	sizeBytes := make([]byte, r.sizeParser.SizeBytes())
+	_, err := io.ReadFull(r.reader, sizeBytes)
 	if err != nil {
-		return err
+		return 0, err
 	}
-	r.size = int(size)
-	r.buffer.SliceFrom(sizeBytes)
-	return nil
+	size, err := r.sizeParser.Decode(sizeBytes)
+	return int(size), err
 }
 
-func (r *AuthenticationReader) readChunk(waitForData bool) ([]byte, error) {
-	if err := r.readSize(); err != nil {
+func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
+	size, err := r.readSize()
+	if err != nil {
 		return nil, err
 	}
-	if r.size > readerBufferSize-r.sizeParser.SizeBytes() {
-		return nil, newError("size too large ", r.size).AtWarning()
-	}
 
-	if r.size == r.auth.Overhead() {
+	if size == r.auth.Overhead() {
 		return nil, io.EOF
 	}
 
-	if r.buffer.Len() < r.size {
-		if !waitForData {
-			return nil, io.ErrNoProgress
-		}
-
-		if r.buffer.IsEmpty() {
-			r.buffer.Clear()
-		} else {
-			common.Must(r.buffer.Reset(buf.ReadFrom(r.buffer)))
-		}
-
-		delta := r.size - r.buffer.Len()
-		if err := r.buffer.AppendSupplier(buf.ReadAtLeastFrom(r.reader, delta)); err != nil {
-			return nil, err
-		}
+	var b *buf.Buffer
+	if size <= buf.Size {
+		b = buf.New()
+	} else {
+		b = buf.NewLocal(size)
 	}
-
-	b, err := r.auth.Open(r.buffer.BytesTo(0), r.buffer.BytesTo(r.size))
-	if err != nil {
+	if err := b.Reset(buf.ReadFullFrom(r.reader, size)); err != nil {
+		b.Release()
 		return nil, err
 	}
-	r.buffer.SliceFrom(r.size)
-	r.size = -1
-	return b, nil
-}
 
-func (r *AuthenticationReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
-	b, err := r.readChunk(true)
+	rb, err := r.auth.Open(b.BytesTo(0), b.BytesTo(size))
 	if err != nil {
+		b.Release()
 		return nil, err
 	}
-
-	var mb buf.MultiBuffer
-	if r.transferType == protocol.TransferTypeStream {
-		mb.Write(b)
-	} else {
-		var bb *buf.Buffer
-		if len(b) <= buf.Size {
-			bb = buf.New()
-		} else {
-			bb = buf.NewLocal(len(b))
-		}
-		bb.Append(b)
-		mb.Append(bb)
-	}
-
-	for r.buffer.Len() >= r.sizeParser.SizeBytes() {
-		b, err := r.readChunk(false)
-		if err != nil {
-			break
-		}
-		if r.transferType == protocol.TransferTypeStream {
-			mb.Write(b)
-		} else {
-			var bb *buf.Buffer
-			if len(b) <= buf.Size {
-				bb = buf.New()
-			} else {
-				bb = buf.NewLocal(len(b))
-			}
-			bb.Append(b)
-			mb.Append(bb)
-		}
-	}
-
-	return mb, nil
+	b.Slice(0, len(rb))
+	return buf.NewMultiBufferValue(b), nil
 }
 
 type AuthenticationWriter struct {

+ 4 - 0
common/crypto/auth_test.go

@@ -122,6 +122,10 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
 
 	b1 := mb.SplitFirst()
 	assert(b1.String(), Equals, "abcd")
+	assert(mb.IsEmpty(), IsTrue)
+
+	mb, err = reader.ReadMultiBuffer()
+	assert(err, IsNil)
 	b2 := mb.SplitFirst()
 	assert(b2.String(), Equals, "efgh")
 	assert(mb.IsEmpty(), IsTrue)