Explorar o código

fix ray stream

Darien Raymond %!s(int64=8) %!d(string=hai) anos
pai
achega
06c92e492d
Modificáronse 2 ficheiros con 63 adicións e 57 borrados
  1. 21 57
      transport/ray/direct.go
  2. 42 0
      transport/ray/direct_test.go

+ 21 - 57
transport/ray/direct.go

@@ -2,8 +2,6 @@ package ray
 
 import (
 	"io"
-	"sync"
-	"time"
 
 	"v2ray.com/core/common/buf"
 )
@@ -42,8 +40,6 @@ func (v *directRay) InboundOutput() InputStream {
 }
 
 type Stream struct {
-	access sync.RWMutex
-	closed bool
 	buffer chan *buf.Buffer
 }
 
@@ -54,72 +50,40 @@ func NewStream() *Stream {
 }
 
 func (v *Stream) Read() (*buf.Buffer, error) {
-	if v.buffer == nil {
-		return nil, io.EOF
-	}
-	v.access.RLock()
-	if v.buffer == nil {
-		v.access.RUnlock()
-		return nil, io.EOF
-	}
-	channel := v.buffer
-	v.access.RUnlock()
-	result, open := <-channel
+	buffer, open := <-v.buffer
 	if !open {
 		return nil, io.EOF
 	}
-	return result, nil
+	return buffer, nil
 }
 
-func (v *Stream) Write(data *buf.Buffer) error {
-	for !v.closed {
-		err := v.TryWriteOnce(data)
-		if err != io.ErrNoProgress {
-			return err
+func (v *Stream) Write(data *buf.Buffer) (err error) {
+	defer func() {
+		if r := recover(); r != nil {
+			err = io.ErrClosedPipe
 		}
-	}
-	return io.ErrClosedPipe
-}
+	}()
 
-func (v *Stream) TryWriteOnce(data *buf.Buffer) error {
-	v.access.RLock()
-	defer v.access.RUnlock()
-	if v.closed {
-		return io.ErrClosedPipe
-	}
-	select {
-	case v.buffer <- data:
-		return nil
-	case <-time.After(2 * time.Second):
-		return io.ErrNoProgress
-	}
+	v.buffer <- data
+	return nil
 }
 
 func (v *Stream) Close() {
-	if v.closed {
-		return
-	}
-	v.access.Lock()
-	defer v.access.Unlock()
-	if v.closed {
-		return
-	}
-	v.closed = true
+	defer swallowPanic()
+
 	close(v.buffer)
 }
 
 func (v *Stream) Release() {
-	if v.buffer == nil {
-		return
-	}
-	v.Close()
-	v.access.Lock()
-	defer v.access.Unlock()
-	if v.buffer == nil {
-		return
-	}
-	for data := range v.buffer {
-		data.Release()
+	defer swallowPanic()
+
+	close(v.buffer)
+
+	for b := range v.buffer {
+		b.Release()
 	}
-	v.buffer = nil
+}
+
+func swallowPanic() {
+	recover()
 }

+ 42 - 0
transport/ray/direct_test.go

@@ -0,0 +1,42 @@
+package ray_test
+
+import (
+	"io"
+	"testing"
+
+	"v2ray.com/core/common/buf"
+	"v2ray.com/core/testing/assert"
+	. "v2ray.com/core/transport/ray"
+)
+
+func TestStreamIO(t *testing.T) {
+	assert := assert.On(t)
+
+	stream := NewStream()
+	assert.Error(stream.Write(buf.New())).IsNil()
+
+	_, err := stream.Read()
+	assert.Error(err).IsNil()
+
+	stream.Close()
+	_, err = stream.Read()
+	assert.Error(err).Equals(io.EOF)
+
+	err = stream.Write(buf.New())
+	assert.Error(err).Equals(io.ErrClosedPipe)
+}
+
+func TestStreamClose(t *testing.T) {
+	assert := assert.On(t)
+
+	stream := NewStream()
+	assert.Error(stream.Write(buf.New())).IsNil()
+
+	stream.Close()
+
+	_, err := stream.Read()
+	assert.Error(err).IsNil()
+
+	_, err = stream.Read()
+	assert.Error(err).Equals(io.EOF)
+}