Darien Raymond 7 years ago
parent
commit
5f93eee8b0
3 changed files with 45 additions and 8 deletions
  1. 14 6
      transport/pipe/impl.go
  2. 2 0
      transport/pipe/pipe.go
  3. 29 2
      transport/pipe/pipe_test.go

+ 14 - 6
transport/pipe/impl.go

@@ -8,6 +8,7 @@ import (
 
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/signal/done"
 )
 
 type state byte
@@ -23,6 +24,7 @@ type pipe struct {
 	data        buf.MultiBuffer
 	readSignal  *signal.Notifier
 	writeSignal *signal.Notifier
+	done        *done.Instance
 	limit       int32
 	state       state
 }
@@ -72,7 +74,10 @@ func (p *pipe) ReadMultiBuffer() (buf.MultiBuffer, error) {
 			return data, err
 		}
 
-		<-p.readSignal.Wait()
+		select {
+		case <-p.readSignal.Wait():
+		case <-p.done.Wait():
+		}
 	}
 }
 
@@ -87,6 +92,7 @@ func (p *pipe) ReadMultiBufferWithTimeout(d time.Duration) (buf.MultiBuffer, err
 
 		select {
 		case <-p.readSignal.Wait():
+		case <-p.done.Wait():
 		case <-timer:
 			return nil, buf.ErrReadTimeout
 		}
@@ -117,7 +123,11 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
 			return err
 		}
 
-		<-p.writeSignal.Wait()
+		select {
+		case <-p.writeSignal.Wait():
+		case <-p.done.Wait():
+			return io.ErrClosedPipe
+		}
 	}
 }
 
@@ -130,8 +140,7 @@ func (p *pipe) Close() error {
 	}
 
 	p.state = closed
-	p.readSignal.Signal()
-	p.writeSignal.Signal()
+	p.done.Close()
 	return nil
 }
 
@@ -150,6 +159,5 @@ func (p *pipe) CloseError() {
 		p.data = nil
 	}
 
-	p.readSignal.Signal()
-	p.writeSignal.Signal()
+	p.done.Close()
 }

+ 2 - 0
transport/pipe/pipe.go

@@ -5,6 +5,7 @@ import (
 
 	"v2ray.com/core"
 	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/signal/done"
 )
 
 // Option for creating new Pipes.
@@ -41,6 +42,7 @@ func New(opts ...Option) (*Reader, *Writer) {
 		limit:       -1,
 		readSignal:  signal.NewNotifier(),
 		writeSignal: signal.NewNotifier(),
+		done:        done.New(),
 	}
 
 	for _, opt := range opts {

+ 29 - 2
transport/pipe/pipe_test.go

@@ -2,12 +2,12 @@ package pipe_test
 
 import (
 	"io"
+	"sync"
 	"testing"
 	"time"
 
-	"v2ray.com/core/common/task"
-
 	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/task"
 	. "v2ray.com/core/transport/pipe"
 	. "v2ray.com/ext/assert"
 )
@@ -91,3 +91,30 @@ func TestPipeLimitZero(t *testing.T) {
 
 	assert(err, IsNil)
 }
+
+func TestPipeWriteMultiThread(t *testing.T) {
+	assert := With(t)
+
+	pReader, pWriter := New(WithSizeLimit(0))
+
+	var wg sync.WaitGroup
+
+	for i := 0; i < 10; i++ {
+		wg.Add(1)
+		go func() {
+			b := buf.New()
+			b.AppendBytes('a', 'b', 'c', 'd')
+			pWriter.WriteMultiBuffer(buf.NewMultiBufferValue(b))
+			wg.Done()
+		}()
+	}
+
+	time.Sleep(time.Millisecond * 100)
+
+	pWriter.Close()
+	wg.Wait()
+
+	b, err := pReader.ReadMultiBuffer()
+	assert(err, IsNil)
+	assert(b[0].Bytes(), Equals, []byte{'a', 'b', 'c', 'd'})
+}