Selaa lähdekoodia

stop data updating thread when there is no data

Darien Raymond 9 vuotta sitten
vanhempi
commit
e023859ef0

+ 25 - 0
common/predicate/predicate.go

@@ -0,0 +1,25 @@
+package predicate
+
+type Predicate func() bool
+
+func All(predicates ...Predicate) Predicate {
+	return func() bool {
+		for _, p := range predicates {
+			if !p() {
+				return false
+			}
+		}
+		return true
+	}
+}
+
+func Any(predicates ...Predicate) Predicate {
+	return func() bool {
+		for _, p := range predicates {
+			if p() {
+				return true
+			}
+		}
+		return false
+	}
+}

+ 67 - 12
transport/internet/kcp/connection.go

@@ -10,6 +10,7 @@ import (
 
 	"v2ray.com/core/common/alloc"
 	"v2ray.com/core/common/log"
+	"v2ray.com/core/common/predicate"
 	"v2ray.com/core/transport/internet"
 )
 
@@ -119,6 +120,45 @@ func (this *RoundTripInfo) SmoothedTime() uint32 {
 	return this.srtt
 }
 
+type Updater struct {
+	interval        time.Duration
+	shouldContinue  predicate.Predicate
+	shouldTerminate predicate.Predicate
+	updateFunc      func()
+	notifier        chan bool
+}
+
+func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater {
+	u := &Updater{
+		interval:        time.Duration(interval) * time.Millisecond,
+		shouldContinue:  shouldContinue,
+		shouldTerminate: shouldTerminate,
+		updateFunc:      updateFunc,
+		notifier:        make(chan bool, 1),
+	}
+	go u.Run()
+	return u
+}
+
+func (this *Updater) WakeUp() {
+	select {
+	case this.notifier <- true:
+	default:
+	}
+}
+
+func (this *Updater) Run() {
+	for <-this.notifier {
+		if this.shouldTerminate() {
+			return
+		}
+		for this.shouldContinue() {
+			this.updateFunc()
+			time.Sleep(this.interval)
+		}
+	}
+}
+
 // Connection is a KCP connection over UDP.
 type Connection struct {
 	block          internet.Authenticator
@@ -147,6 +187,9 @@ type Connection struct {
 	fastresend        uint32
 	congestionControl bool
 	output            *BufferedSegmentWriter
+
+	dataUpdater *Updater
+	pingUpdater *Updater
 }
 
 // NewConnection create a new KCP connection between local and remote.
@@ -182,7 +225,18 @@ func NewConnection(conv uint16, writerCloser io.WriteCloser, local *net.UDPAddr,
 	conn.congestionControl = config.Congestion
 	conn.sendingWorker = NewSendingWorker(conn)
 
-	go conn.updateTask()
+	conn.dataUpdater = NewUpdater(
+		conn.interval,
+		predicate.Any(conn.sendingWorker.UpdateNecessary, conn.receivingWorker.UpdateNecessary),
+		func() bool {
+			return conn.State() == StateTerminated
+		},
+		conn.updateTask)
+	conn.pingUpdater = NewUpdater(
+		3000, // 3 seconds
+		func() bool { return conn.State() != StateTerminated },
+		func() bool { return conn.State() == StateTerminated },
+		conn.updateTask)
 
 	return conn
 }
@@ -240,6 +294,7 @@ func (this *Connection) Write(b []byte) (int, error) {
 		}
 
 		nBytes := this.sendingWorker.Push(b[totalWritten:])
+		this.dataUpdater.WakeUp()
 		if nBytes > 0 {
 			totalWritten += nBytes
 			if totalWritten == len(b) {
@@ -278,16 +333,24 @@ func (this *Connection) SetState(state State) {
 	switch state {
 	case StateReadyToClose:
 		this.receivingWorker.CloseRead()
+		this.dataUpdater.WakeUp()
 	case StatePeerClosed:
 		this.sendingWorker.CloseWrite()
+		this.dataUpdater.WakeUp()
 	case StateTerminating:
 		this.receivingWorker.CloseRead()
 		this.sendingWorker.CloseWrite()
+		this.dataUpdater.interval = time.Second
+		this.dataUpdater.WakeUp()
 	case StatePeerTerminating:
 		this.sendingWorker.CloseWrite()
+		this.dataUpdater.WakeUp()
 	case StateTerminated:
 		this.receivingWorker.CloseRead()
 		this.sendingWorker.CloseWrite()
+		this.dataUpdater.interval = time.Second
+		this.dataUpdater.WakeUp()
+		this.Terminate()
 	}
 }
 
@@ -366,16 +429,7 @@ func (this *Connection) SetWriteDeadline(t time.Time) error {
 
 // kcp update, input loop
 func (this *Connection) updateTask() {
-	for this.State() != StateTerminated {
-		this.flush()
-
-		interval := time.Duration(this.Config.Tti.GetValue()) * time.Millisecond
-		if this.State() == StateTerminating {
-			interval = time.Second
-		}
-		time.Sleep(interval)
-	}
-	this.Terminate()
+	this.flush()
 }
 
 func (this *Connection) FetchInputFrom(conn io.Reader) {
@@ -408,7 +462,7 @@ func (this *Connection) Terminate() {
 	}
 	log.Info("KCP|Connection: Terminating connection to ", this.RemoteAddr())
 
-	this.SetState(StateTerminated)
+	//this.SetState(StateTerminated)
 	this.dataInputCond.Broadcast()
 	this.dataOutputCond.Broadcast()
 	this.writer.Close()
@@ -434,6 +488,7 @@ func (this *Connection) OnPeerClosed() {
 func (this *Connection) Input(data []byte) int {
 	current := this.Elapsed()
 	atomic.StoreUint32(&this.lastIncomingTime, current)
+	this.dataUpdater.WakeUp()
 
 	var seg Segment
 	for {

+ 2 - 0
transport/internet/kcp/connection_test.go

@@ -34,6 +34,8 @@ func TestConnectionReadTimeout(t *testing.T) {
 	nBytes, err := conn.Read(b)
 	assert.Int(nBytes).Equals(0)
 	assert.Error(err).IsNotNil()
+
+	conn.Terminate()
 }
 
 func TestConnectionReadWrite(t *testing.T) {

+ 4 - 0
transport/internet/kcp/receiving.go

@@ -213,3 +213,7 @@ func (this *ReceivingWorker) Write(seg Segment) {
 
 func (this *ReceivingWorker) CloseRead() {
 }
+
+func (this *ReceivingWorker) UpdateNecessary() bool {
+	return len(this.acklist.numbers) > 0
+}

+ 4 - 0
transport/internet/kcp/sending.go

@@ -338,3 +338,7 @@ func (this *SendingWorker) IsEmpty() bool {
 
 	return this.window.IsEmpty()
 }
+
+func (this *SendingWorker) UpdateNecessary() bool {
+	return !this.IsEmpty()
+}