浏览代码

fix passive connection in mux. fix #1167

Darien Raymond 7 年之前
父节点
当前提交
b05597df3a

+ 11 - 1
app/dispatcher/default.go

@@ -28,7 +28,7 @@ type cachedReader struct {
 }
 
 func (r *cachedReader) Cache(b *buf.Buffer) {
-	mb, _ := r.reader.ReadMultiBufferWithTimeout(time.Millisecond * 100)
+	mb, _ := r.reader.ReadMultiBufferTimeout(time.Millisecond * 100)
 	if !mb.IsEmpty() {
 		common.Must(r.cache.WriteMultiBuffer(mb))
 	}
@@ -47,6 +47,16 @@ func (r *cachedReader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 	return r.reader.ReadMultiBuffer()
 }
 
+func (r *cachedReader) ReadMultiBufferTimeout(timeout time.Duration) (buf.MultiBuffer, error) {
+	if !r.cache.IsEmpty() {
+		mb := r.cache
+		r.cache = nil
+		return mb, nil
+	}
+
+	return r.reader.ReadMultiBufferTimeout(timeout)
+}
+
 func (r *cachedReader) CloseError() {
 	r.cache.Release()
 	r.reader.CloseError()

+ 10 - 12
app/proxyman/mux/mux.go

@@ -147,17 +147,17 @@ func (m *Client) monitor() {
 	}
 }
 
-func copyFirstPayload(reader *pipe.Reader, writer *Writer) error {
-	data, err := reader.ReadMultiBufferWithTimeout(time.Millisecond * 200)
-	if err == buf.ErrReadTimeout {
-		return writer.writeMetaOnly()
+func writeFirstPayload(reader buf.Reader, writer *Writer) error {
+	err := buf.CopyOnceTimeout(reader, writer, time.Millisecond*200)
+	if err == buf.ErrNotTimeoutReader || err == buf.ErrReadTimeout {
+		return writer.WriteMultiBuffer(buf.MultiBuffer{})
 	}
 
 	if err != nil {
 		return err
 	}
 
-	return writer.WriteMultiBuffer(data)
+	return nil
 }
 
 func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
@@ -172,13 +172,11 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
 	defer writer.Close() // nolint: errcheck
 
 	newError("dispatching request to ", dest).WriteToLog(session.ExportIDToError(ctx))
-	if pReader, ok := s.input.(*pipe.Reader); ok {
-		if err := copyFirstPayload(pReader, writer); err != nil {
-			newError("failed to fetch first payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
-			writer.hasError = true
-			pipe.CloseError(s.input)
-			return
-		}
+	if err := writeFirstPayload(s.input, writer); err != nil {
+		newError("failed to write first payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
+		writer.hasError = true
+		pipe.CloseError(s.input)
+		return
 	}
 
 	if err := buf.Copy(s.input, writer); err != nil {

+ 15 - 0
common/buf/copy.go

@@ -2,6 +2,7 @@ package buf
 
 import (
 	"io"
+	"time"
 
 	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/signal"
@@ -112,3 +113,17 @@ func Copy(reader Reader, writer Writer, options ...CopyOption) error {
 	}
 	return nil
 }
+
+var ErrNotTimeoutReader = newError("not a TimeoutReader")
+
+func CopyOnceTimeout(reader Reader, writer Writer, timeout time.Duration) error {
+	timeoutReader, ok := reader.(TimeoutReader)
+	if !ok {
+		return ErrNotTimeoutReader
+	}
+	mb, err := timeoutReader.ReadMultiBufferTimeout(timeout)
+	if err != nil {
+		return err
+	}
+	return writer.WriteMultiBuffer(mb)
+}

+ 1 - 1
common/buf/io.go

@@ -16,7 +16,7 @@ var ErrReadTimeout = newError("IO timeout")
 
 // TimeoutReader is a reader that returns error if Read() operation takes longer than the given timeout.
 type TimeoutReader interface {
-	ReadTimeout(time.Duration) (MultiBuffer, error)
+	ReadMultiBufferTimeout(time.Duration) (MultiBuffer, error)
 }
 
 // Writer extends io.Writer with MultiBuffer.

+ 2 - 12
proxy/vmess/outbound/outbound.go

@@ -9,8 +9,6 @@ import (
 	"v2ray.com/core/common/session"
 	"v2ray.com/core/common/task"
 
-	"v2ray.com/core/transport/pipe"
-
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
@@ -118,16 +116,8 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia
 		}
 
 		bodyWriter := session.EncodeRequestBody(request, writer)
-		if tReader, ok := input.(*pipe.Reader); ok {
-			firstPayload, err := tReader.ReadMultiBufferWithTimeout(time.Millisecond * 500)
-			if err != nil && err != buf.ErrReadTimeout {
-				return newError("failed to get first payload").Base(err)
-			}
-			if !firstPayload.IsEmpty() {
-				if err := bodyWriter.WriteMultiBuffer(firstPayload); err != nil {
-					return newError("failed to write first payload").Base(err)
-				}
-			}
+		if err := buf.CopyOnceTimeout(input, bodyWriter, time.Millisecond*500); err != nil && err != buf.ErrNotTimeoutReader && err != buf.ErrReadTimeout {
+			return newError("failed to write first payload").Base(err)
 		}
 
 		if err := writer.SetBuffered(false); err != nil {

+ 1 - 1
transport/pipe/impl.go

@@ -81,7 +81,7 @@ func (p *pipe) ReadMultiBuffer() (buf.MultiBuffer, error) {
 	}
 }
 
-func (p *pipe) ReadMultiBufferWithTimeout(d time.Duration) (buf.MultiBuffer, error) {
+func (p *pipe) ReadMultiBufferTimeout(d time.Duration) (buf.MultiBuffer, error) {
 	timer := time.After(d)
 	for {
 		data, err := p.readMultiBufferInternal()

+ 7 - 0
transport/pipe/pipe_test.go

@@ -118,3 +118,10 @@ func TestPipeWriteMultiThread(t *testing.T) {
 	assert(err, IsNil)
 	assert(b[0].Bytes(), Equals, []byte{'a', 'b', 'c', 'd'})
 }
+
+func TestInterfaces(t *testing.T) {
+	assert := With(t)
+
+	assert((*Reader)(nil), Implements, (*buf.Reader)(nil))
+	assert((*Reader)(nil), Implements, (*buf.TimeoutReader)(nil))
+}

+ 3 - 3
transport/pipe/reader.go

@@ -16,9 +16,9 @@ func (r *Reader) ReadMultiBuffer() (buf.MultiBuffer, error) {
 	return r.pipe.ReadMultiBuffer()
 }
 
-// ReadMultiBufferWithTimeout reads content from a pipe within the given duration, or returns buf.ErrTimeout otherwise.
-func (r *Reader) ReadMultiBufferWithTimeout(d time.Duration) (buf.MultiBuffer, error) {
-	return r.pipe.ReadMultiBufferWithTimeout(d)
+// ReadMultiBufferTimeout reads content from a pipe within the given duration, or returns buf.ErrTimeout otherwise.
+func (r *Reader) ReadMultiBufferTimeout(d time.Duration) (buf.MultiBuffer, error) {
+	return r.pipe.ReadMultiBufferTimeout(d)
 }
 
 // CloseError sets the pipe to error state. Both reading and writing from/to the pipe will return io.ErrClosedPipe.