Ver Fonte

refine ctlcmd

Darien Raymond há 7 anos atrás
pai
commit
f751bb610c

+ 43 - 16
common/buf/multi_buffer.go

@@ -3,7 +3,6 @@ package buf
 import (
 	"io"
 	"net"
-	"os"
 
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/errors"
@@ -14,22 +13,12 @@ import (
 func ReadAllToMultiBuffer(reader io.Reader) (MultiBuffer, error) {
 	mb := NewMultiBufferCap(128)
 
-	for {
-		b := New()
-		err := b.Reset(ReadFrom(reader))
-		if b.IsEmpty() {
-			b.Release()
-		} else {
-			mb.Append(b)
-		}
-		if err != nil {
-			if errors.Cause(err) == io.EOF || errors.Cause(err) == os.ErrClosed {
-				return mb, nil
-			}
-			mb.Release()
-			return nil, err
-		}
+	if _, err := mb.ReadFrom(reader); err != nil {
+		mb.Release()
+		return nil, err
 	}
+
+	return mb, nil
 }
 
 // ReadSizeToMultiBuffer reads specific number of bytes from reader into a MultiBuffer.
@@ -102,6 +91,28 @@ func (mb MultiBuffer) Copy(b []byte) int {
 	return total
 }
 
+// ReadFrom implements io.ReaderFrom.
+func (mb *MultiBuffer) ReadFrom(reader io.Reader) (int64, error) {
+	totalBytes := int64(0)
+
+	for {
+		b := New()
+		err := b.Reset(ReadFrom(reader))
+		if b.IsEmpty() {
+			b.Release()
+		} else {
+			mb.Append(b)
+		}
+		totalBytes += int64(b.Len())
+		if err != nil {
+			if errors.Cause(err) == io.EOF {
+				return totalBytes, nil
+			}
+			return totalBytes, err
+		}
+	}
+}
+
 // Read implements io.Reader.
 func (mb *MultiBuffer) Read(b []byte) (int, error) {
 	if mb.Len() == 0 {
@@ -125,6 +136,22 @@ func (mb *MultiBuffer) Read(b []byte) (int, error) {
 	return totalBytes, nil
 }
 
+// WriteTo implements io.WriterTo.
+func (mb *MultiBuffer) WriteTo(writer io.Writer) (int64, error) {
+	defer mb.Release()
+
+	totalBytes := int64(0)
+	for _, b := range *mb {
+		nBytes, err := writer.Write(b.Bytes())
+		totalBytes += int64(nBytes)
+		if err != nil {
+			return totalBytes, err
+		}
+	}
+
+	return totalBytes, nil
+}
+
 // Write implements io.Writer.
 func (mb *MultiBuffer) Write(b []byte) (int, error) {
 	totalBytes := len(b)

+ 8 - 0
common/buf/multi_buffer_test.go

@@ -2,6 +2,7 @@ package buf_test
 
 import (
 	"crypto/rand"
+	"io"
 	"testing"
 
 	"v2ray.com/core/common"
@@ -48,3 +49,10 @@ func TestMultiBufferSliceBySizeLarge(t *testing.T) {
 	mb2 := mb.SliceBySize(4 * 1024)
 	assert(mb2.Len(), Equals, int32(4*1024))
 }
+
+func TestInterface(t *testing.T) {
+	assert := With(t)
+
+	assert((*MultiBuffer)(nil), Implements, (*io.WriterTo)(nil))
+	assert((*MultiBuffer)(nil), Implements, (*io.ReaderFrom)(nil))
+}

+ 12 - 33
common/platform/ctlcmd/ctlcmd.go

@@ -1,14 +1,12 @@
 package ctlcmd
 
 import (
-	"context"
 	"io"
 	"os"
 	"os/exec"
 
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/platform"
-	"v2ray.com/core/common/signal"
 )
 
 //go:generate go run $GOPATH/src/v2ray.com/core/common/errors/errorgen/main.go -pkg ctlcmd -path Command,Platform,CtlCmd
@@ -19,49 +17,30 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) {
 		return nil, newError("v2ctl doesn't exist").Base(err)
 	}
 
-	errBuffer := &buf.MultiBuffer{}
+	errBuffer := buf.MultiBuffer{}
+	outBuffer := buf.MultiBuffer{}
 
 	cmd := exec.Command(v2ctl, args...)
-	cmd.Stderr = errBuffer
+	cmd.Stderr = &errBuffer
+	cmd.Stdout = &outBuffer
 	cmd.SysProcAttr = getSysProcAttr()
 	if input != nil {
 		cmd.Stdin = input
 	}
 
-	stdoutReader, err := cmd.StdoutPipe()
-	if err != nil {
-		return nil, newError("failed to get stdout from v2ctl").Base(err)
-	}
-	defer stdoutReader.Close()
-
 	if err := cmd.Start(); err != nil {
 		return nil, newError("failed to start v2ctl").Base(err)
 	}
 
-	var content buf.MultiBuffer
-	loadTask := func() error {
-		c, err := buf.ReadAllToMultiBuffer(stdoutReader)
-		if err != nil {
-			return newError("failed to read config").Base(err)
+	if err := cmd.Wait(); err != nil {
+		msg := "failed to execute v2ctl"
+		if errBuffer.Len() > 0 {
+			msg += ": " + errBuffer.String()
 		}
-		content = c
-		return nil
-	}
-
-	waitTask := func() error {
-		if err := cmd.Wait(); err != nil {
-			msg := "failed to execute v2ctl"
-			if errBuffer.Len() > 0 {
-				msg += ": " + errBuffer.String()
-			}
-			return newError(msg).Base(err)
-		}
-		return nil
-	}
-
-	if err := signal.ExecuteParallel(context.Background(), loadTask, waitTask); err != nil {
-		return nil, err
+		errBuffer.Release()
+		outBuffer.Release()
+		return nil, newError(msg).Base(err)
 	}
 
-	return content, nil
+	return outBuffer, nil
 }