Explorar o código

refine connection.read

v2ray %!s(int64=9) %!d(string=hai) anos
pai
achega
56ce062154

+ 28 - 4
transport/internet/kcp/connection.go

@@ -99,9 +99,11 @@ func (this *RountTripInfo) SmoothedTime() uint32 {
 type Connection struct {
 	block         Authenticator
 	local, remote net.Addr
+	rd            time.Time
 	wd            time.Time // write deadline
 	writer        io.WriteCloser
 	since         int64
+	dataInputCond *sync.Cond
 
 	conv             uint16
 	state            State
@@ -133,6 +135,7 @@ func NewConnection(conv uint16, writerCloser io.WriteCloser, local *net.UDPAddr,
 	conn.block = block
 	conn.writer = writerCloser
 	conn.since = nowMillisec()
+	conn.dataInputCond = sync.NewCond(new(sync.Mutex))
 
 	authWriter := &AuthenticationWriter{
 		Authenticator: block,
@@ -167,10 +170,28 @@ func (this *Connection) Read(b []byte) (int, error) {
 		return 0, io.EOF
 	}
 
-	if this.State() == StateTerminating || this.State() == StateTerminated {
-		return 0, io.EOF
+	for {
+		if this.State() == StateReadyToClose || this.State() == StateTerminating || this.State() == StateTerminated {
+			return 0, io.EOF
+		}
+		nBytes := this.receivingWorker.Read(b)
+		if nBytes > 0 {
+			return nBytes, nil
+		}
+		var timer *time.Timer
+		if !this.rd.IsZero() && this.rd.Before(time.Now()) {
+			timer = time.AfterFunc(this.rd.Sub(time.Now()), this.dataInputCond.Signal)
+		}
+		this.dataInputCond.L.Lock()
+		this.dataInputCond.Wait()
+		this.dataInputCond.L.Unlock()
+		if timer != nil {
+			timer.Stop()
+		}
+		if !this.rd.IsZero() && this.rd.Before(time.Now()) {
+			return 0, errTimeout
+		}
 	}
-	return this.receivingWorker.Read(b)
 }
 
 // Write implements the Conn Write method.
@@ -226,6 +247,8 @@ func (this *Connection) Close() error {
 		return errClosedConnection
 	}
 
+	this.dataInputCond.Broadcast()
+
 	state := this.State()
 	if state == StateReadyToClose ||
 		state == StateTerminating ||
@@ -276,7 +299,7 @@ func (this *Connection) SetReadDeadline(t time.Time) error {
 	if this == nil || this.State() != StateActive {
 		return errClosedConnection
 	}
-	this.receivingWorker.SetReadDeadline(t)
+	this.rd = t
 	return nil
 }
 
@@ -371,6 +394,7 @@ func (this *Connection) Input(data []byte) int {
 			this.HandleOption(seg.Opt)
 			this.receivingWorker.ProcessSegment(seg)
 			atomic.StoreUint32(&this.lastPayloadTime, current)
+			this.dataInputCond.Signal()
 		case *AckSegment:
 			this.HandleOption(seg.Opt)
 			this.sendingWorker.ProcessSegment(current, seg)

+ 66 - 104
transport/internet/kcp/receiving.go

@@ -1,9 +1,7 @@
 package kcp
 
 import (
-	"io"
 	"sync"
-	"time"
 
 	"github.com/v2ray/v2ray-core/common/alloc"
 )
@@ -58,101 +56,68 @@ func (this *ReceivingWindow) Advance() {
 }
 
 type ReceivingQueue struct {
-	sync.Mutex
-	closed  bool
-	cache   *alloc.Buffer
-	queue   chan *alloc.Buffer
-	timeout time.Time
+	start uint32
+	cap   uint32
+	len   uint32
+	data  []*alloc.Buffer
 }
 
 func NewReceivingQueue(size uint32) *ReceivingQueue {
 	return &ReceivingQueue{
-		queue: make(chan *alloc.Buffer, size),
+		cap:  size,
+		data: make([]*alloc.Buffer, size),
 	}
 }
 
-func (this *ReceivingQueue) Read(buf []byte) (int, error) {
-	if this.closed {
-		return 0, io.EOF
-	}
+func (this *ReceivingQueue) IsEmpty() bool {
+	return this.len == 0
+}
 
-	if this.cache.Len() > 0 {
-		nBytes, err := this.cache.Read(buf)
-		if this.cache.IsEmpty() {
-			this.cache.Release()
-			this.cache = nil
-		}
-		return nBytes, err
-	}
+func (this *ReceivingQueue) IsFull() bool {
+	return this.len == this.cap
+}
 
-	var totalBytes int
+func (this *ReceivingQueue) Read(buf []byte) int {
+	if this.IsEmpty() {
+		return 0
+	}
 
-L:
-	for totalBytes < len(buf) {
-		timeToSleep := time.Millisecond
-		select {
-		case payload, open := <-this.queue:
-			if !open {
-				return totalBytes, io.EOF
-			}
-			nBytes, err := payload.Read(buf)
-			totalBytes += nBytes
-			if err != nil {
-				return totalBytes, err
-			}
-			if !payload.IsEmpty() {
-				this.cache = payload
-			}
-			buf = buf[nBytes:]
-		case <-time.After(timeToSleep):
-			if totalBytes > 0 {
-				break L
+	totalBytes := 0
+	lenBuf := len(buf)
+	for !this.IsEmpty() && totalBytes < lenBuf {
+		payload := this.data[this.start]
+		nBytes, _ := payload.Read(buf)
+		buf = buf[nBytes:]
+		totalBytes += nBytes
+		if payload.IsEmpty() {
+			payload.Release()
+			this.data[this.start] = nil
+			this.start++
+			if this.start == this.cap {
+				this.start = 0
 			}
-			if !this.timeout.IsZero() && this.timeout.Before(time.Now()) {
-				return totalBytes, errTimeout
-			}
-			timeToSleep += 500 * time.Millisecond
-			if timeToSleep > 5*time.Second {
-				timeToSleep = 5 * time.Second
+			this.len--
+			if this.len == 0 {
+				this.start = 0
 			}
 		}
 	}
-
-	return totalBytes, nil
+	return totalBytes
 }
 
-func (this *ReceivingQueue) Put(payload *alloc.Buffer) bool {
-	if this.closed {
-		payload.Release()
-		return false
-	}
-
-	select {
-	case this.queue <- payload:
-		return true
-	default:
-		return false
-	}
-}
-
-func (this *ReceivingQueue) SetReadDeadline(t time.Time) error {
-	this.timeout = t
-	return nil
+func (this *ReceivingQueue) Put(payload *alloc.Buffer) {
+	this.data[(this.start+this.len)%this.cap] = payload
+	this.len++
 }
 
 func (this *ReceivingQueue) Close() {
-	this.Lock()
-	defer this.Unlock()
-
-	if this.closed {
-		return
+	for i := uint32(0); i < this.len; i++ {
+		this.data[(this.start+i)%this.cap].Release()
+		this.data[(this.start+i)%this.cap] = nil
 	}
-	this.closed = true
-	close(this.queue)
 }
 
 type AckList struct {
-	sync.Mutex
 	writer     SegmentWriter
 	timestamps []uint32
 	numbers    []uint32
@@ -169,18 +134,12 @@ func NewAckList(writer SegmentWriter) *AckList {
 }
 
 func (this *AckList) Add(number uint32, timestamp uint32) {
-	this.Lock()
-	defer this.Unlock()
-
 	this.timestamps = append(this.timestamps, timestamp)
 	this.numbers = append(this.numbers, number)
 	this.nextFlush = append(this.nextFlush, 0)
 }
 
 func (this *AckList) Clear(una uint32) {
-	this.Lock()
-	defer this.Unlock()
-
 	count := 0
 	for i := 0; i < len(this.numbers); i++ {
 		if this.numbers[i] >= una {
@@ -201,14 +160,12 @@ func (this *AckList) Clear(una uint32) {
 
 func (this *AckList) Flush(current uint32, rto uint32) {
 	seg := NewAckSegment()
-	this.Lock()
 	for i := 0; i < len(this.numbers) && !seg.IsFull(); i++ {
 		if this.nextFlush[i] <= current {
 			seg.PutNumber(this.numbers[i], this.timestamps[i])
 			this.nextFlush[i] = current + rto/2
 		}
 	}
-	this.Unlock()
 	if seg.Count > 0 {
 		this.writer.Write(seg)
 		seg.Release()
@@ -216,14 +173,14 @@ func (this *AckList) Flush(current uint32, rto uint32) {
 }
 
 type ReceivingWorker struct {
-	conn        *Connection
-	queue       *ReceivingQueue
-	window      *ReceivingWindow
-	windowMutex sync.Mutex
-	acklist     *AckList
-	updated     bool
-	nextNumber  uint32
-	windowSize  uint32
+	sync.Mutex
+	conn       *Connection
+	queue      *ReceivingQueue
+	window     *ReceivingWindow
+	acklist    *AckList
+	updated    bool
+	nextNumber uint32
+	windowSize uint32
 }
 
 func NewReceivingWorker(kcp *Connection) *ReceivingWorker {
@@ -239,35 +196,35 @@ func NewReceivingWorker(kcp *Connection) *ReceivingWorker {
 }
 
 func (this *ReceivingWorker) ProcessSendingNext(number uint32) {
+	this.Lock()
+	defer this.Unlock()
+
 	this.acklist.Clear(number)
 }
 
 func (this *ReceivingWorker) ProcessSegment(seg *DataSegment) {
+	this.Lock()
+	defer this.Unlock()
+
 	number := seg.Number
 	idx := number - this.nextNumber
 	if idx >= this.windowSize {
 		return
 	}
-	this.ProcessSendingNext(seg.SendingNext)
+	this.acklist.Clear(seg.SendingNext)
 	this.acklist.Add(number, seg.Timestamp)
-	this.windowMutex.Lock()
-	defer this.windowMutex.Unlock()
 
 	if !this.window.Set(idx, seg) {
 		seg.Release()
 	}
 
-	for {
+	for !this.queue.IsFull() {
 		seg := this.window.RemoveFirst()
 		if seg == nil {
 			break
 		}
 
-		if !this.queue.Put(seg.Data) {
-			this.window.Set(0, seg)
-			break
-		}
-
+		this.queue.Put(seg.Data)
 		seg.Data = nil
 		seg.Release()
 		this.window.Advance()
@@ -276,15 +233,17 @@ func (this *ReceivingWorker) ProcessSegment(seg *DataSegment) {
 	}
 }
 
-func (this *ReceivingWorker) Read(b []byte) (int, error) {
-	return this.queue.Read(b)
-}
+func (this *ReceivingWorker) Read(b []byte) int {
+	this.Lock()
+	defer this.Unlock()
 
-func (this *ReceivingWorker) SetReadDeadline(t time.Time) {
-	this.queue.SetReadDeadline(t)
+	return this.queue.Read(b)
 }
 
 func (this *ReceivingWorker) Flush(current uint32) {
+	this.Lock()
+	defer this.Unlock()
+
 	this.acklist.Flush(current, this.conn.roundTrip.Timeout())
 }
 
@@ -301,6 +260,9 @@ func (this *ReceivingWorker) Write(seg Segment) {
 }
 
 func (this *ReceivingWorker) CloseRead() {
+	this.Lock()
+	defer this.Unlock()
+
 	this.queue.Close()
 }
 

+ 7 - 27
transport/internet/kcp/receiving_test.go

@@ -1,9 +1,7 @@
 package kcp_test
 
 import (
-	"io"
 	"testing"
-	"time"
 
 	"github.com/v2ray/v2ray-core/common/alloc"
 	"github.com/v2ray/v2ray-core/testing/assert"
@@ -42,35 +40,17 @@ func TestRecivingQueue(t *testing.T) {
 	assert := assert.On(t)
 
 	queue := NewReceivingQueue(2)
-	assert.Bool(queue.Put(alloc.NewSmallBuffer().Clear().AppendString("abcd"))).IsTrue()
-	assert.Bool(queue.Put(alloc.NewSmallBuffer().Clear().AppendString("efg"))).IsTrue()
-	assert.Bool(queue.Put(alloc.NewSmallBuffer().Clear().AppendString("more content"))).IsFalse()
+	queue.Put(alloc.NewSmallBuffer().Clear().AppendString("abcd"))
+	queue.Put(alloc.NewSmallBuffer().Clear().AppendString("efg"))
+	assert.Bool(queue.IsFull()).IsTrue()
 
 	b := make([]byte, 1024)
-	nBytes, err := queue.Read(b)
-	assert.Error(err).IsNil()
+	nBytes := queue.Read(b)
 	assert.Int(nBytes).Equals(7)
 	assert.String(string(b[:nBytes])).Equals("abcdefg")
 
-	assert.Bool(queue.Put(alloc.NewSmallBuffer().Clear().AppendString("1"))).IsTrue()
+	queue.Put(alloc.NewSmallBuffer().Clear().AppendString("1"))
 	queue.Close()
-	nBytes, err = queue.Read(b)
-	assert.Error(err).Equals(io.EOF)
-}
-
-func TestRecivingQueueTimeout(t *testing.T) {
-	assert := assert.On(t)
-
-	queue := NewReceivingQueue(2)
-	assert.Bool(queue.Put(alloc.NewSmallBuffer().Clear().AppendString("abcd"))).IsTrue()
-	queue.SetReadDeadline(time.Now().Add(time.Second))
-
-	b := make([]byte, 1024)
-	nBytes, err := queue.Read(b)
-	assert.Error(err).IsNil()
-	assert.Int(nBytes).Equals(4)
-	assert.String(string(b[:nBytes])).Equals("abcd")
-
-	nBytes, err = queue.Read(b)
-	assert.Error(err).IsNotNil()
+	nBytes = queue.Read(b)
+	assert.Int(nBytes).Equals(0)
 }