فهرست منبع

refine buffer interface

Darien Raymond 7 سال پیش
والد
کامیت
206f52affc

+ 4 - 4
app/proxyman/mux/frame.go

@@ -59,21 +59,21 @@ type FrameMetadata struct {
 
 func (f FrameMetadata) WriteTo(b *buf.Buffer) error {
 	lenBytes := b.Bytes()
-	common.Must2(b.AppendBytes(0x00, 0x00))
+	common.Must2(b.WriteBytes(0x00, 0x00))
 
 	len0 := b.Len()
 	if err := b.AppendSupplier(serial.WriteUint16(f.SessionID)); err != nil {
 		return err
 	}
 
-	common.Must2(b.AppendBytes(byte(f.SessionStatus), byte(f.Option)))
+	common.Must2(b.WriteBytes(byte(f.SessionStatus), byte(f.Option)))
 
 	if f.SessionStatus == SessionStatusNew {
 		switch f.Target.Network {
 		case net.Network_TCP:
-			common.Must2(b.AppendBytes(byte(TargetNetworkTCP)))
+			common.Must2(b.WriteBytes(byte(TargetNetworkTCP)))
 		case net.Network_UDP:
-			common.Must2(b.AppendBytes(byte(TargetNetworkUDP)))
+			common.Must2(b.WriteBytes(byte(TargetNetworkUDP)))
 		}
 
 		if err := addrParser.WriteAddressPort(b, f.Target.Address, f.Target.Port); err != nil {

+ 1 - 1
app/proxyman/mux/reader.go

@@ -17,7 +17,7 @@ func ReadMetadata(reader io.Reader) (*FrameMetadata, error) {
 		return nil, newError("invalid metalen ", metaLen).AtError()
 	}
 
-	b := buf.NewSize(int32(metaLen))
+	b := buf.New()
 	defer b.Release()
 
 	if err := b.Reset(buf.ReadFullFrom(reader, int32(metaLen))); err != nil {

+ 5 - 5
common/buf/buffer.go

@@ -35,11 +35,6 @@ func (b *Buffer) Clear() {
 	b.end = 0
 }
 
-// AppendBytes appends one or more bytes to the end of the buffer.
-func (b *Buffer) AppendBytes(bytes ...byte) (int, error) {
-	return b.Write(bytes)
-}
-
 // AppendSupplier appends the content of a BytesWriter to the buffer.
 func (b *Buffer) AppendSupplier(writer Supplier) error {
 	nBytes, err := writer(b.v[b.end:])
@@ -145,6 +140,11 @@ func (b *Buffer) Write(data []byte) (int, error) {
 	return nBytes, nil
 }
 
+// WriteBytes appends one or more bytes to the end of the buffer.
+func (b *Buffer) WriteBytes(bytes ...byte) (int, error) {
+	return b.Write(bytes)
+}
+
 // Read implements io.Reader.Read().
 func (b *Buffer) Read(data []byte) (int, error) {
 	if b.Len() == 0 {

+ 5 - 22
common/buf/multi_buffer.go

@@ -21,26 +21,6 @@ func ReadAllToMultiBuffer(reader io.Reader) (MultiBuffer, error) {
 	return mb, nil
 }
 
-// ReadSizeToMultiBuffer reads specific number of bytes from reader into a MultiBuffer.
-func ReadSizeToMultiBuffer(reader io.Reader, size int32) (MultiBuffer, error) {
-	mb := NewMultiBufferCap(32)
-
-	for size > 0 {
-		bSize := size
-		if bSize > Size {
-			bSize = Size
-		}
-		b := NewSize(bSize)
-		if err := b.Reset(ReadFullFrom(reader, bSize)); err != nil {
-			mb.Release()
-			return nil, err
-		}
-		size -= bSize
-		mb.Append(b)
-	}
-	return mb, nil
-}
-
 // ReadAllToBytes reads all content from the reader into a byte array, until EOF.
 func ReadAllToBytes(reader io.Reader) ([]byte, error) {
 	mb, err := ReadAllToMultiBuffer(reader)
@@ -100,7 +80,7 @@ func (mb *MultiBuffer) ReadFrom(reader io.Reader) (int64, error) {
 
 	for {
 		b := New()
-		err := b.Reset(ReadFrom(reader))
+		err := b.Reset(ReadFullFrom(reader, Size))
 		if b.IsEmpty() {
 			b.Release()
 		} else {
@@ -108,7 +88,7 @@ func (mb *MultiBuffer) ReadFrom(reader io.Reader) (int64, error) {
 		}
 		totalBytes += int64(b.Len())
 		if err != nil {
-			if errors.Cause(err) == io.EOF {
+			if errors.Cause(err) == io.EOF || errors.Cause(err) == io.ErrUnexpectedEOF {
 				return totalBytes, nil
 			}
 			return totalBytes, err
@@ -178,6 +158,9 @@ func (mb *MultiBuffer) Write(b []byte) (int, error) {
 // WriteMultiBuffer implements Writer.
 func (mb *MultiBuffer) WriteMultiBuffer(b MultiBuffer) error {
 	*mb = append(*mb, b...)
+	for i := range b {
+		b[i] = nil
+	}
 	return nil
 }
 

+ 3 - 3
common/buf/multi_buffer_test.go

@@ -14,10 +14,10 @@ func TestMultiBufferRead(t *testing.T) {
 	assert := With(t)
 
 	b1 := New()
-	b1.AppendBytes('a', 'b')
+	b1.WriteBytes('a', 'b')
 
 	b2 := New()
-	b2.AppendBytes('c', 'd')
+	b2.WriteBytes('c', 'd')
 	mb := NewMultiBufferValue(b1, b2)
 
 	bs := make([]byte, 32)
@@ -32,7 +32,7 @@ func TestMultiBufferAppend(t *testing.T) {
 
 	var mb MultiBuffer
 	b := New()
-	b.AppendBytes('a', 'b')
+	b.WriteBytes('a', 'b')
 	mb.Append(b)
 	assert(mb.Len(), Equals, int32(2))
 }

+ 4 - 4
common/buf/reader_test.go

@@ -37,9 +37,9 @@ func TestBytesReaderWriteTo(t *testing.T) {
 	pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024))
 	reader := &BufferedReader{Reader: pReader}
 	b1 := New()
-	b1.AppendBytes('a', 'b', 'c')
+	b1.WriteBytes('a', 'b', 'c')
 	b2 := New()
-	b2.AppendBytes('e', 'f', 'g')
+	b2.WriteBytes('e', 'f', 'g')
 	assert(pWriter.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil)
 	pWriter.Close()
 
@@ -64,9 +64,9 @@ func TestBytesReaderMultiBuffer(t *testing.T) {
 	pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024))
 	reader := &BufferedReader{Reader: pReader}
 	b1 := New()
-	b1.AppendBytes('a', 'b', 'c')
+	b1.WriteBytes('a', 'b', 'c')
 	b2 := New()
-	b2.AppendBytes('e', 'f', 'g')
+	b2.WriteBytes('e', 'f', 'g')
 	assert(pWriter.WriteMultiBuffer(NewMultiBufferValue(b1, b2)), IsNil)
 	pWriter.Close()
 

+ 2 - 2
common/crypto/chunk_test.go

@@ -19,11 +19,11 @@ func TestChunkStreamIO(t *testing.T) {
 	reader := NewChunkStreamReader(PlainChunkSizeParser{}, cache)
 
 	b := buf.New()
-	b.AppendBytes('a', 'b', 'c', 'd')
+	b.WriteBytes('a', 'b', 'c', 'd')
 	common.Must(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)))
 
 	b = buf.New()
-	b.AppendBytes('e', 'f', 'g')
+	b.WriteBytes('e', 'f', 'g')
 	common.Must(writer.WriteMultiBuffer(buf.NewMultiBufferValue(b)))
 
 	common.Must(writer.WriteMultiBuffer(buf.MultiBuffer{}))

+ 2 - 1
proxy/http/server.go

@@ -182,7 +182,8 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 	}
 
 	if reader.Buffered() > 0 {
-		payload, err := buf.ReadSizeToMultiBuffer(reader, int32(reader.Buffered()))
+		var payload buf.MultiBuffer
+		_, err := payload.ReadFrom(&io.LimitedReader{R: reader, N: int64(reader.Buffered())})
 		if err != nil {
 			return err
 		}

+ 1 - 1
proxy/shadowsocks/ota_test.go

@@ -12,7 +12,7 @@ func TestNormalChunkReading(t *testing.T) {
 	assert := With(t)
 
 	buffer := buf.New()
-	buffer.AppendBytes(
+	buffer.WriteBytes(
 		0, 8, 39, 228, 69, 96, 133, 39, 254, 26, 201, 70, 11, 12, 13, 14, 15, 16, 17, 18)
 	reader := NewChunkReader(buffer, NewAuthenticator(ChunkKeyGenerator(
 		[]byte{21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36})))

+ 7 - 7
proxy/socks/protocol.go

@@ -241,7 +241,7 @@ func writeSocks5Response(writer io.Writer, errCode byte, address net.Address, po
 	buffer := buf.New()
 	defer buffer.Release()
 
-	common.Must2(buffer.AppendBytes(socks5Version, errCode, 0x00 /* reserved */))
+	common.Must2(buffer.WriteBytes(socks5Version, errCode, 0x00 /* reserved */))
 	if err := addrParser.WriteAddressPort(buffer, address, port); err != nil {
 		return err
 	}
@@ -253,7 +253,7 @@ func writeSocks4Response(writer io.Writer, errCode byte, address net.Address, po
 	buffer := buf.New()
 	defer buffer.Release()
 
-	common.Must2(buffer.AppendBytes(0x00, errCode))
+	common.Must2(buffer.WriteBytes(0x00, errCode))
 	common.Must(buffer.AppendSupplier(serial.WriteUint16(port.Value())))
 	common.Must2(buffer.Write(address.IP()))
 	return buf.WriteAllBytes(writer, buffer.Bytes())
@@ -286,7 +286,7 @@ func DecodeUDPPacket(packet *buf.Buffer) (*protocol.RequestHeader, error) {
 
 func EncodeUDPPacket(request *protocol.RequestHeader, data []byte) (*buf.Buffer, error) {
 	b := buf.New()
-	common.Must2(b.AppendBytes(0, 0, 0 /* Fragment */))
+	common.Must2(b.WriteBytes(0, 0, 0 /* Fragment */))
 	if err := addrParser.WriteAddressPort(b, request.Address, request.Port); err != nil {
 		b.Release()
 		return nil, err
@@ -348,7 +348,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 	b := buf.New()
 	defer b.Release()
 
-	common.Must2(b.AppendBytes(socks5Version, 0x01, authByte))
+	common.Must2(b.WriteBytes(socks5Version, 0x01, authByte))
 	if authByte == authPassword {
 		rawAccount, err := request.User.GetTypedAccount()
 		if err != nil {
@@ -356,9 +356,9 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 		}
 		account := rawAccount.(*Account)
 
-		common.Must2(b.AppendBytes(0x01, byte(len(account.Username))))
+		common.Must2(b.WriteBytes(0x01, byte(len(account.Username))))
 		common.Must2(b.Write([]byte(account.Username)))
-		common.Must2(b.AppendBytes(byte(len(account.Password))))
+		common.Must2(b.WriteBytes(byte(len(account.Password))))
 		common.Must2(b.Write([]byte(account.Password)))
 	}
 
@@ -392,7 +392,7 @@ func ClientHandshake(request *protocol.RequestHeader, reader io.Reader, writer i
 	if request.Command == protocol.RequestCommandUDP {
 		command = byte(cmdUDPPort)
 	}
-	common.Must2(b.AppendBytes(socks5Version, command, 0x00 /* reserved */))
+	common.Must2(b.WriteBytes(socks5Version, command, 0x00 /* reserved */))
 	if err := addrParser.WriteAddressPort(b, request.Address, request.Port); err != nil {
 		return nil, err
 	}

+ 5 - 5
proxy/vmess/encoding/client.go

@@ -69,14 +69,14 @@ func (c *ClientSession) EncodeRequestHeader(header *protocol.RequestHeader, writ
 	buffer := buf.New()
 	defer buffer.Release()
 
-	buffer.AppendBytes(Version)
-	buffer.Write(c.requestBodyIV[:])
-	buffer.Write(c.requestBodyKey[:])
-	buffer.AppendBytes(c.responseHeader, byte(header.Option))
+	common.Must2(buffer.WriteBytes(Version))
+	common.Must2(buffer.Write(c.requestBodyIV[:]))
+	common.Must2(buffer.Write(c.requestBodyKey[:]))
+	common.Must2(buffer.WriteBytes(c.responseHeader, byte(header.Option)))
 
 	padingLen := dice.Roll(16)
 	security := byte(padingLen<<4) | byte(header.Security)
-	buffer.AppendBytes(security, byte(0), byte(header.Command))
+	common.Must2(buffer.WriteBytes(security, byte(0), byte(header.Command)))
 
 	if header.Command != protocol.RequestCommandMux {
 		if err := addrParser.WriteAddressPort(buffer, header.Address, header.Port); err != nil {

+ 1 - 1
transport/internet/kcp/segment_test.go

@@ -50,7 +50,7 @@ func Test1ByteDataSegment(t *testing.T) {
 		Number:      4,
 		SendingNext: 5,
 	}
-	seg.Data().AppendBytes('a')
+	seg.Data().WriteBytes('a')
 
 	nBytes := seg.ByteSize()
 	bytes := make([]byte, nBytes)

+ 1 - 1
transport/internet/udp/dispatcher_test.go

@@ -58,7 +58,7 @@ func TestSameDestinationDispatching(t *testing.T) {
 	dest := net.UDPDestination(net.LocalHostIP, 53)
 
 	b := buf.New()
-	b.AppendBytes('a', 'b', 'c', 'd')
+	b.WriteBytes('a', 'b', 'c', 'd')
 
 	var msgCount uint32
 	dispatcher := NewDispatcher(td, func(ctx context.Context, payload *buf.Buffer) {

+ 1 - 1
transport/pipe/pipe_test.go

@@ -103,7 +103,7 @@ func TestPipeWriteMultiThread(t *testing.T) {
 		wg.Add(1)
 		go func() {
 			b := buf.New()
-			b.AppendBytes('a', 'b', 'c', 'd')
+			b.WriteBytes('a', 'b', 'c', 'd')
 			pWriter.WriteMultiBuffer(buf.NewMultiBufferValue(b))
 			wg.Done()
 		}()