Ver Fonte

sending window

v2ray há 9 anos atrás
pai
commit
4c74e25319

+ 28 - 87
transport/internet/kcp/kcp.go

@@ -54,7 +54,7 @@ type KCP struct {
 
 	snd_queue *SendingQueue
 	rcv_queue *ReceivingQueue
-	snd_buf   []*DataSegment
+	snd_buf   *SendingWindow
 	rcv_buf   *ReceivingWindow
 
 	acklist *ACKList
@@ -82,6 +82,7 @@ func NewKCP(conv uint16, mtu uint32, sendingWindowSize uint32, receivingWindowSi
 	kcp.snd_queue = NewSendingQueue(sendingQueueSize)
 	kcp.rcv_queue = NewReceivingQueue()
 	kcp.acklist = NewACKList(kcp)
+	kcp.snd_buf = NewSendingWindow(kcp, sendingWindowSize)
 	kcp.cwnd = kcp.snd_wnd
 	return kcp
 }
@@ -194,8 +195,8 @@ func (kcp *KCP) update_ack(rtt int32) {
 
 func (kcp *KCP) shrink_buf() {
 	prevUna := kcp.snd_una
-	if len(kcp.snd_buf) > 0 {
-		seg := kcp.snd_buf[0]
+	if kcp.snd_buf.Len() > 0 {
+		seg := kcp.snd_buf.First()
 		kcp.snd_una = seg.Number
 	} else {
 		kcp.snd_una = kcp.snd_nxt
@@ -210,16 +211,7 @@ func (kcp *KCP) parse_ack(sn uint32) {
 		return
 	}
 
-	for k, seg := range kcp.snd_buf {
-		if sn == seg.Number {
-			kcp.snd_buf = append(kcp.snd_buf[:k], kcp.snd_buf[k+1:]...)
-			seg.Release()
-			break
-		}
-		if _itimediff(sn, seg.Number) < 0 {
-			break
-		}
-	}
+	kcp.snd_buf.Remove(sn - kcp.snd_una)
 }
 
 func (kcp *KCP) parse_fastack(sn uint32) {
@@ -227,26 +219,11 @@ func (kcp *KCP) parse_fastack(sn uint32) {
 		return
 	}
 
-	for _, seg := range kcp.snd_buf {
-		if _itimediff(sn, seg.Number) < 0 {
-			break
-		} else if sn != seg.Number {
-			seg.ackSkipped++
-		}
-	}
+	kcp.snd_buf.HandleFastAck(sn)
 }
 
 func (kcp *KCP) HandleReceivingNext(receivingNext uint32) {
-	count := 0
-	for _, seg := range kcp.snd_buf {
-		if _itimediff(receivingNext, seg.Number) > 0 {
-			seg.Release()
-			count++
-		} else {
-			break
-		}
-	}
-	kcp.snd_buf = kcp.snd_buf[count:]
+	kcp.snd_buf.Clear(receivingNext)
 }
 
 func (kcp *KCP) HandleSendingNext(sendingNext uint32) {
@@ -362,7 +339,6 @@ func (kcp *KCP) flush() {
 	}
 
 	current := kcp.current
-	lost := false
 
 	// flush acknowledges
 	if kcp.acklist.Flush() {
@@ -385,47 +361,13 @@ func (kcp *KCP) flush() {
 		seg.timeout = current
 		seg.ackSkipped = 0
 		seg.transmit = 0
-		kcp.snd_buf = append(kcp.snd_buf, seg)
+		kcp.snd_buf.Push(seg)
 		kcp.snd_nxt++
 	}
 
-	// calculate resent
-	resent := uint32(kcp.fastresend)
-	if kcp.fastresend <= 0 {
-		resent = 0xffffffff
-	}
-
 	// flush data segments
-	for _, segment := range kcp.snd_buf {
-		needsend := false
-		if segment.transmit == 0 {
-			needsend = true
-			segment.transmit++
-			segment.timeout = current + kcp.rx_rto
-		} else if _itimediff(current, segment.timeout) >= 0 {
-			needsend = true
-			segment.transmit++
-			segment.timeout = current + kcp.rx_rto
-			lost = true
-		} else if segment.ackSkipped >= resent {
-			needsend = true
-			segment.transmit++
-			segment.ackSkipped = 0
-			segment.timeout = current + kcp.rx_rto
-			lost = true
-		}
-
-		if needsend {
-			segment.Timestamp = current
-			segment.SendingNext = kcp.snd_una
-			segment.Opt = 0
-			if kcp.state == StateReadyToClose {
-				segment.Opt = SegmentOptionClose
-			}
-
-			kcp.output.Write(segment)
-			kcp.sendingUpdated = false
-		}
+	if kcp.snd_buf.Flush() {
+		kcp.sendingUpdated = false
 	}
 
 	if kcp.sendingUpdated || kcp.receivingUpdated || _itimediff(kcp.current, kcp.lastPingTime) >= 5000 {
@@ -447,18 +389,22 @@ func (kcp *KCP) flush() {
 	// flash remain segments
 	kcp.output.Flush()
 
-	if kcp.congestionControl {
-		if lost {
-			kcp.cwnd = 3 * kcp.cwnd / 4
-		} else {
-			kcp.cwnd += kcp.cwnd / 4
-		}
-		if kcp.cwnd < 4 {
-			kcp.cwnd = 4
-		}
-		if kcp.cwnd > kcp.snd_wnd {
-			kcp.cwnd = kcp.snd_wnd
-		}
+}
+
+func (kcp *KCP) HandleLost(lost bool) {
+	if !kcp.congestionControl {
+		return
+	}
+	if lost {
+		kcp.cwnd = 3 * kcp.cwnd / 4
+	} else {
+		kcp.cwnd += kcp.cwnd / 4
+	}
+	if kcp.cwnd < 4 {
+		kcp.cwnd = 4
+	}
+	if kcp.cwnd > kcp.snd_wnd {
+		kcp.cwnd = kcp.snd_wnd
 	}
 }
 
@@ -488,15 +434,10 @@ func (kcp *KCP) NoDelay(interval uint32, resend int, congestionControl bool) int
 
 // WaitSnd gets how many packet is waiting to be sent
 func (kcp *KCP) WaitSnd() uint32 {
-	return uint32(len(kcp.snd_buf)) + kcp.snd_queue.Len()
+	return uint32(kcp.snd_buf.Len()) + kcp.snd_queue.Len()
 }
 
 func (this *KCP) ClearSendQueue() {
 	this.snd_queue.Clear()
-
-	for _, seg := range this.snd_buf {
-		seg.Release()
-	}
-
-	this.snd_buf = nil
+	this.snd_buf.Clear(0xFFFFFFFF)
 }

+ 141 - 0
transport/internet/kcp/sending.go

@@ -1,5 +1,146 @@
 package kcp
 
+type SendingWindow struct {
+	start uint32
+	cap   uint32
+	len   uint32
+	last  uint32
+
+	data []*DataSegment
+	prev []uint32
+	next []uint32
+
+	kcp *KCP
+}
+
+func NewSendingWindow(kcp *KCP, size uint32) *SendingWindow {
+	window := &SendingWindow{
+		start: 0,
+		cap:   size,
+		len:   0,
+		last:  0,
+		data:  make([]*DataSegment, size),
+		prev:  make([]uint32, size),
+		next:  make([]uint32, size),
+	}
+	return window
+}
+
+func (this *SendingWindow) Len() int {
+	return int(this.len)
+}
+
+func (this *SendingWindow) Push(seg *DataSegment) {
+	pos := (this.start + this.len) % this.cap
+	this.data[pos] = seg
+	if this.len > 0 {
+		this.next[this.last] = pos
+		this.prev[pos] = this.last
+	}
+	this.last = pos
+	this.len++
+}
+
+func (this *SendingWindow) First() *DataSegment {
+	return this.data[this.start]
+}
+
+func (this *SendingWindow) Clear(una uint32) {
+	for this.Len() > 0 {
+		if this.data[this.start].Number < una {
+			this.Remove(0)
+		}
+	}
+}
+
+func (this *SendingWindow) Remove(idx uint32) {
+	pos := (this.start + idx) % this.cap
+	seg := this.data[pos]
+	seg.Release()
+	this.data[pos] = nil
+	if pos == this.start {
+		if this.len == 1 {
+			this.len = 0
+			this.start = 0
+			this.last = 0
+		} else {
+			delta := this.next[pos] - this.start
+			this.start = this.next[pos]
+			this.len -= delta
+		}
+	} else if pos == this.last {
+		this.last = this.prev[pos]
+	} else {
+		this.next[this.prev[pos]] = this.next[pos]
+		this.prev[this.next[pos]] = this.prev[pos]
+	}
+}
+
+func (this *SendingWindow) HandleFastAck(number uint32) {
+	for i := this.start; ; i = this.next[i] {
+		seg := this.data[i]
+		if _itimediff(number, seg.Number) < 0 {
+			break
+		}
+		if number != seg.Number {
+			seg.ackSkipped++
+		}
+		if i == this.last {
+			break
+		}
+	}
+}
+
+func (this *SendingWindow) Flush() bool {
+	current := this.kcp.current
+	resent := uint32(this.kcp.fastresend)
+	if this.kcp.fastresend <= 0 {
+		resent = 0xffffffff
+	}
+	lost := false
+	segSent := false
+
+	for i := this.start; ; i = this.next[i] {
+		segment := this.data[i]
+		needsend := false
+		if segment.transmit == 0 {
+			needsend = true
+			segment.transmit++
+			segment.timeout = current + this.kcp.rx_rto
+		} else if _itimediff(current, segment.timeout) >= 0 {
+			needsend = true
+			segment.transmit++
+			segment.timeout = current + this.kcp.rx_rto
+			lost = true
+		} else if segment.ackSkipped >= resent {
+			needsend = true
+			segment.transmit++
+			segment.ackSkipped = 0
+			segment.timeout = current + this.kcp.rx_rto
+			lost = true
+		}
+
+		if needsend {
+			segment.Timestamp = current
+			segment.SendingNext = this.kcp.snd_una
+			segment.Opt = 0
+			if this.kcp.state == StateReadyToClose {
+				segment.Opt = SegmentOptionClose
+			}
+
+			this.kcp.output.Write(segment)
+			segSent = true
+		}
+		if i == this.last {
+			break
+		}
+	}
+
+	this.kcp.HandleLost(lost)
+
+	return segSent
+}
+
 type SendingQueue struct {
 	start uint32
 	cap   uint32

+ 33 - 0
transport/internet/kcp/sending_test.go

@@ -62,3 +62,36 @@ func TestSendingQueueClear(t *testing.T) {
 	queue.Clear()
 	assert.Bool(queue.IsEmpty()).IsTrue()
 }
+
+func TestSendingWindow(t *testing.T) {
+	assert := assert.On(t)
+
+	window := NewSendingWindow(nil, 5)
+	window.Push(&DataSegment{
+		Number: 0,
+	})
+	window.Push(&DataSegment{
+		Number: 1,
+	})
+	window.Push(&DataSegment{
+		Number: 2,
+	})
+	assert.Int(window.Len()).Equals(3)
+
+	window.Remove(1)
+	assert.Int(window.Len()).Equals(3)
+	assert.Uint32(window.First().Number).Equals(0)
+
+	window.Remove(0)
+	assert.Int(window.Len()).Equals(1)
+	assert.Uint32(window.First().Number).Equals(2)
+
+	window.Remove(0)
+	assert.Int(window.Len()).Equals(0)
+
+	window.Push(&DataSegment{
+		Number: 4,
+	})
+	assert.Int(window.Len()).Equals(1)
+	assert.Uint32(window.First().Number).Equals(4)
+}