Browse Source

refactor tproxy handling in dokodemo

Darien Raymond 6 years ago
parent
commit
bb5a959876
1 changed files with 34 additions and 40 deletions
  1. 34 40
      proxy/dokodemo/dokodemo.go

+ 34 - 40
proxy/dokodemo/dokodemo.go

@@ -117,60 +117,54 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 		return nil
 	}
 
-	var tConn net.Conn
-	responseDone := func() error {
-		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
+	tproxyRequest := func() error {
+		return nil
+	}
 
-		var writer buf.Writer
-		if network == net.Network_TCP {
-			writer = buf.NewWriter(conn)
+	var writer buf.Writer
+	if network == net.Network_TCP {
+		writer = buf.NewWriter(conn)
+	} else {
+		//if we are in TPROXY mode, use linux's udp forging functionality
+		if !destinationOverridden {
+			writer = &buf.SequentialWriter{Writer: conn}
 		} else {
-			//if we are in TPROXY mode, use linux's udp forging functionality
-			if !destinationOverridden {
-				writer = &buf.SequentialWriter{Writer: conn}
-			} else {
-				sockopt := &internet.SocketConfig{
-					Tproxy: internet.SocketConfig_TProxy,
-				}
-				if dest.Address.Family().IsIP() {
-					sockopt.BindAddress = dest.Address.IP()
-					sockopt.BindPort = uint32(dest.Port)
-				}
-				var err error
-				tConn, err = internet.DialSystem(ctx, net.DestinationFromAddr(conn.RemoteAddr()), sockopt)
-				if err != nil {
-					return err
+			sockopt := &internet.SocketConfig{
+				Tproxy: internet.SocketConfig_TProxy,
+			}
+			if dest.Address.Family().IsIP() {
+				sockopt.BindAddress = dest.Address.IP()
+				sockopt.BindPort = uint32(dest.Port)
+			}
+			tConn, err := internet.DialSystem(ctx, net.DestinationFromAddr(conn.RemoteAddr()), sockopt)
+			if err != nil {
+				return err
+			}
+			defer tConn.Close()
+
+			writer = &buf.SequentialWriter{Writer: tConn}
+			tReader := buf.NewReader(tConn)
+			tproxyRequest = func() error {
+				if err := buf.Copy(tReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
+					return newError("failed to transport request (TPROXY conn)").Base(err)
 				}
-				writer = &buf.SequentialWriter{Writer: tConn}
-				tReader := buf.NewReader(tConn)
-				go func() {
-					defer tConn.Close()
-					defer common.Close(link.Writer)
-					if err := buf.Copy(tReader, link.Writer, buf.UpdateActivity(timer)); err != nil {
-						newError("failed to transport request (TPROXY conn)").Base(err).WriteToLog()
-					}
-				}()
+				return nil
 			}
 		}
+	}
+
+	responseDone := func() error {
+		defer timer.SetTimeout(plcy.Timeouts.UplinkOnly)
 
-		defer func() {
-			if tConn != nil {
-				tConn.Close()
-			}
-		}()
 		if err := buf.Copy(link.Reader, writer, buf.UpdateActivity(timer)); err != nil {
 			return newError("failed to transport response").Base(err)
 		}
-
 		return nil
 	}
 
-	if err := task.Run(ctx, task.OnSuccess(requestDone, task.Close(link.Writer)), responseDone); err != nil {
+	if err := task.Run(ctx, task.OnSuccess(requestDone, task.Close(link.Writer)), responseDone, tproxyRequest); err != nil {
 		common.Interrupt(link.Reader)
 		common.Interrupt(link.Writer)
-		if tConn != nil {
-			tConn.Close()
-		}
 		return newError("connection ends").Base(err)
 	}