Browse Source

multi buffer

Darien Raymond 8 năm trước cách đây
mục cha
commit
f506a39d32

+ 9 - 20
app/proxyman/mux/mux.go

@@ -180,31 +180,20 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool
 }
 
 func drain(reader *Reader) error {
-	for {
-		data, more, err := reader.Read()
-		if err != nil {
-			return err
-		}
-		data.Release()
-		if !more {
-			return nil
-		}
+	data, err := reader.Read()
+	if err != nil {
+		return err
 	}
+	data.Release()
+	return nil
 }
 
 func pipe(reader *Reader, writer buf.Writer) error {
-	for {
-		data, more, err := reader.Read()
-		if err != nil {
-			return err
-		}
-		if err := writer.Write(data); err != nil {
-			return err
-		}
-		if !more {
-			return nil
-		}
+	data, err := reader.Read()
+	if err != nil {
+		return err
 	}
+	return writer.Write(data)
 }
 
 func (m *Client) handleStatueKeepAlive(meta *FrameMetadata, reader *Reader) error {

+ 3 - 4
app/proxyman/mux/mux_test.go

@@ -20,7 +20,7 @@ func TestReaderWriter(t *testing.T) {
 
 	payload := buf.New()
 	payload.AppendBytes('a', 'b', 'c', 'd')
-	assert.Error(writer.Write(payload)).IsNil()
+	assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil()
 
 	writer.Close()
 
@@ -32,10 +32,9 @@ func TestReaderWriter(t *testing.T) {
 	assert.Destination(meta.Target).Equals(dest)
 	assert.Byte(byte(meta.Option)).Equals(byte(OptionData))
 
-	data, more, err := reader.Read()
+	data, err := reader.Read()
 	assert.Error(err).IsNil()
-	assert.Bool(more).IsFalse()
-	assert.String(data.String()).Equals("abcd")
+	assert.String(data[0].String()).Equals("abcd")
 
 	meta, err = reader.ReadMetadata()
 	assert.Error(err).IsNil()

+ 21 - 23
app/proxyman/mux/reader.go

@@ -8,9 +8,8 @@ import (
 )
 
 type Reader struct {
-	reader          io.Reader
-	remainingLength int
-	buffer          *buf.Buffer
+	reader io.Reader
+	buffer *buf.Buffer
 }
 
 func NewReader(reader buf.Reader) *Reader {
@@ -38,28 +37,27 @@ func (r *Reader) ReadMetadata() (*FrameMetadata, error) {
 	return ReadFrameFrom(b.Bytes())
 }
 
-func (r *Reader) Read() (*buf.Buffer, bool, error) {
-	b := buf.New()
-	var dataLen int
-	if r.remainingLength > 0 {
-		dataLen = r.remainingLength
-		r.remainingLength = 0
-	} else {
-		if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, 2)); err != nil {
-			return nil, false, err
-		}
-		dataLen = int(serial.BytesToUint16(b.Bytes()))
-		b.Clear()
-	}
-
-	if dataLen > buf.Size {
-		r.remainingLength = dataLen - buf.Size
-		dataLen = buf.Size
+func (r *Reader) Read() (buf.MultiBuffer, error) {
+	r.buffer.Clear()
+	if err := r.buffer.AppendSupplier(buf.ReadFullFrom(r.reader, 2)); err != nil {
+		return nil, err
 	}
 
-	if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, dataLen)); err != nil {
-		return nil, false, err
+	dataLen := int(serial.BytesToUint16(r.buffer.Bytes()))
+	mb := buf.NewMultiBuffer()
+	for dataLen > 0 {
+		b := buf.New()
+		readLen := buf.Size
+		if dataLen < readLen {
+			readLen = dataLen
+		}
+		if err := b.AppendSupplier(buf.ReadFullFrom(r.reader, readLen)); err != nil {
+			mb.Release()
+			return nil, err
+		}
+		dataLen -= readLen
+		mb.Append(b)
 	}
 
-	return b, (r.remainingLength > 0), nil
+	return mb, nil
 }

+ 12 - 30
app/proxyman/mux/writer.go

@@ -29,7 +29,7 @@ func NewResponseWriter(id uint16, writer buf.Writer) *Writer {
 	}
 }
 
-func (w *Writer) writeInternal(b *buf.Buffer) error {
+func (w *Writer) Write(mb buf.MultiBuffer) error {
 	meta := FrameMetadata{
 		SessionID: w.id,
 		Target:    w.dest,
@@ -41,42 +41,21 @@ func (w *Writer) writeInternal(b *buf.Buffer) error {
 		meta.SessionStatus = SessionStatusNew
 	}
 
-	if b.Len() > 0 {
+	if mb.Len() > 0 {
 		meta.Option.Add(OptionData)
 	}
 
 	frame := buf.New()
 	frame.AppendSupplier(meta.AsSupplier())
 
-	if b.Len() > 0 {
-		frame.AppendSupplier(serial.WriteUint16(0))
-		lengthBytes := frame.BytesFrom(-2)
+	mb2 := buf.NewMultiBuffer()
+	mb2.Append(frame)
 
-		nBytes, err := frame.Write(b.Bytes())
-		if err != nil {
-			frame.Release()
-			return err
-		}
-
-		serial.Uint16ToBytes(uint16(nBytes), lengthBytes[:0])
-		b.SliceFrom(nBytes)
-	}
-
-	return w.writer.Write(frame)
-}
-
-func (w *Writer) Write(b *buf.Buffer) error {
-	defer b.Release()
-
-	if err := w.writeInternal(b); err != nil {
-		return err
+	if mb.Len() > 0 {
+		frame.AppendSupplier(serial.WriteUint16(uint16(mb.Len())))
+		mb2.AppendMulti(mb)
 	}
-	for !b.IsEmpty() {
-		if err := w.writeInternal(b); err != nil {
-			return err
-		}
-	}
-	return nil
+	return w.writer.Write(mb2)
 }
 
 func (w *Writer) Close() {
@@ -88,5 +67,8 @@ func (w *Writer) Close() {
 	frame := buf.New()
 	frame.AppendSupplier(meta.AsSupplier())
 
-	w.writer.Write(frame)
+	mb := buf.NewMultiBuffer()
+	mb.Append(frame)
+
+	w.writer.Write(mb)
 }

+ 10 - 3
app/proxyman/outbound/handler.go

@@ -129,7 +129,7 @@ type Connection struct {
 	remoteAddr net.Addr
 
 	reader io.Reader
-	writer io.Writer
+	writer buf.Writer
 }
 
 func NewConnection(stream ray.Ray) *Connection {
@@ -144,7 +144,7 @@ func NewConnection(stream ray.Ray) *Connection {
 			Port: 0,
 		},
 		reader: buf.ToBytesReader(stream.InboundOutput()),
-		writer: buf.ToBytesWriter(stream.InboundInput()),
+		writer: stream.InboundInput(),
 	}
 }
 
@@ -161,7 +161,14 @@ func (v *Connection) Write(b []byte) (int, error) {
 	if v.closed {
 		return 0, io.ErrClosedPipe
 	}
-	return v.writer.Write(b)
+	return buf.ToBytesWriter(v.writer).Write(b)
+}
+
+func (v *Connection) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) {
+	if v.closed {
+		return 0, io.ErrClosedPipe
+	}
+	return mb.Len(), v.writer.Write(mb)
 }
 
 // Close implements net.Conn.Close().

+ 4 - 3
common/buf/io.go

@@ -11,19 +11,19 @@ import (
 // Reader extends io.Reader with alloc.Buffer.
 type Reader interface {
 	// Read reads content from underlying reader, and put it into an alloc.Buffer.
-	Read() (*Buffer, error)
+	Read() (MultiBuffer, error)
 }
 
 var ErrReadTimeout = newError("IO timeout")
 
 type TimeoutReader interface {
-	ReadTimeout(time.Duration) (*Buffer, error)
+	ReadTimeout(time.Duration) (MultiBuffer, error)
 }
 
 // Writer extends io.Writer with alloc.Buffer.
 type Writer interface {
 	// Write writes an alloc.Buffer into underlying writer.
-	Write(*Buffer) error
+	Write(MultiBuffer) error
 }
 
 // ReadFrom creates a Supplier to read from a given io.Reader.
@@ -78,6 +78,7 @@ func PipeUntilEOF(timer signal.ActivityTimer, reader Reader, writer Writer) erro
 func NewReader(reader io.Reader) Reader {
 	return &BytesToBufferReader{
 		reader: reader,
+		buffer: NewLocal(32 * 1024),
 	}
 }
 

+ 6 - 25
common/buf/merge_reader.go

@@ -3,7 +3,6 @@ package buf
 type MergingReader struct {
 	reader        Reader
 	timeoutReader TimeoutReader
-	leftover      *Buffer
 }
 
 func NewMergingReader(reader Reader) Reader {
@@ -13,41 +12,23 @@ func NewMergingReader(reader Reader) Reader {
 	}
 }
 
-func (r *MergingReader) Read() (*Buffer, error) {
-	if r.leftover != nil {
-		b := r.leftover
-		r.leftover = nil
-		return b, nil
-	}
-
-	b, err := r.reader.Read()
+func (r *MergingReader) Read() (MultiBuffer, error) {
+	mb, err := r.reader.Read()
 	if err != nil {
 		return nil, err
 	}
 
-	if b.IsFull() {
-		return b, nil
-	}
-
 	if r.timeoutReader == nil {
-		return b, nil
+		return mb, nil
 	}
 
 	for {
-		b2, err := r.timeoutReader.ReadTimeout(0)
+		mb2, err := r.timeoutReader.ReadTimeout(0)
 		if err != nil {
 			break
 		}
-
-		nBytes := b.Append(b2.Bytes())
-		b2.SliceFrom(nBytes)
-		if b2.IsEmpty() {
-			b2.Release()
-		} else {
-			r.leftover = b2
-			break
-		}
+		mb.AppendMulti(mb2)
 	}
 
-	return b, nil
+	return mb, nil
 }

+ 4 - 4
common/buf/merge_reader_test.go

@@ -16,18 +16,18 @@ func TestMergingReader(t *testing.T) {
 	stream := ray.NewStream(context.Background())
 	b1 := New()
 	b1.AppendBytes('a', 'b', 'c')
-	stream.Write(b1)
+	stream.Write(NewMultiBufferValue(b1))
 
 	b2 := New()
 	b2.AppendBytes('e', 'f', 'g')
-	stream.Write(b2)
+	stream.Write(NewMultiBufferValue(b2))
 
 	b3 := New()
 	b3.AppendBytes('h', 'i', 'j')
-	stream.Write(b3)
+	stream.Write(NewMultiBufferValue(b3))
 
 	reader := NewMergingReader(stream)
 	b, err := reader.Read()
 	assert.Error(err).IsNil()
-	assert.String(b.String()).Equals("abcefghij")
+	assert.Int(b.Len()).Equals(9)
 }

+ 88 - 0
common/buf/multi_buffer.go

@@ -0,0 +1,88 @@
+package buf
+
+import (
+	"io"
+	"net"
+)
+
+type MultiBufferWriter interface {
+	WriteMultiBuffer(MultiBuffer) (int, error)
+}
+
+type MultiBuffer []*Buffer
+
+func NewMultiBuffer() MultiBuffer {
+	return MultiBuffer(make([]*Buffer, 0, 8))
+}
+
+func NewMultiBufferValue(b ...*Buffer) MultiBuffer {
+	return MultiBuffer(b)
+}
+
+func (b *MultiBuffer) Append(buf *Buffer) {
+	*b = append(*b, buf)
+}
+
+func (b *MultiBuffer) AppendMulti(mb MultiBuffer) {
+	*b = append(*b, mb...)
+}
+
+func (mb *MultiBuffer) Read(b []byte) (int, error) {
+	if len(*mb) == 0 {
+		return 0, io.EOF
+	}
+	endIndex := len(*mb)
+	totalBytes := 0
+	for i, bb := range *mb {
+		nBytes, err := bb.Read(b)
+		totalBytes += nBytes
+		if err != nil {
+			return totalBytes, err
+		}
+		b = b[nBytes:]
+		if bb.IsEmpty() {
+			bb.Release()
+		} else {
+			endIndex = i
+			break
+		}
+	}
+	*mb = (*mb)[endIndex:]
+	return totalBytes, nil
+}
+
+func (mb MultiBuffer) WriteTo(writer io.Writer) (int, error) {
+	if mw, ok := writer.(MultiBufferWriter); ok {
+		return mw.WriteMultiBuffer(mb)
+	}
+	bs := make([][]byte, len(mb))
+	for i, b := range mb {
+		bs[i] = b.Bytes()
+	}
+	nbs := net.Buffers(bs)
+	nBytes, err := nbs.WriteTo(writer)
+	return int(nBytes), err
+}
+
+func (mb MultiBuffer) Len() int {
+	size := 0
+	for _, b := range mb {
+		size += b.Len()
+	}
+	return size
+}
+
+func (mb MultiBuffer) IsEmpty() bool {
+	for _, b := range mb {
+		if !b.IsEmpty() {
+			return false
+		}
+	}
+	return true
+}
+
+func (mb MultiBuffer) Release() {
+	for _, b := range mb {
+		b.Release()
+	}
+}

+ 25 - 0
common/buf/multi_buffer_test.go

@@ -0,0 +1,25 @@
+package buf_test
+
+import (
+	"testing"
+
+	. "v2ray.com/core/common/buf"
+	"v2ray.com/core/testing/assert"
+)
+
+func TestMultiBufferRead(t *testing.T) {
+	assert := assert.On(t)
+
+	b1 := New()
+	b1.AppendBytes('a', 'b')
+
+	b2 := New()
+	b2.AppendBytes('c', 'd')
+	mb := NewMultiBufferValue(b1, b2)
+
+	bs := make([]byte, 32)
+	nBytes, err := mb.Read(bs)
+	assert.Error(err).IsNil()
+	assert.Int(nBytes).Equals(4)
+	assert.Bytes(bs[:nBytes]).Equals([]byte("abcd"))
+}

+ 11 - 31
common/buf/reader.go

@@ -4,48 +4,28 @@ import "io"
 
 // BytesToBufferReader is a Reader that adjusts its reading speed automatically.
 type BytesToBufferReader struct {
-	reader      io.Reader
-	largeBuffer *Buffer
-	highVolumn  bool
+	reader io.Reader
+	buffer *Buffer
 }
 
 // Read implements Reader.Read().
-func (v *BytesToBufferReader) Read() (*Buffer, error) {
-	if v.highVolumn && v.largeBuffer.IsEmpty() {
-		if v.largeBuffer == nil {
-			v.largeBuffer = NewLocal(32 * 1024)
-		}
-		err := v.largeBuffer.AppendSupplier(ReadFrom(v.reader))
-		if err != nil {
-			return nil, err
-		}
-		if v.largeBuffer.Len() < Size {
-			v.highVolumn = false
-		}
-	}
-
-	buffer := New()
-	if !v.largeBuffer.IsEmpty() {
-		err := buffer.AppendSupplier(ReadFrom(v.largeBuffer))
-		return buffer, err
-	}
-
-	err := buffer.AppendSupplier(ReadFrom(v.reader))
-	if err != nil {
-		buffer.Release()
+func (v *BytesToBufferReader) Read() (MultiBuffer, error) {
+	if err := v.buffer.Reset(ReadFrom(v.reader)); err != nil {
 		return nil, err
 	}
 
-	if buffer.IsFull() {
-		v.highVolumn = true
+	mb := NewMultiBuffer()
+	for !v.buffer.IsEmpty() {
+		b := New()
+		b.AppendSupplier(ReadFrom(v.buffer))
+		mb.Append(b)
 	}
-
-	return buffer, nil
+	return mb, nil
 }
 
 type bufferToBytesReader struct {
 	stream  Reader
-	current *Buffer
+	current MultiBuffer
 	err     error
 }
 

+ 2 - 9
common/buf/reader_test.go

@@ -15,14 +15,7 @@ func TestAdaptiveReader(t *testing.T) {
 	buffer := bytes.NewBuffer(rawContent)
 
 	reader := NewReader(buffer)
-	b1, err := reader.Read()
+	b, err := reader.Read()
 	assert.Error(err).IsNil()
-	assert.Bool(b1.IsFull()).IsTrue()
-	assert.Int(b1.Len()).Equals(Size)
-	assert.Int(buffer.Len()).Equals(cap(rawContent) - Size)
-
-	b2, err := reader.Read()
-	assert.Error(err).IsNil()
-	assert.Bool(b2.IsFull()).IsTrue()
-	assert.Int(buffer.Len()).Equals(1007616)
+	assert.Int(b.Len()).Equals(32 * 1024)
 }

+ 18 - 27
common/buf/writer.go

@@ -8,39 +8,30 @@ type BufferToBytesWriter struct {
 }
 
 // Write implements Writer.Write(). Write() takes ownership of the given buffer.
-func (v *BufferToBytesWriter) Write(buffer *Buffer) error {
-	defer buffer.Release()
-	for {
-		nBytes, err := v.writer.Write(buffer.Bytes())
-		if err != nil {
-			return err
-		}
-		if nBytes == buffer.Len() {
-			break
-		}
-		buffer.SliceFrom(nBytes)
-	}
-	return nil
+func (v *BufferToBytesWriter) Write(buffer MultiBuffer) error {
+	_, err := buffer.WriteTo(v.writer)
+	//buffer.Release()
+	return err
 }
 
 type bytesToBufferWriter struct {
 	writer Writer
 }
 
-func (v *bytesToBufferWriter) Write(payload []byte) (int, error) {
-	bytesWritten := 0
-	size := len(payload)
-	for size > 0 {
-		buffer := New()
-		nBytes, _ := buffer.Write(payload)
-		size -= nBytes
-		payload = payload[nBytes:]
-		bytesWritten += nBytes
-		err := v.writer.Write(buffer)
-		if err != nil {
-			return bytesWritten, err
-		}
+func (w *bytesToBufferWriter) Write(payload []byte) (int, error) {
+	mb := NewMultiBuffer()
+	for p := payload; len(p) > 0; {
+		b := New()
+		nBytes, _ := b.Write(p)
+		p = p[nBytes:]
+		mb.Append(b)
+	}
+	if err := w.writer.Write(mb); err != nil {
+		return 0, err
 	}
+	return len(payload), nil
+}
 
-	return bytesWritten, nil
+func (w *bytesToBufferWriter) WriteMulteBuffer(mb MultiBuffer) (int, error) {
+	return mb.Len(), w.writer.Write(mb)
 }

+ 1 - 1
common/buf/writer_test.go

@@ -20,7 +20,7 @@ func TestWriter(t *testing.T) {
 	writeBuffer := bytes.NewBuffer(make([]byte, 0, 1024*1024))
 
 	writer := NewWriter(NewBufferedWriter(writeBuffer))
-	err := writer.Write(lb)
+	err := writer.Write(NewMultiBufferValue(lb))
 	assert.Error(err).IsNil()
 	assert.Bytes(expectedBytes).Equals(writeBuffer.Bytes())
 }

+ 25 - 5
common/crypto/auth.go

@@ -215,14 +215,34 @@ func NewAuthenticationWriter(auth Authenticator, writer io.Writer, sizeMask Uint
 	}
 }
 
-func (v *AuthenticationWriter) Write(b []byte) (int, error) {
-	cipherChunk, err := v.auth.Seal(v.buffer[2:2], b)
+func (w *AuthenticationWriter) Write(b []byte) (int, error) {
+	cipherChunk, err := w.auth.Seal(w.buffer[2:2], b)
 	if err != nil {
 		return 0, err
 	}
 
-	size := uint16(len(cipherChunk)) ^ v.sizeMask.Next()
-	serial.Uint16ToBytes(size, v.buffer[:0])
-	_, err = v.writer.Write(v.buffer[:2+len(cipherChunk)])
+	size := uint16(len(cipherChunk)) ^ w.sizeMask.Next()
+	serial.Uint16ToBytes(size, w.buffer[:0])
+	_, err = w.writer.Write(w.buffer[:2+len(cipherChunk)])
 	return len(b), err
 }
+
+func (w *AuthenticationWriter) WriteMultiBuffer(mb buf.MultiBuffer) (int, error) {
+	const StartIndex = 17 * 1024
+	var totalBytes int
+	for {
+		payloadLen, err := mb.Read(w.buffer[StartIndex:])
+		if err != nil {
+			return 0, err
+		}
+		nBytes, err := w.Write(w.buffer[StartIndex : StartIndex+payloadLen])
+		totalBytes += nBytes
+		if err != nil {
+			return totalBytes, err
+		}
+		if mb.IsEmpty() {
+			break
+		}
+	}
+	return totalBytes, nil
+}

+ 3 - 1
proxy/blackhole/config.go

@@ -28,7 +28,9 @@ func (v *NoneResponse) WriteTo(buf.Writer) {}
 func (v *HTTPResponse) WriteTo(writer buf.Writer) {
 	b := buf.NewLocal(512)
 	b.AppendSupplier(serial.WriteString(http403response))
-	writer.Write(b)
+	mb := buf.NewMultiBuffer()
+	mb.Append(b)
+	writer.Write(mb)
 }
 
 // GetInternalResponse converts response settings from proto to internal data structure.

+ 21 - 7
proxy/shadowsocks/ota.go

@@ -68,7 +68,7 @@ func NewChunkReader(reader io.Reader, auth *Authenticator) *ChunkReader {
 	}
 }
 
-func (v *ChunkReader) Read() (*buf.Buffer, error) {
+func (v *ChunkReader) Read() (buf.MultiBuffer, error) {
 	buffer := buf.New()
 	if err := buffer.AppendSupplier(buf.ReadFullFrom(v.reader, 2)); err != nil {
 		buffer.Release()
@@ -100,7 +100,10 @@ func (v *ChunkReader) Read() (*buf.Buffer, error) {
 	}
 	buffer.SliceFrom(AuthSize)
 
-	return buffer, nil
+	mb := buf.NewMultiBuffer()
+	mb.Append(buffer)
+
+	return mb, nil
 }
 
 type ChunkWriter struct {
@@ -117,11 +120,22 @@ func NewChunkWriter(writer io.Writer, auth *Authenticator) *ChunkWriter {
 	}
 }
 
-func (v *ChunkWriter) Write(payload *buf.Buffer) error {
+func (w *ChunkWriter) Write(mb buf.MultiBuffer) error {
+	defer mb.Release()
+
+	for _, b := range mb {
+		if err := w.writeInternal(b); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (w *ChunkWriter) writeInternal(payload *buf.Buffer) error {
 	totalLength := payload.Len()
-	serial.Uint16ToBytes(uint16(totalLength), v.buffer[:0])
-	v.auth.Authenticate(payload.Bytes())(v.buffer[2:])
-	copy(v.buffer[2+AuthSize:], payload.Bytes())
-	_, err := v.writer.Write(v.buffer[:2+AuthSize+payload.Len()])
+	serial.Uint16ToBytes(uint16(totalLength), w.buffer[:0])
+	w.auth.Authenticate(payload.Bytes())(w.buffer[2:])
+	copy(w.buffer[2+AuthSize:], payload.Bytes())
+	_, err := w.writer.Write(w.buffer[:2+AuthSize+payload.Len()])
 	return err
 }

+ 2 - 2
proxy/shadowsocks/ota_test.go

@@ -18,7 +18,7 @@ func TestNormalChunkReading(t *testing.T) {
 		[]byte{21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36})))
 	payload, err := reader.Read()
 	assert.Error(err).IsNil()
-	assert.Bytes(payload.Bytes()).Equals([]byte{11, 12, 13, 14, 15, 16, 17, 18})
+	assert.Bytes(payload[0].Bytes()).Equals([]byte{11, 12, 13, 14, 15, 16, 17, 18})
 }
 
 func TestNormalChunkWriting(t *testing.T) {
@@ -30,7 +30,7 @@ func TestNormalChunkWriting(t *testing.T) {
 
 	b := buf.NewLocal(256)
 	b.Append([]byte{11, 12, 13, 14, 15, 16, 17, 18})
-	err := writer.Write(b)
+	err := writer.Write(buf.NewMultiBufferValue(b))
 	assert.Error(err).IsNil()
 	assert.Bytes(buffer.Bytes()).Equals([]byte{0, 8, 39, 228, 69, 96, 133, 39, 254, 26, 201, 70, 11, 12, 13, 14, 15, 16, 17, 18})
 }

+ 16 - 5
proxy/shadowsocks/protocol.go

@@ -362,7 +362,7 @@ type UDPReader struct {
 	User   *protocol.User
 }
 
-func (v *UDPReader) Read() (*buf.Buffer, error) {
+func (v *UDPReader) Read() (buf.MultiBuffer, error) {
 	buffer := buf.NewSmall()
 	err := buffer.AppendSupplier(buf.ReadFrom(v.Reader))
 	if err != nil {
@@ -374,7 +374,9 @@ func (v *UDPReader) Read() (*buf.Buffer, error) {
 		buffer.Release()
 		return nil, err
 	}
-	return payload, nil
+	mb := buf.NewMultiBuffer()
+	mb.Append(payload)
+	return mb, nil
 }
 
 type UDPWriter struct {
@@ -382,12 +384,21 @@ type UDPWriter struct {
 	Request *protocol.RequestHeader
 }
 
-func (v *UDPWriter) Write(buffer *buf.Buffer) error {
-	payload, err := EncodeUDPPacket(v.Request, buffer)
+func (w *UDPWriter) Write(mb buf.MultiBuffer) error {
+	for _, b := range mb {
+		if err := w.writeInternal(b); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (w *UDPWriter) writeInternal(buffer *buf.Buffer) error {
+	payload, err := EncodeUDPPacket(w.Request, buffer)
 	if err != nil {
 		return err
 	}
-	_, err = v.Writer.Write(payload.Bytes())
+	_, err = w.Writer.Write(payload.Bytes())
 	payload.Release()
 	return err
 }

+ 6 - 6
proxy/shadowsocks/protocol_test.go

@@ -66,7 +66,7 @@ func TestTCPRequest(t *testing.T) {
 	writer, err := WriteTCPRequest(request, cache)
 	assert.Error(err).IsNil()
 
-	writer.Write(data)
+	writer.Write(buf.NewMultiBufferValue(data))
 
 	decodedRequest, reader, err := ReadTCPSession(request.User, cache)
 	assert.Error(err).IsNil()
@@ -75,7 +75,7 @@ func TestTCPRequest(t *testing.T) {
 
 	decodedData, err := reader.Read()
 	assert.Error(err).IsNil()
-	assert.String(decodedData.String()).Equals("test string")
+	assert.String(decodedData[0].String()).Equals("test string")
 }
 
 func TestUDPReaderWriter(t *testing.T) {
@@ -106,19 +106,19 @@ func TestUDPReaderWriter(t *testing.T) {
 
 	b := buf.New()
 	b.AppendSupplier(serial.WriteString("test payload"))
-	err := writer.Write(b)
+	err := writer.Write(buf.NewMultiBufferValue(b))
 	assert.Error(err).IsNil()
 
 	payload, err := reader.Read()
 	assert.Error(err).IsNil()
-	assert.String(payload.String()).Equals("test payload")
+	assert.String(payload[0].String()).Equals("test payload")
 
 	b = buf.New()
 	b.AppendSupplier(serial.WriteString("test payload 2"))
-	err = writer.Write(b)
+	err = writer.Write(buf.NewMultiBufferValue(b))
 	assert.Error(err).IsNil()
 
 	payload, err = reader.Read()
 	assert.Error(err).IsNil()
-	assert.String(payload.String()).Equals("test payload 2")
+	assert.String(payload[0].String()).Equals("test payload 2")
 }

+ 37 - 35
proxy/shadowsocks/server.go

@@ -75,52 +75,54 @@ func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 
 	reader := buf.NewReader(conn)
 	for {
-		payload, err := reader.Read()
+		mpayload, err := reader.Read()
 		if err != nil {
 			break
 		}
 
-		request, data, err := DecodeUDPPacket(v.user, payload)
-		if err != nil {
-			if source, ok := proxy.SourceFromContext(ctx); ok {
-				log.Trace(newError("dropping invalid UDP packet from: ", source).Base(err))
-				log.Access(source, "", log.AccessRejected, err)
+		for _, payload := range mpayload {
+			request, data, err := DecodeUDPPacket(v.user, payload)
+			if err != nil {
+				if source, ok := proxy.SourceFromContext(ctx); ok {
+					log.Trace(newError("dropping invalid UDP packet from: ", source).Base(err))
+					log.Access(source, "", log.AccessRejected, err)
+				}
+				payload.Release()
+				continue
 			}
-			payload.Release()
-			continue
-		}
 
-		if request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Disabled {
-			log.Trace(newError("client payload enables OTA but server doesn't allow it"))
-			payload.Release()
-			continue
-		}
+			if request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Disabled {
+				log.Trace(newError("client payload enables OTA but server doesn't allow it"))
+				payload.Release()
+				continue
+			}
 
-		if !request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Enabled {
-			log.Trace(newError("client payload disables OTA but server forces it"))
-			payload.Release()
-			continue
-		}
+			if !request.Option.Has(RequestOptionOneTimeAuth) && v.account.OneTimeAuth == Account_Enabled {
+				log.Trace(newError("client payload disables OTA but server forces it"))
+				payload.Release()
+				continue
+			}
 
-		dest := request.Destination()
-		if source, ok := proxy.SourceFromContext(ctx); ok {
-			log.Access(source, dest, log.AccessAccepted, "")
-		}
-		log.Trace(newError("tunnelling request to ", dest))
+			dest := request.Destination()
+			if source, ok := proxy.SourceFromContext(ctx); ok {
+				log.Access(source, dest, log.AccessAccepted, "")
+			}
+			log.Trace(newError("tunnelling request to ", dest))
 
-		ctx = protocol.ContextWithUser(ctx, request.User)
-		udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) {
-			defer payload.Release()
+			ctx = protocol.ContextWithUser(ctx, request.User)
+			udpServer.Dispatch(ctx, dest, data, func(payload *buf.Buffer) {
+				defer payload.Release()
 
-			data, err := EncodeUDPPacket(request, payload)
-			if err != nil {
-				log.Trace(newError("failed to encode UDP packet").Base(err).AtWarning())
-				return
-			}
-			defer data.Release()
+				data, err := EncodeUDPPacket(request, payload)
+				if err != nil {
+					log.Trace(newError("failed to encode UDP packet").Base(err).AtWarning())
+					return
+				}
+				defer data.Release()
 
-			conn.Write(data.Bytes())
-		})
+				conn.Write(data.Bytes())
+			})
+		}
 	}
 
 	return nil

+ 13 - 8
proxy/socks/protocol.go

@@ -347,7 +347,7 @@ func NewUDPReader(reader io.Reader) *UDPReader {
 	return &UDPReader{reader: reader}
 }
 
-func (r *UDPReader) Read() (*buf.Buffer, error) {
+func (r *UDPReader) Read() (buf.MultiBuffer, error) {
 	b := buf.NewSmall()
 	if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil {
 		return nil, err
@@ -358,7 +358,9 @@ func (r *UDPReader) Read() (*buf.Buffer, error) {
 	}
 	b.Clear()
 	b.Append(data)
-	return b, nil
+	mb := buf.NewMultiBuffer()
+	mb.Append(b)
+	return mb, nil
 }
 
 type UDPWriter struct {
@@ -373,12 +375,15 @@ func NewUDPWriter(request *protocol.RequestHeader, writer io.Writer) *UDPWriter
 	}
 }
 
-func (w *UDPWriter) Write(b *buf.Buffer) error {
-	eb := EncodeUDPPacket(w.request, b.Bytes())
-	b.Release()
-	defer eb.Release()
-	if _, err := w.writer.Write(eb.Bytes()); err != nil {
-		return err
+func (w *UDPWriter) Write(mb buf.MultiBuffer) error {
+	defer mb.Release()
+
+	for _, b := range mb {
+		eb := EncodeUDPPacket(w.request, b.Bytes())
+		defer eb.Release()
+		if _, err := w.writer.Write(eb.Bytes()); err != nil {
+			return err
+		}
 	}
 	return nil
 }

+ 2 - 2
proxy/socks/protocol_test.go

@@ -24,11 +24,11 @@ func TestUDPEncoding(t *testing.T) {
 	content := []byte{'a'}
 	payload := buf.New()
 	payload.Append(content)
-	assert.Error(writer.Write(payload)).IsNil()
+	assert.Error(writer.Write(buf.NewMultiBufferValue(payload))).IsNil()
 
 	reader := NewUDPReader(b)
 
 	decodedPayload, err := reader.Read()
 	assert.Error(err).IsNil()
-	assert.Bytes(decodedPayload.Bytes()).Equals(content)
+	assert.Bytes(decodedPayload[0].Bytes()).Equals(content)
 }

+ 25 - 22
proxy/socks/server.go

@@ -159,38 +159,41 @@ func (v *Server) handleUDPPayload(ctx context.Context, conn internet.Connection,
 
 	reader := buf.NewReader(conn)
 	for {
-		payload, err := reader.Read()
+		mpayload, err := reader.Read()
 		if err != nil {
 			return err
 		}
-		request, data, err := DecodeUDPPacket(payload.Bytes())
 
-		if err != nil {
-			log.Trace(newError("failed to parse UDP request").Base(err))
-			continue
-		}
+		for _, payload := range mpayload {
+			request, data, err := DecodeUDPPacket(payload.Bytes())
 
-		if len(data) == 0 {
-			continue
-		}
+			if err != nil {
+				log.Trace(newError("failed to parse UDP request").Base(err))
+				continue
+			}
 
-		log.Trace(newError("send packet to ", request.Destination(), " with ", len(data), " bytes").AtDebug())
-		if source, ok := proxy.SourceFromContext(ctx); ok {
-			log.Access(source, request.Destination, log.AccessAccepted, "")
-		}
+			if len(data) == 0 {
+				continue
+			}
+
+			log.Trace(newError("send packet to ", request.Destination(), " with ", len(data), " bytes").AtDebug())
+			if source, ok := proxy.SourceFromContext(ctx); ok {
+				log.Access(source, request.Destination, log.AccessAccepted, "")
+			}
 
-		dataBuf := buf.NewSmall()
-		dataBuf.Append(data)
-		udpServer.Dispatch(ctx, request.Destination(), dataBuf, func(payload *buf.Buffer) {
-			defer payload.Release()
+			dataBuf := buf.NewSmall()
+			dataBuf.Append(data)
+			udpServer.Dispatch(ctx, request.Destination(), dataBuf, func(payload *buf.Buffer) {
+				defer payload.Release()
 
-			log.Trace(newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug())
+				log.Trace(newError("writing back UDP response with ", payload.Len(), " bytes").AtDebug())
 
-			udpMessage := EncodeUDPPacket(request, payload.Bytes())
-			defer udpMessage.Release()
+				udpMessage := EncodeUDPPacket(request, payload.Bytes())
+				defer udpMessage.Release()
 
-			conn.Write(udpMessage.Bytes())
-		})
+				conn.Write(udpMessage.Bytes())
+			})
+		}
 	}
 }
 

+ 1 - 1
proxy/vmess/inbound/inbound.go

@@ -166,7 +166,7 @@ func transferResponse(timer signal.ActivityTimer, session *encoding.ServerSessio
 	}
 
 	if request.Option.Has(protocol.RequestOptionChunkStream) {
-		if err := bodyWriter.Write(buf.NewLocal(8)); err != nil {
+		if err := bodyWriter.Write(buf.NewMultiBuffer()); err != nil {
 			return err
 		}
 	}

+ 1 - 1
proxy/vmess/outbound/outbound.go

@@ -133,7 +133,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
 		}
 
 		if request.Option.Has(protocol.RequestOptionChunkStream) {
-			if err := bodyWriter.Write(buf.NewLocal(8)); err != nil {
+			if err := bodyWriter.Write(buf.NewMultiBuffer()); err != nil {
 				return err
 			}
 		}

+ 5 - 3
transport/internet/udp/dispatcher.go

@@ -57,7 +57,7 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination v2net.Destination
 	inboundRay, existing := v.getInboundRay(ctx, destination)
 	outputStream := inboundRay.InboundInput()
 	if outputStream != nil {
-		if err := outputStream.Write(payload); err != nil {
+		if err := outputStream.Write(buf.NewMultiBufferValue(payload)); err != nil {
 			v.RemoveRay(destination)
 		}
 	}
@@ -71,10 +71,12 @@ func (v *Dispatcher) Dispatch(ctx context.Context, destination v2net.Destination
 
 func handleInput(input ray.InputStream, callback ResponseCallback) {
 	for {
-		data, err := input.Read()
+		mb, err := input.Read()
 		if err != nil {
 			break
 		}
-		callback(data)
+		for _, b := range mb {
+			callback(b)
+		}
 	}
 }

+ 5 - 5
transport/ray/direct.go

@@ -42,7 +42,7 @@ func (v *directRay) InboundOutput() InputStream {
 }
 
 type Stream struct {
-	buffer chan *buf.Buffer
+	buffer chan buf.MultiBuffer
 	ctx    context.Context
 	close  chan bool
 	err    chan bool
@@ -51,13 +51,13 @@ type Stream struct {
 func NewStream(ctx context.Context) *Stream {
 	return &Stream{
 		ctx:    ctx,
-		buffer: make(chan *buf.Buffer, bufferSize),
+		buffer: make(chan buf.MultiBuffer, bufferSize),
 		close:  make(chan bool),
 		err:    make(chan bool),
 	}
 }
 
-func (v *Stream) Read() (*buf.Buffer, error) {
+func (v *Stream) Read() (buf.MultiBuffer, error) {
 	select {
 	case <-v.ctx.Done():
 		return nil, io.ErrClosedPipe
@@ -79,7 +79,7 @@ func (v *Stream) Read() (*buf.Buffer, error) {
 	}
 }
 
-func (v *Stream) ReadTimeout(timeout time.Duration) (*buf.Buffer, error) {
+func (v *Stream) ReadTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
 	select {
 	case <-v.ctx.Done():
 		return nil, io.ErrClosedPipe
@@ -107,7 +107,7 @@ func (v *Stream) ReadTimeout(timeout time.Duration) (*buf.Buffer, error) {
 	}
 }
 
-func (v *Stream) Write(data *buf.Buffer) (err error) {
+func (v *Stream) Write(data buf.MultiBuffer) (err error) {
 	if data.IsEmpty() {
 		return
 	}

+ 3 - 3
transport/ray/direct_test.go

@@ -16,7 +16,7 @@ func TestStreamIO(t *testing.T) {
 	stream := NewStream(context.Background())
 	b1 := buf.New()
 	b1.AppendBytes('a')
-	assert.Error(stream.Write(b1)).IsNil()
+	assert.Error(stream.Write(buf.NewMultiBufferValue(b1))).IsNil()
 
 	_, err := stream.Read()
 	assert.Error(err).IsNil()
@@ -27,7 +27,7 @@ func TestStreamIO(t *testing.T) {
 
 	b2 := buf.New()
 	b2.AppendBytes('b')
-	err = stream.Write(b2)
+	err = stream.Write(buf.NewMultiBufferValue(b2))
 	assert.Error(err).Equals(io.ErrClosedPipe)
 }
 
@@ -37,7 +37,7 @@ func TestStreamClose(t *testing.T) {
 	stream := NewStream(context.Background())
 	b1 := buf.New()
 	b1.AppendBytes('a')
-	assert.Error(stream.Write(b1)).IsNil()
+	assert.Error(stream.Write(buf.NewMultiBufferValue(b1))).IsNil()
 
 	stream.Close()