Procházet zdrojové kódy

Refactor timed queue

V2Ray před 10 roky
rodič
revize
8f0cb97e89

+ 0 - 111
common/collect/timed_map.go

@@ -1,111 +0,0 @@
-package collect
-
-import (
-	"container/heap"
-	"sync"
-	"time"
-)
-
-type timedQueueEntry struct {
-	timeSec int64
-	value   interface{}
-}
-
-type timedQueue []*timedQueueEntry
-
-func (queue timedQueue) Len() int {
-	return len(queue)
-}
-
-func (queue timedQueue) Less(i, j int) bool {
-	return queue[i].timeSec < queue[j].timeSec
-}
-
-func (queue timedQueue) Swap(i, j int) {
-	tmp := queue[i]
-	queue[i] = queue[j]
-	queue[j] = tmp
-}
-
-func (queue *timedQueue) Push(value interface{}) {
-	entry := value.(*timedQueueEntry)
-	*queue = append(*queue, entry)
-}
-
-func (queue *timedQueue) Pop() interface{} {
-	old := *queue
-	n := len(old)
-	v := old[n-1]
-	*queue = old[:n-1]
-	return v
-}
-
-type TimedStringMap struct {
-	timedQueue
-	queueMutex sync.Mutex
-	dataMutext sync.RWMutex
-	data       map[string]interface{}
-	interval   int
-}
-
-func NewTimedStringMap(updateInterval int) *TimedStringMap {
-	m := &TimedStringMap{
-		timedQueue: make([]*timedQueueEntry, 0, 1024),
-		queueMutex: sync.Mutex{},
-		dataMutext: sync.RWMutex{},
-		data:       make(map[string]interface{}, 1024),
-		interval:   updateInterval,
-	}
-	m.initialize()
-	return m
-}
-
-func (m *TimedStringMap) initialize() {
-	go m.cleanup(time.Tick(time.Duration(m.interval) * time.Second))
-}
-
-func (m *TimedStringMap) cleanup(tick <-chan time.Time) {
-	for {
-		now := <-tick
-		nowSec := now.UTC().Unix()
-		if m.timedQueue.Len() == 0 {
-			continue
-		}
-		for m.timedQueue.Len() > 0 {
-			entry := m.timedQueue[0]
-			if entry.timeSec > nowSec {
-				break
-			}
-			m.queueMutex.Lock()
-			entry = heap.Pop(&m.timedQueue).(*timedQueueEntry)
-			m.queueMutex.Unlock()
-			m.Remove(entry.value.(string))
-		}
-	}
-}
-
-func (m *TimedStringMap) Get(key string) (interface{}, bool) {
-	m.dataMutext.RLock()
-	value, ok := m.data[key]
-	m.dataMutext.RUnlock()
-	return value, ok
-}
-
-func (m *TimedStringMap) Set(key string, value interface{}, time2Delete int64) {
-	m.dataMutext.Lock()
-	m.data[key] = value
-	m.dataMutext.Unlock()
-
-	m.queueMutex.Lock()
-	heap.Push(&m.timedQueue, &timedQueueEntry{
-		timeSec: time2Delete,
-		value:   key,
-	})
-	m.queueMutex.Unlock()
-}
-
-func (m *TimedStringMap) Remove(key string) {
-	m.dataMutext.Lock()
-	delete(m.data, key)
-	m.dataMutext.Unlock()
-}

+ 0 - 48
common/collect/timed_map_test.go

@@ -1,48 +0,0 @@
-package collect
-
-import (
-	"testing"
-	"time"
-
-	"github.com/v2ray/v2ray-core/testing/unit"
-)
-
-func TestTimedStringMap(t *testing.T) {
-	assert := unit.Assert(t)
-
-	nowSec := time.Now().UTC().Unix()
-	m := NewTimedStringMap(2)
-	m.Set("Key1", "Value1", nowSec)
-	m.Set("Key2", "Value2", nowSec+5)
-
-	v1, ok := m.Get("Key1")
-	assert.Bool(ok).IsTrue()
-	assert.String(v1.(string)).Equals("Value1")
-
-	v2, ok := m.Get("Key2")
-	assert.Bool(ok).IsTrue()
-	assert.String(v2.(string)).Equals("Value2")
-
-	tick := time.Tick(4 * time.Second)
-	<-tick
-
-	v1, ok = m.Get("Key1")
-	assert.Bool(ok).IsFalse()
-
-	v2, ok = m.Get("Key2")
-	assert.Bool(ok).IsTrue()
-	assert.String(v2.(string)).Equals("Value2")
-
-	<-tick
-	v2, ok = m.Get("Key2")
-	assert.Bool(ok).IsFalse()
-
-	<-tick
-	v2, ok = m.Get("Key2")
-	assert.Bool(ok).IsFalse()
-
-	m.Set("Key1", "Value1", time.Now().UTC().Unix()+10)
-	v1, ok = m.Get("Key1")
-	assert.Bool(ok).IsTrue()
-	assert.String(v1.(string)).Equals("Value1")
-}

+ 89 - 0
common/collect/timed_queue.go

@@ -0,0 +1,89 @@
+package collect
+
+import (
+	"container/heap"
+	"sync"
+	"time"
+)
+
+type timedQueueEntry struct {
+	timeSec int64
+	value   interface{}
+}
+
+type timedQueueImpl []*timedQueueEntry
+
+func (queue timedQueueImpl) Len() int {
+	return len(queue)
+}
+
+func (queue timedQueueImpl) Less(i, j int) bool {
+	return queue[i].timeSec < queue[j].timeSec
+}
+
+func (queue timedQueueImpl) Swap(i, j int) {
+	tmp := queue[i]
+	queue[i] = queue[j]
+	queue[j] = tmp
+}
+
+func (queue *timedQueueImpl) Push(value interface{}) {
+	entry := value.(*timedQueueEntry)
+	*queue = append(*queue, entry)
+}
+
+func (queue *timedQueueImpl) Pop() interface{} {
+	old := *queue
+	n := len(old)
+	v := old[n-1]
+	*queue = old[:n-1]
+	return v
+}
+
+type TimedQueue struct {
+	queue   timedQueueImpl
+	access  sync.Mutex
+	removed chan interface{}
+}
+
+func NewTimedQueue(updateInterval int) *TimedQueue {
+	queue := &TimedQueue{
+		queue:   make([]*timedQueueEntry, 0, 256),
+		removed: make(chan interface{}, 16),
+		access:  sync.Mutex{},
+	}
+	go queue.cleanup(time.Tick(time.Duration(updateInterval) * time.Second))
+	return queue
+}
+
+func (queue *TimedQueue) Add(value interface{}, time2Remove int64) {
+	queue.access.Lock()
+	heap.Push(&queue.queue, &timedQueueEntry{
+		timeSec: time2Remove,
+		value:   value,
+	})
+	queue.access.Unlock()
+}
+
+func (queue *TimedQueue) RemovedEntries() <-chan interface{} {
+	return queue.removed
+}
+
+func (queue *TimedQueue) cleanup(tick <-chan time.Time) {
+	for {
+		now := <-tick
+		if queue.queue.Len() == 0 {
+			continue
+		}
+		nowSec := now.UTC().Unix()
+		entry := queue.queue[0]
+		if entry.timeSec > nowSec {
+			continue
+		}
+		queue.access.Lock()
+		entry = heap.Pop(&queue.queue).(*timedQueueEntry)
+		queue.access.Unlock()
+
+		queue.removed <- entry.value
+	}
+}

+ 60 - 0
common/collect/timed_queue_test.go

@@ -0,0 +1,60 @@
+package collect
+
+import (
+	"testing"
+	"time"
+
+	"github.com/v2ray/v2ray-core/testing/unit"
+)
+
+func TestTimedQueue(t *testing.T) {
+	assert := unit.Assert(t)
+
+	removed := make(map[string]bool)
+
+	nowSec := time.Now().UTC().Unix()
+	q := NewTimedQueue(2)
+
+	go func() {
+		for {
+			entry := <-q.RemovedEntries()
+			removed[entry.(string)] = true
+		}
+	}()
+
+	q.Add("Value1", nowSec)
+	q.Add("Value2", nowSec+5)
+
+	v1, ok := removed["Value1"]
+	assert.Bool(ok).IsFalse()
+
+	v2, ok := removed["Value2"]
+	assert.Bool(ok).IsFalse()
+
+	tick := time.Tick(4 * time.Second)
+	<-tick
+
+	v1, ok = removed["Value1"]
+	assert.Bool(ok).IsTrue()
+	assert.Bool(v1).IsTrue()
+	removed["Value1"] = false
+
+	v2, ok = removed["Value2"]
+	assert.Bool(ok).IsFalse()
+
+	<-tick
+	v2, ok = removed["Value2"]
+	assert.Bool(ok).IsTrue()
+	assert.Bool(v2).IsTrue()
+	removed["Value2"] = false
+
+	<-tick
+	assert.Bool(removed["Values"]).IsFalse()
+
+	q.Add("Value1", time.Now().UTC().Unix()+10)
+
+	<-tick
+	v1, ok = removed["Value1"]
+	assert.Bool(ok).IsTrue()
+	assert.Bool(v1).IsFalse()
+}

+ 26 - 7
proxy/vmess/protocol/user/userset.go

@@ -1,6 +1,7 @@
 package user
 
 import (
+	"sync"
 	"time"
 
 	"github.com/v2ray/v2ray-core/common/collect"
@@ -18,8 +19,10 @@ type UserSet interface {
 }
 
 type TimedUserSet struct {
-	validUserIds []ID
-	userHash     *collect.TimedStringMap
+	validUserIds        []ID
+	userHash            map[string]indexTimePair
+	userHashDeleteQueue *collect.TimedQueue
+	access              sync.RWMutex
 }
 
 type indexTimePair struct {
@@ -29,19 +32,34 @@ type indexTimePair struct {
 
 func NewTimedUserSet() UserSet {
 	tus := &TimedUserSet{
-		validUserIds: make([]ID, 0, 16),
-		userHash:     collect.NewTimedStringMap(updateIntervalSec),
+		validUserIds:        make([]ID, 0, 16),
+		userHash:            make(map[string]indexTimePair, 512),
+		userHashDeleteQueue: collect.NewTimedQueue(updateIntervalSec),
+		access:              sync.RWMutex{},
 	}
 	go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second))
+	go tus.removeEntries(tus.userHashDeleteQueue.RemovedEntries())
 	return tus
 }
 
+func (us *TimedUserSet) removeEntries(entries <-chan interface{}) {
+	for {
+		entry := <-entries
+		us.access.Lock()
+		delete(us.userHash, entry.(string))
+		us.access.Unlock()
+	}
+}
+
 func (us *TimedUserSet) generateNewHashes(lastSec, nowSec int64, idx int, id ID) {
 	idHash := NewTimeHash(HMACHash{})
 	for lastSec < nowSec+cacheDurationSec {
 		idHash := idHash.Hash(id.Bytes[:], lastSec)
 		log.Debug("Valid User Hash: %v", idHash)
-		us.userHash.Set(string(idHash), indexTimePair{idx, lastSec}, lastSec+2*cacheDurationSec)
+		us.access.Lock()
+		us.userHash[string(idHash)] = indexTimePair{idx, lastSec}
+		us.access.Unlock()
+		us.userHashDeleteQueue.Add(string(idHash), lastSec+2*cacheDurationSec)
 		lastSec++
 	}
 }
@@ -73,9 +91,10 @@ func (us *TimedUserSet) AddUser(user User) error {
 }
 
 func (us TimedUserSet) GetUser(userHash []byte) (*ID, int64, bool) {
-	rawPair, found := us.userHash.Get(string(userHash))
+	defer us.access.RUnlock()
+	us.access.RLock()
+	pair, found := us.userHash[string(userHash)]
 	if found {
-		pair := rawPair.(indexTimePair)
 		return &us.validUserIds[pair.index], pair.timeSec, true
 	}
 	return nil, 0, false