Bladeren bron

better way to run tasks in parallel

Darien Raymond 7 jaren geleden
bovenliggende
commit
0caf92726b

+ 5 - 5
common/platform/ctlcmd/ctlcmd.go

@@ -39,16 +39,16 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) {
 	}
 
 	var content buf.MultiBuffer
-	loadTask := signal.ExecuteAsync(func() error {
+	loadTask := func() error {
 		c, err := buf.ReadAllToMultiBuffer(stdoutReader)
 		if err != nil {
 			return err
 		}
 		content = c
 		return nil
-	})
+	}
 
-	waitTask := signal.ExecuteAsync(func() error {
+	waitTask := func() error {
 		if err := cmd.Wait(); err != nil {
 			msg := "failed to execute v2ctl"
 			if errBuffer.Len() > 0 {
@@ -57,9 +57,9 @@ func Run(args []string, input io.Reader) (buf.MultiBuffer, error) {
 			return newError(msg).Base(err)
 		}
 		return nil
-	})
+	}
 
-	if err := signal.ErrorOrFinish2(context.Background(), loadTask, waitTask); err != nil {
+	if err := signal.ExecuteParallel(context.Background(), loadTask, waitTask); err != nil {
 		return nil, err
 	}
 

+ 23 - 32
common/signal/exec.go

@@ -4,14 +4,6 @@ import (
 	"context"
 )
 
-func executeAndFulfill(f func() error, done chan<- error) {
-	err := f()
-	if err != nil {
-		done <- err
-	}
-	close(done)
-}
-
 // Execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
 func Execute(tasks ...func() error) error {
 	for _, task := range tasks {
@@ -22,35 +14,34 @@ func Execute(tasks ...func() error) error {
 	return nil
 }
 
-// ExecuteAsync executes a function asynchronously and return its result.
-func ExecuteAsync(f func() error) <-chan error {
+// ExecuteParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
+func ExecuteParallel(ctx context.Context, tasks ...func() error) error {
+	n := len(tasks)
+	s := NewSemaphore(n)
 	done := make(chan error, 1)
-	go executeAndFulfill(f, done)
-	return done
-}
 
-func ErrorOrFinish1(ctx context.Context, c <-chan error) error {
-	select {
-	case <-ctx.Done():
-		return ctx.Err()
-	case err := <-c:
-		return err
+	for _, task := range tasks {
+		<-s.Wait()
+		go func(f func() error) {
+			if err := f(); err != nil {
+				select {
+				case done <- err:
+				default:
+				}
+			}
+			s.Signal()
+		}(task)
 	}
-}
 
-func ErrorOrFinish2(ctx context.Context, c1, c2 <-chan error) error {
-	select {
-	case <-ctx.Done():
-		return ctx.Err()
-	case err := <-c1:
-		if err != nil {
-			return err
-		}
-		return ErrorOrFinish1(ctx, c2)
-	case err := <-c2:
-		if err != nil {
+	for i := 0; i < n; i++ {
+		select {
+		case <-ctx.Done():
+			return ctx.Err()
+		case err := <-done:
 			return err
+		case <-s.Wait():
 		}
-		return ErrorOrFinish1(ctx, c1)
 	}
+
+	return nil
 }

+ 5 - 5
proxy/dokodemo/dokodemo.go

@@ -75,7 +75,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 		return newError("failed to dispatch request").Base(err)
 	}
 
-	requestDone := signal.ExecuteAsync(func() error {
+	requestDone := func() error {
 		defer inboundRay.InboundInput().Close()
 		defer timer.SetTimeout(d.policy().Timeouts.DownlinkOnly)
 
@@ -86,9 +86,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 		}
 
 		return nil
-	})
+	}
 
-	responseDone := signal.ExecuteAsync(func() error {
+	responseDone := func() error {
 		defer timer.SetTimeout(d.policy().Timeouts.UplinkOnly)
 
 		var writer buf.Writer
@@ -113,9 +113,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 		}
 
 		return nil
-	})
+	}
 
-	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
 		inboundRay.InboundInput().CloseError()
 		inboundRay.InboundOutput().CloseError()
 		return newError("connection ends").Base(err)

+ 5 - 5
proxy/freedom/freedom.go

@@ -109,7 +109,7 @@ func (h *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
 	ctx, cancel := context.WithCancel(ctx)
 	timer := signal.CancelAfterInactivity(ctx, cancel, h.policy().Timeouts.ConnectionIdle)
 
-	requestDone := signal.ExecuteAsync(func() error {
+	requestDone := func() error {
 		defer timer.SetTimeout(h.policy().Timeouts.DownlinkOnly)
 
 		var writer buf.Writer
@@ -123,9 +123,9 @@ func (h *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
 		}
 
 		return nil
-	})
+	}
 
-	responseDone := signal.ExecuteAsync(func() error {
+	responseDone := func() error {
 		defer timer.SetTimeout(h.policy().Timeouts.UplinkOnly)
 
 		v2reader := buf.NewReader(conn)
@@ -134,9 +134,9 @@ func (h *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
 		}
 
 		return nil
-	})
+	}
 
-	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
 		input.CloseError()
 		output.CloseError()
 		return newError("connection ends").Base(err)

+ 10 - 10
proxy/http/server.go

@@ -182,15 +182,15 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 		reader = nil
 	}
 
-	requestDone := signal.ExecuteAsync(func() error {
+	requestDone := func() error {
 		defer ray.InboundInput().Close()
 		defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly)
 
 		v2reader := buf.NewReader(conn)
 		return buf.Copy(v2reader, ray.InboundInput(), buf.UpdateActivity(timer))
-	})
+	}
 
-	responseDone := signal.ExecuteAsync(func() error {
+	responseDone := func() error {
 		defer timer.SetTimeout(s.policy().Timeouts.UplinkOnly)
 
 		v2writer := buf.NewWriter(conn)
@@ -199,9 +199,9 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 		}
 
 		return nil
-	})
+	}
 
-	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
 		ray.InboundInput().CloseError()
 		ray.InboundOutput().CloseError()
 		return newError("connection ends").Base(err)
@@ -251,7 +251,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
 
 	var result error = errWaitAnother
 
-	requestDone := signal.ExecuteAsync(func() error {
+	requestDone := func() error {
 		request.Header.Set("Connection", "close")
 
 		requestWriter := buf.NewBufferedWriter(ray.InboundInput())
@@ -260,9 +260,9 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
 			return newError("failed to write whole request").Base(err).AtWarning()
 		}
 		return nil
-	})
+	}
 
-	responseDone := signal.ExecuteAsync(func() error {
+	responseDone := func() error {
 		responseReader := bufio.NewReaderSize(buf.NewBufferedReader(ray.InboundOutput()), buf.Size)
 		response, err := http.ReadResponse(responseReader, request)
 		if err == nil {
@@ -296,9 +296,9 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
 			return newError("failed to write response").Base(err).AtWarning()
 		}
 		return nil
-	})
+	}
 
-	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
 		input.CloseError()
 		output.CloseError()
 		return newError("connection ends").Base(err)

+ 10 - 10
proxy/shadowsocks/client.go

@@ -105,12 +105,12 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
 			return err
 		}
 
-		requestDone := signal.ExecuteAsync(func() error {
+		requestDone := func() error {
 			defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
 			return buf.Copy(outboundRay.OutboundInput(), bodyWriter, buf.UpdateActivity(timer))
-		})
+		}
 
-		responseDone := signal.ExecuteAsync(func() error {
+		responseDone := func() error {
 			defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
 
 			responseReader, err := ReadTCPResponse(user, conn)
@@ -119,9 +119,9 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
 			}
 
 			return buf.Copy(responseReader, outboundRay.OutboundOutput(), buf.UpdateActivity(timer))
-		})
+		}
 
-		if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
+		if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
 			return newError("connection ends").Base(err)
 		}
 
@@ -135,16 +135,16 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
 			Request: request,
 		})
 
-		requestDone := signal.ExecuteAsync(func() error {
+		requestDone := func() error {
 			defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
 
 			if err := buf.Copy(outboundRay.OutboundInput(), writer, buf.UpdateActivity(timer)); err != nil {
 				return newError("failed to transport all UDP request").Base(err)
 			}
 			return nil
-		})
+		}
 
-		responseDone := signal.ExecuteAsync(func() error {
+		responseDone := func() error {
 			defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
 
 			reader := &UDPReader{
@@ -156,9 +156,9 @@ func (c *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
 				return newError("failed to transport all UDP response").Base(err)
 			}
 			return nil
-		})
+		}
 
-		if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
+		if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
 			return newError("connection ends").Base(err)
 		}
 

+ 5 - 5
proxy/shadowsocks/server.go

@@ -172,7 +172,7 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 		return err
 	}
 
-	responseDone := signal.ExecuteAsync(func() error {
+	responseDone := func() error {
 		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
 
 		bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))
@@ -200,9 +200,9 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 		}
 
 		return nil
-	})
+	}
 
-	requestDone := signal.ExecuteAsync(func() error {
+	requestDone := func() error {
 		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
 		defer ray.InboundInput().Close()
 
@@ -211,9 +211,9 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 		}
 
 		return nil
-	})
+	}
 
-	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
 		ray.InboundInput().CloseError()
 		ray.InboundOutput().CloseError()
 		return newError("connection ends").Base(err)

+ 1 - 3
proxy/socks/client.go

@@ -130,9 +130,7 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay, dialer proxy.
 		}
 	}
 
-	requestDone := signal.ExecuteAsync(requestFunc)
-	responseDone := signal.ExecuteAsync(responseFunc)
-	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, requestFunc, responseFunc); err != nil {
 		return newError("connection ends").Base(err)
 	}
 

+ 5 - 5
proxy/socks/server.go

@@ -137,7 +137,7 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 	input := ray.InboundInput()
 	output := ray.InboundOutput()
 
-	requestDone := signal.ExecuteAsync(func() error {
+	requestDone := func() error {
 		defer timer.SetTimeout(s.policy().Timeouts.DownlinkOnly)
 		defer input.Close()
 
@@ -147,9 +147,9 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 		}
 
 		return nil
-	})
+	}
 
-	responseDone := signal.ExecuteAsync(func() error {
+	responseDone := func() error {
 		defer timer.SetTimeout(s.policy().Timeouts.UplinkOnly)
 
 		v2writer := buf.NewWriter(writer)
@@ -158,9 +158,9 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 		}
 
 		return nil
-	})
+	}
 
-	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
 		input.CloseError()
 		output.CloseError()
 		return newError("connection ends").Base(err)

+ 5 - 5
proxy/vmess/inbound/inbound.go

@@ -280,12 +280,12 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
 	input := ray.InboundInput()
 	output := ray.InboundOutput()
 
-	requestDone := signal.ExecuteAsync(func() error {
+	requestDone := func() error {
 		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
 		return transferRequest(timer, session, request, reader, input)
-	})
+	}
 
-	responseDone := signal.ExecuteAsync(func() error {
+	responseDone := func() error {
 		writer := buf.NewBufferedWriter(buf.NewWriter(connection))
 		defer writer.Flush()
 		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
@@ -294,9 +294,9 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
 			Command: h.generateCommand(ctx, request),
 		}
 		return transferResponse(timer, session, request, response, output, writer)
-	})
+	}
 
-	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
 		input.CloseError()
 		output.CloseError()
 		return newError("connection ends").Base(err)

+ 5 - 5
proxy/vmess/outbound/outbound.go

@@ -104,7 +104,7 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
 	ctx, cancel := context.WithCancel(ctx)
 	timer := signal.CancelAfterInactivity(ctx, cancel, sessionPolicy.Timeouts.ConnectionIdle)
 
-	requestDone := signal.ExecuteAsync(func() error {
+	requestDone := func() error {
 		defer timer.SetTimeout(sessionPolicy.Timeouts.DownlinkOnly)
 
 		writer := buf.NewBufferedWriter(buf.NewWriter(conn))
@@ -140,9 +140,9 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
 		}
 
 		return nil
-	})
+	}
 
-	responseDone := signal.ExecuteAsync(func() error {
+	responseDone := func() error {
 		defer timer.SetTimeout(sessionPolicy.Timeouts.UplinkOnly)
 
 		reader := buf.NewBufferedReader(buf.NewReader(conn))
@@ -156,9 +156,9 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
 		bodyReader := session.DecodeResponseBody(request, reader)
 
 		return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
-	})
+	}
 
-	if err := signal.ErrorOrFinish2(ctx, requestDone, responseDone); err != nil {
+	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
 		return newError("connection ends").Base(err)
 	}