Browse Source

simplify task execution

Darien Raymond 7 years ago
parent
commit
427679e66d

+ 2 - 1
common/task/common.go

@@ -2,7 +2,8 @@ package task
 
 import "v2ray.com/core/common"
 
-func Close(v interface{}) Task {
+// Close returns a func() that closes v.
+func Close(v interface{}) func() error {
 	return func() error {
 		return common.Close(v)
 	}

+ 7 - 103
common/task/task.go

@@ -6,121 +6,25 @@ import (
 	"v2ray.com/core/common/signal/semaphore"
 )
 
-type Task func() error
-
-type executionContext struct {
-	ctx       context.Context
-	tasks     []Task
-	onSuccess Task
-	onFailure Task
-}
-
-func (c *executionContext) executeTask() error {
-	if len(c.tasks) == 0 {
-		return nil
-	}
-
-	// Reuse current goroutine if we only have one task to run.
-	if len(c.tasks) == 1 && c.ctx == nil {
-		return c.tasks[0]()
-	}
-
-	ctx := context.Background()
-
-	if c.ctx != nil {
-		ctx = c.ctx
-	}
-
-	return executeParallel(ctx, c.tasks)
-}
-
-func (c *executionContext) run() error {
-	err := c.executeTask()
-	if err == nil && c.onSuccess != nil {
-		return c.onSuccess()
-	}
-	if err != nil && c.onFailure != nil {
-		return c.onFailure()
-	}
-	return err
-}
-
-type ExecutionOption func(*executionContext)
-
-func WithContext(ctx context.Context) ExecutionOption {
-	return func(c *executionContext) {
-		c.ctx = ctx
-	}
-}
-
-func Parallel(tasks ...Task) ExecutionOption {
-	return func(c *executionContext) {
-		c.tasks = append(c.tasks, tasks...)
-	}
-}
-
-// Sequential runs all tasks sequentially, and returns the first error encountered.Sequential
-// Once a task returns an error, the following tasks will not run.
-func Sequential(tasks ...Task) ExecutionOption {
-	return func(c *executionContext) {
-		switch len(tasks) {
-		case 0:
-			return
-		case 1:
-			c.tasks = append(c.tasks, tasks[0])
-		default:
-			c.tasks = append(c.tasks, func() error {
-				return execute(tasks...)
-			})
-		}
-	}
-}
-
-func OnSuccess(task Task) ExecutionOption {
-	return func(c *executionContext) {
-		c.onSuccess = task
-	}
-}
-
-func OnFailure(task Task) ExecutionOption {
-	return func(c *executionContext) {
-		c.onFailure = task
-	}
-}
-
-func Single(task Task, opts ...ExecutionOption) Task {
-	return Run(append([]ExecutionOption{Sequential(task)}, opts...)...)
-}
-
-func Run(opts ...ExecutionOption) Task {
-	var c executionContext
-	for _, opt := range opts {
-		opt(&c)
-	}
+// OnSuccess executes g() after f() returns nil.
+func OnSuccess(f func() error, g func() error) func() error {
 	return func() error {
-		return c.run()
-	}
-}
-
-// execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
-func execute(tasks ...Task) error {
-	for _, task := range tasks {
-		if err := task(); err != nil {
+		if err := f(); err != nil {
 			return err
 		}
+		return g()
 	}
-	return nil
 }
 
-// executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
-func executeParallel(ctx context.Context, tasks []Task) error {
+// Run executes a list of tasks in parallel, returns the first error encountered or nil if all tasks pass.
+func Run(ctx context.Context, tasks ...func() error) error {
 	n := len(tasks)
 	s := semaphore.New(n)
 	done := make(chan error, 1)
 
 	for _, task := range tasks {
 		<-s.Wait()
-		go func(f Task) {
+		go func(f func() error) {
 			err := f()
 			if err == nil {
 				s.Signal()

+ 12 - 22
common/task/task_test.go

@@ -14,13 +14,14 @@ import (
 func TestExecuteParallel(t *testing.T) {
 	assert := With(t)
 
-	err := Run(Parallel(func() error {
-		time.Sleep(time.Millisecond * 200)
-		return errors.New("test")
-	}, func() error {
-		time.Sleep(time.Millisecond * 500)
-		return errors.New("test2")
-	}))()
+	err := Run(context.Background(),
+		func() error {
+			time.Sleep(time.Millisecond * 200)
+			return errors.New("test")
+		}, func() error {
+			time.Sleep(time.Millisecond * 500)
+			return errors.New("test2")
+		})
 
 	assert(err.Error(), Equals, "test")
 }
@@ -29,7 +30,7 @@ func TestExecuteParallelContextCancel(t *testing.T) {
 	assert := With(t)
 
 	ctx, cancel := context.WithCancel(context.Background())
-	err := Run(WithContext(ctx), Parallel(func() error {
+	err := Run(ctx, func() error {
 		time.Sleep(time.Millisecond * 2000)
 		return errors.New("test")
 	}, func() error {
@@ -38,7 +39,7 @@ func TestExecuteParallelContextCancel(t *testing.T) {
 	}, func() error {
 		cancel()
 		return nil
-	}))()
+	})
 
 	assert(err.Error(), HasSubstring, "canceled")
 }
@@ -48,7 +49,7 @@ func BenchmarkExecuteOne(b *testing.B) {
 		return nil
 	}
 	for i := 0; i < b.N; i++ {
-		common.Must(Run(Parallel(noop))())
+		common.Must(Run(context.Background(), noop))
 	}
 }
 
@@ -57,17 +58,6 @@ func BenchmarkExecuteTwo(b *testing.B) {
 		return nil
 	}
 	for i := 0; i < b.N; i++ {
-		common.Must(Run(Parallel(noop, noop))())
-	}
-}
-
-func BenchmarkExecuteContext(b *testing.B) {
-	noop := func() error {
-		return nil
-	}
-	background := context.Background()
-
-	for i := 0; i < b.N; i++ {
-		common.Must(Run(WithContext(background), Parallel(noop, noop))())
+		common.Must(Run(context.Background(), noop, noop))
 	}
 }

+ 1 - 4
proxy/dokodemo/dokodemo.go

@@ -147,10 +147,7 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 		return nil
 	}
 
-	if err := task.Run(task.WithContext(ctx),
-		task.Parallel(
-			task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))),
-			responseDone))(); err != nil {
+	if err := task.Run(ctx, task.OnSuccess(requestDone, task.Close(link.Writer)), responseDone); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)

+ 1 - 1
proxy/freedom/freedom.go

@@ -167,7 +167,7 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		return nil
 	}
 
-	if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, task.Single(responseDone, task.OnSuccess(task.Close(output)))))(); err != nil {
+	if err := task.Run(ctx, requestDone, task.OnSuccess(responseDone, task.Close(output))); err != nil {
 		return newError("connection ends").Base(err)
 	}
 

+ 3 - 3
proxy/http/server.go

@@ -210,8 +210,8 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 		return nil
 	}
 
-	var closeWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
-	if err := task.Run(task.WithContext(ctx), task.Parallel(closeWriter, responseDone))(); err != nil {
+	var closeWriter = task.OnSuccess(requestDone, task.Close(link.Writer))
+	if err := task.Run(ctx, closeWriter, responseDone); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)
@@ -307,7 +307,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
 		return nil
 	}
 
-	if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDone))(); err != nil {
+	if err := task.Run(ctx, requestDone, responseDone); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)

+ 2 - 2
proxy/mtproto/client.go

@@ -62,8 +62,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 		return buf.Copy(connReader, link.Writer)
 	}
 
-	var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer)))
-	if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil {
+	var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer))
+	if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil {
 		return newError("connection ends").Base(err)
 	}
 

+ 2 - 2
proxy/mtproto/server.go

@@ -141,8 +141,8 @@ func (s *Server) Process(ctx context.Context, network net.Network, conn internet
 		return buf.Copy(link.Reader, writer, buf.UpdateActivity(timer))
 	}
 
-	var responseDoneAndCloseWriter = task.Single(response, task.OnSuccess(task.Close(link.Writer)))
-	if err := task.Run(task.WithContext(ctx), task.Parallel(request, responseDoneAndCloseWriter))(); err != nil {
+	var responseDoneAndCloseWriter = task.OnSuccess(response, task.Close(link.Writer))
+	if err := task.Run(ctx, request, responseDoneAndCloseWriter); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)

+ 4 - 4
proxy/shadowsocks/client.go

@@ -129,8 +129,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 			return buf.Copy(responseReader, link.Writer, buf.UpdateActivity(timer))
 		}
 
-		var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
-		if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
+		var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer))
+		if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
 			return newError("connection ends").Base(err)
 		}
 
@@ -167,8 +167,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 			return nil
 		}
 
-		var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
-		if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
+		var responseDoneAndCloseWriter = task.OnSuccess(responseDone, task.Close(link.Writer))
+		if err := task.Run(ctx, requestDone, responseDoneAndCloseWriter); err != nil {
 			return newError("connection ends").Base(err)
 		}
 

+ 2 - 2
proxy/shadowsocks/server.go

@@ -229,8 +229,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 		return nil
 	}
 
-	var requestDoneAndCloseWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
-	if err := task.Run(task.WithContext(ctx), task.Parallel(requestDoneAndCloseWriter, responseDone))(); err != nil {
+	var requestDoneAndCloseWriter = task.OnSuccess(requestDone, task.Close(link.Writer))
+	if err := task.Run(ctx, requestDoneAndCloseWriter, responseDone); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)

+ 2 - 2
proxy/socks/client.go

@@ -137,8 +137,8 @@ func (c *Client) Process(ctx context.Context, link *transport.Link, dialer inter
 		}
 	}
 
-	var responseDonePost = task.Single(responseFunc, task.OnSuccess(task.Close(link.Writer)))
-	if err := task.Run(task.WithContext(ctx), task.Parallel(requestFunc, responseDonePost))(); err != nil {
+	var responseDonePost = task.OnSuccess(responseFunc, task.Close(link.Writer))
+	if err := task.Run(ctx, requestFunc, responseDonePost); err != nil {
 		return newError("connection ends").Base(err)
 	}
 

+ 2 - 2
proxy/socks/server.go

@@ -164,8 +164,8 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 		return nil
 	}
 
-	var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
-	if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
+	var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
+	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)

+ 2 - 2
proxy/vmess/inbound/inbound.go

@@ -302,8 +302,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
 		return transferResponse(timer, svrSession, request, response, link.Reader, writer)
 	}
 
-	var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
-	if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
+	var requestDonePost = task.OnSuccess(requestDone, task.Close(link.Writer))
+	if err := task.Run(ctx, requestDonePost, responseDone); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)

+ 2 - 2
proxy/vmess/outbound/outbound.go

@@ -161,8 +161,8 @@ func (v *Handler) Process(ctx context.Context, link *transport.Link, dialer inte
 		return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
 	}
 
-	var responseDonePost = task.Single(responseDone, task.OnSuccess(task.Close(output)))
-	if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDonePost))(); err != nil {
+	var responseDonePost = task.OnSuccess(responseDone, task.Close(output))
+	if err := task.Run(ctx, requestDone, responseDonePost); err != nil {
 		return newError("connection ends").Base(err)
 	}
 

+ 2 - 2
testing/servers/tcp/tcp.go

@@ -64,7 +64,7 @@ func (server *Server) handleConnection(conn net.Conn) {
 	}
 
 	pReader, pWriter := pipe.New(pipe.WithoutSizeLimit())
-	err := task.Run(task.Parallel(func() error {
+	err := task.Run(context.Background(), func() error {
 		defer pWriter.Close() // nolint: errcheck
 
 		for {
@@ -96,7 +96,7 @@ func (server *Server) handleConnection(conn net.Conn) {
 				return err
 			}
 		}
-	}))()
+	})
 
 	if err != nil {
 		fmt.Println("failed to transfer data: ", err.Error())