v2ray 9 роки тому
батько
коміт
a7f61af79b
3 змінених файлів з 198 додано та 45 видалено
  1. 109 0
      proxy/vmess/io/io_test.go
  2. 89 22
      proxy/vmess/io/reader.go
  3. 0 23
      proxy/vmess/io/writer_test.go

+ 109 - 0
proxy/vmess/io/io_test.go

@@ -0,0 +1,109 @@
+package io_test
+
+import (
+	"bytes"
+	"crypto/rand"
+	"io"
+	"testing"
+
+	"github.com/v2ray/v2ray-core/common/alloc"
+	v2io "github.com/v2ray/v2ray-core/common/io"
+	. "github.com/v2ray/v2ray-core/proxy/vmess/io"
+	v2testing "github.com/v2ray/v2ray-core/testing"
+	"github.com/v2ray/v2ray-core/testing/assert"
+)
+
+func TestAuthenticate(t *testing.T) {
+	v2testing.Current(t)
+
+	buffer := alloc.NewBuffer().Clear()
+	buffer.AppendBytes(1, 2, 3, 4)
+	Authenticate(buffer)
+	assert.Bytes(buffer.Value).Equals([]byte{0, 8, 87, 52, 168, 125, 1, 2, 3, 4})
+
+	b2, err := NewAuthChunkReader(buffer).Read()
+	assert.Error(err).IsNil()
+	assert.Bytes(b2.Value).Equals([]byte{1, 2, 3, 4})
+}
+
+func TestSingleIO(t *testing.T) {
+	v2testing.Current(t)
+
+	content := bytes.NewBuffer(make([]byte, 0, 1024*1024))
+
+	writer := NewAuthChunkWriter(v2io.NewAdaptiveWriter(content))
+	writer.Write(alloc.NewBuffer().Clear().AppendString("abcd"))
+	writer.Release()
+
+	reader := NewAuthChunkReader(content)
+	buffer, err := reader.Read()
+	assert.Error(err).IsNil()
+	assert.Bytes(buffer.Value).Equals([]byte("abcd"))
+}
+
+func TestLargeIO(t *testing.T) {
+	v2testing.Current(t)
+
+	content := make([]byte, 1024*1024)
+	rand.Read(content)
+
+	chunckContent := bytes.NewBuffer(make([]byte, 0, len(content)*2))
+	writer := NewAuthChunkWriter(v2io.NewAdaptiveWriter(chunckContent))
+	writeSize := 0
+	for {
+		chunkSize := 7 * 1024
+		if chunkSize+writeSize > len(content) {
+			chunkSize = len(content) - writeSize
+		}
+		writer.Write(alloc.NewBuffer().Clear().Append(content[writeSize : writeSize+chunkSize]))
+		writeSize += chunkSize
+		if writeSize == len(content) {
+			break
+		}
+
+		chunkSize = 8 * 1024
+		if chunkSize+writeSize > len(content) {
+			chunkSize = len(content) - writeSize
+		}
+		writer.Write(alloc.NewLargeBuffer().Clear().Append(content[writeSize : writeSize+chunkSize]))
+		writeSize += chunkSize
+		if writeSize == len(content) {
+			break
+		}
+
+		chunkSize = 63 * 1024
+		if chunkSize+writeSize > len(content) {
+			chunkSize = len(content) - writeSize
+		}
+		writer.Write(alloc.NewLargeBuffer().Clear().Append(content[writeSize : writeSize+chunkSize]))
+		writeSize += chunkSize
+		if writeSize == len(content) {
+			break
+		}
+
+		chunkSize = 64*1024 - 16
+		if chunkSize+writeSize > len(content) {
+			chunkSize = len(content) - writeSize
+		}
+		writer.Write(alloc.NewLargeBuffer().Clear().Append(content[writeSize : writeSize+chunkSize]))
+		writeSize += chunkSize
+		if writeSize == len(content) {
+			break
+		}
+	}
+	writer.Release()
+
+	actualContent := make([]byte, 0, len(content))
+	reader := NewAuthChunkReader(chunckContent)
+	for {
+		buffer, err := reader.Read()
+		if err == io.EOF {
+			break
+		}
+		assert.Error(err).IsNil()
+		actualContent = append(actualContent, buffer.Value...)
+	}
+
+	assert.Int(len(actualContent)).Equals(len(content))
+	assert.Bytes(actualContent).Equals(content)
+}

+ 89 - 22
proxy/vmess/io/reader.go

@@ -1,6 +1,7 @@
 package io
 
 import (
+	"hash"
 	"hash/fnv"
 	"io"
 
@@ -9,49 +10,115 @@ import (
 	"github.com/v2ray/v2ray-core/transport"
 )
 
+// @Private
+func AllocBuffer(size int) *alloc.Buffer {
+	if size < 8*1024-16 {
+		return alloc.NewBuffer()
+	}
+	return alloc.NewLargeBuffer()
+}
+
+// @Private
+type Validator struct {
+	actualAuth   hash.Hash32
+	expectedAuth uint32
+}
+
+func NewValidator(expectedAuth uint32) *Validator {
+	return &Validator{
+		actualAuth:   fnv.New32a(),
+		expectedAuth: expectedAuth,
+	}
+}
+
+func (this *Validator) Consume(b []byte) {
+	this.actualAuth.Write(b)
+}
+
+func (this *Validator) Validate() bool {
+	return this.actualAuth.Sum32() == this.expectedAuth
+}
+
 type AuthChunkReader struct {
-	reader io.Reader
+	reader      io.Reader
+	last        *alloc.Buffer
+	chunkLength int
+	validator   *Validator
 }
 
 func NewAuthChunkReader(reader io.Reader) *AuthChunkReader {
 	return &AuthChunkReader{
-		reader: reader,
+		reader:      reader,
+		chunkLength: -1,
 	}
 }
 
 func (this *AuthChunkReader) Read() (*alloc.Buffer, error) {
-	buffer := alloc.NewBuffer()
-	if _, err := io.ReadFull(this.reader, buffer.Value[:2]); err != nil {
+	var buffer *alloc.Buffer
+	if this.last != nil {
+		buffer = this.last
+		this.last = nil
+	} else {
+		buffer = AllocBuffer(this.chunkLength).Clear()
+	}
+
+	_, err := buffer.FillFrom(this.reader)
+	if err != nil {
 		buffer.Release()
 		return nil, err
 	}
 
-	length := serial.BytesLiteral(buffer.Value[:2]).Uint16Value()
-	if length <= 4 { // Length of authentication bytes.
-		return nil, io.EOF
+	if this.chunkLength == -1 {
+		for buffer.Len() < 6 {
+			_, err := buffer.FillFrom(this.reader)
+			if err != nil {
+				buffer.Release()
+				return nil, err
+			}
+		}
+		length := serial.BytesLiteral(buffer.Value[:2]).Uint16Value()
+		this.chunkLength = int(length) - 4
+		this.validator = NewValidator(serial.BytesLiteral(buffer.Value[2:6]).Uint32Value())
+		buffer.SliceFrom(6)
 	}
-	if length > 8*1024-16 {
-		buffer.Release()
-		buffer = alloc.NewLargeBuffer()
-	}
-	if _, err := io.ReadFull(this.reader, buffer.Value[:length]); err != nil {
+
+	if this.chunkLength == 0 {
 		buffer.Release()
-		return nil, err
+		return nil, io.EOF
 	}
-	buffer.Slice(0, int(length))
 
-	fnvHash := fnv.New32a()
-	fnvHash.Write(buffer.Value[4:])
-	expAuth := serial.BytesLiteral(fnvHash.Sum(nil))
-	actualAuth := serial.BytesLiteral(buffer.Value[:4])
-	if !actualAuth.Equals(expAuth) {
-		buffer.Release()
-		return nil, transport.ErrorCorruptedPacket
+	if buffer.Len() <= this.chunkLength {
+		this.validator.Consume(buffer.Value)
+		this.chunkLength -= buffer.Len()
+		if this.chunkLength == 0 {
+			if !this.validator.Validate() {
+				buffer.Release()
+				return nil, transport.ErrorCorruptedPacket
+			}
+			this.chunkLength = -1
+			this.validator = nil
+		}
+	} else {
+		this.validator.Consume(buffer.Value[:this.chunkLength])
+		if !this.validator.Validate() {
+			buffer.Release()
+			return nil, transport.ErrorCorruptedPacket
+		}
+		leftLength := buffer.Len() - this.chunkLength
+		this.last = AllocBuffer(leftLength).Clear()
+		this.last.Append(buffer.Value[this.chunkLength:])
+		buffer.Slice(0, this.chunkLength)
+
+		this.chunkLength = -1
+		this.validator = nil
 	}
-	buffer.SliceFrom(4)
+
 	return buffer, nil
 }
 
 func (this *AuthChunkReader) Release() {
 	this.reader = nil
+	this.last.Release()
+	this.last = nil
+	this.validator = nil
 }

+ 0 - 23
proxy/vmess/io/writer_test.go

@@ -1,23 +0,0 @@
-package io_test
-
-import (
-	"testing"
-
-	"github.com/v2ray/v2ray-core/common/alloc"
-	. "github.com/v2ray/v2ray-core/proxy/vmess/io"
-	v2testing "github.com/v2ray/v2ray-core/testing"
-	"github.com/v2ray/v2ray-core/testing/assert"
-)
-
-func TestAuthenticate(t *testing.T) {
-	v2testing.Current(t)
-
-	buffer := alloc.NewBuffer().Clear()
-	buffer.AppendBytes(1, 2, 3, 4)
-	Authenticate(buffer)
-	assert.Bytes(buffer.Value).Equals([]byte{0, 8, 87, 52, 168, 125, 1, 2, 3, 4})
-
-	b2, err := NewAuthChunkReader(buffer).Read()
-	assert.Error(err).IsNil()
-	assert.Bytes(b2.Value).Equals([]byte{1, 2, 3, 4})
-}