| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186 |
- package task
- import (
- "context"
- "strings"
- "v2ray.com/core/common"
- "v2ray.com/core/common/signal/semaphore"
- )
- type Task func() error
- type MultiError []error
- func (e MultiError) Error() string {
- var r strings.Builder
- common.Must2(r.WriteString("multierr: "))
- for _, err := range e {
- common.Must2(r.WriteString(err.Error()))
- common.Must2(r.WriteString(" | "))
- }
- return r.String()
- }
- type executionContext struct {
- ctx context.Context
- tasks []Task
- onSuccess Task
- onFailure Task
- }
- func (c *executionContext) executeTask() error {
- if len(c.tasks) == 0 {
- return nil
- }
- // Reuse current goroutine if we only have one task to run.
- if len(c.tasks) == 1 && c.ctx == nil {
- return c.tasks[0]()
- }
- ctx := context.Background()
- if c.ctx != nil {
- ctx = c.ctx
- }
- return executeParallel(ctx, c.tasks)
- }
- func (c *executionContext) run() error {
- err := c.executeTask()
- if err == nil && c.onSuccess != nil {
- return c.onSuccess()
- }
- if err != nil && c.onFailure != nil {
- return c.onFailure()
- }
- return err
- }
- type ExecutionOption func(*executionContext)
- func WithContext(ctx context.Context) ExecutionOption {
- return func(c *executionContext) {
- c.ctx = ctx
- }
- }
- func Parallel(tasks ...Task) ExecutionOption {
- return func(c *executionContext) {
- c.tasks = append(c.tasks, tasks...)
- }
- }
- // Sequential runs all tasks sequentially, and returns the first error encountered.Sequential
- // Once a task returns an error, the following tasks will not run.
- func Sequential(tasks ...Task) ExecutionOption {
- return func(c *executionContext) {
- switch len(tasks) {
- case 0:
- return
- case 1:
- c.tasks = append(c.tasks, tasks[0])
- default:
- c.tasks = append(c.tasks, func() error {
- return execute(tasks...)
- })
- }
- }
- }
- func SequentialAll(tasks ...Task) ExecutionOption {
- return func(c *executionContext) {
- switch len(tasks) {
- case 0:
- return
- case 1:
- c.tasks = append(c.tasks, tasks[0])
- default:
- c.tasks = append(c.tasks, func() error {
- var merr MultiError
- for _, task := range tasks {
- if err := task(); err != nil {
- merr = append(merr, err)
- }
- }
- if len(merr) == 0 {
- return nil
- }
- return merr
- })
- }
- }
- }
- func OnSuccess(task Task) ExecutionOption {
- return func(c *executionContext) {
- c.onSuccess = task
- }
- }
- func OnFailure(task Task) ExecutionOption {
- return func(c *executionContext) {
- c.onFailure = task
- }
- }
- func Single(task Task, opts ...ExecutionOption) Task {
- return Run(append([]ExecutionOption{Sequential(task)}, opts...)...)
- }
- func Run(opts ...ExecutionOption) Task {
- var c executionContext
- for _, opt := range opts {
- opt(&c)
- }
- return func() error {
- return c.run()
- }
- }
- // execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
- func execute(tasks ...Task) error {
- for _, task := range tasks {
- if err := task(); err != nil {
- return err
- }
- }
- return nil
- }
- // executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
- func executeParallel(ctx context.Context, tasks []Task) error {
- n := len(tasks)
- s := semaphore.New(n)
- done := make(chan error, 1)
- for _, task := range tasks {
- <-s.Wait()
- go func(f Task) {
- err := f()
- if err == nil {
- s.Signal()
- return
- }
- select {
- case done <- err:
- default:
- }
- }(task)
- }
- for i := 0; i < n; i++ {
- select {
- case err := <-done:
- return err
- case <-ctx.Done():
- return ctx.Err()
- case <-s.Wait():
- }
- }
- return nil
- }
|