Browse Source

refactor multibuffer

Darien Raymond 7 years ago
parent
commit
842a089dad

+ 2 - 4
app/dispatcher/stats_test.go

@@ -32,12 +32,10 @@ func TestStatsWriter(t *testing.T) {
 		Writer:  buf.Discard,
 	}
 
-	var mb buf.MultiBuffer
-	common.Must2(mb.Write([]byte("abcd")))
+	mb := buf.MergeBytes(nil, []byte("abcd"))
 	common.Must(writer.WriteMultiBuffer(mb))
 
-	mb = buf.ReleaseMulti(mb)
-	common.Must2(mb.Write([]byte("efg")))
+	mb = buf.MergeBytes(nil, []byte("efg"))
 	common.Must(writer.WriteMultiBuffer(mb))
 
 	if c.Value() != 7 {

+ 1 - 2
app/reverse/portal.go

@@ -251,8 +251,7 @@ func (w *PortalWorker) heartbeat() error {
 
 	b, err := proto.Marshal(msg)
 	common.Must(err)
-	var mb buf.MultiBuffer
-	common.Must2(mb.Write(b))
+	mb := buf.MergeBytes(nil, b)
 	return w.writer.WriteMultiBuffer(mb)
 }
 

+ 70 - 79
common/buf/multi_buffer.go

@@ -8,21 +8,9 @@ import (
 	"v2ray.com/core/common/serial"
 )
 
-// ReadAllToMultiBuffer reads all content from the reader into a MultiBuffer, until EOF.
-func ReadAllToMultiBuffer(reader io.Reader) (MultiBuffer, error) {
-	mb := make(MultiBuffer, 0, 128)
-
-	if _, err := mb.ReadFrom(reader); err != nil {
-		ReleaseMulti(mb)
-		return nil, err
-	}
-
-	return mb, nil
-}
-
 // ReadAllToBytes reads all content from the reader into a byte array, until EOF.
 func ReadAllToBytes(reader io.Reader) ([]byte, error) {
-	mb, err := ReadAllToMultiBuffer(reader)
+	mb, err := ReadFrom(reader)
 	if err != nil {
 		return nil, err
 	}
@@ -30,7 +18,8 @@ func ReadAllToBytes(reader io.Reader) ([]byte, error) {
 		return nil, nil
 	}
 	b := make([]byte, mb.Len())
-	common.Must2(mb.Read(b))
+	mb, _, err = SplitBytes(mb, b)
+	common.Must(err)
 	ReleaseMulti(mb)
 	return b, nil
 }
@@ -47,6 +36,23 @@ func MergeMulti(dest MultiBuffer, src MultiBuffer) (MultiBuffer, MultiBuffer) {
 	return dest, src[:0]
 }
 
+func MergeBytes(dest MultiBuffer, src []byte) MultiBuffer {
+	n := len(dest)
+	if n > 0 && !(dest)[n-1].IsFull() {
+		nBytes, _ := (dest)[n-1].Write(src)
+		src = src[nBytes:]
+	}
+
+	for len(src) > 0 {
+		b := New()
+		nBytes, _ := b.Write(src)
+		src = src[nBytes:]
+		dest = append(dest, b)
+	}
+
+	return dest
+}
+
 // ReleaseMulti release all content of the MultiBuffer, and returns an empty MultiBuffer.
 func ReleaseMulti(mb MultiBuffer) MultiBuffer {
 	for i := range mb {
@@ -69,93 +75,42 @@ func (mb MultiBuffer) Copy(b []byte) int {
 	return total
 }
 
-// ReadFrom implements io.ReaderFrom.
-func (mb *MultiBuffer) ReadFrom(reader io.Reader) (int64, error) {
-	totalBytes := int64(0)
-
+// ReadFrom reads all content from reader until EOF.
+func ReadFrom(reader io.Reader) (MultiBuffer, error) {
+	mb := make(MultiBuffer, 0, 16)
 	for {
 		b := New()
 		_, err := b.ReadFullFrom(reader, Size)
 		if b.IsEmpty() {
 			b.Release()
 		} else {
-			*mb = append(*mb, b)
+			mb = append(mb, b)
 		}
-		totalBytes += int64(b.Len())
 		if err != nil {
 			if errors.Cause(err) == io.EOF || errors.Cause(err) == io.ErrUnexpectedEOF {
-				return totalBytes, nil
+				return mb, nil
 			}
-			return totalBytes, err
+			return mb, err
 		}
 	}
 }
 
-// Read implements io.Reader.
-func (mb *MultiBuffer) Read(b []byte) (int, error) {
-	if mb.IsEmpty() {
-		return 0, io.EOF
-	}
-	endIndex := len(*mb)
+func SplitBytes(mb MultiBuffer, b []byte) (MultiBuffer, int, error) {
 	totalBytes := 0
-	for i, bb := range *mb {
+
+	for len(mb) > 0 {
+		bb := mb[0]
 		nBytes, _ := bb.Read(b)
 		totalBytes += nBytes
 		b = b[nBytes:]
-		if bb.IsEmpty() {
-			bb.Release()
-			(*mb)[i] = nil
-		} else {
-			endIndex = i
+		if !bb.IsEmpty() {
 			break
 		}
-	}
-	*mb = (*mb)[endIndex:]
-	return totalBytes, nil
-}
-
-// WriteTo implements io.WriterTo.
-func (mb *MultiBuffer) WriteTo(writer io.Writer) (int64, error) {
-	defer func() {
-		*mb = ReleaseMulti(*mb)
-	}()
-
-	totalBytes := int64(0)
-	for _, b := range *mb {
-		nBytes, err := writer.Write(b.Bytes())
-		totalBytes += int64(nBytes)
-		if err != nil {
-			return totalBytes, err
-		}
-	}
-
-	return totalBytes, nil
-}
-
-// Write implements io.Writer.
-func (mb *MultiBuffer) Write(b []byte) (int, error) {
-	totalBytes := len(b)
-
-	n := len(*mb)
-	if n > 0 && !(*mb)[n-1].IsFull() {
-		nBytes, _ := (*mb)[n-1].Write(b)
-		b = b[nBytes:]
-	}
-
-	for len(b) > 0 {
-		bb := New()
-		nBytes, _ := bb.Write(b)
-		b = b[nBytes:]
-		*mb = append(*mb, bb)
+		bb.Release()
+		mb = mb[1:]
 	}
 
-	return totalBytes, nil
-}
-
-// WriteMultiBuffer implements Writer.
-func (mb *MultiBuffer) WriteMultiBuffer(b MultiBuffer) error {
-	*mb, _ = MergeMulti(*mb, b)
-	return nil
+	return mb, totalBytes, nil
 }
 
 // Len returns the total number of bytes in the MultiBuffer.
@@ -223,3 +178,39 @@ func (mb *MultiBuffer) SplitFirst() *Buffer {
 	*mb = (*mb)[1:]
 	return b
 }
+
+type MultiBufferContainer struct {
+	MultiBuffer
+}
+
+func (c *MultiBufferContainer) Read(b []byte) (int, error) {
+	if c.MultiBuffer.IsEmpty() {
+		return 0, io.EOF
+	}
+
+	mb, nBytes, err := SplitBytes(c.MultiBuffer, b)
+	c.MultiBuffer = mb
+	return nBytes, err
+}
+
+func (c *MultiBufferContainer) ReadMultiBuffer() (MultiBuffer, error) {
+	mb := c.MultiBuffer
+	c.MultiBuffer = nil
+	return mb, nil
+}
+
+func (c *MultiBufferContainer) Write(b []byte) (int, error) {
+	c.MultiBuffer = MergeBytes(c.MultiBuffer, b)
+	return len(b), nil
+}
+
+func (c *MultiBufferContainer) WriteMultiBuffer(b MultiBuffer) error {
+	mb, _ := MergeMulti(c.MultiBuffer, b)
+	c.MultiBuffer = mb
+	return nil
+}
+
+func (c *MultiBufferContainer) Close() error {
+	c.MultiBuffer = ReleaseMulti(c.MultiBuffer)
+	return nil
+}

+ 2 - 10
common/buf/multi_buffer_test.go

@@ -21,7 +21,7 @@ func TestMultiBufferRead(t *testing.T) {
 	mb := MultiBuffer{b1, b2}
 
 	bs := make([]byte, 32)
-	nBytes, err := mb.Read(bs)
+	_, nBytes, err := SplitBytes(mb, bs)
 	assert(err, IsNil)
 	assert(nBytes, Equals, 4)
 	assert(bs[:nBytes], Equals, []byte("abcd"))
@@ -43,16 +43,8 @@ func TestMultiBufferSliceBySizeLarge(t *testing.T) {
 	lb := make([]byte, 8*1024)
 	common.Must2(io.ReadFull(rand.Reader, lb))
 
-	var mb MultiBuffer
-	common.Must2(mb.Write(lb))
+	mb := MergeBytes(nil, lb)
 
 	mb2 := mb.SliceBySize(1024)
 	assert(mb2.Len(), Equals, int32(1024))
 }
-
-func TestInterface(t *testing.T) {
-	assert := With(t)
-
-	assert((*MultiBuffer)(nil), Implements, (*io.WriterTo)(nil))
-	assert((*MultiBuffer)(nil), Implements, (*io.ReaderFrom)(nil))
-}

+ 4 - 3
common/buf/reader.go

@@ -46,8 +46,9 @@ func (r *BufferedReader) ReadByte() (byte, error) {
 // Read implements io.Reader. It reads from internal buffer first (if available) and then reads from the underlying reader.
 func (r *BufferedReader) Read(b []byte) (int, error) {
 	if !r.Buffer.IsEmpty() {
-		nBytes, err := r.Buffer.Read(b)
+		buffer, nBytes, err := SplitBytes(r.Buffer, b)
 		common.Must(err)
+		r.Buffer = buffer
 		if r.Buffer.IsEmpty() {
 			r.Buffer = nil
 		}
@@ -59,12 +60,12 @@ func (r *BufferedReader) Read(b []byte) (int, error) {
 		return 0, err
 	}
 
-	nBytes, err := mb.Read(b)
+	mb, nBytes, err := SplitBytes(mb, b)
 	common.Must(err)
 	if !mb.IsEmpty() {
 		r.Buffer = mb
 	}
-	return nBytes, err
+	return nBytes, nil
 }
 
 // ReadMultiBuffer implements Reader.

+ 1 - 2
common/buf/reader_test.go

@@ -69,8 +69,7 @@ func TestReadByte(t *testing.T) {
 		t.Error("unexpected byte: ", b, " want a")
 	}
 
-	var mb MultiBuffer
-	nBytes, err := reader.WriteTo(&mb)
+	nBytes, err := reader.WriteTo(DiscardBytes)
 	common.Must(err)
 	if nBytes != 3 {
 		t.Error("unexpect bytes written: ", nBytes)

+ 3 - 3
common/buf/readv_test.go

@@ -33,8 +33,7 @@ func TestReadvReader(t *testing.T) {
 
 	go func() {
 		writer := NewWriter(conn)
-		var mb MultiBuffer
-		common.Must2(mb.Write(data))
+		mb := MergeBytes(nil, data)
 
 		if err := writer.WriteMultiBuffer(mb); err != nil {
 			t.Fatal("failed to write data: ", err)
@@ -58,7 +57,8 @@ func TestReadvReader(t *testing.T) {
 	}
 
 	rdata := make([]byte, size)
-	common.Must2(rmb.Read(rdata))
+	_, _, err = SplitBytes(rmb, rdata)
+	common.Must(err)
 
 	if err := compare.BytesEqualWithDetail(data, rdata); err != nil {
 		t.Fatal(err)

+ 6 - 5
common/buf/writer.go

@@ -134,15 +134,16 @@ func (w *BufferedWriter) WriteMultiBuffer(b MultiBuffer) error {
 		return w.writer.WriteMultiBuffer(b)
 	}
 
-	defer ReleaseMulti(b)
+	reader := MultiBufferContainer{
+		MultiBuffer: b,
+	}
+	defer reader.Close()
 
-	for !b.IsEmpty() {
+	for !reader.MultiBuffer.IsEmpty() {
 		if w.buffer == nil {
 			w.buffer = New()
 		}
-		if _, err := w.buffer.ReadFrom(&b); err != nil {
-			return err
-		}
+		common.Must2(w.buffer.ReadFrom(&reader))
 		if w.buffer.IsFull() {
 			if err := w.flushInternal(); err != nil {
 				return err

+ 11 - 5
common/crypto/auth.go

@@ -194,7 +194,7 @@ func (r *AuthenticationReader) readInternal(soft bool, mb *buf.MultiBuffer) erro
 		return err
 	}
 
-	common.Must2(mb.Write(rb))
+	*mb = buf.MergeBytes(*mb, rb)
 	return nil
 }
 
@@ -279,11 +279,17 @@ func (w *AuthenticationWriter) writeStream(mb buf.MultiBuffer) error {
 	payloadSize := buf.Size - int32(w.auth.Overhead()) - w.sizeParser.SizeBytes() - maxPadding
 	mb2Write := make(buf.MultiBuffer, 0, len(mb)+10)
 
+	temp := buf.New()
+	defer temp.Release()
+
+	rawBytes := temp.Extend(payloadSize)
+
 	for {
-		b := buf.New()
-		common.Must2(b.ReadFrom(io.LimitReader(&mb, int64(payloadSize))))
-		eb, err := w.seal(b.Bytes())
-		b.Release()
+		nb, nBytes, err := buf.SplitBytes(mb, rawBytes)
+		common.Must(err)
+		mb = nb
+
+		eb, err := w.seal(rawBytes[:nBytes])
 
 		if err != nil {
 			buf.ReleaseMulti(mb2Write)

+ 2 - 3
common/crypto/auth_test.go

@@ -30,8 +30,7 @@ func TestAuthenticationReaderWriter(t *testing.T) {
 	rawPayload := make([]byte, payloadSize)
 	rand.Read(rawPayload)
 
-	var payload buf.MultiBuffer
-	payload.Write(rawPayload)
+	payload := buf.MergeBytes(nil, rawPayload)
 	assert(payload.Len(), Equals, int32(payloadSize))
 
 	cache := bytes.NewBuffer(nil)
@@ -66,7 +65,7 @@ func TestAuthenticationReaderWriter(t *testing.T) {
 	assert(mb.Len(), Equals, int32(payloadSize))
 
 	mbContent := make([]byte, payloadSize)
-	mb.Read(mbContent)
+	buf.SplitBytes(mb, mbContent)
 	assert(mbContent, Equals, rawPayload)
 
 	_, err = reader.ReadMultiBuffer()

+ 1 - 1
common/net/connection.go

@@ -100,7 +100,7 @@ func (c *connection) Write(b []byte) (int, error) {
 
 	l := len(b)
 	mb := make(buf.MultiBuffer, 0, l/buf.Size+1)
-	common.Must2(mb.Write(b))
+	mb = buf.MergeBytes(mb, b)
 	return l, c.writer.WriteMultiBuffer(mb)
 }
 

+ 4 - 6
common/platform/ctlcmd/ctlcmd.go

@@ -17,8 +17,8 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) {
 		return nil, newError("v2ctl doesn't exist").Base(err)
 	}
 
-	errBuffer := buf.MultiBuffer{}
-	outBuffer := buf.MultiBuffer{}
+	var errBuffer buf.MultiBufferContainer
+	var outBuffer buf.MultiBufferContainer
 
 	cmd := exec.Command(v2ctl, args...)
 	cmd.Stderr = &errBuffer
@@ -35,12 +35,10 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) {
 	if err := cmd.Wait(); err != nil {
 		msg := "failed to execute v2ctl"
 		if errBuffer.Len() > 0 {
-			msg += ": " + errBuffer.String()
+			msg += ": " + errBuffer.MultiBuffer.String()
 		}
-		buf.ReleaseMulti(errBuffer)
-		buf.ReleaseMulti(outBuffer)
 		return nil, newError(msg).Base(err)
 	}
 
-	return outBuffer, nil
+	return outBuffer.MultiBuffer, nil
 }

+ 7 - 14
main/confloader/external/external.go

@@ -12,16 +12,6 @@ import (
 
 //go:generate errorgen
 
-type ClosableMultiBuffer struct {
-	buf.MultiBuffer
-}
-
-func (c *ClosableMultiBuffer) Close() error {
-	buf.ReleaseMulti(c.MultiBuffer)
-	c.MultiBuffer = nil
-	return nil
-}
-
 func loadConfigFile(configFile string) (io.ReadCloser, error) {
 	if configFile == "stdin:" {
 		return os.Stdin, nil
@@ -32,7 +22,9 @@ func loadConfigFile(configFile string) (io.ReadCloser, error) {
 		if err != nil {
 			return nil, err
 		}
-		return &ClosableMultiBuffer{content}, nil
+		return &buf.MultiBufferContainer{
+			MultiBuffer: content,
+		}, nil
 	}
 
 	fixedFile := os.ExpandEnv(configFile)
@@ -42,12 +34,13 @@ func loadConfigFile(configFile string) (io.ReadCloser, error) {
 	}
 	defer file.Close()
 
-	content, err := buf.ReadAllToMultiBuffer(file)
+	content, err := buf.ReadFrom(file)
 	if err != nil {
 		return nil, newError("failed to load config file: ", fixedFile).Base(err).AtWarning()
 	}
-	return &ClosableMultiBuffer{content}, nil
-
+	return &buf.MultiBufferContainer{
+		MultiBuffer: content,
+	}, nil
 }
 
 func init() {

+ 4 - 1
main/json/config_json.go

@@ -7,6 +7,7 @@ import (
 
 	"v2ray.com/core"
 	"v2ray.com/core/common"
+	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/platform/ctlcmd"
 )
 
@@ -19,7 +20,9 @@ func init() {
 			if err != nil {
 				return nil, newError("failed to execute v2ctl to convert config file.").Base(err).AtWarning()
 			}
-			return core.LoadConfig("protobuf", "", &jsonContent)
+			return core.LoadConfig("protobuf", "", &buf.MultiBufferContainer{
+				MultiBuffer: jsonContent,
+			})
 		},
 	}))
 }

+ 1 - 2
proxy/http/server.go

@@ -185,8 +185,7 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 	}
 
 	if reader.Buffered() > 0 {
-		var payload buf.MultiBuffer
-		_, err := payload.ReadFrom(&io.LimitedReader{R: reader, N: int64(reader.Buffered())})
+		payload, err := buf.ReadFrom(io.LimitReader(reader, int64(reader.Buffered())))
 		if err != nil {
 			return err
 		}

+ 2 - 3
proxy/shadowsocks/ota.go

@@ -92,8 +92,7 @@ func (v *ChunkReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 		return nil, newError("invalid auth")
 	}
 
-	var mb buf.MultiBuffer
-	common.Must2(mb.Write(payload))
+	mb := buf.MergeBytes(nil, payload)
 
 	return mb, nil
 }
@@ -117,7 +116,7 @@ func (w *ChunkWriter) WriteMultiBuffer(mb buf.MultiBuffer) error {
 	defer buf.ReleaseMulti(mb)
 
 	for {
-		payloadLen, _ := mb.Read(w.buffer[2+AuthSize:])
+		mb, payloadLen, _ := buf.SplitBytes(mb, w.buffer[2+AuthSize:])
 		binary.BigEndian.PutUint16(w.buffer, uint16(payloadLen))
 		w.auth.Authenticate(w.buffer[2+AuthSize:2+AuthSize+payloadLen], w.buffer[2:])
 		if err := buf.WriteAllBytes(w.writer, w.buffer[:2+AuthSize+payloadLen]); err != nil {

+ 24 - 12
transport/internet/kcp/connection.go

@@ -1,6 +1,7 @@
 package kcp
 
 import (
+	"bytes"
 	"io"
 	"net"
 	"runtime"
@@ -8,7 +9,6 @@ import (
 	"sync/atomic"
 	"time"
 
-	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/signal"
 	"v2ray.com/core/common/signal/semaphore"
@@ -364,12 +364,8 @@ func (c *Connection) waitForDataOutput() error {
 
 // Write implements io.Writer.
 func (c *Connection) Write(b []byte) (int, error) {
-	// This involves multiple copies of the buffer. But we don't expect this method to be used often.
-	// Only wrapped connections such as TLS and WebSocket will call into this.
-	// TODO: improve efficiency.
-	var mb buf.MultiBuffer
-	common.Must2(mb.Write(b))
-	if err := c.WriteMultiBuffer(mb); err != nil {
+	reader := bytes.NewReader(b)
+	if err := c.writeMultiBufferInternal(reader); err != nil {
 		return 0, err
 	}
 	return len(b), nil
@@ -377,8 +373,15 @@ func (c *Connection) Write(b []byte) (int, error) {
 
 // WriteMultiBuffer implements buf.Writer.
 func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
-	defer buf.ReleaseMulti(mb)
+	reader := &buf.MultiBufferContainer{
+		MultiBuffer: mb,
+	}
+	defer reader.Close()
+
+	return c.writeMultiBufferInternal(reader)
+}
 
+func (c *Connection) writeMultiBufferInternal(reader io.Reader) error {
 	updatePending := false
 	defer func() {
 		if updatePending {
@@ -386,19 +389,28 @@ func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
 		}
 	}()
 
+	var b *buf.Buffer
+	defer b.Release()
+
 	for {
 		for {
 			if c == nil || c.State() != StateActive {
 				return io.ErrClosedPipe
 			}
 
-			if !c.sendingWorker.Push(&mb) {
+			if b == nil {
+				b = buf.New()
+				_, err := b.ReadFrom(io.LimitReader(reader, int64(c.mss)))
+				if err != nil {
+					return nil
+				}
+			}
+
+			if !c.sendingWorker.Push(b) {
 				break
 			}
 			updatePending = true
-			if mb.IsEmpty() {
-				return nil
-			}
+			b = nil
 		}
 
 		if updatePending {

+ 1 - 1
transport/internet/kcp/receiving.go

@@ -209,7 +209,7 @@ func (w *ReceivingWorker) Read(b []byte) int {
 	if mb.IsEmpty() {
 		return 0
 	}
-	nBytes, err := mb.Read(b)
+	mb, nBytes, err := buf.SplitBytes(mb, b)
 	common.Must(err)
 	if !mb.IsEmpty() {
 		w.leftOver = mb

+ 1 - 5
transport/internet/kcp/sending.go

@@ -2,10 +2,8 @@ package kcp
 
 import (
 	"container/list"
-	"io"
 	"sync"
 
-	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 )
 
@@ -262,7 +260,7 @@ func (w *SendingWorker) ProcessSegment(current uint32, seg *AckSegment, rto uint
 	}
 }
 
-func (w *SendingWorker) Push(mb *buf.MultiBuffer) bool {
+func (w *SendingWorker) Push(b *buf.Buffer) bool {
 	w.Lock()
 	defer w.Unlock()
 
@@ -274,8 +272,6 @@ func (w *SendingWorker) Push(mb *buf.MultiBuffer) bool {
 		return false
 	}
 
-	b := buf.New()
-	common.Must2(b.ReadFrom(io.LimitReader(mb, int64(w.conn.mss))))
 	w.window.Push(w.nextNumber, b)
 	w.nextNumber++
 	return true