Переглянути джерело

fix race conditions in kcp

Darien Raymond 8 роки тому
батько
коміт
ebed271a92

+ 17 - 9
transport/internet/kcp/connection.go

@@ -116,7 +116,7 @@ func (v *RoundTripInfo) SmoothedTime() uint32 {
 }
 
 type Updater struct {
-	interval        time.Duration
+	interval        int64
 	shouldContinue  predicate.Predicate
 	shouldTerminate predicate.Predicate
 	updateFunc      func()
@@ -125,7 +125,7 @@ type Updater struct {
 
 func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater {
 	u := &Updater{
-		interval:        time.Duration(interval) * time.Millisecond,
+		interval:        int64(time.Duration(interval) * time.Millisecond),
 		shouldContinue:  shouldContinue,
 		shouldTerminate: shouldTerminate,
 		updateFunc:      updateFunc,
@@ -149,11 +149,19 @@ func (v *Updater) Run() {
 		}
 		for v.shouldContinue() {
 			v.updateFunc()
-			time.Sleep(v.interval)
+			time.Sleep(v.Interval())
 		}
 	}
 }
 
+func (u *Updater) Interval() time.Duration {
+	return time.Duration(atomic.LoadInt64(&u.interval))
+}
+
+func (u *Updater) SetInterval(d time.Duration) {
+	atomic.StoreInt64(&u.interval, int64(d))
+}
+
 type SystemConnection interface {
 	net.Conn
 	Id() internal.ConnectionID
@@ -342,14 +350,14 @@ func (v *Connection) SetState(state State) {
 	case StateTerminating:
 		v.receivingWorker.CloseRead()
 		v.sendingWorker.CloseWrite()
-		v.pingUpdater.interval = time.Second
+		v.pingUpdater.SetInterval(time.Second)
 	case StatePeerTerminating:
 		v.sendingWorker.CloseWrite()
-		v.pingUpdater.interval = time.Second
+		v.pingUpdater.SetInterval(time.Second)
 	case StateTerminated:
 		v.receivingWorker.CloseRead()
 		v.sendingWorker.CloseWrite()
-		v.pingUpdater.interval = time.Second
+		v.pingUpdater.SetInterval(time.Second)
 		v.dataUpdater.WakeUp()
 		v.pingUpdater.WakeUp()
 		go v.Terminate()
@@ -491,7 +499,7 @@ func (v *Connection) Input(segments []Segment) {
 		case *DataSegment:
 			v.HandleOption(seg.Option)
 			v.receivingWorker.ProcessSegment(seg)
-			if seg.Number == v.receivingWorker.nextNumber {
+			if v.receivingWorker.IsDataAvailable() {
 				v.OnDataInput()
 			}
 			v.dataUpdater.WakeUp()
@@ -573,8 +581,8 @@ func (v *Connection) Ping(current uint32, cmd Command) {
 	seg := NewCmdOnlySegment()
 	seg.Conv = v.conv
 	seg.Cmd = cmd
-	seg.ReceivinNext = v.receivingWorker.nextNumber
-	seg.SendingNext = v.sendingWorker.firstUnacknowledged
+	seg.ReceivinNext = v.receivingWorker.NextNumber()
+	seg.SendingNext = v.sendingWorker.FirstUnacknowledged()
 	seg.PeerRTO = v.roundTrip.Timeout()
 	if v.State() == StateReadyToClose {
 		seg.Option = SegmentOptionClose

+ 21 - 16
transport/internet/kcp/listener.go

@@ -79,7 +79,7 @@ func (o *ServerConnection) Id() internal.ConnectionID {
 // Listener defines a server listening for connections
 type Listener struct {
 	sync.Mutex
-	running       bool
+	closed        chan bool
 	sessions      map[ConnectionID]*Connection
 	awaitingConns chan *Connection
 	hub           *udp.Hub
@@ -116,7 +116,7 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen
 		},
 		sessions:      make(map[ConnectionID]*Connection),
 		awaitingConns: make(chan *Connection, 64),
-		running:       true,
+		closed:        make(chan bool),
 		config:        kcpSettings,
 	}
 	if options.Stream != nil && options.Stream.HasSecuritySettings() {
@@ -134,7 +134,9 @@ func NewListener(address v2net.Address, port v2net.Port, options internet.Listen
 	if err != nil {
 		return nil, err
 	}
+	l.Lock()
 	l.hub = hub
+	l.Unlock()
 	log.Info("KCP|Listener: listening on ", address, ":", port)
 	return l, nil
 }
@@ -148,12 +150,15 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina
 		return
 	}
 
-	if !v.running {
+	select {
+	case <-v.closed:
 		return
+	default:
 	}
+
 	v.Lock()
 	defer v.Unlock()
-	if !v.running {
+	if v.hub == nil {
 		return
 	}
 	if payload.Len() < 4 {
@@ -208,24 +213,22 @@ func (v *Listener) OnReceive(payload *buf.Buffer, src v2net.Destination, origina
 }
 
 func (v *Listener) Remove(id ConnectionID) {
-	if !v.running {
-		return
-	}
-	v.Lock()
-	defer v.Unlock()
-	if !v.running {
+	select {
+	case <-v.closed:
 		return
+	default:
+		v.Lock()
+		delete(v.sessions, id)
+		v.Unlock()
 	}
-	delete(v.sessions, id)
 }
 
 // Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn.
 func (v *Listener) Accept() (internet.Connection, error) {
 	for {
-		if !v.running {
-			return nil, ErrClosedListener
-		}
 		select {
+		case <-v.closed:
+			return nil, ErrClosedListener
 		case conn, open := <-v.awaitingConns:
 			if !open {
 				break
@@ -243,13 +246,15 @@ func (v *Listener) Accept() (internet.Connection, error) {
 
 // Close stops listening on the UDP address. Already Accepted connections are not closed.
 func (v *Listener) Close() error {
-	if !v.running {
+	select {
+	case <-v.closed:
 		return ErrClosedListener
+	default:
 	}
 	v.Lock()
 	defer v.Unlock()
 
-	v.running = false
+	close(v.closed)
 	close(v.awaitingConns)
 	for _, conn := range v.sessions {
 		go conn.Terminate()

+ 22 - 0
transport/internet/kcp/receiving.go

@@ -48,6 +48,10 @@ func (v *ReceivingWindow) RemoveFirst() *DataSegment {
 	return v.Remove(0)
 }
 
+func (w *ReceivingWindow) HasFirst() bool {
+	return w.list[w.Position(0)] != nil
+}
+
 func (v *ReceivingWindow) Advance() {
 	v.start++
 	if v.start == v.size {
@@ -163,7 +167,9 @@ func NewReceivingWorker(kcp *Connection) *ReceivingWorker {
 }
 
 func (v *ReceivingWorker) Release() {
+	v.Lock()
 	v.leftOver.Release()
+	v.Unlock()
 }
 
 func (v *ReceivingWorker) ProcessSendingNext(number uint32) {
@@ -228,6 +234,19 @@ func (v *ReceivingWorker) Read(b []byte) int {
 	return total
 }
 
+func (w *ReceivingWorker) IsDataAvailable() bool {
+	w.RLock()
+	defer w.RUnlock()
+	return w.window.HasFirst()
+}
+
+func (w *ReceivingWorker) NextNumber() uint32 {
+	w.RLock()
+	defer w.RUnlock()
+
+	return w.nextNumber
+}
+
 func (v *ReceivingWorker) Flush(current uint32) {
 	v.Lock()
 	defer v.Unlock()
@@ -250,5 +269,8 @@ func (v *ReceivingWorker) CloseRead() {
 }
 
 func (v *ReceivingWorker) UpdateNecessary() bool {
+	v.RLock()
+	defer v.RUnlock()
+
 	return len(v.acklist.numbers) > 0
 }

+ 17 - 3
transport/internet/kcp/sending.go

@@ -207,7 +207,9 @@ func NewSendingWorker(kcp *Connection) *SendingWorker {
 }
 
 func (v *SendingWorker) Release() {
+	v.Lock()
 	v.window.Release()
+	v.Unlock()
 }
 
 func (v *SendingWorker) ProcessReceivingNext(nextNumber uint32) {
@@ -336,7 +338,6 @@ func (v *SendingWorker) OnPacketLoss(lossRate uint32) {
 
 func (v *SendingWorker) Flush(current uint32) {
 	v.Lock()
-	defer v.Unlock()
 
 	cwnd := v.firstUnacknowledged + v.conn.Config.GetSendingInFlightSize()
 	if cwnd > v.remoteNextNumber {
@@ -348,11 +349,17 @@ func (v *SendingWorker) Flush(current uint32) {
 
 	if !v.window.IsEmpty() {
 		v.window.Flush(current, v.conn.roundTrip.Timeout(), cwnd)
-	} else if v.firstUnacknowledgedUpdated {
-		v.conn.Ping(current, CommandPing)
+		v.firstUnacknowledgedUpdated = false
 	}
 
+	updated := v.firstUnacknowledgedUpdated
 	v.firstUnacknowledgedUpdated = false
+
+	v.Unlock()
+
+	if updated {
+		v.conn.Ping(current, CommandPing)
+	}
 }
 
 func (v *SendingWorker) CloseWrite() {
@@ -372,3 +379,10 @@ func (v *SendingWorker) IsEmpty() bool {
 func (v *SendingWorker) UpdateNecessary() bool {
 	return !v.IsEmpty()
 }
+
+func (w *SendingWorker) FirstUnacknowledged() uint32 {
+	w.RLock()
+	defer w.RUnlock()
+
+	return w.firstUnacknowledged
+}