Darien Raymond 7 лет назад
Родитель
Сommit
a059ee2c00
2 измененных файлов с 24 добавлено и 14 удалено
  1. 7 5
      proxy/socks/client.go
  2. 17 9
      proxy/socks/server.go

+ 7 - 5
proxy/socks/client.go

@@ -49,7 +49,7 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay, dialer proxy.
 	var server *protocol.ServerSpec
 	var conn internet.Connection
 
-	err := retry.ExponentialBackoff(5, 100).On(func() error {
+	if err := retry.ExponentialBackoff(5, 100).On(func() error {
 		server = c.serverPicker.PickServer()
 		dest := server.Destination()
 		rawConn, err := dialer.Dial(ctx, dest)
@@ -59,13 +59,15 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay, dialer proxy.
 		conn = rawConn
 
 		return nil
-	})
-
-	if err != nil {
+	}); err != nil {
 		return newError("failed to find an available destination").Base(err)
 	}
 
-	defer conn.Close()
+	defer func() {
+		if err := conn.Close(); err != nil {
+			newError("failed to closed connection").Base(err).WithContext(ctx).WriteToLog()
+		}
+	}()
 
 	p := c.policyManager.ForLevel(0)
 

+ 17 - 9
proxy/socks/server.go

@@ -41,6 +41,7 @@ func (s *Server) policy() core.Policy {
 	return p
 }
 
+// Network implements proxy.Inbound.
 func (s *Server) Network() net.NetworkList {
 	list := net.NetworkList{
 		Network: []net.Network{net.Network_TCP},
@@ -51,6 +52,7 @@ func (s *Server) Network() net.NetworkList {
 	return list
 }
 
+// Process implements proxy.Inbound.
 func (s *Server) Process(ctx context.Context, network net.Network, conn internet.Connection, dispatcher core.Dispatcher) error {
 	switch network {
 	case net.Network_TCP:
@@ -63,7 +65,10 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
 }
 
 func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispatcher core.Dispatcher) error {
-	conn.SetReadDeadline(time.Now().Add(s.policy().Timeouts.Handshake))
+	if err := conn.SetReadDeadline(time.Now().Add(s.policy().Timeouts.Handshake)); err != nil {
+		newError("failed to set deadline").Base(err).WithContext(ctx).WriteToLog()
+	}
+
 	reader := buf.NewBufferedReader(buf.NewReader(conn))
 
 	inboundDest, ok := proxy.InboundEntryPointFromContext(ctx)
@@ -87,7 +92,10 @@ func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispa
 		}
 		return newError("failed to read request").Base(err)
 	}
-	conn.SetReadDeadline(time.Time{})
+
+	if err := conn.SetReadDeadline(time.Time{}); err != nil {
+		newError("failed to clear deadline").Base(err).WithContext(ctx).WriteToLog()
+	}
 
 	if request.Command == protocol.RequestCommandTCP {
 		dest := request.Destination()
@@ -111,16 +119,16 @@ func (s *Server) processTCP(ctx context.Context, conn internet.Connection, dispa
 	return nil
 }
 
-func (*Server) handleUDP(c net.Conn) error {
+func (*Server) handleUDP(c io.Reader) error {
 	// The TCP connection closes after this method returns. We need to wait until
 	// the client closes it.
 	_, err := io.Copy(buf.DiscardBytes, c)
 	return err
 }
 
-func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher core.Dispatcher) error {
+func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writer, dest net.Destination, dispatcher core.Dispatcher) error {
 	ctx, cancel := context.WithCancel(ctx)
-	timer := signal.CancelAfterInactivity(ctx, cancel, v.policy().Timeouts.ConnectionIdle)
+	timer := signal.CancelAfterInactivity(ctx, cancel, s.policy().Timeouts.ConnectionIdle)
 
 	ray, err := dispatcher.Dispatch(ctx, dest)
 	if err != nil {
@@ -131,8 +139,8 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 	output := ray.InboundOutput()
 
 	requestDone := signal.ExecuteAsync(func() error {
-		defer timer.SetTimeout(v.policy().Timeouts.DownlinkOnly)
-		defer input.Close()
+		defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly)
+		defer common.Must(input.Close())
 
 		v2reader := buf.NewReader(reader)
 		if err := buf.Copy(v2reader, input, buf.UpdateActivity(timer)); err != nil {
@@ -143,7 +151,7 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 	})
 
 	responseDone := signal.ExecuteAsync(func() error {
-		defer timer.SetTimeout(v.policy().Timeouts.UplinkOnly)
+		defer timer.SetTimeout(s.policy().Timeouts.UplinkOnly)
 
 		v2writer := buf.NewWriter(writer)
 		if err := buf.Copy(output, v2writer, buf.UpdateActivity(timer)); err != nil {
@@ -162,7 +170,7 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 	return nil
 }
 
-func (v *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, dispatcher core.Dispatcher) error {
+func (s *Server) handleUDPPayload(ctx context.Context, conn internet.Connection, dispatcher core.Dispatcher) error {
 	udpServer := udp.NewDispatcher(dispatcher)
 
 	if source, ok := proxy.SourceFromContext(ctx); ok {