Browse Source

remove closure on ReadFullFrom

Darien Raymond 7 years ago
parent
commit
58e2ed3381

+ 3 - 3
app/dns/udpns.go

@@ -6,15 +6,15 @@ import (
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
 
 
-	"v2ray.com/core/common/session"
-	"v2ray.com/core/features/routing"
-
 	"github.com/miekg/dns"
 	"github.com/miekg/dns"
+
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/session"
 	"v2ray.com/core/common/signal/pubsub"
 	"v2ray.com/core/common/signal/pubsub"
 	"v2ray.com/core/common/task"
 	"v2ray.com/core/common/task"
+	"v2ray.com/core/features/routing"
 	"v2ray.com/core/transport/internet/udp"
 	"v2ray.com/core/transport/internet/udp"
 )
 )
 
 

+ 17 - 0
common/buf/buffer.go

@@ -167,6 +167,23 @@ func (b *Buffer) Read(data []byte) (int, error) {
 	return nBytes, nil
 	return nBytes, nil
 }
 }
 
 
+// ReadFrom implements io.ReaderFrom.
+func (b *Buffer) ReadFrom(reader io.Reader) (int64, error) {
+	n, err := reader.Read(b.v[b.end:])
+	b.end += int32(n)
+	return int64(n), err
+}
+
+func (b *Buffer) ReadFullFrom(reader io.Reader, size int32) (int64, error) {
+	end := b.end + size
+	if end > int32(len(b.v)) {
+		return 0, newError("out of bound: ", end)
+	}
+	n, err := io.ReadFull(reader, b.v[b.end:end])
+	b.end += int32(n)
+	return int64(n), err
+}
+
 // String returns the string form of this Buffer.
 // String returns the string form of this Buffer.
 func (b *Buffer) String() string {
 func (b *Buffer) String() string {
 	return string(b.Bytes())
 	return string(b.Bytes())

+ 20 - 2
common/buf/buffer_test.go

@@ -1,12 +1,13 @@
 package buf_test
 package buf_test
 
 
 import (
 import (
+	"bytes"
+	"crypto/rand"
 	"testing"
 	"testing"
 
 
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
-	"v2ray.com/core/common/compare"
-
 	. "v2ray.com/core/common/buf"
 	. "v2ray.com/core/common/buf"
+	"v2ray.com/core/common/compare"
 	"v2ray.com/core/common/serial"
 	"v2ray.com/core/common/serial"
 	. "v2ray.com/ext/assert"
 	. "v2ray.com/ext/assert"
 )
 )
@@ -73,6 +74,23 @@ func TestBufferSlice(t *testing.T) {
 	}
 	}
 }
 }
 
 
+func TestBufferReadFullFrom(t *testing.T) {
+	payload := make([]byte, 1024)
+	common.Must2(rand.Read(payload))
+
+	reader := bytes.NewReader(payload)
+	b := New()
+	n, err := b.ReadFullFrom(reader, 1024)
+	common.Must(err)
+	if n != 1024 {
+		t.Error("expect reading 1024 bytes, but actually ", n)
+	}
+
+	if err := compare.BytesEqualWithDetail(payload, b.Bytes()); err != nil {
+		t.Error(err)
+	}
+}
+
 func BenchmarkNewBuffer(b *testing.B) {
 func BenchmarkNewBuffer(b *testing.B) {
 	for i := 0; i < b.N; i++ {
 	for i := 0; i < b.N; i++ {
 		buffer := New()
 		buffer := New()

+ 0 - 14
common/buf/io.go

@@ -26,20 +26,6 @@ type Writer interface {
 	WriteMultiBuffer(MultiBuffer) error
 	WriteMultiBuffer(MultiBuffer) error
 }
 }
 
 
-// ReadFrom creates a Supplier to read from a given io.Reader.
-func ReadFrom(reader io.Reader) Supplier {
-	return func(b []byte) (int, error) {
-		return reader.Read(b)
-	}
-}
-
-// ReadFullFrom creates a Supplier to read full buffer from a given io.Reader.
-func ReadFullFrom(reader io.Reader, size int32) Supplier {
-	return func(b []byte) (int, error) {
-		return io.ReadFull(reader, b[:size])
-	}
-}
-
 // WriteAllBytes ensures all bytes are written into the given writer.
 // WriteAllBytes ensures all bytes are written into the given writer.
 func WriteAllBytes(writer io.Writer, payload []byte) error {
 func WriteAllBytes(writer io.Writer, payload []byte) error {
 	for len(payload) > 0 {
 	for len(payload) > 0 {

+ 2 - 2
common/buf/multi_buffer.go

@@ -79,7 +79,7 @@ func (mb *MultiBuffer) ReadFrom(reader io.Reader) (int64, error) {
 
 
 	for {
 	for {
 		b := New()
 		b := New()
-		err := b.Reset(ReadFullFrom(reader, Size))
+		_, err := b.ReadFullFrom(reader, Size)
 		if b.IsEmpty() {
 		if b.IsEmpty() {
 			b.Release()
 			b.Release()
 		} else {
 		} else {
@@ -220,7 +220,7 @@ func (mb *MultiBuffer) SliceBySize(size int32) MultiBuffer {
 	*mb = (*mb)[endIndex:]
 	*mb = (*mb)[endIndex:]
 	if endIndex == 0 && len(*mb) > 0 {
 	if endIndex == 0 && len(*mb) > 0 {
 		b := New()
 		b := New()
-		common.Must(b.Reset(ReadFullFrom((*mb)[0], size)))
+		common.Must2(b.ReadFullFrom((*mb)[0], size))
 		return NewMultiBufferValue(b)
 		return NewMultiBufferValue(b)
 	}
 	}
 	return slice
 	return slice

+ 1 - 1
common/buf/reader.go

@@ -10,7 +10,7 @@ import (
 func readOne(r io.Reader) (*Buffer, error) {
 func readOne(r io.Reader) (*Buffer, error) {
 	b := New()
 	b := New()
 	for i := 0; i < 64; i++ {
 	for i := 0; i < 64; i++ {
-		err := b.Reset(ReadFrom(r))
+		_, err := b.ReadFrom(r)
 		if !b.IsEmpty() {
 		if !b.IsEmpty() {
 			return b, nil
 			return b, nil
 		}
 		}

+ 3 - 2
common/buf/writer.go

@@ -140,7 +140,7 @@ func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error {
 		if w.buffer == nil {
 		if w.buffer == nil {
 			w.buffer = New()
 			w.buffer = New()
 		}
 		}
-		if err := w.buffer.AppendSupplier(ReadFrom(&b)); err != nil {
+		if _, err := w.buffer.ReadFrom(&b); err != nil {
 			return err
 			return err
 		}
 		}
 		if w.buffer.IsFull() {
 		if w.buffer.IsFull() {
@@ -248,7 +248,8 @@ func (noOpWriter) ReadFrom(reader io.Reader) (int64, error) {
 
 
 	totalBytes := int64(0)
 	totalBytes := int64(0)
 	for {
 	for {
-		err := b.Reset(ReadFrom(reader))
+		b.Clear()
+		_, err := b.ReadFrom(reader)
 		totalBytes += int64(b.Len())
 		totalBytes += int64(b.Len())
 		if err != nil {
 		if err != nil {
 			if errors.Cause(err) == io.EOF {
 			if errors.Cause(err) == io.EOF {

+ 2 - 2
common/buf/writer_test.go

@@ -17,7 +17,7 @@ func TestWriter(t *testing.T) {
 	assert := With(t)
 	assert := With(t)
 
 
 	lb := New()
 	lb := New()
-	assert(lb.AppendSupplier(ReadFrom(rand.Reader)), IsNil)
+	common.Must2(lb.ReadFrom(rand.Reader))
 
 
 	expectedBytes := append([]byte(nil), lb.Bytes()...)
 	expectedBytes := append([]byte(nil), lb.Bytes()...)
 
 
@@ -54,7 +54,7 @@ func TestDiscardBytes(t *testing.T) {
 	assert := With(t)
 	assert := With(t)
 
 
 	b := New()
 	b := New()
-	common.Must(b.Reset(ReadFullFrom(rand.Reader, Size)))
+	common.Must2(b.ReadFullFrom(rand.Reader, Size))
 
 
 	nBytes, err := io.Copy(DiscardBytes, b)
 	nBytes, err := io.Copy(DiscardBytes, b)
 	assert(nBytes, Equals, int64(Size))
 	assert(nBytes, Equals, int64(Size))

+ 3 - 5
common/crypto/auth.go

@@ -132,7 +132,7 @@ var errSoft = newError("waiting for more data")
 
 
 func (r *AuthenticationReader) readBuffer(size int32, padding int32) (*buf.Buffer, error) {
 func (r *AuthenticationReader) readBuffer(size int32, padding int32) (*buf.Buffer, error) {
 	b := buf.New()
 	b := buf.New()
-	if err := b.Reset(buf.ReadFullFrom(r.reader, size)); err != nil {
+	if _, err := b.ReadFullFrom(r.reader, size); err != nil {
 		b.Release()
 		b.Release()
 		return nil, err
 		return nil, err
 	}
 	}
@@ -270,7 +270,7 @@ func (w *AuthenticationWriter) seal(b *buf.Buffer) (*buf.Buffer, error) {
 	}
 	}
 	if paddingSize > 0 {
 	if paddingSize > 0 {
 		// With size of the chunk and padding length encrypted, the content of padding doesn't matter much.
 		// With size of the chunk and padding length encrypted, the content of padding doesn't matter much.
-		common.Must(eb.AppendSupplier(buf.ReadFullFrom(w.randReader, int32(paddingSize))))
+		common.Must2(eb.ReadFullFrom(w.randReader, int32(paddingSize)))
 	}
 	}
 
 
 	return eb, nil
 	return eb, nil
@@ -289,9 +289,7 @@ func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
 
 
 	for {
 	for {
 		b := buf.New()
 		b := buf.New()
-		common.Must(b.Reset(func(bb []byte) (int, error) {
-			return mb.Read(bb[:payloadSize])
-		}))
+		common.Must2(b.ReadFrom(io.LimitReader(&mb, int64(payloadSize))))
 		eb, err := w.seal(b)
 		eb, err := w.seal(b)
 		b.Release()
 		b.Release()
 
 

+ 6 - 4
common/mux/frame.go

@@ -1,6 +1,7 @@
 package mux
 package mux
 
 
 import (
 import (
+	"encoding/binary"
 	"io"
 	"io"
 
 
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
@@ -9,6 +10,7 @@ import (
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/serial"
 	"v2ray.com/core/common/serial"
+	"v2ray.com/core/common/vio"
 )
 )
 
 
 type SessionStatus byte
 type SessionStatus byte
@@ -60,11 +62,11 @@ type FrameMetadata struct {
 }
 }
 
 
 func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
 func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
-	lenBytes := b.Bytes()
 	common.Must2(b.WriteBytes(0x00, 0x00))
 	common.Must2(b.WriteBytes(0x00, 0x00))
+	lenBytes := b.Bytes()
 
 
 	len0 := b.Len()
 	len0 := b.Len()
-	if err := b.AppendSupplier(serial.WriteUint16(f.SessionID)); err != nil {
+	if _, err := vio.WriteUint16(b, f.SessionID); err != nil {
 		return err
 		return err
 	}
 	}
 
 
@@ -84,7 +86,7 @@ func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
 	}
 	}
 
 
 	len1 := b.Len()
 	len1 := b.Len()
-	serial.Uint16ToBytes(uint16(len1-len0), lenBytes)
+	binary.BigEndian.PutUint16(lenBytes, uint16(len1-len0))
 	return nil
 	return nil
 }
 }
 
 
@@ -101,7 +103,7 @@ func (f *FrameMetadata) Unmarshal(reader io.Reader) error {
 	b := buf.New()
 	b := buf.New()
 	defer b.Release()
 	defer b.Release()
 
 
-	if err := b.Reset(buf.ReadFullFrom(reader, int32(metaLen))); err != nil {
+	if _, err := b.ReadFullFrom(reader, int32(metaLen)); err != nil {
 		return err
 		return err
 	}
 	}
 	return f.UnmarshalFromBuffer(b)
 	return f.UnmarshalFromBuffer(b)

+ 1 - 1
common/mux/reader.go

@@ -38,7 +38,7 @@ func (r *PacketReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 	}
 	}
 
 
 	b := buf.New()
 	b := buf.New()
-	if err := b.Reset(buf.ReadFullFrom(r.reader, int32(size))); err != nil {
+	if _, err := b.ReadFullFrom(r.reader, int32(size)); err != nil {
 		b.Release()
 		b.Release()
 		return nil, err
 		return nil, err
 	}
 	}

+ 2 - 2
common/mux/writer.go

@@ -5,7 +5,7 @@ import (
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/protocol"
-	"v2ray.com/core/common/serial"
+	"v2ray.com/core/common/vio"
 )
 )
 
 
 type Writer struct {
 type Writer struct {
@@ -66,7 +66,7 @@ func writeMetaWithFrame(writer buf.Writer, meta FrameMetadata, data buf.MultiBuf
 	if err := meta.WriteTo(frame); err != nil {
 	if err := meta.WriteTo(frame); err != nil {
 		return err
 		return err
 	}
 	}
-	if err := frame.AppendSupplier(serial.WriteUint16(uint16(data.Len()))); err != nil {
+	if _, err := vio.WriteUint16(frame, uint16(data.Len())); err != nil {
 		return err
 		return err
 	}
 	}
 
 

+ 6 - 6
common/protocol/address.go

@@ -53,7 +53,7 @@ func NewAddressParser(options ...AddressOption) *AddressParser {
 }
 }
 
 
 func (p *AddressParser) readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) {
 func (p *AddressParser) readPort(b *buf.Buffer, reader io.Reader) (net.Port, error) {
-	if err := b.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil {
+	if _, err := b.ReadFullFrom(reader, 2); err != nil {
 		return 0, err
 		return 0, err
 	}
 	}
 	return net.PortFromBytes(b.BytesFrom(-2)), nil
 	return net.PortFromBytes(b.BytesFrom(-2)), nil
@@ -73,7 +73,7 @@ func isValidDomain(d string) bool {
 }
 }
 
 
 func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) {
 func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Address, error) {
-	if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
+	if _, err := b.ReadFullFrom(reader, 1); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
@@ -89,21 +89,21 @@ func (p *AddressParser) readAddress(b *buf.Buffer, reader io.Reader) (net.Addres
 
 
 	switch addrFamily {
 	switch addrFamily {
 	case net.AddressFamilyIPv4:
 	case net.AddressFamilyIPv4:
-		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 4)); err != nil {
+		if _, err := b.ReadFullFrom(reader, 4); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 		return net.IPAddress(b.BytesFrom(-4)), nil
 		return net.IPAddress(b.BytesFrom(-4)), nil
 	case net.AddressFamilyIPv6:
 	case net.AddressFamilyIPv6:
-		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 16)); err != nil {
+		if _, err := b.ReadFullFrom(reader, 16); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 		return net.IPAddress(b.BytesFrom(-16)), nil
 		return net.IPAddress(b.BytesFrom(-16)), nil
 	case net.AddressFamilyDomain:
 	case net.AddressFamilyDomain:
-		if err := b.AppendSupplier(buf.ReadFullFrom(reader, 1)); err != nil {
+		if _, err := b.ReadFullFrom(reader, 1); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 		domainLength := int32(b.Byte(b.Len() - 1))
 		domainLength := int32(b.Byte(b.Len() - 1))
-		if err := b.AppendSupplier(buf.ReadFullFrom(reader, domainLength)); err != nil {
+		if _, err := b.ReadFullFrom(reader, domainLength); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 		domain := string(b.BytesFrom(-domainLength))
 		domain := string(b.BytesFrom(-domainLength))

+ 0 - 7
common/serial/numbers.go

@@ -20,13 +20,6 @@ func ReadUint16(reader io.Reader) (uint16, error) {
 	return BytesToUint16(b[:]), nil
 	return BytesToUint16(b[:]), nil
 }
 }
 
 
-func WriteUint16(value uint16) func([]byte) (int, error) {
-	return func(b []byte) (int, error) {
-		Uint16ToBytes(value, b[:0])
-		return 2, nil
-	}
-}
-
 func Uint32ToBytes(value uint32, b []byte) []byte {
 func Uint32ToBytes(value uint32, b []byte) []byte {
 	return append(b, byte(value>>24), byte(value>>16), byte(value>>8), byte(value))
 	return append(b, byte(value>>24), byte(value>>16), byte(value>>8), byte(value))
 }
 }

+ 18 - 0
common/vio/serial.go

@@ -0,0 +1,18 @@
+package vio
+
+import (
+	"encoding/binary"
+	"io"
+)
+
+func WriteUint32(writer io.Writer, value uint32) (int, error) {
+	var b [4]byte
+	binary.BigEndian.PutUint32(b[:], value)
+	return writer.Write(b[:])
+}
+
+func WriteUint16(writer io.Writer, value uint16) (int, error) {
+	var b [2]byte
+	binary.BigEndian.PutUint16(b[:], value)
+	return writer.Write(b[:])
+}

+ 24 - 0
common/vio/serial_test.go

@@ -0,0 +1,24 @@
+package vio_test
+
+import (
+	"testing"
+
+	"v2ray.com/core/common"
+	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/compare"
+	"v2ray.com/core/common/vio"
+)
+
+func TestUint32Serial(t *testing.T) {
+	b := buf.New()
+	defer b.Release()
+
+	n, err := vio.WriteUint32(b, 10)
+	common.Must(err)
+	if n != 4 {
+		t.Error("expect 4 bytes writtng, but actually ", n)
+	}
+	if err := compare.BytesEqualWithDetail(b.Bytes(), []byte{0, 0, 0, 10}); err != nil {
+		t.Error(err)
+	}
+}

+ 4 - 4
proxy/shadowsocks/protocol.go

@@ -36,7 +36,7 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ
 	ivLen := account.Cipher.IVSize()
 	ivLen := account.Cipher.IVSize()
 	var iv []byte
 	var iv []byte
 	if ivLen > 0 {
 	if ivLen > 0 {
-		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, ivLen)); err != nil {
+		if _, err := buffer.ReadFullFrom(reader, ivLen); err != nil {
 			return nil, nil, newError("failed to read IV").Base(err)
 			return nil, nil, newError("failed to read IV").Base(err)
 		}
 		}
 
 
@@ -85,7 +85,7 @@ func ReadTCPSession(user *protocol.MemoryUser, reader io.Reader) (*protocol.Requ
 		actualAuth := make([]byte, AuthSize)
 		actualAuth := make([]byte, AuthSize)
 		authenticator.Authenticate(buffer.Bytes())(actualAuth)
 		authenticator.Authenticate(buffer.Bytes())(actualAuth)
 
 
-		err := buffer.AppendSupplier(buf.ReadFullFrom(br, AuthSize))
+		_, err := buffer.ReadFullFrom(br, AuthSize)
 		if err != nil {
 		if err != nil {
 			return nil, nil, newError("Failed to read OTA").Base(err)
 			return nil, nil, newError("Failed to read OTA").Base(err)
 		}
 		}
@@ -196,7 +196,7 @@ func EncodeUDPPacket(request *protocol.RequestHeader, payload []byte) (*buf.Buff
 	buffer := buf.New()
 	buffer := buf.New()
 	ivLen := account.Cipher.IVSize()
 	ivLen := account.Cipher.IVSize()
 	if ivLen > 0 {
 	if ivLen > 0 {
-		common.Must(buffer.Reset(buf.ReadFullFrom(rand.Reader, ivLen)))
+		common.Must2(buffer.ReadFullFrom(rand.Reader, ivLen))
 	}
 	}
 	iv := buffer.Bytes()
 	iv := buffer.Bytes()
 
 
@@ -287,7 +287,7 @@ type UDPReader struct {
 
 
 func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 func (v *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 	buffer := buf.New()
 	buffer := buf.New()
-	err := buffer.AppendSupplier(buf.ReadFrom(v.Reader))
+	_, err := buffer.ReadFrom(v.Reader)
 	if err != nil {
 	if err != nil {
 		buffer.Release()
 		buffer.Release()
 		return nil, err
 		return nil, err

+ 22 - 14
proxy/socks/protocol.go

@@ -7,7 +7,7 @@ import (
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/protocol"
-	"v2ray.com/core/common/serial"
+	"v2ray.com/core/common/vio"
 )
 )
 
 
 const (
 const (
@@ -49,7 +49,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
 
 
 	request := new(protocol.RequestHeader)
 	request := new(protocol.RequestHeader)
 
 
-	if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 2)); err != nil {
+	if _, err := buffer.ReadFullFrom(reader, 2); err != nil {
 		return nil, newError("insufficient header").Base(err)
 		return nil, newError("insufficient header").Base(err)
 	}
 	}
 
 
@@ -60,7 +60,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
 			return nil, newError("socks 4 is not allowed when auth is required.")
 			return nil, newError("socks 4 is not allowed when auth is required.")
 		}
 		}
 
 
-		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, 6)); err != nil {
+		if _, err := buffer.ReadFullFrom(reader, 6); err != nil {
 			return nil, newError("insufficient header").Base(err)
 			return nil, newError("insufficient header").Base(err)
 		}
 		}
 		port := net.PortFromBytes(buffer.BytesRange(2, 4))
 		port := net.PortFromBytes(buffer.BytesRange(2, 4))
@@ -94,7 +94,7 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
 
 
 	if version == socks5Version {
 	if version == socks5Version {
 		nMethod := int32(buffer.Byte(1))
 		nMethod := int32(buffer.Byte(1))
-		if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, nMethod)); err != nil {
+		if _, err := buffer.ReadFullFrom(reader, nMethod); err != nil {
 			return nil, newError("failed to read auth methods").Base(err)
 			return nil, newError("failed to read auth methods").Base(err)
 		}
 		}
 
 
@@ -127,7 +127,9 @@ func (s *ServerSession) Handshake(reader io.Reader, writer io.Writer) (*protocol
 				return nil, newError("failed to write auth response").Base(err)
 				return nil, newError("failed to write auth response").Base(err)
 			}
 			}
 		}
 		}
-		if err := buffer.Reset(buf.ReadFullFrom(reader, 3)); err != nil {
+
+		buffer.Clear()
+		if _, err := buffer.ReadFullFrom(reader, 3); err != nil {
 			return nil, newError("failed to read request").Base(err)
 			return nil, newError("failed to read request").Base(err)
 		}
 		}
 
 
@@ -185,21 +187,25 @@ func readUsernamePassword(reader io.Reader) (string, string, error) {
 	buffer := buf.New()
 	buffer := buf.New()
 	defer buffer.Release()
 	defer buffer.Release()
 
 
-	if err := buffer.Reset(buf.ReadFullFrom(reader, 2)); err != nil {
+	if _, err := buffer.ReadFullFrom(reader, 2); err != nil {
 		return "", "", err
 		return "", "", err
 	}
 	}
 	nUsername := int32(buffer.Byte(1))
 	nUsername := int32(buffer.Byte(1))
 
 
-	if err := buffer.Reset(buf.ReadFullFrom(reader, nUsername)); err != nil {
+	buffer.Clear()
+	if _, err := buffer.ReadFullFrom(reader, nUsername); err != nil {
 		return "", "", err
 		return "", "", err
 	}
 	}
 	username := buffer.String()
 	username := buffer.String()
 
 
-	if err := buffer.Reset(buf.ReadFullFrom(reader, 1)); err != nil {
+	buffer.Clear()
+	if _, err := buffer.ReadFullFrom(reader, 1); err != nil {
 		return "", "", err
 		return "", "", err
 	}
 	}
 	nPassword := int32(buffer.Byte(0))
 	nPassword := int32(buffer.Byte(0))
-	if err := buffer.Reset(buf.ReadFullFrom(reader, nPassword)); err != nil {
+
+	buffer.Clear()
+	if _, err := buffer.ReadFullFrom(reader, nPassword); err != nil {
 		return "", "", err
 		return "", "", err
 	}
 	}
 	password := buffer.String()
 	password := buffer.String()
@@ -254,7 +260,7 @@ func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, po
 	defer buffer.Release()
 	defer buffer.Release()
 
 
 	common.Must2(buffer.WriteBytes(0x00, errCode))
 	common.Must2(buffer.WriteBytes(0x00, errCode))
-	common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value())))
+	common.Must2(vio.WriteUint16(buffer, port.Value()))
 	common.Must2(buffer.Write(address.IP()))
 	common.Must2(buffer.Write(address.IP()))
 	return buf.WriteAllBytes(writer, buffer.Bytes())
 	return buf.WriteAllBytes(writer, buffer.Bytes())
 }
 }
@@ -305,7 +311,7 @@ func NewUDPReader(reader io.Reader) *UDPReader {
 
 
 func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 func (r *UDPReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 	b := buf.New()
 	b := buf.New()
-	if err := b.AppendSupplier(buf.ReadFrom(r.reader)); err != nil {
+	if _, err := b.ReadFrom(r.reader); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 	if _, err := DecodeUDPPacket(b); err != nil {
 	if _, err := DecodeUDPPacket(b); err != nil {
@@ -362,7 +368,8 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 		return nil, err
 		return nil, err
 	}
 	}
 
 
-	if err := b.Reset(buf.ReadFullFrom(reader, 2)); err != nil {
+	b.Clear()
+	if _, err := b.ReadFullFrom(reader, 2); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
@@ -374,7 +381,8 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 	}
 	}
 
 
 	if authByte == authPassword {
 	if authByte == authPassword {
-		if err := b.Reset(buf.ReadFullFrom(reader, 2)); err != nil {
+		b.Clear()
+		if _, err := b.ReadFullFrom(reader, 2); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
 		if b.Byte(1) != 0x00 {
 		if b.Byte(1) != 0x00 {
@@ -398,7 +406,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 	}
 	}
 
 
 	b.Clear()
 	b.Clear()
-	if err := b.AppendSupplier(buf.ReadFullFrom(reader, 3)); err != nil {
+	if _, err := b.ReadFullFrom(reader, 3); err != nil {
 		return nil, err
 		return nil, err
 	}
 	}
 
 

+ 4 - 3
proxy/vmess/encoding/client.go

@@ -80,7 +80,7 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
 	}
 	}
 
 
 	if padingLen > 0 {
 	if padingLen > 0 {
-		common.Must(buffer.AppendSupplier(buf.ReadFullFrom(rand.Reader, int32(padingLen))))
+		common.Must2(buffer.ReadFullFrom(rand.Reader, int32(padingLen)))
 	}
 	}
 
 
 	{
 	{
@@ -164,7 +164,7 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
 	buffer := buf.New()
 	buffer := buf.New()
 	defer buffer.Release()
 	defer buffer.Release()
 
 
-	if err := buffer.AppendSupplier(buf.ReadFullFrom(c.responseReader, 4)); err != nil {
+	if _, err := buffer.ReadFullFrom(c.responseReader, 4); err != nil {
 		return nil, newError("failed to read response header").Base(err)
 		return nil, newError("failed to read response header").Base(err)
 	}
 	}
 
 
@@ -180,7 +180,8 @@ func (c *ClientSession) DecodeResponseHeader(reader io.Reader) (*protocol.Respon
 		cmdID := buffer.Byte(2)
 		cmdID := buffer.Byte(2)
 		dataLen := int32(buffer.Byte(3))
 		dataLen := int32(buffer.Byte(3))
 
 
-		if err := buffer.Reset(buf.ReadFullFrom(c.responseReader, dataLen)); err != nil {
+		buffer.Clear()
+		if _, err := buffer.ReadFullFrom(c.responseReader, dataLen); err != nil {
 			return nil, newError("failed to read response command").Base(err)
 			return nil, newError("failed to read response command").Base(err)
 		}
 		}
 		command, err := UnmarshalCommand(cmdID, buffer.Bytes())
 		command, err := UnmarshalCommand(cmdID, buffer.Bytes())

+ 5 - 4
proxy/vmess/encoding/server.go

@@ -125,7 +125,7 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 	buffer := buf.New()
 	buffer := buf.New()
 	defer buffer.Release()
 	defer buffer.Release()
 
 
-	if err := buffer.AppendSupplier(buf.ReadFullFrom(reader, protocol.IDBytesLen)); err != nil {
+	if _, err := buffer.ReadFullFrom(reader, protocol.IDBytesLen); err != nil {
 		return nil, newError("failed to read request header").Base(err)
 		return nil, newError("failed to read request header").Base(err)
 	}
 	}
 
 
@@ -140,7 +140,8 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 	aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:])
 	aesStream := crypto.NewAesDecryptionStream(vmessAccount.ID.CmdKey(), iv[:])
 	decryptor := crypto.NewCryptionReader(aesStream, reader)
 	decryptor := crypto.NewCryptionReader(aesStream, reader)
 
 
-	if err := buffer.Reset(buf.ReadFullFrom(decryptor, 38)); err != nil {
+	buffer.Clear()
+	if _, err := buffer.ReadFullFrom(decryptor, 38); err != nil {
 		return nil, newError("failed to read request header").Base(err)
 		return nil, newError("failed to read request header").Base(err)
 	}
 	}
 
 
@@ -178,12 +179,12 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 	}
 	}
 
 
 	if padingLen > 0 {
 	if padingLen > 0 {
-		if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, int32(padingLen))); err != nil {
+		if _, err := buffer.ReadFullFrom(decryptor, int32(padingLen)); err != nil {
 			return nil, newError("failed to read padding").Base(err)
 			return nil, newError("failed to read padding").Base(err)
 		}
 		}
 	}
 	}
 
 
-	if err := buffer.AppendSupplier(buf.ReadFullFrom(decryptor, 4)); err != nil {
+	if _, err := buffer.ReadFullFrom(decryptor, 4); err != nil {
 		return nil, newError("failed to read checksum").Base(err)
 		return nil, newError("failed to read checksum").Base(err)
 	}
 	}
 
 

+ 1 - 1
testing/servers/tcp/tcp.go

@@ -69,7 +69,7 @@ func (server *Server) handleConnection(conn net.Conn) {
 
 
 		for {
 		for {
 			b := buf.New()
 			b := buf.New()
-			if err := b.AppendSupplier(buf.ReadFrom(conn)); err != nil {
+			if _, err := b.ReadFrom(conn); err != nil {
 				if err == io.EOF {
 				if err == io.EOF {
 					return nil
 					return nil
 				}
 				}

+ 4 - 4
transport/internet/domainsocket/listener_test.go

@@ -28,7 +28,7 @@ func TestListen(t *testing.T) {
 		defer conn.Close()
 		defer conn.Close()
 
 
 		b := buf.New()
 		b := buf.New()
-		common.Must(b.Reset(buf.ReadFrom(conn)))
+		common.Must2(b.ReadFrom(conn))
 		assert(b.String(), Equals, "Request")
 		assert(b.String(), Equals, "Request")
 
 
 		common.Must2(conn.Write([]byte("Response")))
 		common.Must2(conn.Write([]byte("Response")))
@@ -44,7 +44,7 @@ func TestListen(t *testing.T) {
 	assert(err, IsNil)
 	assert(err, IsNil)
 
 
 	b := buf.New()
 	b := buf.New()
-	common.Must(b.Reset(buf.ReadFrom(conn)))
+	common.Must2(b.ReadFrom(conn))
 
 
 	assert(b.String(), Equals, "Response")
 	assert(b.String(), Equals, "Response")
 }
 }
@@ -67,7 +67,7 @@ func TestListenAbstract(t *testing.T) {
 		defer conn.Close()
 		defer conn.Close()
 
 
 		b := buf.New()
 		b := buf.New()
-		common.Must(b.Reset(buf.ReadFrom(conn)))
+		common.Must2(b.ReadFrom(conn))
 		assert(b.String(), Equals, "Request")
 		assert(b.String(), Equals, "Request")
 
 
 		common.Must2(conn.Write([]byte("Response")))
 		common.Must2(conn.Write([]byte("Response")))
@@ -83,7 +83,7 @@ func TestListenAbstract(t *testing.T) {
 	assert(err, IsNil)
 	assert(err, IsNil)
 
 
 	b := buf.New()
 	b := buf.New()
-	common.Must(b.Reset(buf.ReadFrom(conn)))
+	common.Must2(b.ReadFrom(conn))
 
 
 	assert(b.String(), Equals, "Response")
 	assert(b.String(), Equals, "Response")
 }
 }

+ 1 - 1
transport/internet/headers/http/http.go

@@ -60,7 +60,7 @@ func (*HeaderReader) Read(reader io.Reader) (*buf.Buffer, error) {
 	totalBytes := int32(0)
 	totalBytes := int32(0)
 	endingDetected := false
 	endingDetected := false
 	for totalBytes < maxHeaderLength {
 	for totalBytes < maxHeaderLength {
-		err := buffer.AppendSupplier(buf.ReadFrom(reader))
+		_, err := buffer.ReadFrom(reader)
 		if err != nil {
 		if err != nil {
 			buffer.Release()
 			buffer.Release()
 			return nil, err
 			return nil, err

+ 5 - 3
transport/internet/http/http_test.go

@@ -39,7 +39,7 @@ func TestHTTPConnection(t *testing.T) {
 			defer b.Release()
 			defer b.Release()
 
 
 			for {
 			for {
-				if err := b.Reset(buf.ReadFrom(conn)); err != nil {
+				if _, err := b.ReadFrom(conn); err != nil {
 					return
 					return
 				}
 				}
 				nBytes, err := conn.Write(b.Bytes())
 				nBytes, err := conn.Write(b.Bytes())
@@ -76,13 +76,15 @@ func TestHTTPConnection(t *testing.T) {
 	assert(nBytes, Equals, N)
 	assert(nBytes, Equals, N)
 	assert(err, IsNil)
 	assert(err, IsNil)
 
 
-	assert(b2.Reset(buf.ReadFullFrom(conn, N)), IsNil)
+	b2.Clear()
+	common.Must2(b2.ReadFullFrom(conn, N))
 	assert(b2.Bytes(), Equals, b1)
 	assert(b2.Bytes(), Equals, b1)
 
 
 	nBytes, err = conn.Write(b1)
 	nBytes, err = conn.Write(b1)
 	assert(nBytes, Equals, N)
 	assert(nBytes, Equals, N)
 	assert(err, IsNil)
 	assert(err, IsNil)
 
 
-	assert(b2.Reset(buf.ReadFullFrom(conn, N)), IsNil)
+	b2.Clear()
+	common.Must2(b2.ReadFullFrom(conn, N))
 	assert(b2.Bytes(), Equals, b1)
 	assert(b2.Bytes(), Equals, b1)
 }
 }

+ 1 - 1
transport/internet/kcp/dialer.go

@@ -23,7 +23,7 @@ func fetchInput(ctx context.Context, input io.Reader, reader PacketReader, conn
 	go func() {
 	go func() {
 		for {
 		for {
 			payload := buf.New()
 			payload := buf.New()
-			if err := payload.Reset(buf.ReadFrom(input)); err != nil {
+			if _, err := payload.ReadFrom(input); err != nil {
 				payload.Release()
 				payload.Release()
 				close(cache)
 				close(cache)
 				return
 				return

+ 2 - 3
transport/internet/kcp/sending.go

@@ -2,6 +2,7 @@ package kcp
 
 
 import (
 import (
 	"container/list"
 	"container/list"
+	"io"
 	"sync"
 	"sync"
 
 
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
@@ -274,9 +275,7 @@ func (w *SendingWorker) Push(mb *buf.MultiBuffer) bool {
 	}
 	}
 
 
 	b := buf.New()
 	b := buf.New()
-	common.Must(b.Reset(func(v []byte) (int, error) {
-		return mb.Read(v[:w.conn.mss])
-	}))
+	common.Must2(b.ReadFrom(io.LimitReader(mb, int64(w.conn.mss))))
 	w.window.Push(w.nextNumber, b)
 	w.window.Push(w.nextNumber, b)
 	w.nextNumber++
 	w.nextNumber++
 	return true
 	return true

+ 1 - 1
transport/internet/sockopt_test.go

@@ -40,7 +40,7 @@ func TestTCPFastOpen(t *testing.T) {
 	common.Must(err)
 	common.Must(err)
 
 
 	b := buf.New()
 	b := buf.New()
-	common.Must(b.Reset(buf.ReadFrom(conn)))
+	common.Must2(b.ReadFrom(conn))
 	if err := compare.BytesEqualWithDetail(b.Bytes(), []byte("abcd")); err != nil {
 	if err := compare.BytesEqualWithDetail(b.Bytes(), []byte("abcd")); err != nil {
 		t.Fatal(err)
 		t.Fatal(err)
 	}
 	}