Преглед изворни кода

support packet transfer type in mux

Darien Raymond пре 8 година
родитељ
комит
7a4bab4940

+ 31 - 23
app/proxyman/mux/mux.go

@@ -15,6 +15,7 @@ import (
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/proxy"
 	"v2ray.com/core/transport/ray"
 )
@@ -173,37 +174,37 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool
 	return true
 }
 
-func drain(reader *Reader) error {
-	buf.Copy(reader, buf.Discard)
+func drain(reader io.Reader) error {
+	buf.Copy(NewStreamReader(reader), buf.Discard)
 	return nil
 }
 
-func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader *Reader) error {
+func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader io.Reader) error {
 	if meta.Option.Has(OptionData) {
 		return drain(reader)
 	}
 	return nil
 }
 
-func (m *Client) handleStatusNew(meta *FrameMetadata, reader *Reader) error {
+func (m *Client) handleStatusNew(meta *FrameMetadata, reader io.Reader) error {
 	if meta.Option.Has(OptionData) {
 		return drain(reader)
 	}
 	return nil
 }
 
-func (m *Client) handleStatusKeep(meta *FrameMetadata, reader *Reader) error {
+func (m *Client) handleStatusKeep(meta *FrameMetadata, reader io.Reader) error {
 	if !meta.Option.Has(OptionData) {
 		return nil
 	}
 
 	if s, found := m.sessionManager.Get(meta.SessionID); found {
-		return buf.Copy(reader, s.output, buf.IgnoreWriterError())
+		return buf.Copy(s.NewReader(reader), s.output, buf.IgnoreWriterError())
 	}
 	return drain(reader)
 }
 
-func (m *Client) handleStatusEnd(meta *FrameMetadata, reader *Reader) error {
+func (m *Client) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error {
 	if s, found := m.sessionManager.Get(meta.SessionID); found {
 		s.CloseDownlink()
 		s.output.Close()
@@ -217,9 +218,11 @@ func (m *Client) handleStatusEnd(meta *FrameMetadata, reader *Reader) error {
 func (m *Client) fetchOutput() {
 	defer m.cancel()
 
-	reader := NewReader(m.inboundRay.InboundOutput())
+	reader := buf.ToBytesReader(m.inboundRay.InboundOutput())
+	metaReader := NewMetadataReader(reader)
+
 	for {
-		meta, err := reader.ReadMetadata()
+		meta, err := metaReader.Read()
 		if err != nil {
 			if errors.Cause(err) != io.EOF {
 				log.Trace(newError("failed to read metadata").Base(err))
@@ -289,7 +292,7 @@ type ServerWorker struct {
 }
 
 func handle(ctx context.Context, s *Session, output buf.Writer) {
-	writer := NewResponseWriter(s.ID, output)
+	writer := NewResponseWriter(s.ID, output, s.transferType)
 	if err := buf.Copy(s.input, writer); err != nil {
 		log.Trace(newError("session ", s.ID, " ends: ").Base(err))
 	}
@@ -297,14 +300,14 @@ func handle(ctx context.Context, s *Session, output buf.Writer) {
 	s.CloseDownlink()
 }
 
-func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *Reader) error {
+func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader io.Reader) error {
 	if meta.Option.Has(OptionData) {
 		return drain(reader)
 	}
 	return nil
 }
 
-func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *Reader) error {
+func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader io.Reader) error {
 	log.Trace(newError("received request for ", meta.Target))
 	inboundRay, err := w.dispatcher.Dispatch(ctx, meta.Target)
 	if err != nil {
@@ -314,30 +317,34 @@ func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata,
 		return newError("failed to dispatch request.").Base(err)
 	}
 	s := &Session{
-		input:  inboundRay.InboundOutput(),
-		output: inboundRay.InboundInput(),
-		parent: w.sessionManager,
-		ID:     meta.SessionID,
+		input:        inboundRay.InboundOutput(),
+		output:       inboundRay.InboundInput(),
+		parent:       w.sessionManager,
+		ID:           meta.SessionID,
+		transferType: protocol.TransferTypeStream,
+	}
+	if meta.Target.Network == net.Network_UDP {
+		s.transferType = protocol.TransferTypePacket
 	}
 	w.sessionManager.Add(s)
 	go handle(ctx, s, w.outboundRay.OutboundOutput())
 	if meta.Option.Has(OptionData) {
-		return buf.Copy(reader, s.output, buf.IgnoreWriterError())
+		return buf.Copy(s.NewReader(reader), s.output, buf.IgnoreWriterError())
 	}
 	return nil
 }
 
-func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader *Reader) error {
+func (w *ServerWorker) handleStatusKeep(meta *FrameMetadata, reader io.Reader) error {
 	if !meta.Option.Has(OptionData) {
 		return nil
 	}
 	if s, found := w.sessionManager.Get(meta.SessionID); found {
-		return buf.Copy(reader, s.output, buf.IgnoreWriterError())
+		return buf.Copy(s.NewReader(reader), s.output, buf.IgnoreWriterError())
 	}
 	return drain(reader)
 }
 
-func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *Reader) error {
+func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader io.Reader) error {
 	if s, found := w.sessionManager.Get(meta.SessionID); found {
 		s.CloseUplink()
 		s.output.Close()
@@ -348,8 +355,9 @@ func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *Reader) erro
 	return nil
 }
 
-func (w *ServerWorker) handleFrame(ctx context.Context, reader *Reader) error {
-	meta, err := reader.ReadMetadata()
+func (w *ServerWorker) handleFrame(ctx context.Context, reader io.Reader) error {
+	metaReader := NewMetadataReader(reader)
+	meta, err := metaReader.Read()
 	if err != nil {
 		return newError("failed to read metadata").Base(err)
 	}
@@ -375,7 +383,7 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *Reader) error {
 
 func (w *ServerWorker) run(ctx context.Context) {
 	input := w.outboundRay.OutboundInput()
-	reader := NewReader(input)
+	reader := buf.ToBytesReader(input)
 
 	defer w.sessionManager.Close()
 

+ 37 - 17
app/proxyman/mux/mux_test.go

@@ -2,28 +2,45 @@ package mux_test
 
 import (
 	"context"
+	"io"
 	"testing"
 
 	. "v2ray.com/core/app/proxyman/mux"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/testing/assert"
 	"v2ray.com/core/transport/ray"
 )
 
+func readAll(reader buf.Reader) (buf.MultiBuffer, error) {
+	mb := buf.NewMultiBuffer()
+	for {
+		b, err := reader.Read()
+		if err == io.EOF {
+			break
+		}
+		if err != nil {
+			return nil, err
+		}
+		mb.AppendMulti(b)
+	}
+	return mb, nil
+}
+
 func TestReaderWriter(t *testing.T) {
 	assert := assert.On(t)
 
 	stream := ray.NewStream(context.Background())
 
 	dest := net.TCPDestination(net.DomainAddress("v2ray.com"), 80)
-	writer := NewWriter(1, dest, stream)
+	writer := NewWriter(1, dest, stream, protocol.TransferTypeStream)
 
 	dest2 := net.TCPDestination(net.LocalHostIP, 443)
-	writer2 := NewWriter(2, dest2, stream)
+	writer2 := NewWriter(2, dest2, stream, protocol.TransferTypeStream)
 
 	dest3 := net.TCPDestination(net.LocalHostIPv6, 18374)
-	writer3 := NewWriter(3, dest3, stream)
+	writer3 := NewWriter(3, dest3, stream, protocol.TransferTypeStream)
 
 	writePayload := func(writer *Writer, payload ...byte) error {
 		b := buf.New()
@@ -43,73 +60,76 @@ func TestReaderWriter(t *testing.T) {
 	assert.Error(writePayload(writer2, 'y')).IsNil()
 	writer2.Close()
 
-	reader := NewReader(stream)
-	meta, err := reader.ReadMetadata()
+	bytesReader := buf.ToBytesReader(stream)
+	metaReader := NewMetadataReader(bytesReader)
+	streamReader := NewStreamReader(bytesReader)
+
+	meta, err := metaReader.Read()
 	assert.Error(err).IsNil()
 	assert.Uint16(meta.SessionID).Equals(1)
 	assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew))
 	assert.Destination(meta.Target).Equals(dest)
 	assert.Byte(byte(meta.Option)).Equals(byte(OptionData))
 
-	data, err := reader.Read()
+	data, err := readAll(streamReader)
 	assert.Error(err).IsNil()
 	assert.Int(len(data)).Equals(1)
 	assert.String(data[0].String()).Equals("abcd")
 
-	meta, err = reader.ReadMetadata()
+	meta, err = metaReader.Read()
 	assert.Error(err).IsNil()
 	assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew))
 	assert.Uint16(meta.SessionID).Equals(2)
 	assert.Byte(byte(meta.Option)).Equals(0)
 	assert.Destination(meta.Target).Equals(dest2)
 
-	meta, err = reader.ReadMetadata()
+	meta, err = metaReader.Read()
 	assert.Error(err).IsNil()
 	assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusKeep))
 	assert.Uint16(meta.SessionID).Equals(1)
 	assert.Byte(byte(meta.Option)).Equals(1)
 
-	data, err = reader.Read()
+	data, err = readAll(streamReader)
 	assert.Error(err).IsNil()
 	assert.Int(len(data)).Equals(1)
 	assert.String(data[0].String()).Equals("efgh")
 
-	meta, err = reader.ReadMetadata()
+	meta, err = metaReader.Read()
 	assert.Error(err).IsNil()
 	assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusNew))
 	assert.Uint16(meta.SessionID).Equals(3)
 	assert.Byte(byte(meta.Option)).Equals(1)
 	assert.Destination(meta.Target).Equals(dest3)
 
-	data, err = reader.Read()
+	data, err = readAll(streamReader)
 	assert.Error(err).IsNil()
 	assert.Int(len(data)).Equals(1)
 	assert.String(data[0].String()).Equals("x")
 
-	meta, err = reader.ReadMetadata()
+	meta, err = metaReader.Read()
 	assert.Error(err).IsNil()
 	assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd))
 	assert.Uint16(meta.SessionID).Equals(1)
 	assert.Byte(byte(meta.Option)).Equals(0)
 
-	meta, err = reader.ReadMetadata()
+	meta, err = metaReader.Read()
 	assert.Error(err).IsNil()
 	assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd))
 	assert.Uint16(meta.SessionID).Equals(3)
 	assert.Byte(byte(meta.Option)).Equals(0)
 
-	meta, err = reader.ReadMetadata()
+	meta, err = metaReader.Read()
 	assert.Error(err).IsNil()
 	assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusKeep))
 	assert.Uint16(meta.SessionID).Equals(2)
 	assert.Byte(byte(meta.Option)).Equals(1)
 
-	data, err = reader.Read()
+	data, err = readAll(streamReader)
 	assert.Error(err).IsNil()
 	assert.Int(len(data)).Equals(1)
 	assert.String(data[0].String()).Equals("y")
 
-	meta, err = reader.ReadMetadata()
+	meta, err = metaReader.Read()
 	assert.Error(err).IsNil()
 	assert.Byte(byte(meta.SessionStatus)).Equals(byte(SessionStatusEnd))
 	assert.Uint16(meta.SessionID).Equals(2)
@@ -117,6 +137,6 @@ func TestReaderWriter(t *testing.T) {
 
 	stream.Close()
 
-	meta, err = reader.ReadMetadata()
+	meta, err = metaReader.Read()
 	assert.Error(err).IsNotNil()
 }

+ 63 - 28
app/proxyman/mux/reader.go

@@ -7,57 +7,93 @@ import (
 	"v2ray.com/core/common/serial"
 )
 
-type Reader struct {
-	reader   io.Reader
-	buffer   *buf.Buffer
-	leftOver int
+type MetadataReader struct {
+	reader io.Reader
+	buffer []byte
 }
 
-func NewReader(reader buf.Reader) *Reader {
-	return &Reader{
-		reader:   buf.ToBytesReader(reader),
-		buffer:   buf.NewLocal(1024),
-		leftOver: -1,
+func NewMetadataReader(reader io.Reader) *MetadataReader {
+	return &MetadataReader{
+		reader: reader,
+		buffer: make([]byte, 1024),
 	}
 }
 
-func (r *Reader) ReadMetadata() (*FrameMetadata, error) {
-	r.leftOver = -1
-
-	b := r.buffer
-	b.Clear()
-
-	if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, 2)); err != nil {
+func (r *MetadataReader) Read() (*FrameMetadata, error) {
+	metaLen, err := serial.ReadUint16(r.reader)
+	if err != nil {
 		return nil, err
 	}
-	metaLen := serial.BytesToUint16(b.Bytes())
 	if metaLen > 512 {
 		return nil, newError("invalid metalen ", metaLen).AtWarning()
 	}
-	b.Clear()
-	if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, int(metaLen))); err != nil {
+
+	if _, err := io.ReadFull(r.reader, r.buffer[:metaLen]); err != nil {
 		return nil, err
 	}
-	return ReadFrameFrom(b.Bytes())
+	return ReadFrameFrom(r.buffer)
+}
+
+type PacketReader struct {
+	reader io.Reader
+	eof    bool
 }
 
-func (r *Reader) readSize() error {
-	if err := r.buffer.Reset(buf.ReadFullFrom(r.reader, 2)); err != nil {
-		return err
+func NewPacketReader(reader io.Reader) *PacketReader {
+	return &PacketReader{
+		reader: reader,
+		eof:    false,
 	}
-	r.leftOver = int(serial.BytesToUint16(r.buffer.Bytes()))
-	return nil
 }
 
-func (r *Reader) Read() (buf.MultiBuffer, error) {
+func (r *PacketReader) Read() (buf.MultiBuffer, error) {
+	if r.eof {
+		return nil, io.EOF
+	}
+
+	size, err := serial.ReadUint16(r.reader)
+	if err != nil {
+		return nil, err
+	}
+
+	var b *buf.Buffer
+	if size <= buf.Size {
+		b = buf.New()
+	} else {
+		b = buf.NewLocal(int(size))
+	}
+	if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, int(size))); err != nil {
+		b.Release()
+		return nil, err
+	}
+	r.eof = true
+	return buf.NewMultiBufferValue(b), nil
+}
+
+type StreamReader struct {
+	reader   io.Reader
+	leftOver int
+}
+
+func NewStreamReader(reader io.Reader) *StreamReader {
+	return &StreamReader{
+		reader:   reader,
+		leftOver: -1,
+	}
+}
+
+func (r *StreamReader) Read() (buf.MultiBuffer, error) {
 	if r.leftOver == 0 {
 		r.leftOver = -1
 		return nil, io.EOF
 	}
+
 	if r.leftOver == -1 {
-		if err := r.readSize(); err != nil {
+		size, err := serial.ReadUint16(r.reader)
+		if err != nil {
 			return nil, err
 		}
+		r.leftOver = int(size)
 	}
 
 	mb := buf.NewMultiBuffer()
@@ -79,6 +115,5 @@ func (r *Reader) Read() (buf.MultiBuffer, error) {
 			break
 		}
 	}
-
 	return mb, nil
 }

+ 11 - 0
app/proxyman/mux/session.go

@@ -1,8 +1,11 @@
 package mux
 
 import (
+	"io"
 	"sync"
 
+	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/transport/ray"
 )
 
@@ -119,6 +122,7 @@ type Session struct {
 	ID             uint16
 	uplinkClosed   bool
 	downlinkClosed bool
+	transferType   protocol.TransferType
 }
 
 func (s *Session) CloseUplink() {
@@ -142,3 +146,10 @@ func (s *Session) CloseDownlink() {
 		s.parent.Remove(s.ID)
 	}
 }
+
+func (s *Session) NewReader(reader io.Reader) buf.Reader {
+	if s.transferType == protocol.TransferTypeStream {
+		return NewStreamReader(reader)
+	}
+	return NewPacketReader(reader)
+}

+ 31 - 18
app/proxyman/mux/writer.go

@@ -5,30 +5,34 @@ import (
 
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/serial"
 )
 
 type Writer struct {
-	id       uint16
-	dest     net.Destination
-	writer   buf.Writer
-	followup bool
+	id           uint16
+	dest         net.Destination
+	writer       buf.Writer
+	followup     bool
+	transferType protocol.TransferType
 }
 
-func NewWriter(id uint16, dest net.Destination, writer buf.Writer) *Writer {
+func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType protocol.TransferType) *Writer {
 	return &Writer{
-		id:       id,
-		dest:     dest,
-		writer:   writer,
-		followup: false,
+		id:           id,
+		dest:         dest,
+		writer:       writer,
+		followup:     false,
+		transferType: transferType,
 	}
 }
 
-func NewResponseWriter(id uint16, writer buf.Writer) *Writer {
+func NewResponseWriter(id uint16, writer buf.Writer, transferType protocol.TransferType) *Writer {
 	return &Writer{
-		id:       id,
-		writer:   writer,
-		followup: true,
+		id:           id,
+		writer:       writer,
+		followup:     true,
+		transferType: transferType,
 	}
 }
 
@@ -82,13 +86,22 @@ func (w *Writer) Write(mb buf.MultiBuffer) error {
 		return w.writeMetaOnly()
 	}
 
-	const chunkSize = 8 * 1024
-	for !mb.IsEmpty() {
-		slice := mb.SliceBySize(chunkSize)
-		if err := w.writeData(slice); err != nil {
-			return err
+	if w.transferType == protocol.TransferTypeStream {
+		const chunkSize = 8 * 1024
+		for !mb.IsEmpty() {
+			slice := mb.SliceBySize(chunkSize)
+			if err := w.writeData(slice); err != nil {
+				return err
+			}
+		}
+	} else {
+		for _, b := range mb {
+			if err := w.writeData(buf.NewMultiBufferValue(b)); err != nil {
+				return err
+			}
 		}
 	}
+
 	return nil
 }
 

+ 30 - 36
common/crypto/auth.go

@@ -5,6 +5,7 @@ import (
 	"io"
 
 	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/protocol"
 )
 
 type BytesGenerator interface {
@@ -60,34 +61,27 @@ func (v *AEADAuthenticator) Seal(dst, plainText []byte) ([]byte, error) {
 	return v.AEAD.Seal(dst, iv, plainText, additionalData), nil
 }
 
-type StreamMode int
-
-const (
-	ModeStream StreamMode = iota
-	ModePacket
-)
-
 type AuthenticationReader struct {
-	auth       Authenticator
-	buffer     *buf.Buffer
-	reader     io.Reader
-	sizeParser ChunkSizeDecoder
-	size       int
-	mode       StreamMode
+	auth         Authenticator
+	buffer       *buf.Buffer
+	reader       io.Reader
+	sizeParser   ChunkSizeDecoder
+	size         int
+	transferType protocol.TransferType
 }
 
 const (
 	readerBufferSize = 32 * 1024
 )
 
-func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, mode StreamMode) *AuthenticationReader {
+func NewAuthenticationReader(auth Authenticator, sizeParser ChunkSizeDecoder, reader io.Reader, transferType protocol.TransferType) *AuthenticationReader {
 	return &AuthenticationReader{
-		auth:       auth,
-		buffer:     buf.NewLocal(readerBufferSize),
-		reader:     reader,
-		sizeParser: sizeParser,
-		size:       -1,
-		mode:       mode,
+		auth:         auth,
+		buffer:       buf.NewLocal(readerBufferSize),
+		reader:       reader,
+		sizeParser:   sizeParser,
+		size:         -1,
+		transferType: transferType,
 	}
 }
 
@@ -153,7 +147,7 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
 	}
 
 	mb := buf.NewMultiBuffer()
-	if r.mode == ModeStream {
+	if r.transferType == protocol.TransferTypeStream {
 		mb.Write(b)
 	} else {
 		var bb *buf.Buffer
@@ -171,7 +165,7 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
 		if err != nil {
 			break
 		}
-		if r.mode == ModeStream {
+		if r.transferType == protocol.TransferTypeStream {
 			mb.Write(b)
 		} else {
 			var bb *buf.Buffer
@@ -189,22 +183,22 @@ func (r *AuthenticationReader) Read() (buf.MultiBuffer, error) {
 }
 
 type AuthenticationWriter struct {
-	auth       Authenticator
-	payload    []byte
-	buffer     *buf.Buffer
-	writer     io.Writer
-	sizeParser ChunkSizeEncoder
-	mode       StreamMode
+	auth         Authenticator
+	payload      []byte
+	buffer       *buf.Buffer
+	writer       io.Writer
+	sizeParser   ChunkSizeEncoder
+	transferType protocol.TransferType
 }
 
-func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, mode StreamMode) *AuthenticationWriter {
+func NewAuthenticationWriter(auth Authenticator, sizeParser ChunkSizeEncoder, writer io.Writer, transferType protocol.TransferType) *AuthenticationWriter {
 	return &AuthenticationWriter{
-		auth:       auth,
-		payload:    make([]byte, 1024),
-		buffer:     buf.NewLocal(readerBufferSize),
-		writer:     writer,
-		sizeParser: sizeParser,
-		mode:       mode,
+		auth:         auth,
+		payload:      make([]byte, 1024),
+		buffer:       buf.NewLocal(readerBufferSize),
+		writer:       writer,
+		sizeParser:   sizeParser,
+		transferType: transferType,
 	}
 }
 
@@ -279,7 +273,7 @@ func (w *AuthenticationWriter) writePacket(mb buf.MultiBuffer) error {
 }
 
 func (w *AuthenticationWriter) Write(mb buf.MultiBuffer) error {
-	if w.mode == ModeStream {
+	if w.transferType == protocol.TransferTypeStream {
 		return w.writeStream(mb)
 	}
 

+ 5 - 4
common/crypto/auth_test.go

@@ -9,6 +9,7 @@ import (
 
 	"v2ray.com/core/common/buf"
 	. "v2ray.com/core/common/crypto"
+	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/testing/assert"
 )
 
@@ -39,7 +40,7 @@ func TestAuthenticationReaderWriter(t *testing.T) {
 			Content: iv,
 		},
 		AdditionalDataGenerator: &NoOpBytesGenerator{},
-	}, PlainChunkSizeParser{}, cache, ModeStream)
+	}, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream)
 
 	assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil()
 	assert.Int(cache.Len()).Equals(83360)
@@ -52,7 +53,7 @@ func TestAuthenticationReaderWriter(t *testing.T) {
 			Content: iv,
 		},
 		AdditionalDataGenerator: &NoOpBytesGenerator{},
-	}, PlainChunkSizeParser{}, cache, ModeStream)
+	}, PlainChunkSizeParser{}, cache, protocol.TransferTypeStream)
 
 	mb := buf.NewMultiBuffer()
 
@@ -92,7 +93,7 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
 			Content: iv,
 		},
 		AdditionalDataGenerator: &NoOpBytesGenerator{},
-	}, PlainChunkSizeParser{}, cache, ModePacket)
+	}, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket)
 
 	payload := buf.NewMultiBuffer()
 	pb1 := buf.New()
@@ -114,7 +115,7 @@ func TestAuthenticationReaderWriterPacket(t *testing.T) {
 			Content: iv,
 		},
 		AdditionalDataGenerator: &NoOpBytesGenerator{},
-	}, PlainChunkSizeParser{}, cache, ModePacket)
+	}, PlainChunkSizeParser{}, cache, protocol.TransferTypePacket)
 
 	mb, err := reader.Read()
 	assert.Error(err).IsNil()

+ 8 - 0
common/protocol/headers.go

@@ -15,6 +15,14 @@ const (
 	RequestCommandUDP = RequestCommand(0x02)
 )
 
+func (c RequestCommand) TransferType() TransferType {
+	if c == RequestCommandTCP {
+		return TransferTypeStream
+	}
+
+	return TransferTypePacket
+}
+
 // RequestOption is the options of a request.
 type RequestOption byte
 

+ 8 - 0
common/protocol/payload.go

@@ -0,0 +1,8 @@
+package protocol
+
+type TransferType int
+
+const (
+	TransferTypeStream TransferType = 0
+	TransferTypePacket TransferType = 1
+)

+ 9 - 0
common/serial/numbers.go

@@ -1,6 +1,7 @@
 package serial
 
 import "strconv"
+import "io"
 
 // Uint16ToBytes serializes an uint16 into bytes in big endian order.
 func Uint16ToBytes(value uint16, b []byte) []byte {
@@ -11,6 +12,14 @@ func Uint16ToString(value uint16) string {
 	return strconv.Itoa(int(value))
 }
 
+func ReadUint16(reader io.Reader) (uint16, error) {
+	var b [2]byte
+	if _, err := io.ReadFull(reader, b[:]); err != nil {
+		return 0, err
+	}
+	return BytesToUint16(b[:]), nil
+}
+
 func WriteUint16(value uint16) func([]byte) (int, error) {
 	return func(b []byte) (int, error) {
 		b = Uint16ToBytes(value, b[:0])

+ 0 - 10
proxy/vmess/encoding/auth.go

@@ -6,8 +6,6 @@ import (
 
 	"golang.org/x/crypto/sha3"
 
-	"v2ray.com/core/common/crypto"
-	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/serial"
 )
 
@@ -108,11 +106,3 @@ func (s *ShakeSizeParser) Encode(size uint16, b []byte) []byte {
 	mask := s.next()
 	return serial.Uint16ToBytes(mask^size, b[:0])
 }
-
-func GetStreamMode(request *protocol.RequestHeader) crypto.StreamMode {
-	if request.Command == protocol.RequestCommandTCP {
-		return crypto.ModeStream
-	}
-
-	return crypto.ModePacket
-}

+ 8 - 8
proxy/vmess/encoding/client.go

@@ -131,7 +131,7 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
 				NonceGenerator:          crypto.NoOpBytesGenerator{},
 				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 			}
-			return crypto.NewAuthenticationWriter(auth, sizeParser, writer, crypto.ModePacket)
+			return crypto.NewAuthenticationWriter(auth, sizeParser, writer, protocol.TransferTypePacket)
 		}
 
 		return buf.NewWriter(writer)
@@ -146,7 +146,7 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
 				NonceGenerator:          crypto.NoOpBytesGenerator{},
 				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 			}
-			return crypto.NewAuthenticationWriter(auth, sizeParser, cryptionWriter, GetStreamMode(request))
+			return crypto.NewAuthenticationWriter(auth, sizeParser, cryptionWriter, request.Command.TransferType())
 		}
 
 		return buf.NewWriter(cryptionWriter)
@@ -164,7 +164,7 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, GetStreamMode(request))
+		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType())
 	}
 
 	if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) {
@@ -178,7 +178,7 @@ func (v *ClientSession) EncodeRequestBody(request *protocol.RequestHeader, write
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, GetStreamMode(request))
+		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType())
 	}
 
 	panic("Unknown security type.")
@@ -239,7 +239,7 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read
 				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 			}
 
-			return crypto.NewAuthenticationReader(auth, sizeParser, reader, crypto.ModePacket)
+			return crypto.NewAuthenticationReader(auth, sizeParser, reader, protocol.TransferTypePacket)
 		}
 
 		return buf.NewReader(reader)
@@ -252,7 +252,7 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read
 				NonceGenerator:          crypto.NoOpBytesGenerator{},
 				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 			}
-			return crypto.NewAuthenticationReader(auth, sizeParser, v.responseReader, GetStreamMode(request))
+			return crypto.NewAuthenticationReader(auth, sizeParser, v.responseReader, request.Command.TransferType())
 		}
 
 		return buf.NewReader(v.responseReader)
@@ -270,7 +270,7 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationReader(auth, sizeParser, reader, GetStreamMode(request))
+		return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType())
 	}
 
 	if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) {
@@ -284,7 +284,7 @@ func (v *ClientSession) DecodeResponseBody(request *protocol.RequestHeader, read
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationReader(auth, sizeParser, reader, GetStreamMode(request))
+		return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType())
 	}
 
 	panic("Unknown security type.")

+ 8 - 8
proxy/vmess/encoding/server.go

@@ -249,7 +249,7 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
 				NonceGenerator:          crypto.NoOpBytesGenerator{},
 				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 			}
-			return crypto.NewAuthenticationReader(auth, sizeParser, reader, crypto.ModePacket)
+			return crypto.NewAuthenticationReader(auth, sizeParser, reader, protocol.TransferTypePacket)
 		}
 
 		return buf.NewReader(reader)
@@ -264,7 +264,7 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
 				NonceGenerator:          crypto.NoOpBytesGenerator{},
 				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 			}
-			return crypto.NewAuthenticationReader(auth, sizeParser, cryptionReader, GetStreamMode(request))
+			return crypto.NewAuthenticationReader(auth, sizeParser, cryptionReader, request.Command.TransferType())
 		}
 
 		return buf.NewReader(cryptionReader)
@@ -282,7 +282,7 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationReader(auth, sizeParser, reader, GetStreamMode(request))
+		return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType())
 	}
 
 	if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) {
@@ -296,7 +296,7 @@ func (v *ServerSession) DecodeRequestBody(request *protocol.RequestHeader, reade
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationReader(auth, sizeParser, reader, GetStreamMode(request))
+		return crypto.NewAuthenticationReader(auth, sizeParser, reader, request.Command.TransferType())
 	}
 
 	panic("Unknown security type.")
@@ -335,7 +335,7 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ
 				NonceGenerator:          &crypto.NoOpBytesGenerator{},
 				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 			}
-			return crypto.NewAuthenticationWriter(auth, sizeParser, writer, crypto.ModePacket)
+			return crypto.NewAuthenticationWriter(auth, sizeParser, writer, protocol.TransferTypePacket)
 		}
 
 		return buf.NewWriter(writer)
@@ -348,7 +348,7 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ
 				NonceGenerator:          crypto.NoOpBytesGenerator{},
 				AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 			}
-			return crypto.NewAuthenticationWriter(auth, sizeParser, v.responseWriter, GetStreamMode(request))
+			return crypto.NewAuthenticationWriter(auth, sizeParser, v.responseWriter, request.Command.TransferType())
 		}
 
 		return buf.NewWriter(v.responseWriter)
@@ -366,7 +366,7 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, GetStreamMode(request))
+		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType())
 	}
 
 	if request.Security.Is(protocol.SecurityType_CHACHA20_POLY1305) {
@@ -380,7 +380,7 @@ func (v *ServerSession) EncodeResponseBody(request *protocol.RequestHeader, writ
 			},
 			AdditionalDataGenerator: crypto.NoOpBytesGenerator{},
 		}
-		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, GetStreamMode(request))
+		return crypto.NewAuthenticationWriter(auth, sizeParser, writer, request.Command.TransferType())
 	}
 
 	panic("Unknown security type.")