Darien Raymond 7 years ago
parent
commit
54e2244c5a
1 changed files with 12 additions and 18 deletions
  1. 12 18
      common/task/task.go

+ 12 - 18
common/task/task.go

@@ -10,29 +10,23 @@ type Task func() error
 
 type executionContext struct {
 	ctx       context.Context
-	task      Task
+	tasks     []Task
 	onSuccess Task
 	onFailure Task
 }
 
 func (c *executionContext) executeTask() error {
-	if c.ctx == nil && c.task == nil {
+	if len(c.tasks) == 0 {
 		return nil
 	}
 
-	if c.ctx == nil {
-		return c.task()
-	}
+	ctx := context.Background()
 
-	if c.task == nil {
-		<-c.ctx.Done()
-		return c.ctx.Err()
+	if c.ctx != nil {
+		ctx = c.ctx
 	}
 
-	return executeParallel(func() error {
-		<-c.ctx.Done()
-		return c.ctx.Err()
-	}, c.task)
+	return executeParallel(ctx, c.tasks)
 }
 
 func (c *executionContext) run() error {
@@ -56,17 +50,15 @@ func WithContext(ctx context.Context) ExecutionOption {
 
 func Parallel(tasks ...Task) ExecutionOption {
 	return func(c *executionContext) {
-		c.task = func() error {
-			return executeParallel(tasks...)
-		}
+		c.tasks = append(c.tasks, tasks...)
 	}
 }
 
 func Sequential(tasks ...Task) ExecutionOption {
 	return func(c *executionContext) {
-		c.task = func() error {
+		c.tasks = append(c.tasks, func() error {
 			return execute(tasks...)
-		}
+		})
 	}
 }
 
@@ -107,7 +99,7 @@ func execute(tasks ...Task) error {
 }
 
 // executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
-func executeParallel(tasks ...Task) error {
+func executeParallel(ctx context.Context, tasks []Task) error {
 	n := len(tasks)
 	s := semaphore.New(n)
 	done := make(chan error, 1)
@@ -129,6 +121,8 @@ func executeParallel(tasks ...Task) error {
 		select {
 		case err := <-done:
 			return err
+		case <-ctx.Done():
+			return ctx.Err()
 		case <-s.Wait():
 		}
 	}