Bläddra i källkod

refine locker in kcp connection

v2ray 9 år sedan
förälder
incheckning
78ef65e17b

+ 94 - 72
transport/internet/kcp/connection.go

@@ -5,6 +5,7 @@ import (
 	"io"
 	"net"
 	"sync"
+	"sync/atomic"
 	"time"
 
 	"github.com/v2ray/v2ray-core/common/alloc"
@@ -18,7 +19,7 @@ var (
 	errClosedConnection = errors.New("Connection closed.")
 )
 
-type State int
+type State int32
 
 const (
 	StateActive       State = 0
@@ -37,9 +38,65 @@ func nowMillisec() int64 {
 	return now.Unix()*1000 + int64(now.Nanosecond()/1000000)
 }
 
+type RountTripInfo struct {
+	sync.RWMutex
+	variation uint32
+	srtt      uint32
+	rto       uint32
+	minRtt    uint32
+}
+
+func (this *RountTripInfo) Update(rtt uint32) {
+	if rtt > 0x7FFFFFFF {
+		return
+	}
+	this.Lock()
+	defer this.Unlock()
+
+	// https://tools.ietf.org/html/rfc6298
+	if this.srtt == 0 {
+		this.srtt = rtt
+		this.variation = rtt / 2
+	} else {
+		delta := rtt - this.srtt
+		if this.srtt > rtt {
+			delta = this.srtt - rtt
+		}
+		this.variation = (3*this.variation + delta) / 4
+		this.srtt = (7*this.srtt + rtt) / 8
+		if this.srtt < this.minRtt {
+			this.srtt = this.minRtt
+		}
+	}
+	var rto uint32
+	if this.minRtt < 4*this.variation {
+		rto = this.srtt + 4*this.variation
+	} else {
+		rto = this.srtt + this.variation
+	}
+
+	if rto > 10000 {
+		rto = 10000
+	}
+	this.rto = rto * 3 / 2
+}
+
+func (this *RountTripInfo) Timeout() uint32 {
+	this.RLock()
+	defer this.RUnlock()
+
+	return this.rto
+}
+
+func (this *RountTripInfo) SmoothedTime() uint32 {
+	this.RLock()
+	defer this.RUnlock()
+
+	return this.srtt
+}
+
 // Connection is a KCP connection over UDP.
 type Connection struct {
-	sync.RWMutex
 	block         Authenticator
 	local, remote net.Addr
 	wd            time.Time // write deadline
@@ -54,9 +111,9 @@ type Connection struct {
 	sendingUpdated   bool
 	lastPingTime     uint32
 
-	mss                        uint32
-	rx_rttvar, rx_srtt, rx_rto uint32
-	interval                   uint32
+	mss       uint32
+	roundTrip *RountTripInfo
+	interval  uint32
 
 	receivingWorker *ReceivingWorker
 	sendingWorker   *SendingWorker
@@ -85,7 +142,10 @@ func NewConnection(conv uint16, writerCloser io.WriteCloser, local *net.UDPAddr,
 	conn.output = NewSegmentWriter(authWriter)
 
 	conn.mss = authWriter.Mtu() - DataSegmentOverhead
-	conn.rx_rto = 100
+	conn.roundTrip = &RountTripInfo{
+		rto:    100,
+		minRtt: effectiveConfig.Tti,
+	}
 	conn.interval = effectiveConfig.Tti
 	conn.receivingWorker = NewReceivingWorker(conn)
 	conn.fastresend = 2
@@ -107,8 +167,7 @@ func (this *Connection) Read(b []byte) (int, error) {
 		return 0, io.EOF
 	}
 
-	state := this.State()
-	if state == StateTerminating || state == StateTerminated {
+	if this.State() == StateTerminating || this.State() == StateTerminated {
 		return 0, io.EOF
 	}
 	return this.receivingWorker.Read(b)
@@ -144,10 +203,8 @@ func (this *Connection) Write(b []byte) (int, error) {
 }
 
 func (this *Connection) SetState(state State) {
-	this.Lock()
-	this.state = state
-	this.stateBeginTime = this.Elapsed()
-	this.Unlock()
+	atomic.StoreInt32((*int32)(&this.state), int32(state))
+	atomic.StoreUint32(&this.stateBeginTime, this.Elapsed())
 
 	switch state {
 	case StateReadyToClose:
@@ -297,42 +354,10 @@ func (this *Connection) OnPeerClosed() {
 	}
 }
 
-// https://tools.ietf.org/html/rfc6298
-func (this *Connection) update_ack(rtt int32) {
-	this.Lock()
-	defer this.Unlock()
-
-	if this.rx_srtt == 0 {
-		this.rx_srtt = uint32(rtt)
-		this.rx_rttvar = uint32(rtt) / 2
-	} else {
-		delta := rtt - int32(this.rx_srtt)
-		if delta < 0 {
-			delta = -delta
-		}
-		this.rx_rttvar = (3*this.rx_rttvar + uint32(delta)) / 4
-		this.rx_srtt = (7*this.rx_srtt + uint32(rtt)) / 8
-		if this.rx_srtt < this.interval {
-			this.rx_srtt = this.interval
-		}
-	}
-	var rto uint32
-	if this.interval < 4*this.rx_rttvar {
-		rto = this.rx_srtt + 4*this.rx_rttvar
-	} else {
-		rto = this.rx_srtt + this.interval
-	}
-
-	if rto > 10000 {
-		rto = 10000
-	}
-	this.rx_rto = rto * 3 / 2
-}
-
 // Input when you received a low level packet (eg. UDP packet), call it
-func (kcp *Connection) Input(data []byte) int {
-	current := kcp.Elapsed()
-	kcp.lastIncomingTime = current
+func (this *Connection) Input(data []byte) int {
+	current := this.Elapsed()
+	atomic.StoreUint32(&this.lastIncomingTime, current)
 
 	var seg Segment
 	for {
@@ -343,26 +368,27 @@ func (kcp *Connection) Input(data []byte) int {
 
 		switch seg := seg.(type) {
 		case *DataSegment:
-			kcp.HandleOption(seg.Opt)
-			kcp.receivingWorker.ProcessSegment(seg)
-			kcp.lastPayloadTime = current
+			this.HandleOption(seg.Opt)
+			this.receivingWorker.ProcessSegment(seg)
+			atomic.StoreUint32(&this.lastPayloadTime, current)
 		case *AckSegment:
-			kcp.HandleOption(seg.Opt)
-			kcp.sendingWorker.ProcessSegment(current, seg)
-			kcp.lastPayloadTime = current
+			this.HandleOption(seg.Opt)
+			this.sendingWorker.ProcessSegment(current, seg)
+			atomic.StoreUint32(&this.lastPayloadTime, current)
 		case *CmdOnlySegment:
-			kcp.HandleOption(seg.Opt)
+			this.HandleOption(seg.Opt)
 			if seg.Cmd == SegmentCommandTerminated {
-				if kcp.state == StateActive ||
-					kcp.state == StateReadyToClose ||
-					kcp.state == StatePeerClosed {
-					kcp.SetState(StateTerminating)
-				} else if kcp.state == StateTerminating {
-					kcp.SetState(StateTerminated)
+				state := this.State()
+				if state == StateActive ||
+					state == StateReadyToClose ||
+					state == StatePeerClosed {
+					this.SetState(StateTerminating)
+				} else if state == StateTerminating {
+					this.SetState(StateTerminated)
 				}
 			}
-			kcp.sendingWorker.ProcessReceivingNext(seg.ReceivinNext)
-			kcp.receivingWorker.ProcessSendingNext(seg.SendingNext)
+			this.sendingWorker.ProcessReceivingNext(seg.ReceivinNext)
+			this.receivingWorker.ProcessSendingNext(seg.SendingNext)
 		default:
 		}
 	}
@@ -372,16 +398,15 @@ func (kcp *Connection) Input(data []byte) int {
 
 func (this *Connection) flush() {
 	current := this.Elapsed()
-	state := this.State()
 
-	if state == StateTerminated {
+	if this.State() == StateTerminated {
 		return
 	}
-	if state == StateActive && current-this.lastPayloadTime >= 30000 {
+	if this.State() == StateActive && current-this.lastPayloadTime >= 30000 {
 		this.Close()
 	}
 
-	if state == StateTerminating {
+	if this.State() == StateTerminating {
 		this.output.Write(&CmdOnlySegment{
 			Conv: this.conv,
 			Cmd:  SegmentCommandTerminated,
@@ -394,7 +419,7 @@ func (this *Connection) flush() {
 		return
 	}
 
-	if state == StateReadyToClose && current-this.stateBeginTime > 15000 {
+	if this.State() == StateReadyToClose && current-this.stateBeginTime > 15000 {
 		this.SetState(StateTerminating)
 	}
 
@@ -408,7 +433,7 @@ func (this *Connection) flush() {
 		seg.Cmd = SegmentCommandPing
 		seg.ReceivinNext = this.receivingWorker.nextNumber
 		seg.SendingNext = this.sendingWorker.firstUnacknowledged
-		if state == StateReadyToClose {
+		if this.State() == StateReadyToClose {
 			seg.Opt = SegmentOptionClose
 		}
 		this.output.Write(seg)
@@ -423,8 +448,5 @@ func (this *Connection) flush() {
 }
 
 func (this *Connection) State() State {
-	this.RLock()
-	defer this.RUnlock()
-
-	return this.state
+	return State(atomic.LoadInt32((*int32)(&this.state)))
 }

+ 1 - 1
transport/internet/kcp/receiving.go

@@ -285,7 +285,7 @@ func (this *ReceivingWorker) SetReadDeadline(t time.Time) {
 }
 
 func (this *ReceivingWorker) Flush(current uint32) {
-	this.acklist.Flush(current, this.conn.rx_rto)
+	this.acklist.Flush(current, this.conn.roundTrip.Timeout())
 }
 
 func (this *ReceivingWorker) Write(seg Segment) {

+ 3 - 3
transport/internet/kcp/sending.go

@@ -293,7 +293,7 @@ func (this *SendingWorker) ProcessSegment(current uint32, seg *AckSegment) {
 		timestamp := seg.TimestampList[i]
 		number := seg.NumberList[i]
 		if current-timestamp < 10000 {
-			this.conn.update_ack(int32(current - timestamp))
+			this.conn.roundTrip.Update(current - timestamp)
 		}
 		this.ProcessAck(number)
 		if maxack < number {
@@ -344,7 +344,7 @@ func (this *SendingWorker) PingNecessary() bool {
 }
 
 func (this *SendingWorker) OnPacketLoss(lossRate uint32) {
-	if !effectiveConfig.Congestion || this.conn.rx_srtt == 0 {
+	if !effectiveConfig.Congestion || this.conn.roundTrip.Timeout() == 0 {
 		return
 	}
 
@@ -383,7 +383,7 @@ func (this *SendingWorker) Flush(current uint32) {
 		this.nextNumber++
 	}
 
-	this.window.Flush(current, this.conn.fastresend, this.conn.rx_rto, cwnd)
+	this.window.Flush(current, this.conn.fastresend, this.conn.roundTrip.Timeout(), cwnd)
 }
 
 func (this *SendingWorker) CloseWrite() {