瀏覽代碼

split read and write signal

Darien Raymond 8 年之前
父節點
當前提交
ad35fc7028
共有 1 個文件被更改,包括 39 次插入22 次删除
  1. 39 22
      transport/ray/direct.go

+ 39 - 22
transport/ray/direct.go

@@ -54,20 +54,22 @@ func init() {
 }
 
 type Stream struct {
-	access sync.RWMutex
-	data   buf.MultiBuffer
-	size   uint64
-	ctx    context.Context
-	wakeup chan bool
-	close  bool
-	err    bool
+	access      sync.RWMutex
+	data        buf.MultiBuffer
+	size        uint64
+	ctx         context.Context
+	readSignal  chan bool
+	writeSignal chan bool
+	close       bool
+	err         bool
 }
 
 func NewStream(ctx context.Context) *Stream {
 	return &Stream{
-		ctx:    ctx,
-		wakeup: make(chan bool, 1),
-		size:   0,
+		ctx:         ctx,
+		readSignal:  make(chan bool, 1),
+		writeSignal: make(chan bool, 1),
+		size:        0,
 	}
 }
 
@@ -110,14 +112,14 @@ func (s *Stream) Read() (buf.MultiBuffer, error) {
 		}
 
 		if mb != nil {
-			s.wakeUp()
+			s.notifyRead()
 			return mb, nil
 		}
 
 		select {
 		case <-s.ctx.Done():
 			return nil, io.EOF
-		case <-s.wakeup:
+		case <-s.writeSignal:
 		}
 	}
 }
@@ -130,7 +132,7 @@ func (s *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
 		}
 
 		if mb != nil {
-			s.wakeUp()
+			s.notifyRead()
 			return mb, nil
 		}
 
@@ -139,7 +141,7 @@ func (s *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
 			return nil, io.EOF
 		case <-time.After(timeout):
 			return nil, buf.ErrReadTimeout
-		case <-s.wakeup:
+		case <-s.writeSignal:
 		}
 	}
 }
@@ -149,13 +151,18 @@ func (s *Stream) Write(data buf.MultiBuffer) error {
 		return nil
 	}
 
-L:
 	for streamSizeLimit > 0 && s.size >= streamSizeLimit {
 		select {
 		case <-s.ctx.Done():
 			return io.ErrClosedPipe
-		case <-s.wakeup:
-			break L
+		case <-s.readSignal:
+			s.access.RLock()
+			if s.err || s.close {
+				data.Release()
+				s.access.RUnlock()
+				return io.ErrClosedPipe
+			}
+			s.access.RUnlock()
 		}
 	}
 
@@ -173,14 +180,21 @@ L:
 		s.data.AppendMulti(data)
 	}
 	s.size += uint64(data.Len())
-	s.wakeUp()
+	s.notifyWrite()
 
 	return nil
 }
 
-func (s *Stream) wakeUp() {
+func (s *Stream) notifyRead() {
 	select {
-	case s.wakeup <- true:
+	case s.readSignal <- true:
+	default:
+	}
+}
+
+func (s *Stream) notifyWrite() {
+	select {
+	case s.writeSignal <- true:
 	default:
 	}
 }
@@ -188,7 +202,8 @@ func (s *Stream) wakeUp() {
 func (s *Stream) Close() {
 	s.access.Lock()
 	s.close = true
-	s.wakeUp()
+	s.notifyRead()
+	s.notifyWrite()
 	s.access.Unlock()
 }
 
@@ -198,7 +213,9 @@ func (s *Stream) CloseError() {
 	if s.data != nil {
 		s.data.Release()
 		s.data = nil
+		s.size = 0
 	}
-	s.wakeUp()
+	s.notifyRead()
+	s.notifyWrite()
 	s.access.Unlock()
 }