Explorar o código

update activity timer

Darien Raymond %!s(int64=7) %!d(string=hai) anos
pai
achega
ac6a0f7511

+ 2 - 2
app/commander/commander.go

@@ -10,7 +10,7 @@ import (
 	"google.golang.org/grpc"
 	"v2ray.com/core"
 	"v2ray.com/core/common"
-	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/signal/done"
 )
 
 // Commander is a V2Ray feature that provides gRPC methods to external clients.
@@ -64,7 +64,7 @@ func (c *Commander) Start() error {
 
 	listener := &OutboundListener{
 		buffer: make(chan net.Conn, 4),
-		done:   signal.NewDone(),
+		done:   done.New(),
 	}
 
 	go func() {

+ 2 - 1
app/commander/outbound.go

@@ -8,13 +8,14 @@ import (
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/signal/done"
 	"v2ray.com/core/transport/pipe"
 )
 
 // OutboundListener is a net.Listener for listening gRPC connections.
 type OutboundListener struct {
 	buffer chan net.Conn
-	done   *signal.Done
+	done   *done.Instance
 }
 
 func (l *OutboundListener) add(conn net.Conn) {

+ 5 - 5
app/proxyman/inbound/worker.go

@@ -13,7 +13,7 @@ import (
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/session"
-	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/signal/done"
 	"v2ray.com/core/proxy"
 	"v2ray.com/core/transport/internet"
 	"v2ray.com/core/transport/internet/tcp"
@@ -115,7 +115,7 @@ type udpConn struct {
 	output           func([]byte) (int, error)
 	remote           net.Addr
 	local            net.Addr
-	done             *signal.Done
+	done             *done.Instance
 	uplink           core.StatCounter
 	downlink         core.StatCounter
 }
@@ -223,7 +223,7 @@ type udpWorker struct {
 	uplinkCounter   core.StatCounter
 	downlinkCounter core.StatCounter
 
-	done       *signal.Done
+	done       *done.Instance
 	activeConn map[connID]*udpConn
 }
 
@@ -248,7 +248,7 @@ func (w *udpWorker) getConnection(id connID) (*udpConn, bool) {
 			IP:   w.address.IP(),
 			Port: int(w.port),
 		},
-		done:     signal.NewDone(),
+		done:     done.New(),
 		uplink:   w.uplinkCounter,
 		downlink: w.downlinkCounter,
 	}
@@ -305,7 +305,7 @@ func (w *udpWorker) removeConn(id connID) {
 
 func (w *udpWorker) Start() error {
 	w.activeConn = make(map[connID]*udpConn, 16)
-	w.done = signal.NewDone()
+	w.done = done.New()
 	h, err := udp.ListenUDP(w.address, w.port, w.callback, udp.HubReceiveOriginalDestination(w.recvOrigDest), udp.HubCapacity(256))
 	if err != nil {
 		return err

+ 3 - 3
app/proxyman/mux/mux.go

@@ -16,7 +16,7 @@ import (
 	"v2ray.com/core/common/log"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
-	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/signal/done"
 	"v2ray.com/core/proxy"
 	"v2ray.com/core/transport/pipe"
 )
@@ -77,7 +77,7 @@ func (m *ClientManager) onClientFinish() {
 type Client struct {
 	sessionManager *SessionManager
 	link           core.Link
-	done           *signal.Done
+	done           *done.Instance
 	manager        *ClientManager
 	concurrency    uint32
 }
@@ -100,7 +100,7 @@ func NewClient(pctx context.Context, p proxy.Outbound, dialer proxy.Dialer, m *C
 			Reader: downlinkReader,
 			Writer: upLinkWriter,
 		},
-		done:        signal.NewDone(),
+		done:        done.New(),
 		manager:     m,
 		concurrency: m.config.Concurrency,
 	}

+ 6 - 5
common/log/logger.go

@@ -7,7 +7,8 @@ import (
 	"time"
 
 	"v2ray.com/core/common/platform"
-	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/signal/done"
+	"v2ray.com/core/common/signal/semaphore"
 )
 
 // Writer is the interface for writing logs.
@@ -22,8 +23,8 @@ type WriterCreator func() Writer
 type generalLogger struct {
 	creator WriterCreator
 	buffer  chan Message
-	access  *signal.Semaphore
-	done    *signal.Done
+	access  *semaphore.Instance
+	done    *done.Instance
 }
 
 // NewLogger returns a generic log handler that can handle all type of messages.
@@ -31,8 +32,8 @@ func NewLogger(logWriterCreator WriterCreator) Handler {
 	return &generalLogger{
 		creator: logWriterCreator,
 		buffer:  make(chan Message, 16),
-		access:  signal.NewSemaphore(1),
-		done:    signal.NewDone(),
+		access:  semaphore.New(1),
+		done:    done.New(),
 	}
 }
 

+ 3 - 3
common/net/connection.go

@@ -7,7 +7,7 @@ import (
 
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
-	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/signal/done"
 )
 
 type ConnectionOption func(*connection)
@@ -56,7 +56,7 @@ func ConnectionOnClose(n io.Closer) ConnectionOption {
 
 func NewConnection(opts ...ConnectionOption) net.Conn {
 	c := &connection{
-		done: signal.NewDone(),
+		done: done.New(),
 		local: &net.TCPAddr{
 			IP:   []byte{0, 0, 0, 0},
 			Port: 0,
@@ -77,7 +77,7 @@ func NewConnection(opts ...ConnectionOption) net.Conn {
 type connection struct {
 	reader  *buf.BufferedReader
 	writer  buf.Writer
-	done    *signal.Done
+	done    *done.Instance
 	onClose io.Closer
 	local   Addr
 	remote  Addr

+ 10 - 9
common/protocol/address.go

@@ -3,10 +3,11 @@ package protocol
 import (
 	"io"
 
+	"v2ray.com/core/common/task"
+
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
-	"v2ray.com/core/common/signal"
 )
 
 type AddressOption func(*AddressParser)
@@ -153,9 +154,9 @@ func (p *AddressParser) ReadAddressPort(buffer *buf.Buffer, input io.Reader) (ne
 	var err error
 
 	if p.portFirst {
-		err = signal.Execute(pTask, aTask)
+		err = task.Run(task.Sequential(pTask, aTask))()
 	} else {
-		err = signal.Execute(aTask, pTask)
+		err = task.Run(task.Sequential(aTask, pTask))()
 	}
 
 	if err != nil {
@@ -177,21 +178,21 @@ func (p *AddressParser) writeAddress(writer io.Writer, address net.Address) erro
 
 	switch address.Family() {
 	case net.AddressFamilyIPv4, net.AddressFamilyIPv6:
-		return signal.Execute(func() error {
+		return task.Run(task.Sequential(func() error {
 			return common.Error2(writer.Write([]byte{tb}))
 		}, func() error {
 			return common.Error2(writer.Write(address.IP()))
-		})
+		}))()
 	case net.AddressFamilyDomain:
 		domain := address.Domain()
 		if isDomainTooLong(domain) {
 			return newError("Super long domain is not supported: ", domain)
 		}
-		return signal.Execute(func() error {
+		return task.Run(task.Sequential(func() error {
 			return common.Error2(writer.Write([]byte{tb, byte(len(domain))}))
 		}, func() error {
 			return common.Error2(writer.Write([]byte(domain)))
-		})
+		}))()
 	default:
 		panic("Unknown family type.")
 	}
@@ -207,8 +208,8 @@ func (p *AddressParser) WriteAddressPort(writer io.Writer, addr net.Address, por
 	}
 
 	if p.portFirst {
-		return signal.Execute(pTask, aTask)
+		return task.Run(task.Sequential(pTask, aTask))()
 	}
 
-	return signal.Execute(aTask, pTask)
+	return task.Run(task.Sequential(aTask, pTask))()
 }

+ 9 - 9
common/signal/done.go → common/signal/done/done.go

@@ -1,25 +1,25 @@
-package signal
+package done
 
 import (
 	"sync"
 )
 
-// Done is a utility for notifications of something being done.
-type Done struct {
+// Instance is a utility for notifications of something being done.
+type Instance struct {
 	access sync.Mutex
 	c      chan struct{}
 	closed bool
 }
 
-// NewDone returns a new Done.
-func NewDone() *Done {
-	return &Done{
+// New returns a new Done.
+func New() *Instance {
+	return &Instance{
 		c: make(chan struct{}),
 	}
 }
 
 // Done returns true if Close() is called.
-func (d *Done) Done() bool {
+func (d *Instance) Done() bool {
 	select {
 	case <-d.Wait():
 		return true
@@ -29,12 +29,12 @@ func (d *Done) Done() bool {
 }
 
 // Wait returns a channel for waiting for done.
-func (d *Done) Wait() <-chan struct{} {
+func (d *Instance) Wait() <-chan struct{} {
 	return d.c
 }
 
 // Close marks this Done 'done'. This method may be called multiple times. All calls after first call will have no effect on its status.
-func (d *Done) Close() error {
+func (d *Instance) Close() error {
 	d.access.Lock()
 	defer d.access.Unlock()
 

+ 0 - 47
common/signal/exec.go

@@ -1,47 +0,0 @@
-package signal
-
-import (
-	"context"
-)
-
-// Execute runs a list of tasks sequentially, returns the first error encountered or nil if all tasks pass.
-func Execute(tasks ...func() error) error {
-	for _, task := range tasks {
-		if err := task(); err != nil {
-			return err
-		}
-	}
-	return nil
-}
-
-// ExecuteParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
-func ExecuteParallel(ctx context.Context, tasks ...func() error) error {
-	n := len(tasks)
-	s := NewSemaphore(n)
-	done := make(chan error, 1)
-
-	for _, task := range tasks {
-		<-s.Wait()
-		go func(f func() error) {
-			if err := f(); err != nil {
-				select {
-				case done <- err:
-				default:
-				}
-			}
-			s.Signal()
-		}(task)
-	}
-
-	for i := 0; i < n; i++ {
-		select {
-		case <-ctx.Done():
-			return ctx.Err()
-		case err := <-done:
-			return err
-		case <-s.Wait():
-		}
-	}
-
-	return nil
-}

+ 0 - 43
common/signal/exec_test.go

@@ -1,43 +0,0 @@
-package signal_test
-
-import (
-	"context"
-	"errors"
-	"testing"
-	"time"
-
-	. "v2ray.com/core/common/signal"
-	. "v2ray.com/ext/assert"
-)
-
-func TestExecuteParallel(t *testing.T) {
-	assert := With(t)
-
-	err := ExecuteParallel(context.Background(), func() error {
-		time.Sleep(time.Millisecond * 200)
-		return errors.New("test")
-	}, func() error {
-		time.Sleep(time.Millisecond * 500)
-		return errors.New("test2")
-	})
-
-	assert(err.Error(), Equals, "test")
-}
-
-func TestExecuteParallelContextCancel(t *testing.T) {
-	assert := With(t)
-
-	ctx, cancel := context.WithCancel(context.Background())
-	err := ExecuteParallel(ctx, func() error {
-		time.Sleep(time.Millisecond * 2000)
-		return errors.New("test")
-	}, func() error {
-		time.Sleep(time.Millisecond * 5000)
-		return errors.New("test2")
-	}, func() error {
-		cancel()
-		return nil
-	})
-
-	assert(err.Error(), HasSubstring, "canceled")
-}

+ 0 - 27
common/signal/semaphore.go

@@ -1,27 +0,0 @@
-package signal
-
-// Semaphore is an implementation of semaphore.
-type Semaphore struct {
-	token chan struct{}
-}
-
-// NewSemaphore create a new Semaphore with n permits.
-func NewSemaphore(n int) *Semaphore {
-	s := &Semaphore{
-		token: make(chan struct{}, n),
-	}
-	for i := 0; i < n; i++ {
-		s.token <- struct{}{}
-	}
-	return s
-}
-
-// Wait returns a channel for acquiring a permit.
-func (s *Semaphore) Wait() <-chan struct{} {
-	return s.token
-}
-
-// Signal releases a permit into the Semaphore.
-func (s *Semaphore) Signal() {
-	s.token <- struct{}{}
-}

+ 27 - 0
common/signal/semaphore/semaphore.go

@@ -0,0 +1,27 @@
+package semaphore
+
+// Instance is an implementation of semaphore.
+type Instance struct {
+	token chan struct{}
+}
+
+// New create a new Semaphore with n permits.
+func New(n int) *Instance {
+	s := &Instance{
+		token: make(chan struct{}, n),
+	}
+	for i := 0; i < n; i++ {
+		s.token <- struct{}{}
+	}
+	return s
+}
+
+// Wait returns a channel for acquiring a permit.
+func (s *Instance) Wait() <-chan struct{} {
+	return s.token
+}
+
+// Signal releases a permit into the semaphore.
+func (s *Instance) Signal() {
+	s.token <- struct{}{}
+}

+ 40 - 43
common/signal/timer.go

@@ -2,7 +2,11 @@ package signal
 
 import (
 	"context"
+	"sync"
 	"time"
+
+	"v2ray.com/core/common"
+	"v2ray.com/core/common/task"
 )
 
 type ActivityUpdater interface {
@@ -10,9 +14,10 @@ type ActivityUpdater interface {
 }
 
 type ActivityTimer struct {
-	updated chan struct{}
-	timeout chan time.Duration
-	closing chan struct{}
+	sync.RWMutex
+	updated   chan struct{}
+	checkTask *task.Periodic
+	onTimeout func()
 }
 
 func (t *ActivityTimer) Update() {
@@ -22,60 +27,52 @@ func (t *ActivityTimer) Update() {
 	}
 }
 
-func (t *ActivityTimer) SetTimeout(timeout time.Duration) {
+func (t *ActivityTimer) check() error {
 	select {
-	case <-t.closing:
-	case t.timeout <- timeout:
+	case <-t.updated:
+	default:
+		t.finish()
 	}
+	return nil
 }
 
-func (t *ActivityTimer) run(ctx context.Context, cancel context.CancelFunc) {
-	defer func() {
-		cancel()
-		close(t.closing)
-	}()
+func (t *ActivityTimer) finish() {
+	t.Lock()
+	defer t.Unlock()
 
-	timeout := <-t.timeout
-	if timeout == 0 {
-		return
+	if t.onTimeout != nil {
+		t.onTimeout()
 	}
+	if t.checkTask != nil {
+		t.checkTask.Close()
+		t.checkTask = nil
+	}
+}
 
-	ticker := time.NewTicker(timeout)
-	defer func() {
-		ticker.Stop()
-	}()
-
-	for {
-		select {
-		case <-ticker.C:
-		case <-ctx.Done():
-			return
-		case timeout := <-t.timeout:
-			if timeout == 0 {
-				return
-			}
+func (t *ActivityTimer) SetTimeout(timeout time.Duration) {
+	if timeout == 0 {
+		t.finish()
+	}
 
-			ticker.Stop()
-			ticker = time.NewTicker(timeout)
-			continue
-		}
+	t.Lock()
 
-		select {
-		case <-t.updated:
-		// Updated keep waiting.
-		default:
-			return
-		}
+	if t.checkTask != nil {
+		t.checkTask.Close() // nolint: errcheck
+	}
+	t.checkTask = &task.Periodic{
+		Interval: timeout,
+		Execute:  t.check,
 	}
+	t.Unlock()
+	t.Update()
+	common.Must(t.checkTask.Start())
 }
 
 func CancelAfterInactivity(ctx context.Context, cancel context.CancelFunc, timeout time.Duration) *ActivityTimer {
 	timer := &ActivityTimer{
-		timeout: make(chan time.Duration, 1),
-		updated: make(chan struct{}, 1),
-		closing: make(chan struct{}),
+		updated:   make(chan struct{}, 1),
+		onTimeout: cancel,
 	}
-	timer.timeout <- timeout
-	go timer.run(ctx, cancel)
+	timer.SetTimeout(timeout)
 	return timer
 }

+ 18 - 8
common/task/periodic.go

@@ -14,16 +14,26 @@ type Periodic struct {
 	// OnFailure will be called when Execute returns non-nil error
 	OnError func(error)
 
-	access sync.Mutex
+	access sync.RWMutex
 	timer  *time.Timer
 	closed bool
 }
 
-func (t *Periodic) checkedExecute() error {
+func (t *Periodic) setClosed(f bool) {
 	t.access.Lock()
-	defer t.access.Unlock()
+	t.closed = f
+	t.access.Unlock()
+}
+
+func (t *Periodic) hasClosed() bool {
+	t.access.RLock()
+	defer t.access.RUnlock()
+
+	return t.closed
+}
 
-	if t.closed {
+func (t *Periodic) checkedExecute() error {
+	if t.hasClosed() {
 		return nil
 	}
 
@@ -31,23 +41,23 @@ func (t *Periodic) checkedExecute() error {
 		return err
 	}
 
+	t.access.Lock()
 	t.timer = time.AfterFunc(t.Interval, func() {
 		if err := t.checkedExecute(); err != nil && t.OnError != nil {
 			t.OnError(err)
 		}
 	})
+	t.access.Unlock()
 
 	return nil
 }
 
 // Start implements common.Runnable. Start must not be called multiple times without Close being called.
 func (t *Periodic) Start() error {
-	t.access.Lock()
-	t.closed = false
-	t.access.Unlock()
+	t.setClosed(false)
 
 	if err := t.checkedExecute(); err != nil {
-		t.closed = true
+		t.setClosed(true)
 		return err
 	}
 

+ 2 - 2
common/task/task.go

@@ -3,7 +3,7 @@ package task
 import (
 	"context"
 
-	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/signal/semaphore"
 )
 
 type Task func() error
@@ -109,7 +109,7 @@ func execute(tasks ...Task) error {
 // executeParallel executes a list of tasks asynchronously, returns the first error encountered or nil if all tasks pass.
 func executeParallel(tasks ...Task) error {
 	n := len(tasks)
-	s := signal.NewSemaphore(n)
+	s := semaphore.New(n)
 	done := make(chan error, 1)
 
 	for _, task := range tasks {

+ 1 - 1
proxy/http/server.go

@@ -308,7 +308,7 @@ func (s *Server) handlePlainHTTP(ctx context.Context, request *http.Request, wri
 		return nil
 	}
 
-	if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
+	if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDone))(); err != nil {
 		pipe.CloseError(link.Reader)
 		pipe.CloseError(link.Writer)
 		return newError("connection ends").Base(err)

+ 2 - 1
proxy/shadowsocks/client.go

@@ -122,7 +122,8 @@ func (c *Client) Process(ctx context.Context, link *core.Link, dialer proxy.Dial
 			return buf.Copy(responseReader, link.Writer, buf.UpdateActivity(timer))
 		}
 
-		if err := signal.ExecuteParallel(ctx, requestDone, responseDone); err != nil {
+		var responseDoneAndCloseWriter = task.Single(responseDone, task.OnSuccess(task.Close(link.Writer)))
+		if err := task.Run(task.WithContext(ctx), task.Parallel(requestDone, responseDoneAndCloseWriter))(); err != nil {
 			return newError("connection ends").Base(err)
 		}
 

+ 3 - 3
transport/internet/http/hub.go

@@ -9,7 +9,7 @@ import (
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/serial"
-	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/signal/done"
 	"v2ray.com/core/transport/internet"
 	"v2ray.com/core/transport/internet/tls"
 )
@@ -31,7 +31,7 @@ func (l *Listener) Close() error {
 
 type flushWriter struct {
 	w io.Writer
-	d *signal.Done
+	d *done.Instance
 }
 
 func (fw flushWriter) Write(p []byte) (n int, err error) {
@@ -75,7 +75,7 @@ func (l *Listener) ServeHTTP(writer http.ResponseWriter, request *http.Request)
 		}
 	}
 
-	done := signal.NewDone()
+	done := done.New()
 	conn := net.NewConnection(
 		net.ConnectionOutput(request.Body),
 		net.ConnectionInput(flushWriter{w: writer, d: done}),

+ 3 - 2
transport/internet/kcp/connection.go

@@ -10,6 +10,7 @@ import (
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/predicate"
 	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/signal/semaphore"
 )
 
 var (
@@ -121,7 +122,7 @@ type Updater struct {
 	shouldContinue  predicate.Predicate
 	shouldTerminate predicate.Predicate
 	updateFunc      func()
-	notifier        *signal.Semaphore
+	notifier        *semaphore.Instance
 }
 
 func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTerminate predicate.Predicate, updateFunc func()) *Updater {
@@ -130,7 +131,7 @@ func NewUpdater(interval uint32, shouldContinue predicate.Predicate, shouldTermi
 		shouldContinue:  shouldContinue,
 		shouldTerminate: shouldTerminate,
 		updateFunc:      updateFunc,
-		notifier:        signal.NewSemaphore(1),
+		notifier:        semaphore.New(1),
 	}
 	return u
 }

+ 4 - 4
transport/pipe/pipe_test.go

@@ -1,13 +1,13 @@
 package pipe_test
 
 import (
-	"context"
 	"io"
 	"testing"
 	"time"
 
+	"v2ray.com/core/common/task"
+
 	"v2ray.com/core/common/buf"
-	"v2ray.com/core/common/signal"
 	. "v2ray.com/core/transport/pipe"
 	. "v2ray.com/ext/assert"
 )
@@ -68,7 +68,7 @@ func TestPipeLimitZero(t *testing.T) {
 	bb.Write([]byte{'a', 'b'})
 	assert(pWriter.WriteMultiBuffer(buf.NewMultiBufferValue(bb)), IsNil)
 
-	err := signal.ExecuteParallel(context.Background(), func() error {
+	err := task.Run(task.Parallel(func() error {
 		b := buf.New()
 		b.Write([]byte{'c', 'd'})
 		return pWriter.WriteMultiBuffer(buf.NewMultiBufferValue(b))
@@ -87,7 +87,7 @@ func TestPipeLimitZero(t *testing.T) {
 		}
 		assert(rb.String(), Equals, "cd")
 		return nil
-	})
+	}))()
 
 	assert(err, IsNil)
 }