瀏覽代碼

simplify receiving worker

Darien Raymond 9 年之前
父節點
當前提交
2839ce7a88
共有 2 個文件被更改,包括 32 次插入101 次删除
  1. 32 81
      transport/internet/kcp/receiving.go
  2. 0 20
      transport/internet/kcp/receiving_test.go

+ 32 - 81
transport/internet/kcp/receiving.go

@@ -55,68 +55,6 @@ func (this *ReceivingWindow) Advance() {
 	}
 }
 
-type ReceivingQueue struct {
-	start uint32
-	cap   uint32
-	len   uint32
-	data  []*alloc.Buffer
-}
-
-func NewReceivingQueue(size uint32) *ReceivingQueue {
-	return &ReceivingQueue{
-		cap:  size,
-		data: make([]*alloc.Buffer, size),
-	}
-}
-
-func (this *ReceivingQueue) IsEmpty() bool {
-	return this.len == 0
-}
-
-func (this *ReceivingQueue) IsFull() bool {
-	return this.len == this.cap
-}
-
-func (this *ReceivingQueue) Read(buf []byte) int {
-	if this.IsEmpty() {
-		return 0
-	}
-
-	totalBytes := 0
-	lenBuf := len(buf)
-	for !this.IsEmpty() && totalBytes < lenBuf {
-		payload := this.data[this.start]
-		nBytes, _ := payload.Read(buf)
-		buf = buf[nBytes:]
-		totalBytes += nBytes
-		if payload.IsEmpty() {
-			payload.Release()
-			this.data[this.start] = nil
-			this.start++
-			if this.start == this.cap {
-				this.start = 0
-			}
-			this.len--
-			if this.len == 0 {
-				this.start = 0
-			}
-		}
-	}
-	return totalBytes
-}
-
-func (this *ReceivingQueue) Put(payload *alloc.Buffer) {
-	this.data[(this.start+this.len)%this.cap] = payload
-	this.len++
-}
-
-func (this *ReceivingQueue) Close() {
-	for i := uint32(0); i < this.len; i++ {
-		this.data[(this.start+i)%this.cap].Release()
-		this.data[(this.start+i)%this.cap] = nil
-	}
-}
-
 type AckList struct {
 	writer     SegmentWriter
 	timestamps []uint32
@@ -176,7 +114,7 @@ func (this *AckList) Flush(current uint32, rto uint32) {
 type ReceivingWorker struct {
 	sync.RWMutex
 	conn       *Connection
-	queue      *ReceivingQueue
+	leftOver   *alloc.Buffer
 	window     *ReceivingWindow
 	acklist    *AckList
 	updated    bool
@@ -185,10 +123,9 @@ type ReceivingWorker struct {
 }
 
 func NewReceivingWorker(kcp *Connection) *ReceivingWorker {
-	windowSize := effectiveConfig.GetReceivingWindowSize()
+	windowSize := effectiveConfig.GetReceivingQueueSize()
 	worker := &ReceivingWorker{
 		conn:       kcp,
-		queue:      NewReceivingQueue(effectiveConfig.GetReceivingQueueSize()),
 		window:     NewReceivingWindow(windowSize),
 		windowSize: windowSize,
 	}
@@ -218,27 +155,45 @@ func (this *ReceivingWorker) ProcessSegment(seg *DataSegment) {
 	if !this.window.Set(idx, seg) {
 		seg.Release()
 	}
+}
+
+func (this *ReceivingWorker) Read(b []byte) int {
+	this.Lock()
+	defer this.Unlock()
+
+	total := 0
+	if this.leftOver != nil {
+		nBytes := copy(b, this.leftOver.Value)
+		if nBytes < this.leftOver.Len() {
+			this.leftOver.SliceFrom(nBytes)
+			return nBytes
+		}
+		this.leftOver.Release()
+		this.leftOver = nil
+		total += nBytes
+	}
 
-	for !this.queue.IsFull() {
+	for total < len(b) {
 		seg := this.window.RemoveFirst()
 		if seg == nil {
 			break
 		}
-
-		this.queue.Put(seg.Data)
-		seg.Data = nil
-		seg.Release()
 		this.window.Advance()
 		this.nextNumber++
 		this.updated = true
-	}
-}
 
-func (this *ReceivingWorker) Read(b []byte) int {
-	this.Lock()
-	defer this.Unlock()
-
-	return this.queue.Read(b)
+		nBytes := copy(b[total:], seg.Data.Value)
+		total += nBytes
+		if nBytes < seg.Data.Len() {
+			seg.Data.SliceFrom(nBytes)
+			this.leftOver = seg.Data
+			seg.Data = nil
+			seg.Release()
+			break
+		}
+		seg.Release()
+	}
+	return total
 }
 
 func (this *ReceivingWorker) Flush(current uint32) {
@@ -261,10 +216,6 @@ func (this *ReceivingWorker) Write(seg Segment) {
 }
 
 func (this *ReceivingWorker) CloseRead() {
-	this.Lock()
-	defer this.Unlock()
-
-	this.queue.Close()
 }
 
 func (this *ReceivingWorker) PingNecessary() bool {

+ 0 - 20
transport/internet/kcp/receiving_test.go

@@ -3,7 +3,6 @@ package kcp_test
 import (
 	"testing"
 
-	"v2ray.com/core/common/alloc"
 	"v2ray.com/core/testing/assert"
 	. "v2ray.com/core/transport/internet/kcp"
 )
@@ -35,22 +34,3 @@ func TestRecivingWindow(t *testing.T) {
 	assert.Pointer(window.Remove(1)).Equals(seg2)
 	assert.Pointer(window.Remove(2)).Equals(seg3)
 }
-
-func TestRecivingQueue(t *testing.T) {
-	assert := assert.On(t)
-
-	queue := NewReceivingQueue(2)
-	queue.Put(alloc.NewLocalBuffer(512).Clear().AppendString("abcd"))
-	queue.Put(alloc.NewLocalBuffer(512).Clear().AppendString("efg"))
-	assert.Bool(queue.IsFull()).IsTrue()
-
-	b := make([]byte, 1024)
-	nBytes := queue.Read(b)
-	assert.Int(nBytes).Equals(7)
-	assert.String(string(b[:nBytes])).Equals("abcdefg")
-
-	queue.Put(alloc.NewLocalBuffer(512).Clear().AppendString("1"))
-	queue.Close()
-	nBytes = queue.Read(b)
-	assert.Int(nBytes).Equals(0)
-}