Kaynağa Gözat

close outbound connections when context is done

Darien Raymond 8 yıl önce
ebeveyn
işleme
fab20bb0cf

+ 8 - 1
app/proxyman/outbound/handler.go

@@ -65,7 +65,14 @@ func NewHandler(ctx context.Context, config *proxyman.OutboundHandlerConfig) (*H
 
 func (h *Handler) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) {
 	ctx = proxy.ContextWithDialer(ctx, h)
-	h.proxy.Process(ctx, outboundRay)
+	err := h.proxy.Process(ctx, outboundRay)
+	// Ensure outbound ray is properly closed.
+	if err != nil {
+		outboundRay.OutboundOutput().CloseError()
+	} else {
+		outboundRay.OutboundOutput().Close()
+	}
+	outboundRay.OutboundInput().CloseError()
 }
 
 func (h *Handler) Dial(ctx context.Context, dest v2net.Destination) (internet.Connection, error) {

+ 18 - 3
common/signal/exec.go

@@ -1,5 +1,9 @@
 package signal
 
+import (
+	"context"
+)
+
 func executeAndFulfill(f func() error, done chan<- error) {
 	err := f()
 	if err != nil {
@@ -14,17 +18,28 @@ func ExecuteAsync(f func() error) <-chan error {
 	return done
 }
 
-func ErrorOrFinish2(c1, c2 <-chan error) error {
+func ErrorOrFinish1(ctx context.Context, c <-chan error) error {
+	select {
+	case <-ctx.Done():
+		return ctx.Err()
+	case err := <-c:
+		return err
+	}
+}
+
+func ErrorOrFinish2(ctx context.Context, c1, c2 <-chan error) error {
 	select {
+	case <-ctx.Done():
+		return ctx.Err()
 	case err, failed := <-c1:
 		if failed {
 			return err
 		}
-		return <-c2
+		return ErrorOrFinish1(ctx, c2)
 	case err, failed := <-c2:
 		if failed {
 			return err
 		}
-		return <-c1
+		return ErrorOrFinish1(ctx, c1)
 	}
 }

+ 5 - 4
common/signal/exec_test.go

@@ -1,6 +1,7 @@
 package signal_test
 
 import (
+	"context"
 	"errors"
 	"testing"
 
@@ -16,7 +17,7 @@ func TestErrorOrFinish2_Error(t *testing.T) {
 	c := make(chan error, 1)
 
 	go func() {
-		c <- ErrorOrFinish2(c1, c2)
+		c <- ErrorOrFinish2(context.Background(), c1, c2)
 	}()
 
 	c1 <- errors.New("test")
@@ -32,7 +33,7 @@ func TestErrorOrFinish2_Error2(t *testing.T) {
 	c := make(chan error, 1)
 
 	go func() {
-		c <- ErrorOrFinish2(c1, c2)
+		c <- ErrorOrFinish2(context.Background(), c1, c2)
 	}()
 
 	c2 <- errors.New("test")
@@ -48,7 +49,7 @@ func TestErrorOrFinish2_NoneError(t *testing.T) {
 	c := make(chan error, 1)
 
 	go func() {
-		c <- ErrorOrFinish2(c1, c2)
+		c <- ErrorOrFinish2(context.Background(), c1, c2)
 	}()
 
 	close(c1)
@@ -71,7 +72,7 @@ func TestErrorOrFinish2_NoneError2(t *testing.T) {
 	c := make(chan error, 1)
 
 	go func() {
-		c <- ErrorOrFinish2(c1, c2)
+		c <- ErrorOrFinish2(context.Background(), c1, c2)
 	}()
 
 	close(c2)

+ 1 - 1
proxy/dokodemo/dokodemo.go

@@ -83,7 +83,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 		return nil
 	})
 
-	if err := signal.ErrorOrFinish2(requestDone, responseDone); err != nil {
+	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
 		inboundRay.InboundInput().CloseError()
 		inboundRay.InboundOutput().CloseError()
 		log.Info("Dokodemo: Connection ends with ", err)

+ 1 - 1
proxy/freedom/freedom.go

@@ -136,7 +136,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay) erro
 		return nil
 	})
 
-	if err := signal.ErrorOrFinish2(requestDone, responseDone); err != nil {
+	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
 		log.Info("Freedom: Connection ending with ", err)
 		input.CloseError()
 		output.CloseError()

+ 2 - 2
proxy/http/server.go

@@ -150,7 +150,7 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 		return nil
 	})
 
-	if err := signal.ErrorOrFinish2(requestDone, responseDone); err != nil {
+	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
 		log.Info("HTTP|Server: Connection ends with: ", err)
 		ray.InboundInput().CloseError()
 		ray.InboundOutput().CloseError()
@@ -246,7 +246,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, rea
 		return nil
 	})
 
-	if err := signal.ErrorOrFinish2(requestDone, responseDone); err != nil {
+	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
 		log.Info("HTTP|Server: Connecton ending with ", err)
 		input.CloseError()
 		output.CloseError()

+ 3 - 4
proxy/shadowsocks/client.go

@@ -61,6 +61,7 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay) error
 	}
 	log.Info("Shadowsocks|Client: Tunneling request to ", destination, " via ", server.Destination())
 
+	defer conn.Close()
 	conn.SetReusable(false)
 
 	request := &protocol.RequestHeader{
@@ -119,7 +120,7 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay) error
 			return nil
 		})
 
-		if err := signal.ErrorOrFinish2(requestDone, responseDone); err != nil {
+		if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
 			log.Info("Shadowsocks|Client: Connection ends with ", err)
 			outboundRay.OutboundInput().CloseError()
 			outboundRay.OutboundOutput().CloseError()
@@ -161,10 +162,8 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay) error
 			return nil
 		})
 
-		if err := signal.ErrorOrFinish2(requestDone, responseDone); err != nil {
+		if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
 			log.Info("Shadowsocks|Client: Connection ends with ", err)
-			outboundRay.OutboundInput().CloseError()
-			outboundRay.OutboundOutput().CloseError()
 			return err
 		}
 

+ 3 - 3
proxy/shadowsocks/server.go

@@ -69,6 +69,8 @@ func (s *Server) Network() net.NetworkList {
 }
 
 func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection) error {
+	conn.SetReusable(false)
+
 	switch network {
 	case net.Network_TCP:
 		return s.handleConnection(ctx, conn)
@@ -132,8 +134,6 @@ func (v *Server) handlerUDPPayload(ctx context.Context, conn internet.Connection
 }
 
 func (s *Server) handleConnection(ctx context.Context, conn internet.Connection) error {
-	conn.SetReusable(false)
-
 	timedReader := net.NewTimeOutReader(16, conn)
 	bufferedReader := bufio.NewReader(timedReader)
 	request, bodyReader, err := ReadTCPSession(s.user, bufferedReader)
@@ -195,7 +195,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection)
 		return nil
 	})
 
-	if err := signal.ErrorOrFinish2(requestDone, responseDone); err != nil {
+	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
 		log.Info("Shadowsocks|Server: Connection ends with ", err)
 		ray.InboundInput().CloseError()
 		ray.InboundOutput().CloseError()

+ 1 - 3
proxy/socks/client.go

@@ -108,10 +108,8 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay) error {
 
 	requestDone := signal.ExecuteAsync(requestFunc)
 	responseDone := signal.ExecuteAsync(responseFunc)
-	if err := signal.ErrorOrFinish2(requestDone, responseDone); err != nil {
+	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
 		log.Info("Socks|Client: Connection ends with ", err)
-		ray.OutboundInput().CloseError()
-		ray.OutboundOutput().CloseError()
 		return err
 	}
 

+ 1 - 2
proxy/socks/server.go

@@ -137,10 +137,9 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 			return err
 		}
 		return nil
-
 	})
 
-	if err := signal.ErrorOrFinish2(requestDone, responseDone); err != nil {
+	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
 		log.Info("Socks|Server: Connection ends with ", err)
 		input.CloseError()
 		output.CloseError()

+ 1 - 1
proxy/vmess/inbound/inbound.go

@@ -222,7 +222,7 @@ func (v *VMessInboundHandler) Process(ctx context.Context, network net.Network,
 		return transferResponse(session, request, response, output, writer)
 	})
 
-	if err := signal.ErrorOrFinish2(requestDone, responseDone); err != nil {
+	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
 		log.Info("VMess|Inbound: Connection ending with ", err)
 		connection.SetReusable(false)
 		input.CloseError()

+ 1 - 3
proxy/vmess/outbound/outbound.go

@@ -152,11 +152,9 @@ func (v *VMessOutboundHandler) Process(ctx context.Context, outboundRay ray.Outb
 		return nil
 	})
 
-	if err := signal.ErrorOrFinish2(requestDone, responseDone); err != nil {
+	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
 		log.Info("VMess|Outbound: Connection ending with ", err)
 		conn.SetReusable(false)
-		input.CloseError()
-		output.CloseError()
 		return err
 	}