Browse Source

unified task package

Darien Raymond 7 năm trước cách đây
mục cha
commit
13f3c356ca

+ 3 - 3
app/dns/nameserver.go

@@ -11,7 +11,7 @@ import (
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/dice"
 	"v2ray.com/core/common/net"
-	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/task"
 	"v2ray.com/core/transport/internet/udp"
 )
 
@@ -42,7 +42,7 @@ type UDPNameServer struct {
 	address   net.Destination
 	requests  map[uint16]*PendingRequest
 	udpServer *udp.Dispatcher
-	cleanup   *signal.PeriodicTask
+	cleanup   *task.Periodic
 }
 
 func NewUDPNameServer(address net.Destination, dispatcher core.Dispatcher) *UDPNameServer {
@@ -51,7 +51,7 @@ func NewUDPNameServer(address net.Destination, dispatcher core.Dispatcher) *UDPN
 		requests:  make(map[uint16]*PendingRequest),
 		udpServer: udp.NewDispatcher(dispatcher),
 	}
-	s.cleanup = &signal.PeriodicTask{
+	s.cleanup = &task.Periodic{
 		Interval: time.Minute,
 		Execute:  s.Cleanup,
 	}

+ 3 - 3
app/dns/server.go

@@ -11,7 +11,7 @@ import (
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/net"
-	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/task"
 )
 
 const (
@@ -33,7 +33,7 @@ type Server struct {
 	hosts   map[string]net.IP
 	records map[string]*DomainRecord
 	servers []NameServer
-	task    *signal.PeriodicTask
+	task    *task.Periodic
 }
 
 func New(ctx context.Context, config *Config) (*Server, error) {
@@ -42,7 +42,7 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 		servers: make([]NameServer, len(config.NameServers)),
 		hosts:   config.GetInternalHosts(),
 	}
-	server.task = &signal.PeriodicTask{
+	server.task = &task.Periodic{
 		Interval: time.Minute * 10,
 		Execute: func() error {
 			server.cleanup()

+ 3 - 3
app/proxyman/inbound/dynamic.go

@@ -10,7 +10,7 @@ import (
 	"v2ray.com/core/app/proxyman/mux"
 	"v2ray.com/core/common/dice"
 	"v2ray.com/core/common/net"
-	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/task"
 	"v2ray.com/core/proxy"
 )
 
@@ -25,7 +25,7 @@ type DynamicInboundHandler struct {
 	worker         []worker
 	lastRefresh    time.Time
 	mux            *mux.Server
-	task           *signal.PeriodicTask
+	task           *task.Periodic
 }
 
 func NewDynamicInboundHandler(ctx context.Context, tag string, receiverConfig *proxyman.ReceiverConfig, proxyConfig interface{}) (*DynamicInboundHandler, error) {
@@ -39,7 +39,7 @@ func NewDynamicInboundHandler(ctx context.Context, tag string, receiverConfig *p
 		v:              v,
 	}
 
-	h.task = &signal.PeriodicTask{
+	h.task = &task.Periodic{
 		Interval: time.Minute * time.Duration(h.receiverConfig.AllocationStrategy.GetRefreshValue()),
 		Execute:  h.refresh,
 	}

+ 0 - 23
common/functions/functions.go

@@ -1,23 +0,0 @@
-package functions
-
-import "v2ray.com/core/common"
-
-// Task is a function that may return an error.
-type Task func() error
-
-// OnSuccess returns a Task to run a follow task if pre-condition passes, otherwise the error in pre-condition is returned.
-func OnSuccess(pre func() error, followup Task) Task {
-	return func() error {
-		if err := pre(); err != nil {
-			return err
-		}
-		return followup()
-	}
-}
-
-// Close returns a Task to close the object.
-func Close(obj interface{}) Task {
-	return func() error {
-		return common.Close(obj)
-	}
-}

+ 9 - 0
common/task/common.go

@@ -0,0 +1,9 @@
+package task
+
+import "v2ray.com/core/common"
+
+func Close(v interface{}) Task {
+	return func() error {
+		return common.Close(v)
+	}
+}

+ 6 - 6
common/signal/task.go → common/task/periodic.go

@@ -1,12 +1,12 @@
-package signal
+package task
 
 import (
 	"sync"
 	"time"
 )
 
-// PeriodicTask is a task that runs periodically.
-type PeriodicTask struct {
+// Periodic is a task that runs periodically.
+type Periodic struct {
 	// Interval of the task being run
 	Interval time.Duration
 	// Execute is the task function
@@ -19,7 +19,7 @@ type PeriodicTask struct {
 	closed bool
 }
 
-func (t *PeriodicTask) checkedExecute() error {
+func (t *Periodic) checkedExecute() error {
 	t.access.Lock()
 	defer t.access.Unlock()
 
@@ -41,7 +41,7 @@ func (t *PeriodicTask) checkedExecute() error {
 }
 
 // Start implements common.Runnable. Start must not be called multiple times without Close being called.
-func (t *PeriodicTask) Start() error {
+func (t *Periodic) Start() error {
 	t.access.Lock()
 	t.closed = false
 	t.access.Unlock()
@@ -55,7 +55,7 @@ func (t *PeriodicTask) Start() error {
 }
 
 // Close implements common.Closable.
-func (t *PeriodicTask) Close() error {
+func (t *Periodic) Close() error {
 	t.access.Lock()
 	defer t.access.Unlock()
 

+ 5 - 4
common/signal/task_test.go → common/task/periodic_test.go

@@ -1,19 +1,20 @@
-package signal_test
+package task_test
 
 import (
 	"testing"
 	"time"
 
-	"v2ray.com/core/common"
-	. "v2ray.com/core/common/signal"
+	. "v2ray.com/core/common/task"
 	. "v2ray.com/ext/assert"
+
+	"v2ray.com/core/common"
 )
 
 func TestPeriodicTaskStop(t *testing.T) {
 	assert := With(t)
 
 	value := 0
-	task := &PeriodicTask{
+	task := &Periodic{
 		Interval: time.Second * 2,
 		Execute: func() error {
 			value++

+ 137 - 0
common/task/task.go

@@ -0,0 +1,137 @@
+package task
+
+import (
+	"context"
+
+	"v2ray.com/core/common/signal"
+)
+
+type Task func() error
+
+type executionContext struct {
+	ctx       context.Context
+	task      Task
+	onSuccess Task
+	onFailure Task
+}
+
+func (c *executionContext) executeTask() error {
+	if c.ctx == nil && c.task == nil {
+		return nil
+	}
+
+	if c.ctx == nil {
+		return c.task()
+	}
+
+	if c.task == nil {
+		<-c.ctx.Done()
+		return c.ctx.Err()
+	}
+
+	return executeParallel(func() error {
+		<-c.ctx.Done()
+		return c.ctx.Err()
+	}, c.task)
+}
+
+func (c *executionContext) run() error {
+	err := c.executeTask()
+	if err == nil && c.onSuccess != nil {
+		return c.onSuccess()
+	}
+	if err != nil && c.onFailure != nil {
+		return c.onFailure()
+	}
+	return err
+}
+
+type ExecutionOption func(*executionContext)
+
+func WithContext(ctx context.Context) ExecutionOption {
+	return func(c *executionContext) {
+		c.ctx = ctx
+	}
+}
+
+func Parallel(tasks ...Task) ExecutionOption {
+	return func(c *executionContext) {
+		c.task = func() error {
+			return executeParallel(tasks...)
+		}
+	}
+}
+
+func Sequential(tasks ...Task) ExecutionOption {
+	return func(c *executionContext) {
+		c.task = func() error {
+			return execute(tasks...)
+		}
+	}
+}
+
+func OnSuccess(task Task) ExecutionOption {
+	return func(c *executionContext) {
+		c.onSuccess = task
+	}
+}
+
+func OnFailure(task Task) ExecutionOption {
+	return func(c *executionContext) {
+		c.onFailure = task
+	}
+}
+
+func Single(task Task, opts ExecutionOption) Task {
+	return Run(append([]ExecutionOption{Sequential(task)}, opts)...)
+}
+
+func Run(opts ...ExecutionOption) Task {
+	var c executionContext
+	for _, opt := range opts {
+		opt(&c)
+	}
+	return func() error {
+		return c.run()
+	}
+}
+
+// execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
+func execute(tasks ...Task) error {
+	for _, task := range tasks {
+		if err := task(); err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+// executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
+func executeParallel(tasks ...Task) error {
+	n := len(tasks)
+	s := signal.NewSemaphore(n)
+	done := make(chan error, 1)
+
+	for _, task := range tasks {
+		<-s.Wait()
+		go func(f func() error) {
+			if err := f(); err != nil {
+				select {
+				case done <- err:
+				default:
+				}
+			}
+			s.Signal()
+		}(task)
+	}
+
+	for i := 0; i < n; i++ {
+		select {
+		case err := <-done:
+			return err
+		case <-s.Wait():
+		}
+	}
+
+	return nil
+}

+ 43 - 0
common/task/task_test.go

@@ -0,0 +1,43 @@
+package task_test
+
+import (
+	"context"
+	"errors"
+	"testing"
+	"time"
+
+	. "v2ray.com/core/common/task"
+	. "v2ray.com/ext/assert"
+)
+
+func TestExecuteParallel(t *testing.T) {
+	assert := With(t)
+
+	err := Run(Parallel(func() error {
+		time.Sleep(time.Millisecond * 200)
+		return errors.New("test")
+	}, func() error {
+		time.Sleep(time.Millisecond * 500)
+		return errors.New("test2")
+	}))()
+
+	assert(err.Error(), Equals, "test")
+}
+
+func TestExecuteParallelContextCancel(t *testing.T) {
+	assert := With(t)
+
+	ctx, cancel := context.WithCancel(context.Background())
+	err := Run(WithContext(ctx), Parallel(func() error {
+		time.Sleep(time.Millisecond * 2000)
+		return errors.New("test")
+	}, func() error {
+		time.Sleep(time.Millisecond * 5000)
+		return errors.New("test2")
+	}, func() error {
+		cancel()
+		return nil
+	}))()
+
+	assert(err.Error(), HasSubstring, "canceled")
+}

+ 5 - 2
proxy/dokodemo/dokodemo.go

@@ -9,9 +9,9 @@ import (
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
-	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/task"
 	"v2ray.com/core/proxy"
 	"v2ray.com/core/transport/internet"
 	"v2ray.com/core/transport/internet/udp"
@@ -118,7 +118,10 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 		return nil
 	}
 
-	if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil {
+	if err := task.Run(task.WithContext(ctx),
+		task.Parallel(
+			task.Single(requestDone, task.OnSuccess(task.Close(link.Writer))),
+			responseDone))(); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)

+ 2 - 2
proxy/freedom/freedom.go

@@ -10,10 +10,10 @@ import (
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/dice"
-	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/retry"
 	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/task"
 	"v2ray.com/core/proxy"
 	"v2ray.com/core/transport/internet"
 )
@@ -136,7 +136,7 @@ func (h *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia
 		return nil
 	}
 
-	if err := signal.ExecuteParallel(ctx, requestDone, functions.OnSuccess(responseDone, functions.Close(output))); err != nil {
+	if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, task.Single(responseDone, task.OnSuccess(task.Close(output)))))(); err != nil {
 		return newError("connection ends").Base(err)
 	}
 

+ 2 - 0
proxy/http/client.go

@@ -1,4 +1,6 @@
 package http
 
+/*
 type Client struct {
 }
+*/

+ 4 - 2
proxy/http/server.go

@@ -10,13 +10,14 @@ import (
 	"strings"
 	"time"
 
+	"v2ray.com/core/common/task"
+
 	"v2ray.com/core/transport/pipe"
 
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/errors"
-	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/net"
 	http_proto "v2ray.com/core/common/protocol/http"
@@ -210,7 +211,8 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 		return nil
 	}
 
-	if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil {
+	var closeWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
+	if err := task.Run(task.WithContext(ctx), task.Parallel(closeWriter, responseDone))(); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)

+ 4 - 2
proxy/shadowsocks/client.go

@@ -3,10 +3,11 @@ package shadowsocks
 import (
 	"context"
 
+	"v2ray.com/core/common/task"
+
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
-	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/retry"
@@ -158,7 +159,8 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial
 			return nil
 		}
 
-		if err := signal.ExecuteParallel(ctx, requestDone, functions.OnSuccess(responseDone, functions.Close(link.Writer))); err != nil {
+		var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
+		if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
 			return newError("connection ends").Base(err)
 		}
 

+ 4 - 2
proxy/shadowsocks/server.go

@@ -4,10 +4,11 @@ import (
 	"context"
 	"time"
 
+	"v2ray.com/core/common/task"
+
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
-	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
@@ -216,7 +217,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 		return nil
 	}
 
-	if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil {
+	var requestDoneAndCloseWriter = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
+	if err := task.Run(task.WithContext(ctx), task.Parallel(requestDoneAndCloseWriter, responseDone))(); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)

+ 4 - 2
proxy/socks/client.go

@@ -4,10 +4,11 @@ import (
 	"context"
 	"time"
 
+	"v2ray.com/core/common/task"
+
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
-	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/retry"
@@ -130,7 +131,8 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial
 		}
 	}
 
-	if err := signal.ExecuteParallel(ctx, requestFunc, functions.OnSuccess(responseFunc, functions.Close(link.Writer))); err != nil {
+	var responseDonePost = task.Single(responseFunc, task.OnSuccess(task.Close(link.Writer)))
+	if err := task.Run(task.WithContext(ctx), task.Parallel(requestFunc, responseDonePost))(); err != nil {
 		return newError("connection ends").Base(err)
 	}
 

+ 4 - 2
proxy/socks/server.go

@@ -5,10 +5,11 @@ import (
 	"io"
 	"time"
 
+	"v2ray.com/core/common/task"
+
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
-	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
@@ -160,7 +161,8 @@ func (s *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 		return nil
 	}
 
-	if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil {
+	var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
+	if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)

+ 3 - 3
proxy/vmess/encoding/server.go

@@ -19,7 +19,7 @@ import (
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/serial"
-	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/task"
 	"v2ray.com/core/proxy/vmess"
 )
 
@@ -33,7 +33,7 @@ type sessionId struct {
 type SessionHistory struct {
 	sync.RWMutex
 	cache map[sessionId]time.Time
-	task  *signal.PeriodicTask
+	task  *task.Periodic
 }
 
 // NewSessionHistory creates a new SessionHistory object.
@@ -41,7 +41,7 @@ func NewSessionHistory() *SessionHistory {
 	h := &SessionHistory{
 		cache: make(map[sessionId]time.Time, 128),
 	}
-	h.task = &signal.PeriodicTask{
+	h.task = &task.Periodic{
 		Interval: time.Second * 30,
 		Execute: func() error {
 			h.removeExpiredEntries()

+ 4 - 2
proxy/vmess/inbound/inbound.go

@@ -9,11 +9,12 @@ import (
 	"sync"
 	"time"
 
+	"v2ray.com/core/common/task"
+
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/errors"
-	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
@@ -294,7 +295,8 @@ func (h *Handler) Process(ctx context.Context, network net.Network, connection i
 		return transferResponse(timer, session, request, response, link.Reader, writer)
 	}
 
-	if err := signal.ExecuteParallel(ctx, functions.OnSuccess(requestDone, functions.Close(link.Writer)), responseDone); err != nil {
+	var requestDonePost = task.Single(requestDone, task.OnSuccess(task.Close(link.Writer)))
+	if err := task.Run(task.WithContext(ctx), task.Parallel(requestDonePost, responseDone))(); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)

+ 4 - 2
proxy/vmess/outbound/outbound.go

@@ -6,12 +6,13 @@ import (
 	"context"
 	"time"
 
+	"v2ray.com/core/common/task"
+
 	"v2ray.com/core/transport/pipe"
 
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
-	"v2ray.com/core/common/functions"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/retry"
@@ -161,7 +162,8 @@ func (v *Handler) Process(ctx context.Context, link *core.Link, dialer proxy.Dia
 		return buf.Copy(bodyReader, output, buf.UpdateActivity(timer))
 	}
 
-	if err := signal.ExecuteParallel(ctx, requestDone, functions.OnSuccess(responseDone, functions.Close(output))); err != nil {
+	var responseDonePost = task.Single(responseDone, task.OnSuccess(task.Close(output)))
+	if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDonePost))(); err != nil {
 		return newError("connection ends").Base(err)
 	}
 

+ 3 - 3
proxy/vmess/vmess.go

@@ -14,7 +14,7 @@ import (
 
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/protocol"
-	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/task"
 )
 
 const (
@@ -34,7 +34,7 @@ type TimedUserValidator struct {
 	userHash map[[16]byte]indexTimePair
 	hasher   protocol.IDHash
 	baseTime protocol.Timestamp
-	task     *signal.PeriodicTask
+	task     *task.Periodic
 }
 
 type indexTimePair struct {
@@ -49,7 +49,7 @@ func NewTimedUserValidator(hasher protocol.IDHash) *TimedUserValidator {
 		hasher:   hasher,
 		baseTime: protocol.Timestamp(time.Now().Unix() - cacheDurationSec*2),
 	}
-	tuv.task = &signal.PeriodicTask{
+	tuv.task = &task.Periodic{
 		Interval: updateInterval,
 		Execute: func() error {
 			tuv.updateUserHash()