Sfoglia il codice sorgente

handle read with data and error at same time

Darien Raymond 8 anni fa
parent
commit
2fdccb2424
4 ha cambiato i file con 72 aggiunte e 43 eliminazioni
  1. 10 14
      common/buf/copy.go
  2. 36 18
      common/buf/reader.go
  3. 21 8
      common/buf/writer.go
  4. 5 3
      common/buf/writer_test.go

+ 10 - 14
common/buf/copy.go

@@ -82,21 +82,17 @@ func CountSize(sc *SizeCounter) CopyOption {
 func copyInternal(reader Reader, writer Writer, handler *copyHandler) error {
 	for {
 		buffer, err := handler.readFrom(reader)
-		if err != nil {
-			return err
-		}
-
-		if buffer.IsEmpty() {
-			buffer.Release()
-			continue
+		if !buffer.IsEmpty() {
+			for _, handler := range handler.onData {
+				handler(buffer)
+			}
+
+			if werr := handler.writeTo(writer, buffer); werr != nil {
+				buffer.Release()
+				return werr
+			}
 		}
-
-		for _, handler := range handler.onData {
-			handler(buffer)
-		}
-
-		if err := handler.writeTo(writer, buffer); err != nil {
-			buffer.Release()
+		if err != nil {
 			return err
 		}
 	}

+ 36 - 18
common/buf/reader.go

@@ -6,6 +6,11 @@ import (
 	"v2ray.com/core/common/errors"
 )
 
+var (
+	_ Reader    = (*BytesToBufferReader)(nil)
+	_ io.Reader = (*BytesToBufferReader)(nil)
+)
+
 // BytesToBufferReader is a Reader that adjusts its reading speed automatically.
 type BytesToBufferReader struct {
 	io.Reader
@@ -37,15 +42,21 @@ func (r *BytesToBufferReader) ReadMultiBuffer() (MultiBuffer, error) {
 	}
 
 	nBytes, err := r.Reader.Read(r.buffer)
-	if err != nil {
-		return nil, err
+	if nBytes > 0 {
+		mb := NewMultiBufferCap(nBytes/Size + 1)
+		mb.Write(r.buffer[:nBytes])
+		return mb, err
 	}
-
-	mb := NewMultiBufferCap(nBytes/Size + 1)
-	mb.Write(r.buffer[:nBytes])
-	return mb, nil
+	return nil, err
 }
 
+var (
+	_ Reader        = (*BufferedReader)(nil)
+	_ io.Reader     = (*BufferedReader)(nil)
+	_ io.ByteReader = (*BufferedReader)(nil)
+	_ io.WriterTo   = (*BufferedReader)(nil)
+)
+
 type BufferedReader struct {
 	stream       Reader
 	legacyReader io.Reader
@@ -72,6 +83,12 @@ func (r *BufferedReader) IsBuffered() bool {
 	return r.buffered
 }
 
+func (r *BufferedReader) ReadByte() (byte, error) {
+	var b [1]byte
+	_, err := r.Read(b[:])
+	return b[0], err
+}
+
 func (r *BufferedReader) Read(b []byte) (int, error) {
 	if r.leftOver != nil {
 		nBytes, _ := r.leftOver.Read(b)
@@ -87,15 +104,14 @@ func (r *BufferedReader) Read(b []byte) (int, error) {
 	}
 
 	mb, err := r.stream.ReadMultiBuffer()
-	if err != nil {
-		return 0, err
-	}
-
-	nBytes, _ := mb.Read(b)
-	if !mb.IsEmpty() {
-		r.leftOver = mb
+	if mb != nil {
+		nBytes, _ := mb.Read(b)
+		if !mb.IsEmpty() {
+			r.leftOver = mb
+		}
+		return nBytes, err
 	}
-	return nBytes, nil
+	return 0, err
 }
 
 func (r *BufferedReader) ReadMultiBuffer() (MultiBuffer, error) {
@@ -120,11 +136,13 @@ func (r *BufferedReader) writeToInternal(writer io.Writer) (int64, error) {
 
 	for {
 		mb, err := r.stream.ReadMultiBuffer()
-		if err != nil {
-			return totalBytes, err
+		if mb != nil {
+			totalBytes += int64(mb.Len())
+			if werr := mbWriter.WriteMultiBuffer(mb); werr != nil {
+				return totalBytes, err
+			}
 		}
-		totalBytes += int64(mb.Len())
-		if err := mbWriter.WriteMultiBuffer(mb); err != nil {
+		if err != nil {
 			return totalBytes, err
 		}
 	}

+ 21 - 8
common/buf/writer.go

@@ -6,6 +6,12 @@ import (
 	"v2ray.com/core/common/errors"
 )
 
+var (
+	_ io.ReaderFrom = (*BufferToBytesWriter)(nil)
+	_ io.Writer     = (*BufferToBytesWriter)(nil)
+	_ Writer        = (*BufferToBytesWriter)(nil)
+)
+
 // BufferToBytesWriter is a Writer that writes alloc.Buffer into underlying writer.
 type BufferToBytesWriter struct {
 	io.Writer
@@ -33,6 +39,13 @@ func (w *BufferToBytesWriter) ReadFrom(reader io.Reader) (int64, error) {
 	return sc.Size, err
 }
 
+var (
+	_ io.ReaderFrom = (*BufferedWriter)(nil)
+	_ io.Writer     = (*BufferedWriter)(nil)
+	_ Writer        = (*BufferedWriter)(nil)
+	_ io.ByteWriter = (*BufferedWriter)(nil)
+)
+
 // BufferedWriter is a Writer with internal buffer.
 type BufferedWriter struct {
 	writer       Writer
@@ -54,6 +67,11 @@ func NewBufferedWriter(writer Writer) *BufferedWriter {
 	return w
 }
 
+func (w *BufferedWriter) WriteByte(c byte) error {
+	_, err := w.Write([]byte{c})
+	return err
+}
+
 // Write implements io.Writer.
 func (w *BufferedWriter) Write(b []byte) (int, error) {
 	if !w.buffered && w.legacyWriter != nil {
@@ -130,17 +148,12 @@ func (w *BufferedWriter) SetBuffered(f bool) error {
 
 // ReadFrom implements io.ReaderFrom.
 func (w *BufferedWriter) ReadFrom(reader io.Reader) (int64, error) {
-	var sc SizeCounter
-	if !w.buffer.IsEmpty() {
-		sc.Size += int64(w.buffer.Len())
-		if err := w.Flush(); err != nil {
-			return sc.Size, err
-		}
+	if err := w.SetBuffered(false); err != nil {
+		return 0, err
 	}
 
-	w.buffered = false
+	var sc SizeCounter
 	err := Copy(NewReader(reader), w, CountSize(&sc))
-
 	return sc.Size, err
 }
 

+ 5 - 3
common/buf/writer_test.go

@@ -37,15 +37,17 @@ func TestBytesWriterReadFrom(t *testing.T) {
 	assert := With(t)
 
 	cache := ray.NewStream(context.Background())
-	reader := bufio.NewReader(io.LimitReader(rand.Reader, 8192))
+	const size = 50000
+	reader := bufio.NewReader(io.LimitReader(rand.Reader, size))
 	writer := NewBufferedWriter(cache)
 	writer.SetBuffered(false)
-	_, err := reader.WriteTo(writer)
+	nBytes, err := reader.WriteTo(writer)
+	assert(nBytes, Equals, int64(size))
 	assert(err, IsNil)
 
 	mb, err := cache.ReadMultiBuffer()
 	assert(err, IsNil)
-	assert(mb.Len(), Equals, 8192)
+	assert(mb.Len(), Equals, size)
 }
 
 func TestDiscardBytes(t *testing.T) {