Ver código fonte

simplify ray stream

Darien Raymond 8 anos atrás
pai
commit
2f565bfd5e

+ 0 - 34
common/buf/merge_reader.go

@@ -1,34 +0,0 @@
-package buf
-
-type MergingReader struct {
-	reader        Reader
-	timeoutReader TimeoutReader
-}
-
-func NewMergingReader(reader Reader) Reader {
-	return &MergingReader{
-		reader:        reader,
-		timeoutReader: reader.(TimeoutReader),
-	}
-}
-
-func (r *MergingReader) Read() (MultiBuffer, error) {
-	mb, err := r.reader.Read()
-	if err != nil {
-		return nil, err
-	}
-
-	if r.timeoutReader == nil {
-		return mb, nil
-	}
-
-	for {
-		mb2, err := r.timeoutReader.ReadTimeout(0)
-		if err != nil {
-			break
-		}
-		mb.AppendMulti(mb2)
-	}
-
-	return mb, nil
-}

+ 0 - 33
common/buf/merge_reader_test.go

@@ -1,33 +0,0 @@
-package buf_test
-
-import (
-	"testing"
-
-	"context"
-
-	. "v2ray.com/core/common/buf"
-	"v2ray.com/core/testing/assert"
-	"v2ray.com/core/transport/ray"
-)
-
-func TestMergingReader(t *testing.T) {
-	assert := assert.On(t)
-
-	stream := ray.NewStream(context.Background())
-	b1 := New()
-	b1.AppendBytes('a', 'b', 'c')
-	stream.Write(NewMultiBufferValue(b1))
-
-	b2 := New()
-	b2.AppendBytes('e', 'f', 'g')
-	stream.Write(NewMultiBufferValue(b2))
-
-	b3 := New()
-	b3.AppendBytes('h', 'i', 'j')
-	stream.Write(NewMultiBufferValue(b3))
-
-	reader := NewMergingReader(stream)
-	b, err := reader.Read()
-	assert.Error(err).IsNil()
-	assert.Int(b.Len()).Equals(9)
-}

+ 1 - 1
common/buf/multi_buffer.go

@@ -13,7 +13,7 @@ type MultiBufferReader interface {
 type MultiBuffer []*Buffer
 
 func NewMultiBuffer() MultiBuffer {
-	return MultiBuffer(make([]*Buffer, 0, 8))
+	return MultiBuffer(make([]*Buffer, 0, 32))
 }
 
 func NewMultiBufferValue(b ...*Buffer) MultiBuffer {

+ 24 - 2
proxy/freedom/freedom.go

@@ -4,6 +4,7 @@ package freedom
 
 import (
 	"context"
+	"io"
 	"runtime"
 	"time"
 
@@ -112,8 +113,13 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
 	ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
 
 	requestDone := signal.ExecuteAsync(func() error {
-		v2writer := buf.NewWriter(conn)
-		if err := buf.PipeUntilEOF(timer, input, v2writer); err != nil {
+		var writer buf.Writer
+		if destination.Network == net.Network_TCP {
+			writer = buf.NewWriter(conn)
+		} else {
+			writer = &seqWriter{writer: conn}
+		}
+		if err := buf.PipeUntilEOF(timer, input, writer); err != nil {
 			return err
 		}
 		return nil
@@ -145,3 +151,19 @@ func init() {
 		return New(ctx, config.(*Config))
 	}))
 }
+
+type seqWriter struct {
+	writer io.Writer
+}
+
+func (w *seqWriter) Write(mb buf.MultiBuffer) error {
+	defer mb.Release()
+
+	for _, b := range mb {
+		if _, err := w.writer.Write(b.Bytes()); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}

+ 1 - 2
proxy/shadowsocks/client.go

@@ -105,8 +105,7 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
 		}
 
 		requestDone := signal.ExecuteAsync(func() error {
-			mergedInput := buf.NewMergingReader(outboundRay.OutboundInput())
-			if err := buf.PipeUntilEOF(timer, mergedInput, bodyWriter); err != nil {
+			if err := buf.PipeUntilEOF(timer, outboundRay.OutboundInput(), bodyWriter); err != nil {
 				return err
 			}
 			return nil

+ 2 - 3
proxy/shadowsocks/server.go

@@ -160,8 +160,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 			return newError("failed to write response").Base(err)
 		}
 
-		mergeReader := buf.NewMergingReader(ray.InboundOutput())
-		payload, err := mergeReader.Read()
+		payload, err := ray.InboundOutput().Read()
 		if err != nil {
 			return err
 		}
@@ -174,7 +173,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 			return err
 		}
 
-		if err := buf.PipeUntilEOF(timer, mergeReader, responseWriter); err != nil {
+		if err := buf.PipeUntilEOF(timer, ray.InboundOutput(), responseWriter); err != nil {
 			return newError("failed to transport all TCP response").Base(err)
 		}
 

+ 2 - 6
proxy/vmess/inbound/inbound.go

@@ -140,12 +140,8 @@ func transferResponse(timer signal.ActivityTimer, session *encoding.ServerSessio
 
 	bodyWriter := session.EncodeResponseBody(request, output)
 
-	var reader buf.Reader = input
-	if request.Command == protocol.RequestCommandTCP {
-		reader = buf.NewMergingReader(input)
-	}
 	// Optimize for small response packet
-	data, err := reader.Read()
+	data, err := input.Read()
 	if err != nil {
 		return err
 	}
@@ -161,7 +157,7 @@ func transferResponse(timer signal.ActivityTimer, session *encoding.ServerSessio
 		}
 	}
 
-	if err := buf.PipeUntilEOF(timer, reader, bodyWriter); err != nil {
+	if err := buf.PipeUntilEOF(timer, input, bodyWriter); err != nil {
 		return err
 	}
 

+ 1 - 6
proxy/vmess/outbound/outbound.go

@@ -123,12 +123,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
 			return err
 		}
 
-		var inputReader buf.Reader = input
-		if request.Command == protocol.RequestCommandTCP {
-			inputReader = buf.NewMergingReader(input)
-		}
-
-		if err := buf.PipeUntilEOF(timer, inputReader, bodyWriter); err != nil {
+		if err := buf.PipeUntilEOF(timer, input, bodyWriter); err != nil {
 			return err
 		}
 

+ 87 - 79
transport/ray/direct.go

@@ -3,15 +3,12 @@ package ray
 import (
 	"context"
 	"io"
+	"sync"
 	"time"
 
 	"v2ray.com/core/common/buf"
 )
 
-const (
-	bufferSize = 512
-)
-
 // NewRay creates a new Ray for direct traffic transport.
 func NewRay(ctx context.Context) Ray {
 	return &directRay{
@@ -42,121 +39,132 @@ func (v *directRay) InboundOutput() InputStream {
 }
 
 type Stream struct {
-	buffer chan buf.MultiBuffer
+	access sync.Mutex
+	data   buf.MultiBuffer
 	ctx    context.Context
-	close  chan bool
-	err    chan bool
+	wakeup chan bool
+	close  bool
+	err    bool
 }
 
 func NewStream(ctx context.Context) *Stream {
 	return &Stream{
 		ctx:    ctx,
-		buffer: make(chan buf.MultiBuffer, bufferSize),
-		close:  make(chan bool),
-		err:    make(chan bool),
+		wakeup: make(chan bool, 1),
 	}
 }
 
-func (v *Stream) Read() (buf.MultiBuffer, error) {
-	select {
-	case <-v.ctx.Done():
-		return nil, io.ErrClosedPipe
-	case <-v.err:
+func (s *Stream) getData() (buf.MultiBuffer, error) {
+	s.access.Lock()
+	defer s.access.Unlock()
+
+	if s.data != nil {
+		mb := s.data
+		s.data = nil
+		return mb, nil
+	}
+
+	if s.close {
+		return nil, io.EOF
+	}
+
+	if s.err {
 		return nil, io.ErrClosedPipe
-	case b := <-v.buffer:
-		return b, nil
-	default:
+	}
+
+	return nil, nil
+}
+
+func (s *Stream) Read() (buf.MultiBuffer, error) {
+	for {
+		mb, err := s.getData()
+		if err != nil {
+			return nil, err
+		}
+
+		if mb != nil {
+			return mb, nil
+		}
+
 		select {
-		case <-v.ctx.Done():
-			return nil, io.ErrClosedPipe
-		case b := <-v.buffer:
-			return b, nil
-		case <-v.close:
-			return nil, io.EOF
-		case <-v.err:
+		case <-s.ctx.Done():
 			return nil, io.ErrClosedPipe
+		case <-s.wakeup:
 		}
 	}
 }
 
-func (v *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
-	select {
-	case <-v.ctx.Done():
-		return nil, io.ErrClosedPipe
-	case <-v.err:
-		return nil, io.ErrClosedPipe
-	case b := <-v.buffer:
-		return b, nil
-	default:
-		if timeout == 0 {
-			return nil, buf.ErrReadTimeout
+func (s *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
+	for {
+		mb, err := s.getData()
+		if err != nil {
+			return nil, err
+		}
+
+		if mb != nil {
+			return mb, nil
 		}
 
 		select {
-		case <-v.ctx.Done():
-			return nil, io.ErrClosedPipe
-		case b := <-v.buffer:
-			return b, nil
-		case <-v.close:
-			return nil, io.EOF
-		case <-v.err:
+		case <-s.ctx.Done():
 			return nil, io.ErrClosedPipe
 		case <-time.After(timeout):
 			return nil, buf.ErrReadTimeout
+		case <-s.wakeup:
 		}
 	}
 }
 
-func (v *Stream) Write(data buf.MultiBuffer) (err error) {
+func (s *Stream) Write(data buf.MultiBuffer) (err error) {
 	if data.IsEmpty() {
 		return
 	}
 
-	select {
-	case <-v.ctx.Done():
-		return io.ErrClosedPipe
-	case <-v.err:
+	s.access.Lock()
+	defer s.access.Unlock()
+
+	if s.err {
+		data.Release()
 		return io.ErrClosedPipe
-	case <-v.close:
+	}
+	if s.close {
+		data.Release()
 		return io.ErrClosedPipe
-	default:
-		select {
-		case <-v.ctx.Done():
-			return io.ErrClosedPipe
-		case <-v.err:
-			return io.ErrClosedPipe
-		case <-v.close:
-			return io.ErrClosedPipe
-		case v.buffer <- data:
-			return nil
-		}
 	}
-}
 
-func (v *Stream) Close() {
-	defer swallowPanic()
+	if s.data == nil {
+		s.data = data
+	} else {
+		s.data.AppendMulti(data)
+	}
+	s.wakeUp()
 
-	close(v.close)
+	return nil
 }
 
-func (v *Stream) CloseError() {
-	defer swallowPanic()
+func (s *Stream) wakeUp() {
+	select {
+	case s.wakeup <- true:
+	default:
+	}
+}
 
-	close(v.err)
+func (s *Stream) Close() {
+	s.access.Lock()
+	s.close = true
+	s.wakeUp()
+	s.access.Unlock()
+}
 
-	n := len(v.buffer)
-	for i := 0; i < n; i++ {
-		select {
-		case b := <-v.buffer:
-			b.Release()
-		default:
-			return
-		}
+func (s *Stream) CloseError() {
+	s.access.Lock()
+	s.err = true
+	if s.data != nil {
+		s.data.Release()
+		s.data = nil
 	}
+	s.wakeUp()
+	s.access.Unlock()
 }
 
 func (v *Stream) Release() {}
-
-func swallowPanic() {
-	recover()
-}