Ver Fonte

fix data input and output signal

Darien Raymond há 9 anos atrás
pai
commit
41fcffbfab
1 ficheiros alterados com 62 adições e 51 exclusões
  1. 62 51
      transport/internet/kcp/connection.go

+ 62 - 51
transport/internet/kcp/connection.go

@@ -168,15 +168,15 @@ type SystemConnection interface {
 
 // Connection is a KCP connection over UDP.
 type Connection struct {
-	conn           SystemConnection
-	connRecycler   internal.ConnectionRecyler
-	block          internet.Authenticator
-	rd             time.Time
-	wd             time.Time // write deadline
-	since          int64
-	dataInputCond  *sync.Cond
-	dataOutputCond *sync.Cond
-	Config         *Config
+	conn         SystemConnection
+	connRecycler internal.ConnectionRecyler
+	block        internet.Authenticator
+	rd           time.Time
+	wd           time.Time // write deadline
+	since        int64
+	dataInput    chan bool
+	dataOutput   chan bool
+	Config       *Config
 
 	conv             uint16
 	state            State
@@ -203,15 +203,15 @@ func NewConnection(conv uint16, sysConn SystemConnection, recycler internal.Conn
 	log.Info("KCP|Connection: creating connection ", conv)
 
 	conn := &Connection{
-		conv:           conv,
-		conn:           sysConn,
-		connRecycler:   recycler,
-		since:          nowMillisec(),
-		dataInputCond:  sync.NewCond(new(sync.Mutex)),
-		dataOutputCond: sync.NewCond(new(sync.Mutex)),
-		Config:         config,
-		output:         NewSegmentWriter(sysConn, config.GetMtu().GetValue()-uint32(sysConn.Overhead())),
-		mss:            config.GetMtu().GetValue() - uint32(sysConn.Overhead()) - DataSegmentOverhead,
+		conv:         conv,
+		conn:         sysConn,
+		connRecycler: recycler,
+		since:        nowMillisec(),
+		dataInput:    make(chan bool, 1),
+		dataOutput:   make(chan bool, 1),
+		Config:       config,
+		output:       NewSegmentWriter(sysConn, config.GetMtu().GetValue()-uint32(sysConn.Overhead())),
+		mss:          config.GetMtu().GetValue() - uint32(sysConn.Overhead()) - DataSegmentOverhead,
 		roundTrip: &RoundTripInfo{
 			rto:    100,
 			minRtt: config.Tti.GetValue(),
@@ -247,6 +247,20 @@ func (v *Connection) Elapsed() uint32 {
 	return uint32(nowMillisec() - v.since)
 }
 
+func (v *Connection) OnDataInput() {
+	select {
+	case v.dataInput <- true:
+	default:
+	}
+}
+
+func (v *Connection) OnDataOutput() {
+	select {
+	case v.dataOutput <- true:
+	default:
+	}
+}
+
 // Read implements the Conn Read method.
 func (v *Connection) Read(b []byte) (int, error) {
 	if v == nil {
@@ -266,22 +280,20 @@ func (v *Connection) Read(b []byte) (int, error) {
 			return 0, io.EOF
 		}
 
-		var timer *time.Timer
+		duration := time.Duration(time.Minute)
 		if !v.rd.IsZero() {
-			duration := v.rd.Sub(time.Now())
-			if duration <= 0 {
+			duration = v.rd.Sub(time.Now())
+			if duration < 0 {
 				return 0, ErrIOTimeout
 			}
-			timer = time.AfterFunc(duration, v.dataInputCond.Signal)
-		}
-		v.dataInputCond.L.Lock()
-		v.dataInputCond.Wait()
-		v.dataInputCond.L.Unlock()
-		if timer != nil {
-			timer.Stop()
 		}
-		if !v.rd.IsZero() && v.rd.Before(time.Now()) {
-			return 0, ErrIOTimeout
+
+		select {
+		case <-v.dataInput:
+		case <-time.After(duration):
+			if !v.rd.IsZero() && v.rd.Before(time.Now()) {
+				return 0, ErrIOTimeout
+			}
 		}
 	}
 }
@@ -304,24 +316,20 @@ func (v *Connection) Write(b []byte) (int, error) {
 			}
 		}
 
-		var timer *time.Timer
-		if !v.wd.IsZero() {
-			duration := v.wd.Sub(time.Now())
-			if duration <= 0 {
+		duration := time.Duration(time.Minute)
+		if !v.rd.IsZero() {
+			duration = v.wd.Sub(time.Now())
+			if duration < 0 {
 				return totalWritten, ErrIOTimeout
 			}
-			timer = time.AfterFunc(duration, v.dataOutputCond.Signal)
-		}
-		v.dataOutputCond.L.Lock()
-		v.dataOutputCond.Wait()
-		v.dataOutputCond.L.Unlock()
-
-		if timer != nil {
-			timer.Stop()
 		}
 
-		if !v.wd.IsZero() && v.wd.Before(time.Now()) {
-			return totalWritten, ErrIOTimeout
+		select {
+		case <-v.dataOutput:
+		case <-time.After(duration):
+			if !v.wd.IsZero() && v.wd.Before(time.Now()) {
+				return totalWritten, ErrIOTimeout
+			}
 		}
 	}
 }
@@ -360,8 +368,8 @@ func (v *Connection) Close() error {
 		return ErrClosedConnection
 	}
 
-	v.dataInputCond.Broadcast()
-	v.dataOutputCond.Broadcast()
+	v.OnDataInput()
+	v.OnDataOutput()
 
 	state := v.State()
 	if state.Is(StateReadyToClose, StateTerminating, StateTerminated) {
@@ -447,8 +455,9 @@ func (v *Connection) Terminate() {
 	log.Info("KCP|Connection: Terminating connection to ", v.RemoteAddr())
 
 	//v.SetState(StateTerminated)
-	v.dataInputCond.Broadcast()
-	v.dataOutputCond.Broadcast()
+	v.OnDataInput()
+	v.OnDataOutput()
+
 	if v.Config.ConnectionReuse.IsEnabled() && v.reusable {
 		v.connRecycler.Put(v.conn.Id(), v.conn)
 	} else {
@@ -481,19 +490,21 @@ func (v *Connection) Input(segments []Segment) {
 
 	for _, seg := range segments {
 		if seg.Conversation() != v.conv {
-			return
+			break
 		}
 
 		switch seg := seg.(type) {
 		case *DataSegment:
 			v.HandleOption(seg.Option)
 			v.receivingWorker.ProcessSegment(seg)
-			v.dataInputCond.Signal()
+			if seg.Number == v.receivingWorker.nextNumber {
+				v.OnDataInput()
+			}
 			v.dataUpdater.WakeUp()
 		case *AckSegment:
 			v.HandleOption(seg.Option)
 			v.sendingWorker.ProcessSegment(current, seg, v.roundTrip.Timeout())
-			v.dataOutputCond.Signal()
+			v.OnDataOutput()
 			v.dataUpdater.WakeUp()
 		case *CmdOnlySegment:
 			v.HandleOption(seg.Option)