浏览代码

refine stream handling

Darien Raymond 8 年之前
父节点
当前提交
49210d8362

+ 54 - 12
app/dispatcher/impl/default.go

@@ -1,6 +1,8 @@
 package impl
 
 import (
+	"time"
+
 	"v2ray.com/core/app"
 	"v2ray.com/core/app/dispatcher"
 	"v2ray.com/core/app/proxyman"
@@ -48,7 +50,6 @@ func (v *DefaultDispatcher) Release() {
 }
 
 func (v *DefaultDispatcher) DispatchToOutbound(session *proxy.SessionInfo) ray.InboundRay {
-	direct := ray.NewRay()
 	dispatcher := v.ohm.GetDefaultHandler()
 	destination := session.Destination
 
@@ -65,26 +66,32 @@ func (v *DefaultDispatcher) DispatchToOutbound(session *proxy.SessionInfo) ray.I
 		}
 	}
 
+	direct := ray.NewRay()
+	var waitFunc func() error
 	if session.Inbound != nil && session.Inbound.AllowPassiveConnection {
-		go dispatcher.Dispatch(destination, buf.NewLocal(32), direct)
+		waitFunc = noOpWait()
 	} else {
-		go v.FilterPacketAndDispatch(destination, direct, dispatcher)
+		wdi := &waitDataInspector{
+			hasData: make(chan bool, 1),
+		}
+		direct.AddInspector(wdi)
+		waitFunc = waitForData(wdi)
 	}
 
+	go v.waitAndDispatch(waitFunc, destination, direct, dispatcher)
+
 	return direct
 }
 
-// FilterPacketAndDispatch waits for a payload from source and starts dispatching.
-// Private: Visible for testing.
-func (v *DefaultDispatcher) FilterPacketAndDispatch(destination v2net.Destination, link ray.OutboundRay, dispatcher proxy.OutboundHandler) {
-	payload, err := link.OutboundInput().Read()
-	if err != nil {
-		log.Info("DefaultDispatcher: No payload towards ", destination, ", stopping now.")
-		link.OutboundInput().Release()
-		link.OutboundOutput().Release()
+func (v *DefaultDispatcher) waitAndDispatch(wait func() error, destination v2net.Destination, link ray.OutboundRay, dispatcher proxy.OutboundHandler) {
+	if err := wait(); err != nil {
+		log.Info("DefaultDispatcher: Failed precondition: ", err)
+		link.OutboundInput().ForceClose()
+		link.OutboundOutput().Close()
 		return
 	}
-	dispatcher.Dispatch(destination, payload, link)
+
+	dispatcher.Dispatch(destination, link)
 }
 
 type DefaultDispatcherFactory struct{}
@@ -100,3 +107,38 @@ func (v DefaultDispatcherFactory) AppId() app.ID {
 func init() {
 	app.RegisterApplicationFactory(serial.GetMessageType(new(dispatcher.Config)), DefaultDispatcherFactory{})
 }
+
+type waitDataInspector struct {
+	hasData chan bool
+}
+
+func (wdi *waitDataInspector) Input(*buf.Buffer) {
+	select {
+	case wdi.hasData <- true:
+	default:
+	}
+}
+
+func (wdi *waitDataInspector) WaitForData() bool {
+	select {
+	case <-wdi.hasData:
+		return true
+	case <-time.After(time.Minute):
+		return false
+	}
+}
+
+func waitForData(wdi *waitDataInspector) func() error {
+	return func() error {
+		if wdi.WaitForData() {
+			return nil
+		}
+		return errors.New("DefaultDispatcher: No data.")
+	}
+}
+
+func noOpWait() func() error {
+	return func() error {
+		return nil
+	}
+}

+ 1 - 1
app/proxy/proxy.go

@@ -50,7 +50,7 @@ func (v *OutboundProxy) Dial(src v2net.Address, dest v2net.Destination, options
 	}
 	log.Info("Proxy: Dialing to ", dest)
 	stream := ray.NewRay()
-	go handler.Dispatch(dest, nil, stream)
+	go handler.Dispatch(dest, stream)
 	return NewConnection(src, dest, stream), nil
 }
 

+ 1 - 4
proxy/blackhole/blackhole.go

@@ -3,7 +3,6 @@ package blackhole
 
 import (
 	"v2ray.com/core/app"
-	"v2ray.com/core/common/buf"
 	v2net "v2ray.com/core/common/net"
 	"v2ray.com/core/proxy"
 	"v2ray.com/core/transport/ray"
@@ -28,9 +27,7 @@ func New(space app.Space, config *Config, meta *proxy.OutboundHandlerMeta) (prox
 }
 
 // Dispatch implements OutboundHandler.Dispatch().
-func (v *Handler) Dispatch(destination v2net.Destination, payload *buf.Buffer, ray ray.OutboundRay) {
-	payload.Release()
-
+func (v *Handler) Dispatch(destination v2net.Destination, ray ray.OutboundRay) {
 	v.response.WriteTo(ray.OutboundOutput())
 	ray.OutboundOutput().Close()
 

+ 1 - 9
proxy/freedom/freedom.go

@@ -67,10 +67,9 @@ func (v *Handler) ResolveIP(destination v2net.Destination) v2net.Destination {
 	return newDest
 }
 
-func (v *Handler) Dispatch(destination v2net.Destination, payload *buf.Buffer, ray ray.OutboundRay) {
+func (v *Handler) Dispatch(destination v2net.Destination, ray ray.OutboundRay) {
 	log.Info("Freedom: Opening connection to ", destination)
 
-	defer payload.Release()
 	input := ray.OutboundInput()
 	output := ray.OutboundOutput()
 	defer input.ForceClose()
@@ -96,13 +95,6 @@ func (v *Handler) Dispatch(destination v2net.Destination, payload *buf.Buffer, r
 
 	conn.SetReusable(false)
 
-	if !payload.IsEmpty() {
-		if _, err := conn.Write(payload.Bytes()); err != nil {
-			log.Warning("Freedom: Failed to write to destination: ", destination, ": ", err)
-			return
-		}
-	}
-
 	requestDone := signal.ExecuteAsync(func() error {
 		defer input.ForceClose()
 

+ 2 - 1
proxy/freedom/freedom_test.go

@@ -53,9 +53,10 @@ func TestSinglePacket(t *testing.T) {
 	data2Send := "Data to be sent to remote"
 	payload := buf.NewLocal(2048)
 	payload.Append([]byte(data2Send))
+	traffic.InboundInput().Write(payload)
 
 	fmt.Println(tcpServerAddr.Network, tcpServerAddr.Address, tcpServerAddr.Port)
-	go freedom.Dispatch(tcpServerAddr, payload, traffic)
+	go freedom.Dispatch(tcpServerAddr, traffic)
 	traffic.InboundInput().Close()
 
 	respPayload, err := traffic.InboundOutput().Read()

+ 1 - 2
proxy/proxy.go

@@ -2,7 +2,6 @@
 package proxy
 
 import (
-	"v2ray.com/core/common/buf"
 	v2net "v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/transport/internet"
@@ -58,5 +57,5 @@ type InboundHandler interface {
 // An OutboundHandler handles outbound network connection for V2Ray.
 type OutboundHandler interface {
 	// Dispatch sends one or more Packets to its destination.
-	Dispatch(destination v2net.Destination, payload *buf.Buffer, ray ray.OutboundRay)
+	Dispatch(destination v2net.Destination, ray ray.OutboundRay)
 }

+ 1 - 16
proxy/shadowsocks/client.go

@@ -35,9 +35,7 @@ func NewClient(config *ClientConfig, space app.Space, meta *proxy.OutboundHandle
 }
 
 // Dispatch implements OutboundHandler.Dispatch().
-func (v *Client) Dispatch(destination v2net.Destination, payload *buf.Buffer, ray ray.OutboundRay) {
-	defer payload.Release()
-
+func (v *Client) Dispatch(destination v2net.Destination, ray ray.OutboundRay) {
 	network := destination.Network
 
 	var server *protocol.ServerSpec
@@ -99,13 +97,6 @@ func (v *Client) Dispatch(destination v2net.Destination, payload *buf.Buffer, ra
 			return
 		}
 
-		if !payload.IsEmpty() {
-			if err := bodyWriter.Write(payload); err != nil {
-				log.Info("Shadowsocks|Client: Failed to write payload: ", err)
-				return
-			}
-		}
-
 		bufferedWriter.SetBuffered(false)
 
 		requestDone := signal.ExecuteAsync(func() error {
@@ -143,12 +134,6 @@ func (v *Client) Dispatch(destination v2net.Destination, payload *buf.Buffer, ra
 			Writer:  conn,
 			Request: request,
 		}
-		if !payload.IsEmpty() {
-			if err := writer.Write(payload); err != nil {
-				log.Info("Shadowsocks|Client: Failed to write payload: ", err)
-				return
-			}
-		}
 
 		requestDone := signal.ExecuteAsync(func() error {
 			defer ray.OutboundInput().ForceClose()

+ 17 - 9
proxy/vmess/outbound/outbound.go

@@ -1,10 +1,13 @@
 package outbound
 
 import (
+	"time"
+
 	"v2ray.com/core/app"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/bufio"
+	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/log"
 	v2net "v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
@@ -26,10 +29,9 @@ type VMessOutboundHandler struct {
 }
 
 // Dispatch implements OutboundHandler.Dispatch().
-func (v *VMessOutboundHandler) Dispatch(target v2net.Destination, payload *buf.Buffer, ray ray.OutboundRay) {
-	defer payload.Release()
-	defer ray.OutboundInput().ForceClose()
-	defer ray.OutboundOutput().Close()
+func (v *VMessOutboundHandler) Dispatch(target v2net.Destination, outboundRay ray.OutboundRay) {
+	defer outboundRay.OutboundInput().ForceClose()
+	defer outboundRay.OutboundOutput().Close()
 
 	var rec *protocol.ServerSpec
 	var conn internet.Connection
@@ -77,8 +79,8 @@ func (v *VMessOutboundHandler) Dispatch(target v2net.Destination, payload *buf.B
 		request.Option.Set(protocol.RequestOptionConnectionReuse)
 	}
 
-	input := ray.OutboundInput()
-	output := ray.OutboundOutput()
+	input := outboundRay.OutboundInput()
+	output := outboundRay.OutboundOutput()
 
 	session := encoding.NewClientSession(protocol.DefaultIDHash)
 
@@ -93,11 +95,17 @@ func (v *VMessOutboundHandler) Dispatch(target v2net.Destination, payload *buf.B
 		bodyWriter := session.EncodeRequestBody(request, writer)
 		defer bodyWriter.Release()
 
-		if !payload.IsEmpty() {
-			if err := bodyWriter.Write(payload); err != nil {
-				return err
+		firstPayload, err := input.ReadTimeout(time.Millisecond * 500)
+		if err != nil && err != ray.ErrReadTimeout {
+			return errors.Base(err).Message("VMess|Outbound: Failed to get first payload.")
+		}
+		if !firstPayload.IsEmpty() {
+			if err := bodyWriter.Write(firstPayload); err != nil {
+				return errors.Base(err).Message("VMess|Outbound: Failed to write first payload.")
 			}
+			firstPayload.Release()
 		}
+
 		writer.SetBuffered(false)
 
 		if err := buf.PipeUntilEOF(input, bodyWriter); err != nil {

+ 36 - 0
transport/ray/direct.go

@@ -1,8 +1,11 @@
 package ray
 
 import (
+	"errors"
 	"io"
 
+	"time"
+
 	"v2ray.com/core/common/buf"
 )
 
@@ -10,6 +13,8 @@ const (
 	bufferSize = 512
 )
 
+var ErrReadTimeout = errors.New("Ray: timeout.")
+
 // NewRay creates a new Ray for direct traffic transport.
 func NewRay() Ray {
 	return &directRay{
@@ -39,10 +44,19 @@ func (v *directRay) InboundOutput() InputStream {
 	return v.Output
 }
 
+func (v *directRay) AddInspector(inspector Inspector) {
+	if inspector == nil {
+		return
+	}
+	v.Input.inspector.AddInspector(inspector)
+	v.Output.inspector.AddInspector(inspector)
+}
+
 type Stream struct {
 	buffer    chan *buf.Buffer
 	srcClose  chan bool
 	destClose chan bool
+	inspector *InspectorChain
 }
 
 func NewStream() *Stream {
@@ -50,6 +64,7 @@ func NewStream() *Stream {
 		buffer:    make(chan *buf.Buffer, bufferSize),
 		srcClose:  make(chan bool),
 		destClose: make(chan bool),
+		inspector: &InspectorChain{},
 	}
 }
 
@@ -71,6 +86,26 @@ func (v *Stream) Read() (*buf.Buffer, error) {
 	}
 }
 
+func (v *Stream) ReadTimeout(timeout time.Duration) (*buf.Buffer, error) {
+	select {
+	case <-v.destClose:
+		return nil, io.ErrClosedPipe
+	case b := <-v.buffer:
+		return b, nil
+	default:
+		select {
+		case b := <-v.buffer:
+			return b, nil
+		case <-v.srcClose:
+			return nil, io.EOF
+		case <-v.destClose:
+			return nil, io.ErrClosedPipe
+		case <-time.After(timeout):
+			return nil, ErrReadTimeout
+		}
+	}
+}
+
 func (v *Stream) Write(data *buf.Buffer) (err error) {
 	if data.IsEmpty() {
 		return
@@ -88,6 +123,7 @@ func (v *Stream) Write(data *buf.Buffer) (err error) {
 		case <-v.srcClose:
 			return io.ErrClosedPipe
 		case v.buffer <- data:
+			v.inspector.Input(data)
 			return nil
 		}
 	}

+ 36 - 0
transport/ray/inspector.go

@@ -0,0 +1,36 @@
+package ray
+
+import (
+	"sync"
+
+	"v2ray.com/core/common/buf"
+)
+
+type Inspector interface {
+	Input(*buf.Buffer)
+}
+
+type NoOpInspector struct{}
+
+func (NoOpInspector) Input(*buf.Buffer) {}
+
+type InspectorChain struct {
+	sync.RWMutex
+	chain []Inspector
+}
+
+func (ic *InspectorChain) AddInspector(inspector Inspector) {
+	ic.Lock()
+	defer ic.Unlock()
+
+	ic.chain = append(ic.chain, inspector)
+}
+
+func (ic *InspectorChain) Input(b *buf.Buffer) {
+	ic.RLock()
+	defer ic.RUnlock()
+
+	for _, inspector := range ic.chain {
+		inspector.Input(b)
+	}
+}

+ 3 - 0
transport/ray/ray.go

@@ -1,6 +1,7 @@
 package ray
 
 import "v2ray.com/core/common/buf"
+import "time"
 
 // OutboundRay is a transport interface for outbound connections.
 type OutboundRay interface {
@@ -31,10 +32,12 @@ type InboundRay interface {
 type Ray interface {
 	InboundRay
 	OutboundRay
+	AddInspector(Inspector)
 }
 
 type InputStream interface {
 	buf.Reader
+	ReadTimeout(time.Duration) (*buf.Buffer, error)
 	ForceClose()
 }