Browse Source

only yield goroutine after second write

Darien Raymond 7 years ago
parent
commit
120058310a
2 changed files with 23 additions and 10 deletions
  1. 16 7
      transport/pipe/impl.go
  2. 7 3
      transport/pipe/pipe_test.go

+ 16 - 7
transport/pipe/impl.go

@@ -41,6 +41,7 @@ type pipe struct {
 }
 
 var errBufferFull = errors.New("buffer full")
+var errSlowDown = errors.New("slow down")
 
 func (p *pipe) getState(forRead bool) error {
 	switch p.state {
@@ -122,11 +123,11 @@ func (p *pipe) writeMultiBufferInternal(mb buf.MultiBuffer) error {
 
 	if p.data == nil {
 		p.data = mb
-	} else {
-		p.data, _ = buf.MergeMulti(p.data, mb)
+		return nil
 	}
 
-	return nil
+	p.data, _ = buf.MergeMulti(p.data, mb)
+	return errSlowDown
 }
 
 func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
@@ -136,17 +137,25 @@ func (p *pipe) WriteMultiBuffer(mb buf.MultiBuffer) error {
 
 	for {
 		err := p.writeMultiBufferInternal(mb)
-		switch {
-		case err == nil:
+		if err == nil {
+			p.readSignal.Signal()
+			return nil
+		}
+
+		if err == errSlowDown {
 			p.readSignal.Signal()
 
 			// Yield current goroutine. Hopefully the reading counterpart can pick up the payload.
 			runtime.Gosched()
 			return nil
-		case err == errBufferFull && p.option.discardOverflow:
+		}
+
+		if err == errBufferFull && p.option.discardOverflow {
 			buf.ReleaseMulti(mb)
 			return nil
-		case err != errBufferFull:
+		}
+
+		if err != errBufferFull {
 			buf.ReleaseMulti(mb)
 			p.readSignal.Signal()
 			return err

+ 7 - 3
transport/pipe/pipe_test.go

@@ -17,14 +17,18 @@ func TestPipeReadWrite(t *testing.T) {
 	assert := With(t)
 
 	pReader, pWriter := New(WithSizeLimit(1024))
-	payload := []byte{'a', 'b', 'c', 'd'}
+
 	b := buf.New()
-	b.Write(payload)
+	b.WriteString("abcd")
 	assert(pWriter.WriteMultiBuffer(buf.MultiBuffer{b}), IsNil)
 
+	b2 := buf.New()
+	b2.WriteString("efg")
+	assert(pWriter.WriteMultiBuffer(buf.MultiBuffer{b2}), IsNil)
+
 	rb, err := pReader.ReadMultiBuffer()
 	assert(err, IsNil)
-	assert(rb.String(), Equals, b.String())
+	assert(rb.String(), Equals, "abcdefg")
 }
 
 func TestPipeCloseError(t *testing.T) {