瀏覽代碼

fix goroutine leak in dynamic port

Darien Raymond 8 年之前
父節點
當前提交
bfcc75f5d1

+ 28 - 7
app/proxyman/inbound/dynamic.go

@@ -12,6 +12,17 @@ import (
 	"v2ray.com/core/proxy"
 )
 
+type workerWithContext struct {
+	ctx    context.Context
+	cancel context.CancelFunc
+	worker worker
+}
+
+func (w *workerWithContext) Close() {
+	w.cancel()
+	w.worker.Close()
+}
+
 type DynamicInboundHandler struct {
 	sync.Mutex
 	tag            string
@@ -20,8 +31,8 @@ type DynamicInboundHandler struct {
 	proxyConfig    interface{}
 	receiverConfig *proxyman.ReceiverConfig
 	portsInUse     map[v2net.Port]bool
-	worker         []worker
-	worker2Recycle []worker
+	worker         []*workerWithContext
+	worker2Recycle []*workerWithContext
 	lastRefresh    time.Time
 }
 
@@ -62,7 +73,7 @@ func (h *DynamicInboundHandler) refresh() error {
 	ports2Del := make([]v2net.Port, 0, 16)
 	for _, worker := range h.worker2Recycle {
 		worker.Close()
-		ports2Del = append(ports2Del, worker.Port())
+		ports2Del = append(ports2Del, worker.worker.Port())
 	}
 
 	h.Lock()
@@ -78,8 +89,10 @@ func (h *DynamicInboundHandler) refresh() error {
 		address = v2net.AnyIP
 	}
 	for i := uint32(0); i < h.receiverConfig.AllocationStrategy.GetConcurrencyValue(); i++ {
+		ctx, cancel := context.WithCancel(h.ctx)
+
 		port := h.allocatePort()
-		p, err := proxy.CreateInboundHandler(h.ctx, h.proxyConfig)
+		p, err := proxy.CreateInboundHandler(ctx, h.proxyConfig)
 		if err != nil {
 			log.Warning("Proxyman|DefaultInboundHandler: Failed to create proxy instance: ", err)
 			continue
@@ -98,7 +111,11 @@ func (h *DynamicInboundHandler) refresh() error {
 			if err := worker.Start(); err != nil {
 				return err
 			}
-			h.worker = append(h.worker, worker)
+			h.worker = append(h.worker, &workerWithContext{
+				ctx:    ctx,
+				cancel: cancel,
+				worker: worker,
+			})
 		}
 
 		if nl.HasNetwork(v2net.Network_UDP) {
@@ -112,7 +129,11 @@ func (h *DynamicInboundHandler) refresh() error {
 			if err := worker.Start(); err != nil {
 				return err
 			}
-			h.worker = append(h.worker, worker)
+			h.worker = append(h.worker, &workerWithContext{
+				ctx:    ctx,
+				cancel: cancel,
+				worker: worker,
+			})
 		}
 	}
 
@@ -143,5 +164,5 @@ func (h *DynamicInboundHandler) Close() {
 func (h *DynamicInboundHandler) GetRandomInboundProxy() (proxy.Inbound, v2net.Port, int) {
 	w := h.worker[dice.Roll(len(h.worker))]
 	expire := h.receiverConfig.AllocationStrategy.GetRefreshValue() - uint32(time.Since(h.lastRefresh)/time.Minute)
-	return w.Proxy(), w.Port(), int(expire)
+	return w.worker.Proxy(), w.worker.Port(), int(expire)
 }

+ 0 - 1
common/protocol/user_validator.go

@@ -3,5 +3,4 @@ package protocol
 type UserValidator interface {
 	Add(user *User) error
 	Get(timeHash []byte) (*User, Timestamp, bool)
-	Release()
 }

+ 4 - 1
proxy/vmess/encoding/encoding_test.go

@@ -1,6 +1,7 @@
 package encoding_test
 
 import (
+	"context"
 	"testing"
 
 	"v2ray.com/core/common/buf"
@@ -40,7 +41,8 @@ func TestRequestSerialization(t *testing.T) {
 	client := NewClientSession(protocol.DefaultIDHash)
 	client.EncodeRequestHeader(expectedRequest, buffer)
 
-	userValidator := vmess.NewTimedUserValidator(protocol.DefaultIDHash)
+	ctx, cancel := context.WithCancel(context.Background())
+	userValidator := vmess.NewTimedUserValidator(ctx, protocol.DefaultIDHash)
 	userValidator.Add(user)
 
 	server := NewServerSession(userValidator)
@@ -53,4 +55,5 @@ func TestRequestSerialization(t *testing.T) {
 	assert.Address(expectedRequest.Address).Equals(actualRequest.Address)
 	assert.Port(expectedRequest.Port).Equals(actualRequest.Port)
 	assert.Byte(byte(expectedRequest.Security)).Equals(byte(actualRequest.Security))
+	cancel()
 }

+ 1 - 1
proxy/vmess/inbound/inbound.go

@@ -86,7 +86,7 @@ func New(ctx context.Context, config *Config) (*VMessInboundHandler, error) {
 		return nil, errors.New("VMess|Inbound: No space in context.")
 	}
 
-	allowedClients := vmess.NewTimedUserValidator(protocol.DefaultIDHash)
+	allowedClients := vmess.NewTimedUserValidator(ctx, protocol.DefaultIDHash)
 	for _, user := range config.User {
 		allowedClients.Add(user)
 	}

+ 5 - 31
proxy/vmess/vmess.go

@@ -6,11 +6,11 @@
 package vmess
 
 import (
+	"context"
 	"sync"
 	"time"
 
 	"v2ray.com/core/common/protocol"
-	"v2ray.com/core/common/signal"
 )
 
 const (
@@ -27,12 +27,11 @@ type idEntry struct {
 
 type TimedUserValidator struct {
 	sync.RWMutex
-	running    bool
+	ctx        context.Context
 	validUsers []*protocol.User
 	userHash   map[[16]byte]*indexTimePair
 	ids        []*idEntry
 	hasher     protocol.IDHash
-	cancel     *signal.CancelSignal
 }
 
 type indexTimePair struct {
@@ -40,37 +39,18 @@ type indexTimePair struct {
 	timeSec protocol.Timestamp
 }
 
-func NewTimedUserValidator(hasher protocol.IDHash) protocol.UserValidator {
+func NewTimedUserValidator(ctx context.Context, hasher protocol.IDHash) protocol.UserValidator {
 	tus := &TimedUserValidator{
+		ctx:        ctx,
 		validUsers: make([]*protocol.User, 0, 16),
 		userHash:   make(map[[16]byte]*indexTimePair, 512),
 		ids:        make([]*idEntry, 0, 512),
 		hasher:     hasher,
-		running:    true,
-		cancel:     signal.NewCloseSignal(),
 	}
 	go tus.updateUserHash(updateIntervalSec * time.Second)
 	return tus
 }
 
-func (v *TimedUserValidator) Release() {
-	if !v.running {
-		return
-	}
-
-	v.cancel.Cancel()
-	v.cancel.WaitForDone()
-
-	v.Lock()
-	defer v.Unlock()
-
-	if !v.running {
-		return
-	}
-
-	v.running = false
-}
-
 func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, idx int, entry *idEntry) {
 	var hashValue [16]byte
 	var hashValueRemoval [16]byte
@@ -93,9 +73,6 @@ func (v *TimedUserValidator) generateNewHashes(nowSec protocol.Timestamp, idx in
 }
 
 func (v *TimedUserValidator) updateUserHash(interval time.Duration) {
-	v.cancel.WaitThread()
-	defer v.cancel.FinishThread()
-
 	for {
 		select {
 		case now := <-time.After(interval):
@@ -105,7 +82,7 @@ func (v *TimedUserValidator) updateUserHash(interval time.Duration) {
 				v.generateNewHashes(nowSec, entry.userIdx, entry)
 			}
 			v.Unlock()
-		case <-v.cancel.WaitForCancel():
+		case <-v.ctx.Done():
 			return
 		}
 	}
@@ -151,9 +128,6 @@ func (v *TimedUserValidator) Get(userHash []byte) (*protocol.User, protocol.Time
 	defer v.RUnlock()
 	v.RLock()
 
-	if !v.running {
-		return nil, 0, false
-	}
 	var fixedSizeHash [16]byte
 	copy(fixedSizeHash[:], userHash)
 	pair, found := v.userHash[fixedSizeHash]