Explorar o código

implement WriteMultiBuffer

Darien Raymond %!s(int64=8) %!d(string=hai) anos
pai
achega
b3e6994e52

+ 49 - 4
transport/internet/kcp/connection.go

@@ -8,6 +8,7 @@ import (
 	"time"
 
 	"v2ray.com/core/app/log"
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/predicate"
 )
@@ -343,10 +344,15 @@ func (v *Connection) Write(b []byte) (int, error) {
 			return totalWritten, io.ErrClosedPipe
 		}
 
-		nBytes := v.sendingWorker.Push(b[totalWritten:])
-		v.dataUpdater.WakeUp()
-		if nBytes > 0 {
-			totalWritten += nBytes
+		for {
+			rb := v.sendingWorker.Push()
+			if rb == nil {
+				break
+			}
+			common.Must(rb.Reset(func(bb []byte) (int, error) {
+				return copy(bb[:v.mss], b[totalWritten:]), nil
+			}))
+			totalWritten += rb.Len()
 			if totalWritten == len(b) {
 				return totalWritten, nil
 			}
@@ -370,6 +376,45 @@ func (v *Connection) Write(b []byte) (int, error) {
 	}
 }
 
+func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
+	defer mb.Release()
+
+	for {
+		if v == nil || v.State() != StateActive {
+			return io.ErrClosedPipe
+		}
+
+		for {
+			rb := v.sendingWorker.Push()
+			if rb == nil {
+				break
+			}
+			common.Must(rb.Reset(func(bb []byte) (int, error) {
+				return mb.Read(bb[:v.mss])
+			}))
+			if mb.IsEmpty() {
+				return nil
+			}
+		}
+
+		duration := time.Minute
+		if !v.wd.IsZero() {
+			duration = time.Until(v.wd)
+			if duration < 0 {
+				return ErrIOTimeout
+			}
+		}
+
+		select {
+		case <-v.dataOutput:
+		case <-time.After(duration):
+			if !v.wd.IsZero() && v.wd.Before(time.Now()) {
+				return ErrIOTimeout
+			}
+		}
+	}
+}
+
 func (v *Connection) SetState(state State) {
 	current := v.Elapsed()
 	atomic.StoreInt32((*int32)(&v.state), int32(state))

+ 1 - 2
transport/internet/kcp/receiving.go

@@ -214,8 +214,7 @@ func (w *ReceivingWorker) ReadMultiBuffer() buf.MultiBuffer {
 		}
 		w.window.Advance()
 		w.nextNumber++
-		mb.Append(seg.Data)
-		seg.Data = nil
+		mb.Append(seg.Detach())
 		seg.Release()
 	}
 

+ 18 - 14
transport/internet/kcp/segment.go

@@ -1,7 +1,6 @@
 package kcp
 
 import (
-	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/serial"
 )
@@ -44,8 +43,8 @@ type DataSegment struct {
 	Timestamp   uint32
 	Number      uint32
 	SendingNext uint32
-	Data        *buf.Buffer
 
+	payload  *buf.Buffer
 	timeout  uint32
 	transmit uint32
 }
@@ -62,13 +61,17 @@ func (v *DataSegment) Command() Command {
 	return CommandData
 }
 
-func (v *DataSegment) SetData(data []byte) {
-	if v.Data == nil {
-		v.Data = buf.New()
+func (v *DataSegment) Detach() *buf.Buffer {
+	r := v.payload
+	v.payload = nil
+	return r
+}
+
+func (v *DataSegment) Data() *buf.Buffer {
+	if v.payload == nil {
+		v.payload = buf.New()
 	}
-	common.Must(v.Data.Reset(func(b []byte) (int, error) {
-		return copy(b, data), nil
-	}))
+	return v.payload
 }
 
 func (v *DataSegment) Bytes() buf.Supplier {
@@ -78,19 +81,19 @@ func (v *DataSegment) Bytes() buf.Supplier {
 		b = serial.Uint32ToBytes(v.Timestamp, b)
 		b = serial.Uint32ToBytes(v.Number, b)
 		b = serial.Uint32ToBytes(v.SendingNext, b)
-		b = serial.Uint16ToBytes(uint16(v.Data.Len()), b)
-		b = append(b, v.Data.Bytes()...)
+		b = serial.Uint16ToBytes(uint16(v.payload.Len()), b)
+		b = append(b, v.payload.Bytes()...)
 		return len(b), nil
 	}
 }
 
 func (v *DataSegment) ByteSize() int {
-	return 2 + 1 + 1 + 4 + 4 + 4 + 2 + v.Data.Len()
+	return 2 + 1 + 1 + 4 + 4 + 4 + 2 + v.payload.Len()
 }
 
 func (v *DataSegment) Release() {
-	v.Data.Release()
-	v.Data = nil
+	v.payload.Release()
+	v.payload = nil
 }
 
 type AckSegment struct {
@@ -233,7 +236,8 @@ func ReadSegment(buf []byte) (Segment, []byte) {
 		if len(buf) < dataLen {
 			return nil, nil
 		}
-		seg.SetData(buf[:dataLen])
+		seg.Data().Clear()
+		seg.Data().Append(buf[:dataLen])
 		buf = buf[dataLen:]
 
 		return seg, buf

+ 4 - 9
transport/internet/kcp/segment_test.go

@@ -3,7 +3,6 @@ package kcp_test
 import (
 	"testing"
 
-	"v2ray.com/core/common/buf"
 	. "v2ray.com/core/transport/internet/kcp"
 	. "v2ray.com/ext/assert"
 )
@@ -19,15 +18,13 @@ func TestBadSegment(t *testing.T) {
 func TestDataSegment(t *testing.T) {
 	assert := With(t)
 
-	b := buf.NewLocal(512)
-	b.Append([]byte{'a', 'b', 'c', 'd'})
 	seg := &DataSegment{
 		Conv:        1,
 		Timestamp:   3,
 		Number:      4,
 		SendingNext: 5,
-		Data:        b,
 	}
+	seg.Data().Append([]byte{'a', 'b', 'c', 'd'})
 
 	nBytes := seg.ByteSize()
 	bytes := make([]byte, nBytes)
@@ -41,21 +38,19 @@ func TestDataSegment(t *testing.T) {
 	assert(seg2.Timestamp, Equals, seg.Timestamp)
 	assert(seg2.SendingNext, Equals, seg.SendingNext)
 	assert(seg2.Number, Equals, seg.Number)
-	assert(seg2.Data.Bytes(), Equals, seg.Data.Bytes())
+	assert(seg2.Data().Bytes(), Equals, seg.Data().Bytes())
 }
 
 func Test1ByteDataSegment(t *testing.T) {
 	assert := With(t)
 
-	b := buf.NewLocal(512)
-	b.AppendBytes('a')
 	seg := &DataSegment{
 		Conv:        1,
 		Timestamp:   3,
 		Number:      4,
 		SendingNext: 5,
-		Data:        b,
 	}
+	seg.Data().AppendBytes('a')
 
 	nBytes := seg.ByteSize()
 	bytes := make([]byte, nBytes)
@@ -69,7 +64,7 @@ func Test1ByteDataSegment(t *testing.T) {
 	assert(seg2.Timestamp, Equals, seg.Timestamp)
 	assert(seg2.SendingNext, Equals, seg.SendingNext)
 	assert(seg2.Number, Equals, seg.Number)
-	assert(seg2.Data.Bytes(), Equals, seg.Data.Bytes())
+	assert(seg2.Data().Bytes(), Equals, seg.Data().Bytes())
 }
 
 func TestACKSegment(t *testing.T) {

+ 9 - 16
transport/internet/kcp/sending.go

@@ -2,6 +2,8 @@ package kcp
 
 import (
 	"sync"
+
+	"v2ray.com/core/common/buf"
 )
 
 type SendingWindow struct {
@@ -62,9 +64,8 @@ func (sw *SendingWindow) IsFull() bool {
 	return sw.len == sw.cap
 }
 
-func (sw *SendingWindow) Push(number uint32, data []byte) {
+func (sw *SendingWindow) Push(number uint32) *buf.Buffer {
 	pos := (sw.start + sw.len) % sw.cap
-	sw.data[pos].SetData(data)
 	sw.data[pos].Number = number
 	sw.data[pos].timeout = 0
 	sw.data[pos].transmit = 0
@@ -75,6 +76,7 @@ func (sw *SendingWindow) Push(number uint32, data []byte) {
 	}
 	sw.last = pos
 	sw.len++
+	return sw.data[pos].Data()
 }
 
 func (sw *SendingWindow) FirstNumber() uint32 {
@@ -224,7 +226,6 @@ func (v *SendingWorker) ProcessReceivingNextWithoutLock(nextNumber uint32) {
 	v.FindFirstUnacknowledged()
 }
 
-// Private: Visible for testing.
 func (v *SendingWorker) FindFirstUnacknowledged() {
 	first := v.firstUnacknowledged
 	if !v.window.IsEmpty() {
@@ -283,24 +284,16 @@ func (v *SendingWorker) ProcessSegment(current uint32, seg *AckSegment, rto uint
 	}
 }
 
-func (v *SendingWorker) Push(b []byte) int {
-	nBytes := 0
+func (v *SendingWorker) Push() *buf.Buffer {
 	v.Lock()
 	defer v.Unlock()
 
-	for len(b) > 0 && !v.window.IsFull() {
-		var size int
-		if len(b) > int(v.conn.mss) {
-			size = int(v.conn.mss)
-		} else {
-			size = len(b)
-		}
-		v.window.Push(v.nextNumber, b[:size])
+	if !v.window.IsFull() {
+		b := v.window.Push(v.nextNumber)
 		v.nextNumber++
-		b = b[size:]
-		nBytes += size
+		return b
 	}
-	return nBytes
+	return nil
 }
 
 // Private: Visible for testing.

+ 5 - 5
transport/internet/kcp/sending_test.go

@@ -11,9 +11,9 @@ func TestSendingWindow(t *testing.T) {
 	assert := With(t)
 
 	window := NewSendingWindow(5, nil, nil)
-	window.Push(0, []byte{})
-	window.Push(1, []byte{})
-	window.Push(2, []byte{})
+	window.Push(0)
+	window.Push(1)
+	window.Push(2)
 	assert(window.Len(), Equals, 3)
 
 	window.Remove(1)
@@ -27,11 +27,11 @@ func TestSendingWindow(t *testing.T) {
 	window.Remove(0)
 	assert(window.Len(), Equals, 0)
 
-	window.Push(4, []byte{})
+	window.Push(4)
 	assert(window.Len(), Equals, 1)
 	assert(window.FirstNumber(), Equals, uint32(4))
 
-	window.Push(5, []byte{})
+	window.Push(5)
 	assert(window.Len(), Equals, 2)
 
 	window.Remove(1)