Browse Source

packet mode reader and writer

Darien Raymond 8 năm trước cách đây
mục cha
commit
939fae00e9

+ 24 - 0
common/buf/multi_buffer.go

@@ -61,6 +61,21 @@ func (mb *MultiBuffer) Read(b []byte) (int, error) {
 	return totalBytes, nil
 }
 
+func (mb *MultiBuffer) Write(b []byte) {
+	n := len(*mb)
+	if n > 0 && !(*mb)[n-1].IsFull() {
+		nBytes, _ := (*mb)[n-1].Write(b)
+		b = b[nBytes:]
+	}
+
+	for len(b) > 0 {
+		bb := New()
+		nBytes, _ := bb.Write(b)
+		b = b[nBytes:]
+		mb.Append(bb)
+	}
+}
+
 // Len returns the total number of bytes in the MultiBuffer.
 func (mb MultiBuffer) Len() int {
 	size := 0
@@ -112,3 +127,12 @@ func (mb *MultiBuffer) SliceBySize(size int) MultiBuffer {
 	*mb = (*mb)[endIndex:]
 	return slice
 }
+
+func (mb *MultiBuffer) SplitFirst() *Buffer {
+	if len(*mb) == 0 {
+		return nil
+	}
+	b := (*mb)[0]
+	*mb = (*mb)[1:]
+	return b
+}

+ 2 - 8
common/buf/reader.go

@@ -20,14 +20,8 @@ func (r *BytesToBufferReader) Read() (MultiBuffer, error) {
 	}
 
 	mb := NewMultiBuffer()
-	p := r.buffer[:nBytes]
-	for len(p) > 0 {
-		b := New()
-		nBytes, _ := b.Write(p)
-		mb.Append(b)
-		p = p[nBytes:]
-	}
-	return mb, nil
+	mb.Write(r.buffer[:nBytes])
+  return mb, nil
 }
 
 type readerAdpater struct {

+ 1 - 6
common/buf/writer.go

@@ -72,12 +72,7 @@ type bytesToBufferWriter struct {
 // Write implements io.Writer.
 func (w *bytesToBufferWriter) Write(payload []byte) (int, error) {
 	mb := NewMultiBuffer()
-	for p := payload; len(p) > 0; {
-		b := New()
-		nBytes, _ := b.Write(p)
-		p = p[nBytes:]
-		mb.Append(b)
-	}
+	mb.Write(payload)
 	if err := w.writer.Write(mb); err != nil {
 		return 0, err
 	}

+ 65 - 12
common/crypto/auth.go

@@ -73,19 +73,21 @@ type AuthenticationReader struct {
 	reader     io.Reader
 	sizeParser ChunkSizeDecoder
 	size       int
+	mode       StreamMode
 }
 
 const (
 	readerBufferSize = 32 * 1024
 )
 
-func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader) *AuthenticationReader {
+func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, mode StreamMode) *AuthenticationReader {
 	return &AuthenticationReader{
 		auth:       auth,
 		buffer:     buf.NewLocal(readerBufferSize),
 		reader:     reader,
 		sizeParser: sizeParser,
 		size:       -1,
+		mode:       mode,
 	}
 }
 
@@ -151,23 +153,36 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
 	}
 
 	mb := buf.NewMultiBuffer()
-
-	appendBytes := func(b []byte) {
-		for len(b) > 0 {
-			buffer := buf.New()
-			n, _ := buffer.Write(b)
-			b = b[n:]
-			mb.Append(buffer)
+	if r.mode == ModeStream {
+		mb.Write(b)
+	} else {
+		var bb *buf.Buffer
+		if len(b) < buf.Size {
+			bb = buf.New()
+		} else {
+			bb = buf.NewLocal(len(b))
 		}
+		bb.Append(b)
+		mb.Append(bb)
 	}
-	appendBytes(b)
 
 	for r.buffer.Len() >= r.sizeParser.SizeBytes() {
 		b, err := r.readChunk(false)
 		if err != nil {
 			break
 		}
-		appendBytes(b)
+		if r.mode == ModeStream {
+			mb.Write(b)
+		} else {
+			var bb *buf.Buffer
+			if len(b) < buf.Size {
+				bb = buf.New()
+			} else {
+				bb = buf.NewLocal(len(b))
+			}
+			bb.Append(b)
+			mb.Append(bb)
+		}
 	}
 
 	return mb, nil
@@ -179,15 +194,17 @@ type AuthenticationWriter struct {
 	buffer     *buf.Buffer
 	writer     io.Writer
 	sizeParser ChunkSizeEncoder
+	mode       StreamMode
 }
 
-func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer) *AuthenticationWriter {
+func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, mode StreamMode) *AuthenticationWriter {
 	return &AuthenticationWriter{
 		auth:       auth,
 		payload:    make([]byte, 1024),
 		buffer:     buf.NewLocal(readerBufferSize),
 		writer:     writer,
 		sizeParser: sizeParser,
+		mode:       mode,
 	}
 }
 
@@ -211,7 +228,7 @@ func (w *AuthenticationWriter) flush() error {
 	return err
 }
 
-func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error {
+func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
 	defer mb.Release()
 
 	for {
@@ -232,3 +249,39 @@ func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error {
 	}
 	return nil
 }
+
+func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error {
+	defer mb.Release()
+
+	for {
+		b := mb.SplitFirst()
+		if b == nil {
+			b = buf.New()
+		}
+		if w.buffer.Len() > readerBufferSize-b.Len()-128 {
+			if err := w.flush(); err != nil {
+				b.Release()
+				return err
+			}
+		}
+		w.append(b.Bytes())
+		b.Release()
+		if mb.IsEmpty() {
+			break
+		}
+	}
+
+	if !w.buffer.IsEmpty() {
+		return w.flush()
+	}
+
+	return nil
+}
+
+func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error {
+	if w.mode == ModeStream {
+		return w.writeStream(mb)
+	}
+
+	return w.writePacket(mb)
+}

+ 60 - 2
common/crypto/auth_test.go

@@ -39,7 +39,7 @@ func TestAuthenticationReaderWriter(t *testing.T) {
 			Content: iv,
 		},
 		AdditionalDataGenerator: &NoOpBytesGenerator{},
-	}, PlainChunkSizeParser{}, cache)
+	}, PlainChunkSizeParser{}, cache, ModeStream)
 
 	assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil()
 	assert.Int(cache.Len()).Equals(83360)
@@ -52,7 +52,7 @@ func TestAuthenticationReaderWriter(t *testing.T) {
 			Content: iv,
 		},
 		AdditionalDataGenerator: &NoOpBytesGenerator{},
-	}, PlainChunkSizeParser{}, cache)
+	}, PlainChunkSizeParser{}, cache, ModeStream)
 
 	mb := buf.NewMultiBuffer()
 
@@ -70,3 +70,61 @@ func TestAuthenticationReaderWriter(t *testing.T) {
 	_, err = reader.Read()
 	assert.Error(err).Equals(io.EOF)
 }
+
+func TestAuthenticationReaderWriterPacket(t *testing.T) {
+	assert := assert.On(t)
+
+	key := make([]byte, 16)
+	rand.Read(key)
+	block, err := aes.NewCipher(key)
+	assert.Error(err).IsNil()
+
+	aead, err := cipher.NewGCM(block)
+	assert.Error(err).IsNil()
+
+	cache := buf.NewLocal(1024)
+	iv := make([]byte, 12)
+	rand.Read(iv)
+
+	writer := NewAuthenticationWriter(&AEADAuthenticator{
+		AEAD: aead,
+		NonceGenerator: &StaticBytesGenerator{
+			Content: iv,
+		},
+		AdditionalDataGenerator: &NoOpBytesGenerator{},
+	}, PlainChunkSizeParser{}, cache, ModePacket)
+
+	payload := buf.NewMultiBuffer()
+	pb1 := buf.New()
+	pb1.Append([]byte("abcd"))
+	payload.Append(pb1)
+
+	pb2 := buf.New()
+	pb2.Append([]byte("efgh"))
+	payload.Append(pb2)
+
+	assert.Error(writer.Write(payload)).IsNil()
+	assert.Int(cache.Len()).GreaterThan(0)
+	assert.Error(writer.Write(buf.NewMultiBuffer())).IsNil()
+	assert.Error(err).IsNil()
+
+	reader := NewAuthenticationReader(&AEADAuthenticator{
+		AEAD: aead,
+		NonceGenerator: &StaticBytesGenerator{
+			Content: iv,
+		},
+		AdditionalDataGenerator: &NoOpBytesGenerator{},
+	}, PlainChunkSizeParser{}, cache, ModePacket)
+
+	mb, err := reader.Read()
+	assert.Error(err).IsNil()
+
+	b1 := mb.SplitFirst()
+	assert.String(b1.String()).Equals("abcd")
+	b2 := mb.SplitFirst()
+	assert.String(b2.String()).Equals("efgh")
+	assert.Bool(mb.IsEmpty()).IsTrue()
+
+	_, err = reader.Read()
+	assert.Error(err).Equals(io.EOF)
+}

+ 30 - 0
proxy/vmess/encoding/auth.go

@@ -6,6 +6,8 @@ import (
 
 	"golang.org/x/crypto/sha3"
 
+	"v2ray.com/core/common/crypto"
+	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/serial"
 )
 
@@ -16,6 +18,26 @@ func Authenticate(b []byte) uint32 {
 	return fnv1hash.Sum32()
 }
 
+type NoOpAuthenticator struct{}
+
+func (NoOpAuthenticator) NonceSize() int {
+	return 0
+}
+
+func (NoOpAuthenticator) Overhead() int {
+	return 0
+}
+
+// Seal implements AEAD.Seal().
+func (NoOpAuthenticator) Seal(dst, nonce, plaintext, additionalData []byte) []byte {
+	return append(dst[:0], plaintext...)
+}
+
+// Open implements AEAD.Open().
+func (NoOpAuthenticator) Open(dst, nonce, ciphertext, additionalData []byte) ([]byte, error) {
+	return append(dst[:0], ciphertext...), nil
+}
+
 // FnvAuthenticator is an AEAD based on Fnv hash.
 type FnvAuthenticator struct {
 }
@@ -86,3 +108,11 @@ func (s *ShakeSizeParser) Encode(size uint16, b []byte) []byte {
 	mask := s.next()
 	return serial.Uint16ToBytes(mask^size, b[:0])
 }
+
+func GetStreamMode(request *protocol.RequestHeader) crypto.StreamMode {
+	if request.Command == protocol.RequestCommandTCP {
+		return crypto.ModeStream
+	}
+
+	return crypto.ModePacket
+}

+ 26 - 8
proxy/vmess/encoding/client.go

@@ -123,7 +123,15 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
 	}
 	if request.Security.Is(protocol.SecurityType_NONE) {
 		if request.Option.Has(protocol.RequestOptionChunkStream) {
-			return crypto.NewChunkStreamWriter(sizeParser, writer)
+			if request.Command == protocol.RequestCommandTCP {
+				return crypto.NewChunkStreamWriter(sizeParser, writer)
+			}
+			auth := &crypto.AEADAuthenticator{
+				AEAD:                    new(NoOpAuthenticator),
+				NonceGenerator:          crypto.NoOpBytesGenerator{},
+				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
+			}
+			return crypto.NewAuthenticationWriter(auth, sizeParser, writer, crypto.ModePacket)
 		}
 
 		return buf.NewWriter(writer)
@@ -138,7 +146,7 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
 				NonceGenerator:          crypto.NoOpBytesGenerator{},
 				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 			}
-			return crypto.NewAuthenticationWriter(auth, sizeParser, cryptionWriter)
+			return crypto.NewAuthenticationWriter(auth, sizeParser, cryptionWriter, GetStreamMode(request))
 		}
 
 		return buf.NewWriter(cryptionWriter)
@@ -156,7 +164,7 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationWriter(auth, sizeParser, writer)
+		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, GetStreamMode(request))
 	}
 
 	if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) {
@@ -170,7 +178,7 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationWriter(auth, sizeParser, writer)
+		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, GetStreamMode(request))
 	}
 
 	panic("Unknown security type.")
@@ -221,7 +229,17 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read
 	}
 	if request.Security.Is(protocol.SecurityType_NONE) {
 		if request.Option.Has(protocol.RequestOptionChunkStream) {
-			return crypto.NewChunkStreamReader(sizeParser, reader)
+			if request.Command == protocol.RequestCommandTCP {
+				return crypto.NewChunkStreamReader(sizeParser, reader)
+			}
+
+			auth := &crypto.AEADAuthenticator{
+				AEAD:                    new(NoOpAuthenticator),
+				NonceGenerator:          crypto.NoOpBytesGenerator{},
+				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
+			}
+
+			return crypto.NewAuthenticationReader(auth, sizeParser, reader, crypto.ModePacket)
 		}
 
 		return buf.NewReader(reader)
@@ -234,7 +252,7 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read
 				NonceGenerator:          crypto.NoOpBytesGenerator{},
 				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 			}
-			return crypto.NewAuthenticationReader(auth, sizeParser, v.responseReader)
+			return crypto.NewAuthenticationReader(auth, sizeParser, v.responseReader, GetStreamMode(request))
 		}
 
 		return buf.NewReader(v.responseReader)
@@ -252,7 +270,7 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationReader(auth, sizeParser, reader)
+		return crypto.NewAuthenticationReader(auth, sizeParser, reader, GetStreamMode(request))
 	}
 
 	if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) {
@@ -266,7 +284,7 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationReader(auth, sizeParser, reader)
+		return crypto.NewAuthenticationReader(auth, sizeParser, reader, GetStreamMode(request))
 	}
 
 	panic("Unknown security type.")

+ 26 - 8
proxy/vmess/encoding/server.go

@@ -240,7 +240,16 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
 	}
 	if request.Security.Is(protocol.SecurityType_NONE) {
 		if request.Option.Has(protocol.RequestOptionChunkStream) {
-			return crypto.NewChunkStreamReader(sizeParser, reader)
+			if request.Command == protocol.RequestCommandTCP {
+				return crypto.NewChunkStreamReader(sizeParser, reader)
+			}
+
+			auth := &crypto.AEADAuthenticator{
+				AEAD:                    new(NoOpAuthenticator),
+				NonceGenerator:          crypto.NoOpBytesGenerator{},
+				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
+			}
+			return crypto.NewAuthenticationReader(auth, sizeParser, reader, crypto.ModePacket)
 		}
 
 		return buf.NewReader(reader)
@@ -255,7 +264,7 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
 				NonceGenerator:          crypto.NoOpBytesGenerator{},
 				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 			}
-			return crypto.NewAuthenticationReader(auth, sizeParser, cryptionReader)
+			return crypto.NewAuthenticationReader(auth, sizeParser, cryptionReader, GetStreamMode(request))
 		}
 
 		return buf.NewReader(cryptionReader)
@@ -273,7 +282,7 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationReader(auth, sizeParser, reader)
+		return crypto.NewAuthenticationReader(auth, sizeParser, reader, GetStreamMode(request))
 	}
 
 	if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) {
@@ -287,7 +296,7 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationReader(auth, sizeParser, reader)
+		return crypto.NewAuthenticationReader(auth, sizeParser, reader, GetStreamMode(request))
 	}
 
 	panic("Unknown security type.")
@@ -317,7 +326,16 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ
 	}
 	if request.Security.Is(protocol.SecurityType_NONE) {
 		if request.Option.Has(protocol.RequestOptionChunkStream) {
-			return crypto.NewChunkStreamWriter(sizeParser, writer)
+			if request.Command == protocol.RequestCommandTCP {
+				return crypto.NewChunkStreamWriter(sizeParser, writer)
+			}
+
+			auth := &crypto.AEADAuthenticator{
+				AEAD:                    new(NoOpAuthenticator),
+				NonceGenerator:          &crypto.NoOpBytesGenerator{},
+				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
+			}
+			return crypto.NewAuthenticationWriter(auth, sizeParser, writer, crypto.ModePacket)
 		}
 
 		return buf.NewWriter(writer)
@@ -330,7 +348,7 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ
 				NonceGenerator:          crypto.NoOpBytesGenerator{},
 				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 			}
-			return crypto.NewAuthenticationWriter(auth, sizeParser, v.responseWriter)
+			return crypto.NewAuthenticationWriter(auth, sizeParser, v.responseWriter, GetStreamMode(request))
 		}
 
 		return buf.NewWriter(v.responseWriter)
@@ -348,7 +366,7 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationWriter(auth, sizeParser, writer)
+		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, GetStreamMode(request))
 	}
 
 	if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) {
@@ -362,7 +380,7 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationWriter(auth, sizeParser, writer)
+		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, GetStreamMode(request))
 	}
 
 	panic("Unknown security type.")