Bladeren bron

chunk stream in vmess

v2ray 9 jaren geleden
bovenliggende
commit
d3ff2b3698

+ 6 - 13
common/alloc/buffer.go

@@ -9,19 +9,6 @@ const (
 	DefaultOffset = 16
 )
 
-func Release(buffer *Buffer) {
-	if buffer != nil {
-		buffer.Release()
-	}
-}
-
-func Len(buffer *Buffer) int {
-	if buffer == nil {
-		return 0
-	}
-	return buffer.Len()
-}
-
 // Buffer is a recyclable allocation of a byte array. Buffer.Release() recycles
 // the buffer into an internal buffer pool, in order to recreate a buffer more
 // quickly.
@@ -34,6 +21,9 @@ type Buffer struct {
 
 // Release recycles the buffer into an internal buffer pool.
 func (b *Buffer) Release() {
+	if b == nil {
+		return
+	}
 	b.pool.free(b)
 	b.head = nil
 	b.Value = nil
@@ -96,6 +86,9 @@ func (b *Buffer) SliceBack(offset int) *Buffer {
 
 // Len returns the length of the buffer content.
 func (b *Buffer) Len() int {
+	if b == nil {
+		return 0
+	}
 	return len(b.Value)
 }
 

+ 1 - 1
common/io/reader.go

@@ -47,7 +47,7 @@ func (this *AdaptiveReader) Read() (*alloc.Buffer, error) {
 	}
 
 	if err != nil {
-		alloc.Release(buffer)
+		buffer.Release()
 		return nil, err
 	}
 	return buffer, nil

+ 8 - 7
common/io/transport.go

@@ -14,10 +14,10 @@ func RawReaderToChan(stream chan<- *alloc.Buffer, reader io.Reader) error {
 func ReaderToChan(stream chan<- *alloc.Buffer, reader Reader) error {
 	for {
 		buffer, err := reader.Read()
-		if alloc.Len(buffer) > 0 {
+		if buffer.Len() > 0 {
 			stream <- buffer
 		} else {
-			alloc.Release(buffer)
+			buffer.Release()
 		}
 
 		if err != nil {
@@ -26,13 +26,14 @@ func ReaderToChan(stream chan<- *alloc.Buffer, reader Reader) error {
 	}
 }
 
+func ChanToRawWriter(writer io.Writer, stream <-chan *alloc.Buffer) error {
+	return ChanToWriter(NewAdaptiveWriter(writer), stream)
+}
+
 // ChanToWriter dumps all content from a given chan to a writer until the chan is closed.
-func ChanToWriter(writer io.Writer, stream <-chan *alloc.Buffer) error {
+func ChanToWriter(writer Writer, stream <-chan *alloc.Buffer) error {
 	for buffer := range stream {
-		nBytes, err := writer.Write(buffer.Value)
-		if nBytes < buffer.Len() {
-			_, err = writer.Write(buffer.Value[nBytes:])
-		}
+		err := writer.Write(buffer)
 		buffer.Release()
 		if err != nil {
 			return err

+ 1 - 1
common/io/transport_test.go

@@ -30,7 +30,7 @@ func TestReaderAndWrite(t *testing.T) {
 	assert.Error(err).Equals(io.EOF)
 	close(transportChan)
 
-	err = ChanToWriter(writerBuffer, transportChan)
+	err = ChanToRawWriter(writerBuffer, transportChan)
 	assert.Error(err).IsNil()
 
 	assert.Bytes(buffer).Equals(writerBuffer.Bytes())

+ 29 - 0
common/io/writer.go

@@ -0,0 +1,29 @@
+package io
+
+import (
+	"io"
+
+	"github.com/v2ray/v2ray-core/common/alloc"
+)
+
+type Writer interface {
+	Write(*alloc.Buffer) error
+}
+
+type AdaptiveWriter struct {
+	writer io.Writer
+}
+
+func NewAdaptiveWriter(writer io.Writer) *AdaptiveWriter {
+	return &AdaptiveWriter{
+		writer: writer,
+	}
+}
+
+func (this *AdaptiveWriter) Write(buffer *alloc.Buffer) error {
+	nBytes, err := this.writer.Write(buffer.Value)
+	if nBytes < buffer.Len() {
+		_, err = this.writer.Write(buffer.Value[nBytes:])
+	}
+	return err
+}

+ 1 - 1
proxy/blackhole/blackhole.go

@@ -26,7 +26,7 @@ func (this *BlackHole) Dispatch(firstPacket v2net.Packet, ray ray.OutboundRay) e
 
 	close(ray.OutboundOutput())
 	if firstPacket.MoreChunks() {
-		v2io.ChanToWriter(ioutil.Discard, ray.OutboundInput())
+		v2io.ChanToRawWriter(ioutil.Discard, ray.OutboundInput())
 	}
 	return nil
 }

+ 1 - 1
proxy/dokodemo/dokodemo.go

@@ -147,6 +147,6 @@ func dumpInput(reader io.Reader, input chan<- *alloc.Buffer, finish *sync.Mutex)
 }
 
 func dumpOutput(writer io.Writer, output <-chan *alloc.Buffer, finish *sync.Mutex) {
-	v2io.ChanToWriter(writer, output)
+	v2io.ChanToRawWriter(writer, output)
 	finish.Unlock()
 }

+ 1 - 1
proxy/freedom/freedom.go

@@ -49,7 +49,7 @@ func (this *FreedomConnection) Dispatch(firstPacket v2net.Packet, ray ray.Outbou
 		writeMutex.Unlock()
 	} else {
 		go func() {
-			v2io.ChanToWriter(conn, input)
+			v2io.ChanToRawWriter(conn, input)
 			writeMutex.Unlock()
 		}()
 	}

+ 1 - 1
proxy/http/http.go

@@ -160,7 +160,7 @@ func (this *HttpProxyServer) transport(input io.Reader, output io.Writer, ray ra
 	}()
 
 	go func() {
-		v2io.ChanToWriter(output, ray.InboundOutput())
+		v2io.ChanToRawWriter(output, ray.InboundOutput())
 		wg.Done()
 	}()
 }

+ 3 - 3
proxy/shadowsocks/ota.go

@@ -69,14 +69,14 @@ func NewChunkReader(reader io.Reader, auth *Authenticator) *ChunkReader {
 func (this *ChunkReader) Read() (*alloc.Buffer, error) {
 	buffer := alloc.NewLargeBuffer()
 	if _, err := io.ReadFull(this.reader, buffer.Value[:2]); err != nil {
-		alloc.Release(buffer)
+		buffer.Release()
 		return nil, err
 	}
 	// There is a potential buffer overflow here. Large buffer is 64K bytes,
 	// while uin16 + 10 will be more than that
 	length := serial.BytesLiteral(buffer.Value[:2]).Uint16Value() + AuthSize
 	if _, err := io.ReadFull(this.reader, buffer.Value[:length]); err != nil {
-		alloc.Release(buffer)
+		buffer.Release()
 		return nil, err
 	}
 	buffer.Slice(0, int(length))
@@ -86,7 +86,7 @@ func (this *ChunkReader) Read() (*alloc.Buffer, error) {
 
 	actualAuthBytes := this.auth.Authenticate(nil, payload)
 	if !serial.BytesLiteral(authBytes).Equals(serial.BytesLiteral(actualAuthBytes)) {
-		alloc.Release(buffer)
+		buffer.Release()
 		log.Debug("AuthenticationReader: Unexpected auth: ", authBytes)
 		return nil, transport.ErrorCorruptedPacket
 	}

+ 1 - 1
proxy/shadowsocks/shadowsocks.go

@@ -199,7 +199,7 @@ func (this *Shadowsocks) handleConnection(conn *hub.TCPConn) {
 
 			writer.Write(payload.Value)
 			payload.Release()
-			v2io.ChanToWriter(writer, ray.InboundOutput())
+			v2io.ChanToRawWriter(writer, ray.InboundOutput())
 		}
 		writeFinish.Unlock()
 	}()

+ 1 - 1
proxy/socks/socks.go

@@ -277,7 +277,7 @@ func (this *SocksServer) transport(reader io.Reader, writer io.Writer, firstPack
 	}()
 
 	go func() {
-		v2io.ChanToWriter(writer, output)
+		v2io.ChanToRawWriter(writer, output)
 		outputFinish.Unlock()
 	}()
 	outputFinish.Lock()

+ 1 - 1
proxy/testing/mocks/inboundhandler.go

@@ -48,7 +48,7 @@ func (this *InboundConnectionHandler) Communicate(packet v2net.Packet) error {
 	}()
 
 	go func() {
-		v2io.ChanToWriter(this.ConnOutput, output)
+		v2io.ChanToRawWriter(this.ConnOutput, output)
 		writeFinish.Unlock()
 	}()
 

+ 1 - 1
proxy/testing/mocks/outboundhandler.go

@@ -33,7 +33,7 @@ func (this *OutboundConnectionHandler) Dispatch(packet v2net.Packet, ray ray.Out
 		writeFinish.Lock()
 
 		go func() {
-			v2io.ChanToWriter(this.ConnOutput, input)
+			v2io.ChanToRawWriter(this.ConnOutput, input)
 			writeFinish.Unlock()
 		}()
 

+ 21 - 8
proxy/vmess/inbound/inbound.go

@@ -17,6 +17,7 @@ import (
 	"github.com/v2ray/v2ray-core/proxy"
 	"github.com/v2ray/v2ray-core/proxy/internal"
 	"github.com/v2ray/v2ray-core/proxy/vmess"
+	vmessio "github.com/v2ray/v2ray-core/proxy/vmess/io"
 	"github.com/v2ray/v2ray-core/proxy/vmess/protocol"
 	"github.com/v2ray/v2ray-core/transport/hub"
 )
@@ -119,10 +120,21 @@ func (this *VMessInboundHandler) HandleConnection(connection *hub.TCPConn) {
 	this.generateCommand(buffer)
 
 	if data, open := <-output; open {
+		if request.IsChunkStream() {
+			vmessio.Authenticate(data)
+		}
 		buffer.Append(data.Value)
 		data.Release()
 		responseWriter.Write(buffer.Value)
-		go handleOutput(request, responseWriter, output, &writeFinish)
+		go func(finish *sync.Mutex) {
+			var writer v2io.Writer
+			writer = v2io.NewAdaptiveWriter(responseWriter)
+			if request.IsChunkStream() {
+				writer = vmessio.NewAuthChunkWriter(writer)
+			}
+			v2io.ChanToWriter(writer, output)
+			finish.Unlock()
+		}(&writeFinish)
 		writeFinish.Lock()
 	}
 
@@ -139,13 +151,14 @@ func handleInput(request *protocol.VMessRequest, reader io.Reader, input chan<-
 		log.Error("VMessIn: Failed to create AES decryption stream: ", err)
 		return
 	}
-	requestReader := v2crypto.NewCryptionReader(aesStream, reader)
-	v2io.RawReaderToChan(input, requestReader)
-}
-
-func handleOutput(request *protocol.VMessRequest, writer io.Writer, output <-chan *alloc.Buffer, finish *sync.Mutex) {
-	v2io.ChanToWriter(writer, output)
-	finish.Unlock()
+	descriptionReader := v2crypto.NewCryptionReader(aesStream, reader)
+	var requestReader v2io.Reader
+	if request.IsChunkStream() {
+		requestReader = vmessio.NewAuthChunkReader(descriptionReader)
+	} else {
+		requestReader = v2io.NewAdaptiveReader(descriptionReader)
+	}
+	v2io.ReaderToChan(input, requestReader)
 }
 
 func init() {

+ 46 - 0
proxy/vmess/io/reader.go

@@ -0,0 +1,46 @@
+package io
+
+import (
+	"hash/fnv"
+	"io"
+
+	"github.com/v2ray/v2ray-core/common/alloc"
+	"github.com/v2ray/v2ray-core/common/serial"
+	"github.com/v2ray/v2ray-core/transport"
+)
+
+type AuthChunkReader struct {
+	reader io.Reader
+}
+
+func NewAuthChunkReader(reader io.Reader) *AuthChunkReader {
+	return &AuthChunkReader{
+		reader: reader,
+	}
+}
+
+func (this *AuthChunkReader) Read() (*alloc.Buffer, error) {
+	buffer := alloc.NewBuffer()
+	if _, err := io.ReadFull(this.reader, buffer.Value[:2]); err != nil {
+		buffer.Release()
+		return nil, err
+	}
+
+	length := serial.BytesLiteral(buffer.Value[:2]).Uint16Value()
+	if _, err := io.ReadFull(this.reader, buffer.Value[:length]); err != nil {
+		buffer.Release()
+		return nil, err
+	}
+	buffer.Slice(0, int(length))
+
+	fnvHash := fnv.New32a()
+	fnvHash.Write(buffer.Value[4:])
+	expAuth := serial.BytesLiteral(fnvHash.Sum(nil))
+	actualAuth := serial.BytesLiteral(buffer.Value[:4])
+	if !actualAuth.Equals(expAuth) {
+		buffer.Release()
+		return nil, transport.ErrorCorruptedPacket
+	}
+	buffer.SliceFrom(4)
+	return buffer, nil
+}

+ 34 - 0
proxy/vmess/io/writer.go

@@ -0,0 +1,34 @@
+package io
+
+import (
+	"hash/fnv"
+
+	"github.com/v2ray/v2ray-core/common/alloc"
+	v2io "github.com/v2ray/v2ray-core/common/io"
+	"github.com/v2ray/v2ray-core/common/serial"
+)
+
+type AuthChunkWriter struct {
+	writer v2io.Writer
+}
+
+func NewAuthChunkWriter(writer v2io.Writer) *AuthChunkWriter {
+	return &AuthChunkWriter{
+		writer: writer,
+	}
+}
+
+func (this *AuthChunkWriter) Write(buffer *alloc.Buffer) error {
+	Authenticate(buffer)
+	return this.writer.Write(buffer)
+}
+
+func Authenticate(buffer *alloc.Buffer) {
+	fnvHash := fnv.New32a()
+	fnvHash.Write(buffer.Value)
+
+	buffer.SliceBack(4)
+	fnvHash.Sum(buffer.Value[:0])
+
+	buffer.Prepend(serial.Uint16Literal(uint16(buffer.Len())).Bytes())
+}

+ 23 - 0
proxy/vmess/io/writer_test.go

@@ -0,0 +1,23 @@
+package io_test
+
+import (
+	"testing"
+
+	"github.com/v2ray/v2ray-core/common/alloc"
+	. "github.com/v2ray/v2ray-core/proxy/vmess/io"
+	v2testing "github.com/v2ray/v2ray-core/testing"
+	"github.com/v2ray/v2ray-core/testing/assert"
+)
+
+func TestAuthenticate(t *testing.T) {
+	v2testing.Current(t)
+
+	buffer := alloc.NewBuffer().Clear()
+	buffer.AppendBytes(1, 2, 3, 4)
+	Authenticate(buffer)
+	assert.Bytes(buffer.Value).Equals([]byte{0, 8, 87, 52, 168, 125, 1, 2, 3, 4})
+
+	b2, err := NewAuthChunkReader(buffer).Read()
+	assert.Error(err).IsNil()
+	assert.Bytes(b2.Value).Equals([]byte{1, 2, 3, 4})
+}

+ 34 - 20
proxy/vmess/outbound/outbound.go

@@ -16,6 +16,7 @@ import (
 	v2net "github.com/v2ray/v2ray-core/common/net"
 	"github.com/v2ray/v2ray-core/proxy"
 	"github.com/v2ray/v2ray-core/proxy/internal"
+	vmessio "github.com/v2ray/v2ray-core/proxy/vmess/io"
 	"github.com/v2ray/v2ray-core/proxy/vmess/protocol"
 	"github.com/v2ray/v2ray-core/transport/ray"
 )
@@ -38,6 +39,9 @@ func (this *VMessOutboundHandler) Dispatch(firstPacket v2net.Packet, ray ray.Out
 		Address: firstPacket.Destination().Address(),
 		Port:    firstPacket.Destination().Port(),
 	}
+	if command == protocol.CmdUDP {
+		request.Option |= protocol.OptionChunk
+	}
 
 	buffer := alloc.NewSmallBuffer()
 	defer buffer.Release()                      // Buffer is released after communication finishes.
@@ -83,7 +87,7 @@ func (this *VMessOutboundHandler) startCommunicate(request *protocol.VMessReques
 	responseFinish.Lock()
 
 	go this.handleRequest(conn, request, firstPacket, input, &requestFinish)
-	go this.handleResponse(conn, request, dest, output, &responseFinish, (request.Command == protocol.CmdUDP))
+	go this.handleResponse(conn, request, dest, output, &responseFinish)
 
 	requestFinish.Lock()
 	conn.CloseWrite()
@@ -121,6 +125,10 @@ func (this *VMessOutboundHandler) handleRequest(conn net.Conn, request *protocol
 		return
 	}
 
+	if request.IsChunkStream() {
+		vmessio.Authenticate(firstChunk)
+	}
+
 	aesStream.XORKeyStream(firstChunk.Value, firstChunk.Value)
 	buffer.Append(firstChunk.Value)
 	firstChunk.Release()
@@ -132,7 +140,12 @@ func (this *VMessOutboundHandler) handleRequest(conn net.Conn, request *protocol
 	}
 
 	if moreChunks {
-		v2io.ChanToWriter(encryptRequestWriter, input)
+		var streamWriter v2io.Writer
+		streamWriter = v2io.NewAdaptiveWriter(encryptRequestWriter)
+		if request.IsChunkStream() {
+			streamWriter = vmessio.NewAuthChunkWriter(streamWriter)
+		}
+		v2io.ChanToWriter(streamWriter, input)
 	}
 	return
 }
@@ -141,7 +154,7 @@ func headerMatch(request *protocol.VMessRequest, responseHeader byte) bool {
 	return request.ResponseHeader == responseHeader
 }
 
-func (this *VMessOutboundHandler) handleResponse(conn net.Conn, request *protocol.VMessRequest, dest v2net.Destination, output chan<- *alloc.Buffer, finish *sync.Mutex, isUDP bool) {
+func (this *VMessOutboundHandler) handleResponse(conn net.Conn, request *protocol.VMessRequest, dest v2net.Destination, output chan<- *alloc.Buffer, finish *sync.Mutex) {
 	defer finish.Unlock()
 	defer close(output)
 	responseKey := md5.Sum(request.RequestKey[:])
@@ -154,39 +167,40 @@ func (this *VMessOutboundHandler) handleResponse(conn net.Conn, request *protoco
 	}
 	decryptResponseReader := v2crypto.NewCryptionReader(aesStream, conn)
 
-	buffer, err := v2io.ReadFrom(decryptResponseReader, nil)
+	buffer := alloc.NewSmallBuffer()
+	defer buffer.Release()
+	_, err = io.ReadFull(decryptResponseReader, buffer.Value[:4])
+
 	if err != nil {
 		log.Error("VMessOut: Failed to read VMess response (", buffer.Len(), " bytes): ", err)
-		buffer.Release()
 		return
 	}
-	if buffer.Len() < 4 || !headerMatch(request, buffer.Value[0]) {
+	if !headerMatch(request, buffer.Value[0]) {
 		log.Warning("VMessOut: unexepcted response header. The connection is probably hijacked.")
 		return
 	}
-	log.Info("VMessOut received ", buffer.Len()-4, " bytes from ", conn.RemoteAddr())
 
-	responseBegin := 4
 	if buffer.Value[2] != 0 {
+		command := buffer.Value[2]
 		dataLen := int(buffer.Value[3])
-		if buffer.Len() < dataLen+4 { // Rare case
-			diffBuffer := make([]byte, dataLen+4-buffer.Len())
-			io.ReadFull(decryptResponseReader, diffBuffer)
-			buffer.Append(diffBuffer)
+		_, err := io.ReadFull(decryptResponseReader, buffer.Value[:dataLen])
+		if err != nil {
+			log.Error("VMessOut: Failed to read response command: ", err)
+			return
 		}
-		command := buffer.Value[2]
-		data := buffer.Value[4 : 4+dataLen]
+		data := buffer.Value[:dataLen]
 		go this.handleCommand(dest, command, data)
-		responseBegin = 4 + dataLen
 	}
 
-	buffer.SliceFrom(responseBegin)
-	output <- buffer
-
-	if !isUDP {
-		v2io.RawReaderToChan(output, decryptResponseReader)
+	var reader v2io.Reader
+	if request.IsChunkStream() {
+		reader = vmessio.NewAuthChunkReader(decryptResponseReader)
+	} else {
+		reader = v2io.NewAdaptiveReader(decryptResponseReader)
 	}
 
+	v2io.ReaderToChan(output, reader)
+
 	return
 }
 

+ 0 - 75
proxy/vmess/protocol/io/validation.go

@@ -1,75 +0,0 @@
-package io
-
-import (
-	"errors"
-	"hash/fnv"
-	"io"
-
-	"github.com/v2ray/v2ray-core/common/alloc"
-	"github.com/v2ray/v2ray-core/transport"
-)
-
-var (
-	TruncatedPayload = errors.New("Truncated payload.")
-)
-
-type ValidationReader struct {
-	reader io.Reader
-	buffer *alloc.Buffer
-}
-
-func NewValidationReader(reader io.Reader) *ValidationReader {
-	return &ValidationReader{
-		reader: reader,
-		buffer: alloc.NewLargeBuffer().Clear(),
-	}
-}
-
-func (this *ValidationReader) Read(data []byte) (int, error) {
-	nBytes, err := this.reader.Read(data)
-	if err != nil {
-		return nBytes, err
-	}
-	nBytesActual := 0
-	dataActual := data[:]
-	for {
-		payload, rest, err := parsePayload(data)
-		if err != nil {
-			return nBytesActual, err
-		}
-		copy(dataActual, payload)
-		nBytesActual += len(payload)
-		dataActual = dataActual[nBytesActual:]
-		if len(rest) == 0 {
-			break
-		}
-		data = rest
-	}
-	return nBytesActual, nil
-}
-
-func parsePayload(data []byte) (payload []byte, rest []byte, err error) {
-	dataLen := len(data)
-	if dataLen < 6 {
-		err = TruncatedPayload
-		return
-	}
-	payloadLen := int(data[0])<<8 + int(data[1])
-	if dataLen < payloadLen+6 {
-		err = TruncatedPayload
-		return
-	}
-
-	payload = data[6 : 6+payloadLen]
-	rest = data[6+payloadLen:]
-
-	fnv1a := fnv.New32a()
-	fnv1a.Write(payload)
-	actualHash := fnv1a.Sum32()
-	expectedHash := uint32(data[2])<<24 + uint32(data[3])<<16 + uint32(data[4])<<8 + uint32(data[5])
-	if actualHash != expectedHash {
-		err = transport.ErrorCorruptedPacket
-		return
-	}
-	return
-}

+ 10 - 2
proxy/vmess/protocol/vmess.go

@@ -26,6 +26,8 @@ const (
 
 	Version = byte(0x01)
 
+	OptionChunk = byte(0x01)
+
 	blockSize = 16
 )
 
@@ -39,6 +41,7 @@ type VMessRequest struct {
 	RequestKey     []byte
 	ResponseHeader byte
 	Command        byte
+	Option         byte
 	Address        v2net.Address
 	Port           v2net.Port
 }
@@ -52,6 +55,10 @@ func (this *VMessRequest) Destination() v2net.Destination {
 	}
 }
 
+func (this *VMessRequest) IsChunkStream() bool {
+	return (this.Option & OptionChunk) == OptionChunk
+}
+
 // VMessRequestReader is a parser to read VMessRequest from a byte stream.
 type VMessRequestReader struct {
 	vUserSet UserSet
@@ -110,7 +117,8 @@ func (this *VMessRequestReader) Read(reader io.Reader) (*VMessRequest, error) {
 
 	request.RequestIV = append([]byte(nil), buffer.Value[1:17]...)   // 16 bytes
 	request.RequestKey = append([]byte(nil), buffer.Value[17:33]...) // 16 bytes
-	request.ResponseHeader = buffer.Value[33]                        // 1 byte + 3 bytes reserved.
+	request.ResponseHeader = buffer.Value[33]                        // 1 byte
+	request.Option = buffer.Value[34]                                // 1 byte + 2 bytes reserved
 	request.Command = buffer.Value[37]
 
 	request.Port = v2net.PortFromBytes(buffer.Value[38:40])
@@ -189,7 +197,7 @@ func (this *VMessRequest) ToBytes(timestampGenerator RandomTimestampGenerator, b
 	buffer.AppendBytes(this.Version)
 	buffer.Append(this.RequestIV)
 	buffer.Append(this.RequestKey)
-	buffer.AppendBytes(this.ResponseHeader, byte(0), byte(0), byte(0))
+	buffer.AppendBytes(this.ResponseHeader, this.Option, byte(0), byte(0))
 	buffer.AppendBytes(this.Command)
 	buffer.Append(this.Port.Bytes())