Bläddra i källkod

typed segments

v2ray 9 år sedan
förälder
incheckning
6915095a0a

+ 4 - 0
common/serial/numbers.go

@@ -16,6 +16,10 @@ func Uint32ToBytes(value uint32, b []byte) []byte {
 	return append(b, byte(value>>24), byte(value>>16), byte(value>>8), byte(value))
 }
 
+func Uint32ToString(value uint32) string {
+	return strconv.FormatUint(uint64(value), 10)
+}
+
 func IntToBytes(value int, b []byte) []byte {
 	return append(b, byte(value>>24), byte(value>>16), byte(value>>8), byte(value))
 }

+ 19 - 1
testing/assert/pointer.go

@@ -1,6 +1,8 @@
 package assert
 
 import (
+	"reflect"
+
 	"github.com/v2ray/v2ray-core/common/serial"
 )
 
@@ -26,7 +28,15 @@ func (subject *PointerSubject) Equals(expectation interface{}) {
 }
 
 func (subject *PointerSubject) IsNil() {
-	if subject.value != nil {
+	if subject.value == nil {
+		return
+	}
+
+	valueType := reflect.TypeOf(subject.value)
+	nilType := reflect.Zero(valueType)
+	realValue := reflect.ValueOf(subject.value)
+
+	if nilType != realValue {
 		subject.Fail("is", "nil")
 	}
 }
@@ -35,4 +45,12 @@ func (subject *PointerSubject) IsNotNil() {
 	if subject.value == nil {
 		subject.Fail("is not", "nil")
 	}
+
+	valueType := reflect.TypeOf(subject.value)
+	nilType := reflect.Zero(valueType)
+	realValue := reflect.ValueOf(subject.value)
+
+	if nilType == realValue {
+		subject.Fail("is not", "nil")
+	}
 }

+ 50 - 0
testing/assert/uint32.go

@@ -0,0 +1,50 @@
+package assert
+
+import (
+	"github.com/v2ray/v2ray-core/common/serial"
+)
+
+func (this *Assert) Uint32(value uint32) *Uint32Subject {
+	return &Uint32Subject{
+		Subject: Subject{
+			a:    this,
+			disp: serial.Uint32ToString(value),
+		},
+		value: value,
+	}
+}
+
+type Uint32Subject struct {
+	Subject
+	value uint32
+}
+
+func (subject *Uint32Subject) Equals(expectation uint32) {
+	if subject.value != expectation {
+		subject.Fail("is equal to", serial.Uint32ToString(expectation))
+	}
+}
+
+func (subject *Uint32Subject) GreaterThan(expectation uint32) {
+	if subject.value <= expectation {
+		subject.Fail("is greater than", serial.Uint32ToString(expectation))
+	}
+}
+
+func (subject *Uint32Subject) LessThan(expectation uint32) {
+	if subject.value >= expectation {
+		subject.Fail("is less than", serial.Uint32ToString(expectation))
+	}
+}
+
+func (subject *Uint32Subject) IsPositive() {
+	if subject.value <= 0 {
+		subject.Fail("is", "positive")
+	}
+}
+
+func (subject *Uint32Subject) IsNegative() {
+	if subject.value >= 0 {
+		subject.Fail("is not", "negative")
+	}
+}

+ 171 - 0
transport/internet/kcp/segment.go

@@ -0,0 +1,171 @@
+package kcp
+
+import (
+	"github.com/v2ray/v2ray-core/common/alloc"
+	"github.com/v2ray/v2ray-core/common/serial"
+)
+
+type SegmentCommand byte
+
+const (
+	SegmentCommandACK        SegmentCommand = 0
+	SegmentCommandData       SegmentCommand = 1
+	SegmentCommandTerminated SegmentCommand = 2
+)
+
+type SegmentOption byte
+
+const (
+	SegmentOptionClose SegmentOption = 1
+)
+
+type ISegment interface {
+	ByteSize() int
+	Bytes([]byte) []byte
+}
+
+type DataSegment struct {
+	Conv            uint16
+	Opt             SegmentOption
+	ReceivingWindow uint32
+	Timestamp       uint32
+	Number          uint32
+	Unacknowledged  uint32
+	Data            *alloc.Buffer
+
+	timeout    uint32
+	ackSkipped uint32
+	transmit   uint32
+}
+
+func (this *DataSegment) Bytes(b []byte) []byte {
+	b = serial.Uint16ToBytes(this.Conv, b)
+	b = append(b, byte(SegmentCommandData), byte(this.Opt))
+	b = serial.Uint32ToBytes(this.ReceivingWindow, b)
+	b = serial.Uint32ToBytes(this.Timestamp, b)
+	b = serial.Uint32ToBytes(this.Number, b)
+	b = serial.Uint32ToBytes(this.Unacknowledged, b)
+	b = serial.Uint16ToBytes(uint16(this.Data.Len()), b)
+	b = append(b, this.Data.Value...)
+	return b
+}
+
+func (this *DataSegment) ByteSize() int {
+	return 2 + 1 + 1 + 4 + 4 + 4 + 4 + 2 + this.Data.Len()
+}
+
+type ACKSegment struct {
+	Conv            uint16
+	Opt             SegmentOption
+	ReceivingWindow uint32
+	Unacknowledged  uint32
+	Count           byte
+	NumberList      []uint32
+	TimestampList   []uint32
+}
+
+func (this *ACKSegment) ByteSize() int {
+	return 2 + 1 + 1 + 4 + 4 + 1 + len(this.NumberList)*4 + len(this.TimestampList)*4
+}
+
+func (this *ACKSegment) Bytes(b []byte) []byte {
+	b = serial.Uint16ToBytes(this.Conv, b)
+	b = append(b, byte(SegmentCommandACK), byte(this.Opt))
+	b = serial.Uint32ToBytes(this.ReceivingWindow, b)
+	b = serial.Uint32ToBytes(this.Unacknowledged, b)
+	b = append(b, this.Count)
+	for i := byte(0); i < this.Count; i++ {
+		b = serial.Uint32ToBytes(this.NumberList[i], b)
+		b = serial.Uint32ToBytes(this.TimestampList[i], b)
+	}
+	return b
+}
+
+type TerminationSegment struct {
+	Conv uint16
+	Opt  SegmentOption
+}
+
+func (this *TerminationSegment) ByteSize() int {
+	return 2 + 1 + 1
+}
+
+func (this *TerminationSegment) Bytes(b []byte) []byte {
+	b = serial.Uint16ToBytes(this.Conv, b)
+	b = append(b, byte(SegmentCommandTerminated), byte(this.Opt))
+	return b
+}
+
+func ReadSegment(buf []byte) (ISegment, []byte) {
+	if len(buf) <= 12 {
+		return nil, nil
+	}
+
+	conv := serial.BytesToUint16(buf)
+	buf = buf[2:]
+
+	cmd := SegmentCommand(buf[0])
+	opt := SegmentOption(buf[1])
+	buf = buf[2:]
+
+	if cmd == SegmentCommandData {
+		seg := &DataSegment{
+			Conv: conv,
+			Opt:  opt,
+		}
+		seg.ReceivingWindow = serial.BytesToUint32(buf)
+		buf = buf[4:]
+
+		seg.Timestamp = serial.BytesToUint32(buf)
+		buf = buf[4:]
+
+		seg.Number = serial.BytesToUint32(buf)
+		buf = buf[4:]
+
+		seg.Unacknowledged = serial.BytesToUint32(buf)
+		buf = buf[4:]
+
+		len := serial.BytesToUint16(buf)
+		buf = buf[2:]
+
+		seg.Data = alloc.NewSmallBuffer().Clear().Append(buf[:len])
+		buf = buf[len:]
+
+		return seg, buf
+	}
+
+	if cmd == SegmentCommandACK {
+		seg := &ACKSegment{
+			Conv: conv,
+			Opt:  opt,
+		}
+		seg.ReceivingWindow = serial.BytesToUint32(buf)
+		buf = buf[4:]
+
+		seg.Unacknowledged = serial.BytesToUint32(buf)
+		buf = buf[4:]
+
+		seg.Count = buf[0]
+		buf = buf[1:]
+
+		seg.NumberList = make([]uint32, 0, seg.Count)
+		seg.TimestampList = make([]uint32, 0, seg.Count)
+
+		for i := 0; i < int(seg.Count); i++ {
+			seg.NumberList = append(seg.NumberList, serial.BytesToUint32(buf))
+			seg.TimestampList = append(seg.TimestampList, serial.BytesToUint32(buf[4:]))
+			buf = buf[8:]
+		}
+
+		return seg, buf
+	}
+
+	if cmd == SegmentCommandTerminated {
+		return &TerminationSegment{
+			Conv: conv,
+			Opt:  opt,
+		}, buf
+	}
+
+	return nil, nil
+}

+ 73 - 0
transport/internet/kcp/segment_test.go

@@ -0,0 +1,73 @@
+package kcp_test
+
+import (
+	"testing"
+
+	"github.com/v2ray/v2ray-core/common/alloc"
+	"github.com/v2ray/v2ray-core/testing/assert"
+	. "github.com/v2ray/v2ray-core/transport/internet/kcp"
+)
+
+func TestBadSegment(t *testing.T) {
+	assert := assert.On(t)
+
+	seg, buf := ReadSegment(nil)
+	assert.Pointer(seg).IsNil()
+	assert.Int(len(buf)).Equals(0)
+}
+
+func TestDataSegment(t *testing.T) {
+	assert := assert.On(t)
+
+	seg := &DataSegment{
+		Conv:            1,
+		ReceivingWindow: 2,
+		Timestamp:       3,
+		Number:          4,
+		Unacknowledged:  5,
+		Data:            alloc.NewSmallBuffer().Clear().Append([]byte{'a', 'b', 'c', 'd'}),
+	}
+
+	nBytes := seg.ByteSize()
+	bytes := seg.Bytes(nil)
+
+	assert.Int(len(bytes)).Equals(nBytes)
+
+	iseg, _ := ReadSegment(bytes)
+	seg2 := iseg.(*DataSegment)
+	assert.Uint16(seg2.Conv).Equals(seg.Conv)
+	assert.Uint32(seg2.ReceivingWindow).Equals(seg.ReceivingWindow)
+	assert.Uint32(seg2.Timestamp).Equals(seg.Timestamp)
+	assert.Uint32(seg2.Unacknowledged).Equals(seg.Unacknowledged)
+	assert.Uint32(seg2.Number).Equals(seg.Number)
+	assert.Bytes(seg2.Data.Value).Equals(seg.Data.Value)
+}
+
+func TestACKSegment(t *testing.T) {
+	assert := assert.On(t)
+
+	seg := &ACKSegment{
+		Conv:            1,
+		ReceivingWindow: 2,
+		Unacknowledged:  3,
+		Count:           5,
+		NumberList:      []uint32{1, 3, 5, 7, 9},
+		TimestampList:   []uint32{2, 4, 6, 8, 10},
+	}
+
+	nBytes := seg.ByteSize()
+	bytes := seg.Bytes(nil)
+
+	assert.Int(len(bytes)).Equals(nBytes)
+
+	iseg, _ := ReadSegment(bytes)
+	seg2 := iseg.(*ACKSegment)
+	assert.Uint16(seg2.Conv).Equals(seg.Conv)
+	assert.Uint32(seg2.ReceivingWindow).Equals(seg.ReceivingWindow)
+	assert.Uint32(seg2.Unacknowledged).Equals(seg.Unacknowledged)
+	assert.Byte(seg2.Count).Equals(seg.Count)
+	for i := byte(0); i < seg2.Count; i++ {
+		assert.Uint32(seg2.TimestampList[i]).Equals(seg.TimestampList[i])
+		assert.Uint32(seg2.NumberList[i]).Equals(seg.NumberList[i])
+	}
+}