Darien Raymond 8 лет назад
Родитель
Сommit
8868fe68ed
4 измененных файлов с 127 добавлено и 44 удалено
  1. 83 29
      common/buf/reader.go
  2. 22 3
      common/buf/writer.go
  3. 19 0
      common/buf/writer_test.go
  4. 3 12
      proxy/http/server.go

+ 83 - 29
common/buf/reader.go

@@ -1,6 +1,10 @@
 package buf
 
-import "io"
+import (
+	"io"
+
+	"v2ray.com/core/common/errors"
+)
 
 // BytesToBufferReader is a Reader that adjusts its reading speed automatically.
 type BytesToBufferReader struct {
@@ -9,20 +13,40 @@ type BytesToBufferReader struct {
 }
 
 // Read implements Reader.Read().
-func (v *BytesToBufferReader) Read() (MultiBuffer, error) {
-	if err := v.buffer.Reset(ReadFrom(v.reader)); err != nil {
+func (r *BytesToBufferReader) Read() (MultiBuffer, error) {
+	if err := r.buffer.Reset(ReadFrom(r.reader)); err != nil {
 		return nil, err
 	}
 
 	mb := NewMultiBuffer()
-	for !v.buffer.IsEmpty() {
+	for !r.buffer.IsEmpty() {
 		b := New()
-		b.AppendSupplier(ReadFrom(v.buffer))
+		b.AppendSupplier(ReadFrom(r.buffer))
 		mb.Append(b)
 	}
 	return mb, nil
 }
 
+func (r *BytesToBufferReader) WriteTo(writer io.Writer) (int64, error) {
+	totalBytes := int64(0)
+	eof := false
+	for !eof {
+		if err := r.buffer.Reset(ReadFrom(r.reader)); err != nil {
+			if errors.Cause(err) == io.EOF {
+				eof = true
+			} else {
+				return totalBytes, err
+			}
+		}
+		nBytes, err := writer.Write(r.buffer.Bytes())
+		totalBytes += int64(nBytes)
+		if err != nil {
+			return totalBytes, err
+		}
+	}
+	return totalBytes, nil
+}
+
 type readerAdpater struct {
 	MultiBufferReader
 }
@@ -38,45 +62,75 @@ type bufferToBytesReader struct {
 }
 
 // fill fills in the internal buffer.
-func (v *bufferToBytesReader) fill() {
-	b, err := v.stream.Read()
+func (r *bufferToBytesReader) fill() {
+	b, err := r.stream.Read()
 	if err != nil {
-		v.err = err
+		r.err = err
 		return
 	}
-	v.current = b
+	r.current = b
 }
 
-func (v *bufferToBytesReader) Read(b []byte) (int, error) {
-	if v.err != nil {
-		return 0, v.err
+func (r *bufferToBytesReader) Read(b []byte) (int, error) {
+	if r.err != nil {
+		return 0, r.err
 	}
 
-	if v.current == nil {
-		v.fill()
-		if v.err != nil {
-			return 0, v.err
+	if r.current == nil {
+		r.fill()
+		if r.err != nil {
+			return 0, r.err
 		}
 	}
-	nBytes, err := v.current.Read(b)
-	if v.current.IsEmpty() {
-		v.current.Release()
-		v.current = nil
+	nBytes, err := r.current.Read(b)
+	if r.current.IsEmpty() {
+		r.current.Release()
+		r.current = nil
 	}
 	return nBytes, err
 }
 
-func (v *bufferToBytesReader) ReadMultiBuffer() (MultiBuffer, error) {
-	if v.err != nil {
-		return nil, v.err
+func (r *bufferToBytesReader) ReadMultiBuffer() (MultiBuffer, error) {
+	if r.err != nil {
+		return nil, r.err
 	}
-	if v.current == nil {
-		v.fill()
-		if v.err != nil {
-			return nil, v.err
+	if r.current == nil {
+		r.fill()
+		if r.err != nil {
+			return nil, r.err
 		}
 	}
-	b := v.current
-	v.current = nil
+	b := r.current
+	r.current = nil
 	return b, nil
 }
+
+func (r *bufferToBytesReader) writeToInternal(writer io.Writer) (int64, error) {
+	if r.err != nil {
+		return 0, r.err
+	}
+
+	mbWriter := NewWriter(writer)
+	totalBytes := int64(0)
+	for {
+		if r.current == nil {
+			r.fill()
+			if r.err != nil {
+				return totalBytes, r.err
+			}
+		}
+		totalBytes := int64(r.current.Len())
+		if err := mbWriter.Write(r.current); err != nil {
+			return totalBytes, err
+		}
+		r.current = nil
+	}
+}
+
+func (r *bufferToBytesReader) WriteTo(writer io.Writer) (int64, error) {
+	nBytes, err := r.writeToInternal(writer)
+	if errors.Cause(err) == io.EOF {
+		return nBytes, nil
+	}
+	return nBytes, err
+}

+ 22 - 3
common/buf/writer.go

@@ -8,8 +8,8 @@ type BufferToBytesWriter struct {
 }
 
 // Write implements Writer.Write(). Write() takes ownership of the given buffer.
-func (v *BufferToBytesWriter) Write(mb MultiBuffer) error {
-	if mw, ok := v.writer.(MultiBufferWriter); ok {
+func (w *BufferToBytesWriter) Write(mb MultiBuffer) error {
+	if mw, ok := w.writer.(MultiBufferWriter); ok {
 		_, err := mw.WriteMultiBuffer(mb)
 		return err
 	}
@@ -17,7 +17,7 @@ func (v *BufferToBytesWriter) Write(mb MultiBuffer) error {
 	defer mb.Release()
 
 	bs := mb.ToNetBuffers()
-	_, err := bs.WriteTo(v.writer)
+	_, err := bs.WriteTo(w.writer)
 	return err
 }
 
@@ -42,3 +42,22 @@ func (w *bytesToBufferWriter) Write(payload []byte) (int, error) {
 func (w *bytesToBufferWriter) WriteMulteBuffer(mb MultiBuffer) (int, error) {
 	return mb.Len(), w.writer.Write(mb)
 }
+
+func (w *bytesToBufferWriter) ReadFrom(reader io.Reader) (int64, error) {
+	mbReader := NewReader(reader)
+	totalBytes := int64(0)
+	eof := false
+	for !eof {
+		mb, err := mbReader.Read()
+		if err == io.EOF {
+			eof = true
+		} else if err != nil {
+			return totalBytes, err
+		}
+		totalBytes += int64(mb.Len())
+		if err := w.writer.Write(mb); err != nil {
+			return totalBytes, err
+		}
+	}
+	return totalBytes, nil
+}

+ 19 - 0
common/buf/writer_test.go

@@ -1,12 +1,17 @@
 package buf_test
 
 import (
+	"bufio"
 	"bytes"
 	"crypto/rand"
 	"testing"
 
+	"context"
+	"io"
+
 	. "v2ray.com/core/common/buf"
 	"v2ray.com/core/testing/assert"
+	"v2ray.com/core/transport/ray"
 )
 
 func TestWriter(t *testing.T) {
@@ -24,3 +29,17 @@ func TestWriter(t *testing.T) {
 	assert.Error(err).IsNil()
 	assert.Bytes(expectedBytes).Equals(writeBuffer.Bytes())
 }
+
+func TestBytesWriterReadFrom(t *testing.T) {
+	assert := assert.On(t)
+
+	cache := ray.NewStream(context.Background())
+	reader := bufio.NewReader(io.LimitReader(rand.Reader, 8192))
+	_, err := reader.WriteTo(ToBytesWriter(cache))
+	assert.Error(err).IsNil()
+
+	mb, err := cache.Read()
+	assert.Error(err).IsNil()
+	assert.Int(mb.Len()).Equals(8192)
+	assert.Int(len(mb)).Equals(4)
+}

+ 3 - 12
proxy/http/server.go

@@ -235,12 +235,8 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, rea
 	requestDone := signal.ExecuteAsync(func() error {
 		request.Header.Set("Connection", "close")
 
-		requestWriter := buf.NewBufferedWriter(buf.ToBytesWriter(ray.InboundInput()))
-		err := request.Write(requestWriter)
-		if err != nil {
-			return err
-		}
-		if err := requestWriter.Flush(); err != nil {
+		requestWriter := buf.ToBytesWriter(ray.InboundInput())
+		if err := request.Write(requestWriter); err != nil {
 			return err
 		}
 		return nil
@@ -271,12 +267,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, rea
 			response.Header.Set("Connection", "close")
 			response.Header.Set("Proxy-Connection", "close")
 		}
-		responseWriter := buf.NewBufferedWriter(writer)
-		if err := response.Write(responseWriter); err != nil {
-			return err
-		}
-
-		if err := responseWriter.Flush(); err != nil {
+		if err := response.Write(writer); err != nil {
 			return err
 		}
 		return nil