task.go 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package task
  2. import (
  3. "context"
  4. "v2ray.com/core/common/signal/semaphore"
  5. )
  6. type Task func() error
  7. type executionContext struct {
  8. ctx context.Context
  9. tasks []Task
  10. onSuccess Task
  11. onFailure Task
  12. }
  13. func (c *executionContext) executeTask() error {
  14. if len(c.tasks) == 0 {
  15. return nil
  16. }
  17. // Reuse current goroutine if we only have one task to run.
  18. if len(c.tasks) == 1 && c.ctx == nil {
  19. return c.tasks[0]()
  20. }
  21. ctx := context.Background()
  22. if c.ctx != nil {
  23. ctx = c.ctx
  24. }
  25. return executeParallel(ctx, c.tasks)
  26. }
  27. func (c *executionContext) run() error {
  28. err := c.executeTask()
  29. if err == nil && c.onSuccess != nil {
  30. return c.onSuccess()
  31. }
  32. if err != nil && c.onFailure != nil {
  33. return c.onFailure()
  34. }
  35. return err
  36. }
  37. type ExecutionOption func(*executionContext)
  38. func WithContext(ctx context.Context) ExecutionOption {
  39. return func(c *executionContext) {
  40. c.ctx = ctx
  41. }
  42. }
  43. func Parallel(tasks ...Task) ExecutionOption {
  44. return func(c *executionContext) {
  45. c.tasks = append(c.tasks, tasks...)
  46. }
  47. }
  48. // Sequential runs all tasks sequentially, and returns the first error encountered.Sequential
  49. // Once a task returns an error, the following tasks will not run.
  50. func Sequential(tasks ...Task) ExecutionOption {
  51. return func(c *executionContext) {
  52. switch len(tasks) {
  53. case 0:
  54. return
  55. case 1:
  56. c.tasks = append(c.tasks, tasks[0])
  57. default:
  58. c.tasks = append(c.tasks, func() error {
  59. return execute(tasks...)
  60. })
  61. }
  62. }
  63. }
  64. func OnSuccess(task Task) ExecutionOption {
  65. return func(c *executionContext) {
  66. c.onSuccess = task
  67. }
  68. }
  69. func OnFailure(task Task) ExecutionOption {
  70. return func(c *executionContext) {
  71. c.onFailure = task
  72. }
  73. }
  74. func Single(task Task, opts ...ExecutionOption) Task {
  75. return Run(append([]ExecutionOption{Sequential(task)}, opts...)...)
  76. }
  77. func Run(opts ...ExecutionOption) Task {
  78. var c executionContext
  79. for _, opt := range opts {
  80. opt(&c)
  81. }
  82. return func() error {
  83. return c.run()
  84. }
  85. }
  86. // execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
  87. func execute(tasks ...Task) error {
  88. for _, task := range tasks {
  89. if err := task(); err != nil {
  90. return err
  91. }
  92. }
  93. return nil
  94. }
  95. // executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
  96. func executeParallel(ctx context.Context, tasks []Task) error {
  97. n := len(tasks)
  98. s := semaphore.New(n)
  99. done := make(chan error, 1)
  100. for _, task := range tasks {
  101. <-s.Wait()
  102. go func(f Task) {
  103. err := f()
  104. if err == nil {
  105. s.Signal()
  106. return
  107. }
  108. select {
  109. case done <- err:
  110. default:
  111. }
  112. }(task)
  113. }
  114. for i := 0; i < n; i++ {
  115. select {
  116. case err := <-done:
  117. return err
  118. case <-ctx.Done():
  119. return ctx.Err()
  120. case <-s.Wait():
  121. }
  122. }
  123. return nil
  124. }