浏览代码

high performance sending queue

v2ray 9 年之前
父节点
当前提交
3925b62751

+ 13 - 9
transport/internet/kcp/connection.go

@@ -77,12 +77,12 @@ func NewConnection(conv uint32, writerCloser io.WriteCloser, local *net.UDPAddr,
 	conn.block = block
 	conn.writer = writerCloser
 	conn.since = nowMillisec()
+	conn.writeBufferSize = effectiveConfig.WriteBuffer / effectiveConfig.Mtu
 
 	mtu := effectiveConfig.Mtu - uint32(block.HeaderSize()) - headerSize
-	conn.kcp = NewKCP(conv, mtu, effectiveConfig.GetSendingWindowSize(), effectiveConfig.GetReceivingWindowSize(), conn.output)
+	conn.kcp = NewKCP(conv, mtu, effectiveConfig.GetSendingWindowSize(), effectiveConfig.GetReceivingWindowSize(), conn.writeBufferSize, conn.output)
 	conn.kcp.NoDelay(effectiveConfig.Tti, 2, effectiveConfig.Congestion)
 	conn.kcp.current = conn.Elapsed()
-	conn.writeBufferSize = effectiveConfig.WriteBuffer / effectiveConfig.Mtu
 
 	go conn.updateTask()
 
@@ -133,6 +133,7 @@ func (this *Connection) Write(b []byte) (int, error) {
 		this.state == ConnStateClosed {
 		return 0, io.ErrClosedPipe
 	}
+	totalWritten := 0
 
 	for {
 		this.RLock()
@@ -140,23 +141,26 @@ func (this *Connection) Write(b []byte) (int, error) {
 			this.state == ConnStatePeerClosed ||
 			this.state == ConnStateClosed {
 			this.RUnlock()
-			return 0, io.ErrClosedPipe
+			return totalWritten, io.ErrClosedPipe
 		}
 		this.RUnlock()
 
 		this.kcpAccess.Lock()
-		if this.kcp.WaitSnd() < this.writeBufferSize {
-			nBytes := len(b)
-			this.kcp.Send(b)
+		nBytes := this.kcp.Send(b[totalWritten:])
+		if nBytes > 0 {
 			this.kcp.current = this.Elapsed()
 			this.kcp.flush()
-			this.kcpAccess.Unlock()
-			return nBytes, nil
+			totalWritten += nBytes
+			if totalWritten == len(b) {
+				this.kcpAccess.Unlock()
+				return totalWritten, nil
+			}
 		}
+
 		this.kcpAccess.Unlock()
 
 		if !this.wd.IsZero() && this.wd.Before(time.Now()) {
-			return 0, errTimeout
+			return totalWritten, errTimeout
 		}
 
 		// Sending windows is 1024 for the moment. This amount is not gonna sent in 1 sec.

+ 12 - 38
transport/internet/kcp/kcp.go

@@ -146,7 +146,7 @@ type KCP struct {
 	ts_probe, probe_wait                   uint32
 	dead_link, incr                        uint32
 
-	snd_queue []*Segment
+	snd_queue *SendingQueue
 	rcv_queue []*Segment
 	snd_buf   []*Segment
 	rcv_buf   *ReceivingWindow
@@ -161,7 +161,7 @@ type KCP struct {
 
 // NewKCP create a new kcp control object, 'conv' must equal in two endpoint
 // from the same connection.
-func NewKCP(conv uint32, mtu uint32, sendingWindowSize uint32, receivingWindowSize uint32, output Output) *KCP {
+func NewKCP(conv uint32, mtu uint32, sendingWindowSize uint32, receivingWindowSize uint32, sendingQueueSize uint32, output Output) *KCP {
 	kcp := new(KCP)
 	kcp.conv = conv
 	kcp.snd_wnd = sendingWindowSize
@@ -177,6 +177,7 @@ func NewKCP(conv uint32, mtu uint32, sendingWindowSize uint32, receivingWindowSi
 	kcp.dead_link = IKCP_DEADLINK
 	kcp.output = output
 	kcp.rcv_buf = NewReceivingWindow(receivingWindowSize)
+	kcp.snd_queue = NewSendingQueue(sendingQueueSize)
 	return kcp
 }
 
@@ -232,26 +233,8 @@ func (kcp *KCP) DumpReceivingBuf() {
 
 // Send is user/upper level send, returns below zero for error
 func (kcp *KCP) Send(buffer []byte) int {
-	var count int
-	if len(buffer) == 0 {
-		return -1
-	}
-
-	if len(buffer) < int(kcp.mss) {
-		count = 1
-	} else {
-		count = (len(buffer) + int(kcp.mss) - 1) / int(kcp.mss)
-	}
-
-	if count > 255 {
-		return -2
-	}
-
-	if count == 0 {
-		count = 1
-	}
-
-	for i := 0; i < count; i++ {
+	nBytes := 0
+	for len(buffer) > 0 && !kcp.snd_queue.IsFull() {
 		var size int
 		if len(buffer) > int(kcp.mss) {
 			size = int(kcp.mss)
@@ -260,11 +243,11 @@ func (kcp *KCP) Send(buffer []byte) int {
 		}
 		seg := NewSegment()
 		seg.data.Append(buffer[:size])
-		seg.frg = uint32(count - i - 1)
-		kcp.snd_queue = append(kcp.snd_queue, seg)
+		kcp.snd_queue.Push(seg)
 		buffer = buffer[size:]
+		nBytes += size
 	}
-	return 0
+	return nBytes
 }
 
 // https://tools.ietf.org/html/rfc6298
@@ -572,12 +555,8 @@ func (kcp *KCP) flush() {
 		cwnd = _imin_(kcp.cwnd, cwnd)
 	}
 
-	count = 0
-	for k := range kcp.snd_queue {
-		if _itimediff(kcp.snd_nxt, cwnd) >= 0 {
-			break
-		}
-		newseg := kcp.snd_queue[k]
+	for !kcp.snd_queue.IsEmpty() && _itimediff(kcp.snd_nxt, cwnd) < 0 {
+		newseg := kcp.snd_queue.Pop()
 		newseg.conv = kcp.conv
 		newseg.cmd = IKCP_CMD_PUSH
 		newseg.wnd = seg.wnd
@@ -589,9 +568,7 @@ func (kcp *KCP) flush() {
 		newseg.xmit = 0
 		kcp.snd_buf = append(kcp.snd_buf, newseg)
 		kcp.snd_nxt++
-		count++
 	}
-	kcp.snd_queue = kcp.snd_queue[count:]
 
 	// calculate resent
 	resent := uint32(kcp.fastresend)
@@ -774,14 +751,11 @@ func (kcp *KCP) NoDelay(interval uint32, resend int, congestionControl bool) int
 
 // WaitSnd gets how many packet is waiting to be sent
 func (kcp *KCP) WaitSnd() uint32 {
-	return uint32(len(kcp.snd_buf) + len(kcp.snd_queue))
+	return uint32(len(kcp.snd_buf)) + kcp.snd_queue.Len()
 }
 
 func (this *KCP) ClearSendQueue() {
-	for _, seg := range this.snd_queue {
-		seg.Release()
-	}
-	this.snd_queue = nil
+	this.snd_queue.Clear()
 
 	for _, seg := range this.snd_buf {
 		seg.Release()

+ 60 - 0
transport/internet/kcp/sending.go

@@ -0,0 +1,60 @@
+package kcp
+
+type SendingQueue struct {
+	start uint32
+	cap   uint32
+	len   uint32
+	list  []*Segment
+}
+
+func NewSendingQueue(size uint32) *SendingQueue {
+	return &SendingQueue{
+		start: 0,
+		cap:   size,
+		list:  make([]*Segment, size),
+		len:   0,
+	}
+}
+
+func (this *SendingQueue) IsFull() bool {
+	return this.len == this.cap
+}
+
+func (this *SendingQueue) IsEmpty() bool {
+	return this.len == 0
+}
+
+func (this *SendingQueue) Pop() *Segment {
+	if this.IsEmpty() {
+		return nil
+	}
+	seg := this.list[this.start]
+	this.list[this.start] = nil
+	this.len--
+	this.start++
+	if this.start == this.cap {
+		this.start = 0
+	}
+	return seg
+}
+
+func (this *SendingQueue) Push(seg *Segment) {
+	if this.IsFull() {
+		return
+	}
+	this.list[(this.start+this.len)%this.cap] = seg
+	this.len++
+}
+
+func (this *SendingQueue) Clear() {
+	for i := uint32(0); i < this.len; i++ {
+		this.list[(i+this.start)%this.cap].Release()
+		this.list[(i+this.start)%this.cap] = nil
+	}
+	this.start = 0
+	this.len = 0
+}
+
+func (this *SendingQueue) Len() uint32 {
+	return this.len
+}

+ 64 - 0
transport/internet/kcp/sending_test.go

@@ -0,0 +1,64 @@
+package kcp_test
+
+import (
+	"testing"
+
+	"github.com/v2ray/v2ray-core/testing/assert"
+	. "github.com/v2ray/v2ray-core/transport/internet/kcp"
+)
+
+func TestSendingQueue(t *testing.T) {
+	assert := assert.On(t)
+
+	queue := NewSendingQueue(3)
+
+	seg0 := &Segment{}
+	seg1 := &Segment{}
+	seg2 := &Segment{}
+	seg3 := &Segment{}
+
+	assert.Bool(queue.IsEmpty()).IsTrue()
+	assert.Bool(queue.IsFull()).IsFalse()
+
+	queue.Push(seg0)
+	assert.Bool(queue.IsEmpty()).IsFalse()
+
+	queue.Push(seg1)
+	queue.Push(seg2)
+
+	assert.Bool(queue.IsFull()).IsTrue()
+
+	assert.Pointer(queue.Pop()).Equals(seg0)
+
+	queue.Push(seg3)
+	assert.Bool(queue.IsFull()).IsTrue()
+
+	assert.Pointer(queue.Pop()).Equals(seg1)
+	assert.Pointer(queue.Pop()).Equals(seg2)
+	assert.Pointer(queue.Pop()).Equals(seg3)
+	assert.Int(int(queue.Len())).Equals(0)
+}
+
+func TestSendingQueueClear(t *testing.T) {
+	assert := assert.On(t)
+
+	queue := NewSendingQueue(3)
+
+	seg0 := &Segment{}
+	seg1 := &Segment{}
+	seg2 := &Segment{}
+	seg3 := &Segment{}
+
+	queue.Push(seg0)
+	assert.Bool(queue.IsEmpty()).IsFalse()
+
+	queue.Clear()
+	assert.Bool(queue.IsEmpty()).IsTrue()
+
+	queue.Push(seg1)
+	queue.Push(seg2)
+	queue.Push(seg3)
+
+	queue.Clear()
+	assert.Bool(queue.IsEmpty()).IsTrue()
+}