浏览代码

Initial version of TimedStringMap

V2Ray 10 年之前
父节点
当前提交
8ce7ee1cda
共有 2 个文件被更改,包括 118 次插入59 次删除
  1. 106 0
      common/collect/timed_map.go
  2. 12 59
      proxy/vmess/protocol/user/userset.go

+ 106 - 0
common/collect/timed_map.go

@@ -0,0 +1,106 @@
+package collect
+
+import (
+	"container/heap"
+	"sync"
+	"time"
+)
+
+type timedQueueEntry struct {
+	timeSec int64
+	value   interface{}
+}
+
+type timedQueue []*timedQueueEntry
+
+func (queue timedQueue) Len() int {
+	return len(queue)
+}
+
+func (queue timedQueue) Less(i, j int) bool {
+	return queue[i].timeSec < queue[j].timeSec
+}
+
+func (queue timedQueue) Swap(i, j int) {
+	tmp := queue[i]
+	queue[i] = queue[j]
+	queue[j] = tmp
+}
+
+func (queue *timedQueue) Push(value interface{}) {
+	entry := value.(*timedQueueEntry)
+	*queue = append(*queue, entry)
+}
+
+func (queue *timedQueue) Pop() interface{} {
+	old := *queue
+	n := len(old)
+	v := old[n-1]
+	*queue = old[:n-1]
+	return v
+}
+
+type TimedStringMap struct {
+	timedQueue
+	access   sync.RWMutex
+	data     map[string]interface{}
+	interval int
+}
+
+func NewTimedStringMap(updateInterval int) *TimedStringMap {
+	m := &TimedStringMap{
+		timedQueue: make([]*timedQueueEntry, 0, 1024),
+		access:     sync.RWMutex{},
+		data:       make(map[string]interface{}, 1024),
+		interval:   updateInterval,
+	}
+	m.initialize()
+	return m
+}
+
+func (m *TimedStringMap) initialize() {
+	go m.cleanup(time.Tick(time.Duration(m.interval) * time.Second))
+}
+
+func (m *TimedStringMap) cleanup(tick <-chan time.Time) {
+	for {
+		now := <-tick
+		nowSec := now.UTC().Unix()
+		if m.timedQueue.Len() == 0 {
+			continue
+		}
+		for m.timedQueue.Len() > 0 {
+			entry := m.timedQueue[0]
+			if entry.timeSec > nowSec {
+				break
+			}
+			m.access.Lock()
+			entry = heap.Pop(&m.timedQueue).(*timedQueueEntry)
+			m.access.Unlock()
+			m.Remove(entry.value.(string))
+		}
+	}
+}
+
+func (m *TimedStringMap) Get(key string) (interface{}, bool) {
+	m.access.RLock()
+	value, ok := m.data[key]
+	m.access.RUnlock()
+	return value, ok
+}
+
+func (m *TimedStringMap) Set(key string, value interface{}, time2Delete int64) {
+	m.access.Lock()
+	m.data[key] = value
+	heap.Push(&m.timedQueue, &timedQueueEntry{
+		timeSec: time2Delete,
+		value:   key,
+	})
+	m.access.Unlock()
+}
+
+func (m *TimedStringMap) Remove(key string) {
+	m.access.Lock()
+	delete(m.data, key)
+	m.access.Unlock()
+}

+ 12 - 59
proxy/vmess/protocol/user/userset.go

@@ -1,9 +1,9 @@
 package user
 
 import (
-	"container/heap"
 	"time"
 
+	"github.com/v2ray/v2ray-core/common/collect"
 	"github.com/v2ray/v2ray-core/common/log"
 )
 
@@ -19,8 +19,7 @@ type UserSet interface {
 
 type TimedUserSet struct {
 	validUserIds []ID
-	userHashes   map[string]indexTimePair
-	hash2Remove  hashEntrySet
+	userHash     *collect.TimedStringMap
 }
 
 type indexTimePair struct {
@@ -28,58 +27,21 @@ type indexTimePair struct {
 	timeSec int64
 }
 
-type hashEntry struct {
-	hash    string
-	timeSec int64
-}
-
-type hashEntrySet []*hashEntry
-
-func (set hashEntrySet) Len() int {
-	return len(set)
-}
-
-func (set hashEntrySet) Less(i, j int) bool {
-	return set[i].timeSec < set[j].timeSec
-}
-
-func (set hashEntrySet) Swap(i, j int) {
-	tmp := set[i]
-	set[i] = set[j]
-	set[j] = tmp
-}
-
-func (set *hashEntrySet) Push(value interface{}) {
-	entry := value.(*hashEntry)
-	*set = append(*set, entry)
-}
-
-func (set *hashEntrySet) Pop() interface{} {
-	old := *set
-	n := len(old)
-	v := old[n-1]
-	*set = old[:n-1]
-	return v
-}
-
 func NewTimedUserSet() UserSet {
-	vuSet := new(TimedUserSet)
-	vuSet.validUserIds = make([]ID, 0, 16)
-	vuSet.userHashes = make(map[string]indexTimePair)
-	vuSet.hash2Remove = make(hashEntrySet, 0, cacheDurationSec*10)
-
-	go vuSet.updateUserHash(time.Tick(updateIntervalSec * time.Second))
-	return vuSet
+	tus := &TimedUserSet{
+		validUserIds: make([]ID, 0, 16),
+		userHash:     collect.NewTimedStringMap(updateIntervalSec),
+	}
+	go tus.updateUserHash(time.Tick(updateIntervalSec * time.Second))
+	return tus
 }
 
 func (us *TimedUserSet) generateNewHashes(lastSec, nowSec int64, idx int, id ID) {
 	idHash := NewTimeHash(HMACHash{})
 	for lastSec < nowSec+cacheDurationSec {
-
 		idHash := idHash.Hash(id.Bytes, lastSec)
 		log.Debug("Valid User Hash: %v", idHash)
-		heap.Push(&us.hash2Remove, &hashEntry{string(idHash), lastSec})
-		us.userHashes[string(idHash)] = indexTimePair{idx, lastSec}
+		us.userHash.Set(string(idHash), indexTimePair{idx, lastSec}, lastSec+2*cacheDurationSec)
 		lastSec++
 	}
 }
@@ -87,24 +49,14 @@ func (us *TimedUserSet) generateNewHashes(lastSec, nowSec int64, idx int, id ID)
 func (us *TimedUserSet) updateUserHash(tick <-chan time.Time) {
 	now := time.Now().UTC()
 	lastSec := now.Unix()
-	lastSec2Remove := now.Unix()
 
 	for {
 		now := <-tick
 		nowSec := now.UTC().Unix()
-
-		remove2Sec := nowSec - cacheDurationSec
-		if remove2Sec > lastSec2Remove {
-			for lastSec2Remove+1 < remove2Sec {
-				front := heap.Pop(&us.hash2Remove)
-				entry := front.(*hashEntry)
-				lastSec2Remove = entry.timeSec
-				delete(us.userHashes, entry.hash)
-			}
-		}
 		for idx, id := range us.validUserIds {
 			us.generateNewHashes(lastSec, nowSec, idx, id)
 		}
+		lastSec = nowSec
 	}
 }
 
@@ -121,8 +73,9 @@ func (us *TimedUserSet) AddUser(user User) error {
 }
 
 func (us TimedUserSet) GetUser(userHash []byte) (*ID, int64, bool) {
-	pair, found := us.userHashes[string(userHash)]
+	rawPair, found := us.userHash.Get(string(userHash))
 	if found {
+		pair := rawPair.(indexTimePair)
 		return &us.validUserIds[pair.index], pair.timeSec, true
 	}
 	return nil, 0, false