浏览代码

fix pipe closing logic for inbound proxies.

Darien Raymond 7 年之前
父节点
当前提交
7fa4bb434b

+ 2 - 2
common/functions/functions.go

@@ -5,8 +5,8 @@ import "v2ray.com/core/common"
 // Task is a function that may return an error.
 // Task is a function that may return an error.
 type Task func() error
 type Task func() error
 
 
-// CloseOnSuccess returns a Task to run a follow task if pre-condition passes, otherwise the error in pre-condition is returned.
-func CloseOnSuccess(pre func() error, followup Task) Task {
+// OnSuccess returns a Task to run a follow task if pre-condition passes, otherwise the error in pre-condition is returned.
+func OnSuccess(pre func() error, followup Task) Task {
 	return func() error {
 	return func() error {
 		if err := pre(); err != nil {
 		if err := pre(); err != nil {
 			return err
 			return err

+ 2 - 2
proxy/dokodemo/dokodemo.go

@@ -9,6 +9,7 @@ import (
 	"v2ray.com/core"
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/signal"
 	"v2ray.com/core/common/signal"
 	"v2ray.com/core/proxy"
 	"v2ray.com/core/proxy"
@@ -79,7 +80,6 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 	}
 	}
 
 
 	requestDone := func() error {
 	requestDone := func() error {
-		defer common.Close(link.Writer)
 		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
 		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
 
 
 		chunkReader := buf.NewReader(conn)
 		chunkReader := buf.NewReader(conn)
@@ -118,7 +118,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 		return nil
 		return nil
 	}
 	}
 
 
-	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)
 		return newError("connection ends").Base(err)

+ 1 - 1
proxy/freedom/freedom.go

@@ -136,7 +136,7 @@ func (h *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia
 		return nil
 		return nil
 	}
 	}
 
 
-	if err := signal.ExecuteParallel(ctx, requestDone, functions.CloseOnSuccess(responseDone, functions.Close(output))); err != nil {
+	if err := signal.ExecuteParallel(ctx, requestDone, functions.OnSuccess(responseDone, functions.Close(output))); err != nil {
 		return newError("connection ends").Base(err)
 		return newError("connection ends").Base(err)
 	}
 	}
 
 

+ 2 - 2
proxy/http/server.go

@@ -16,6 +16,7 @@ import (
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/errors"
+	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	http_proto "v2ray.com/core/common/protocol/http"
 	http_proto "v2ray.com/core/common/protocol/http"
@@ -192,7 +193,6 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 	}
 	}
 
 
 	requestDone := func() error {
 	requestDone := func() error {
-		defer common.Close(link.Writer)
 		defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly)
 		defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly)
 
 
 		v2reader := buf.NewReader(conn)
 		v2reader := buf.NewReader(conn)
@@ -210,7 +210,7 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 		return nil
 		return nil
 	}
 	}
 
 
-	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)
 		return newError("connection ends").Base(err)

+ 1 - 1
proxy/shadowsocks/client.go

@@ -158,7 +158,7 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial
 			return nil
 			return nil
 		}
 		}
 
 
-		if err := signal.ExecuteParallel(ctx, requestDone, functions.CloseOnSuccess(responseDone, functions.Close(link.Writer))); err != nil {
+		if err := signal.ExecuteParallel(ctx, requestDone, functions.OnSuccess(responseDone, functions.Close(link.Writer))); err != nil {
 			return newError("connection ends").Base(err)
 			return newError("connection ends").Base(err)
 		}
 		}
 
 

+ 2 - 2
proxy/shadowsocks/server.go

@@ -7,6 +7,7 @@ import (
 	"v2ray.com/core"
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/protocol"
@@ -207,7 +208,6 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 
 
 	requestDone := func() error {
 	requestDone := func() error {
 		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
 		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
-		defer common.Close(link.Writer)
 
 
 		if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
 		if err := buf.Copy(bodyReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
 			return newError("failed to transport all TCP request").Base(err)
 			return newError("failed to transport all TCP request").Base(err)
@@ -216,7 +216,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 		return nil
 		return nil
 	}
 	}
 
 
-	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)
 		return newError("connection ends").Base(err)

+ 1 - 1
proxy/socks/client.go

@@ -130,7 +130,7 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial
 		}
 		}
 	}
 	}
 
 
-	if err := signal.ExecuteParallel(ctx, requestFunc, functions.CloseOnSuccess(responseFunc, functions.Close(link.Writer))); err != nil {
+	if err := signal.ExecuteParallel(ctx, requestFunc, functions.OnSuccess(responseFunc, functions.Close(link.Writer))); err != nil {
 		return newError("connection ends").Base(err)
 		return newError("connection ends").Base(err)
 	}
 	}
 
 

+ 2 - 2
proxy/socks/server.go

@@ -8,6 +8,7 @@ import (
 	"v2ray.com/core"
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/protocol"
@@ -139,7 +140,6 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 
 
 	requestDone := func() error {
 	requestDone := func() error {
 		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
 		defer timer.SetTimeout(plcy.Timeouts.DownlinkOnly)
-		defer common.Close(link.Writer) // nolint: errcheck
 
 
 		v2reader := buf.NewReader(reader)
 		v2reader := buf.NewReader(reader)
 		if err := buf.Copy(v2reader, link.Writer, buf.UpdateActivity(timer)); err != nil {
 		if err := buf.Copy(v2reader, link.Writer, buf.UpdateActivity(timer)); err != nil {
@@ -160,7 +160,7 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 		return nil
 		return nil
 	}
 	}
 
 
-	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)
 		return newError("connection ends").Base(err)

+ 2 - 3
proxy/vmess/inbound/inbound.go

@@ -13,6 +13,7 @@ import (
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/errors"
+	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/protocol"
@@ -168,8 +169,6 @@ func (h *Handler) RemoveUser(ctx context.Context, email string) error {
 }
 }
 
 
 func transferRequest(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output buf.Writer) error {
 func transferRequest(timer signal.ActivityUpdater, session *encoding.ServerSession, request *protocol.RequestHeader, input io.Reader, output buf.Writer) error {
-	defer common.Close(output)
-
 	bodyReader := session.DecodeRequestBody(request, input)
 	bodyReader := session.DecodeRequestBody(request, input)
 	if err := buf.Copy(bodyReader, output, buf.UpdateActivity(timer)); err != nil {
 	if err := buf.Copy(bodyReader, output, buf.UpdateActivity(timer)); err != nil {
 		return newError("failed to transfer request").Base(err)
 		return newError("failed to transfer request").Base(err)
@@ -295,7 +294,7 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
 		return transferResponse(timer, session, request, response, link.Reader, writer)
 		return transferResponse(timer, session, request, response, link.Reader, writer)
 	}
 	}
 
 
-	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)
 		return newError("connection ends").Base(err)

+ 1 - 1
proxy/vmess/outbound/outbound.go

@@ -161,7 +161,7 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia
 		return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
 		return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
 	}
 	}
 
 
-	if err := signal.ExecuteParallel(ctx, requestDone, functions.CloseOnSuccess(responseDone, functions.Close(output))); err != nil {
+	if err := signal.ExecuteParallel(ctx, requestDone, functions.OnSuccess(responseDone, functions.Close(output))); err != nil {
 		return newError("connection ends").Base(err)
 		return newError("connection ends").Base(err)
 	}
 	}