task.go 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. package task
  2. import (
  3. "context"
  4. "strings"
  5. "v2ray.com/core/common"
  6. "v2ray.com/core/common/signal/semaphore"
  7. )
  8. type Task func() error
  9. type MultiError []error
  10. func (e MultiError) Error() string {
  11. var r strings.Builder
  12. common.Must2(r.WriteString("multierr: "))
  13. for _, err := range e {
  14. common.Must2(r.WriteString(err.Error()))
  15. common.Must2(r.WriteString(" | "))
  16. }
  17. return r.String()
  18. }
  19. type executionContext struct {
  20. ctx context.Context
  21. tasks []Task
  22. onSuccess Task
  23. onFailure Task
  24. }
  25. func (c *executionContext) executeTask() error {
  26. if len(c.tasks) == 0 {
  27. return nil
  28. }
  29. // Reuse current goroutine if we only have one task to run.
  30. if len(c.tasks) == 1 && c.ctx == nil {
  31. return c.tasks[0]()
  32. }
  33. ctx := context.Background()
  34. if c.ctx != nil {
  35. ctx = c.ctx
  36. }
  37. return executeParallel(ctx, c.tasks)
  38. }
  39. func (c *executionContext) run() error {
  40. err := c.executeTask()
  41. if err == nil && c.onSuccess != nil {
  42. return c.onSuccess()
  43. }
  44. if err != nil && c.onFailure != nil {
  45. return c.onFailure()
  46. }
  47. return err
  48. }
  49. type ExecutionOption func(*executionContext)
  50. func WithContext(ctx context.Context) ExecutionOption {
  51. return func(c *executionContext) {
  52. c.ctx = ctx
  53. }
  54. }
  55. func Parallel(tasks ...Task) ExecutionOption {
  56. return func(c *executionContext) {
  57. c.tasks = append(c.tasks, tasks...)
  58. }
  59. }
  60. // Sequential runs all tasks sequentially, and returns the first error encountered.Sequential
  61. // Once a task returns an error, the following tasks will not run.
  62. func Sequential(tasks ...Task) ExecutionOption {
  63. return func(c *executionContext) {
  64. switch len(tasks) {
  65. case 0:
  66. return
  67. case 1:
  68. c.tasks = append(c.tasks, tasks[0])
  69. default:
  70. c.tasks = append(c.tasks, func() error {
  71. return execute(tasks...)
  72. })
  73. }
  74. }
  75. }
  76. func SequentialAll(tasks ...Task) ExecutionOption {
  77. return func(c *executionContext) {
  78. switch len(tasks) {
  79. case 0:
  80. return
  81. case 1:
  82. c.tasks = append(c.tasks, tasks[0])
  83. default:
  84. c.tasks = append(c.tasks, func() error {
  85. var merr MultiError
  86. for _, task := range tasks {
  87. if err := task(); err != nil {
  88. merr = append(merr, err)
  89. }
  90. }
  91. if len(merr) == 0 {
  92. return nil
  93. }
  94. return merr
  95. })
  96. }
  97. }
  98. }
  99. func OnSuccess(task Task) ExecutionOption {
  100. return func(c *executionContext) {
  101. c.onSuccess = task
  102. }
  103. }
  104. func OnFailure(task Task) ExecutionOption {
  105. return func(c *executionContext) {
  106. c.onFailure = task
  107. }
  108. }
  109. func Single(task Task, opts ...ExecutionOption) Task {
  110. return Run(append([]ExecutionOption{Sequential(task)}, opts...)...)
  111. }
  112. func Run(opts ...ExecutionOption) Task {
  113. var c executionContext
  114. for _, opt := range opts {
  115. opt(&c)
  116. }
  117. return func() error {
  118. return c.run()
  119. }
  120. }
  121. // execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
  122. func execute(tasks ...Task) error {
  123. for _, task := range tasks {
  124. if err := task(); err != nil {
  125. return err
  126. }
  127. }
  128. return nil
  129. }
  130. // executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
  131. func executeParallel(ctx context.Context, tasks []Task) error {
  132. n := len(tasks)
  133. s := semaphore.New(n)
  134. done := make(chan error, 1)
  135. for _, task := range tasks {
  136. <-s.Wait()
  137. go func(f Task) {
  138. err := f()
  139. if err == nil {
  140. s.Signal()
  141. return
  142. }
  143. select {
  144. case done <- err:
  145. default:
  146. }
  147. }(task)
  148. }
  149. for i := 0; i < n; i++ {
  150. select {
  151. case err := <-done:
  152. return err
  153. case <-ctx.Done():
  154. return ctx.Err()
  155. case <-s.Wait():
  156. }
  157. }
  158. return nil
  159. }