Sfoglia il codice sorgente

cleanup session history

Darien Raymond 7 anni fa
parent
commit
a1ae4aa515

+ 1 - 1
proxy/vmess/encoding/encoding_test.go

@@ -46,7 +46,7 @@ func TestRequestSerialization(t *testing.T) {
 	buffer2.Append(buffer.Bytes())
 
 	ctx, cancel := context.WithCancel(context.Background())
-	sessionHistory := NewSessionHistory(ctx)
+	sessionHistory := NewSessionHistory()
 
 	userValidator := vmess.NewTimedUserValidator(ctx, protocol.DefaultIDHash)
 	userValidator.Add(user)

+ 17 - 29
proxy/vmess/encoding/server.go

@@ -1,7 +1,6 @@
 package encoding
 
 import (
-	"context"
 	"crypto/aes"
 	"crypto/cipher"
 	"crypto/md5"
@@ -32,26 +31,25 @@ type SessionHistory struct {
 	sync.RWMutex
 	cache map[sessionId]time.Time
 	token *signal.Semaphore
-	ctx   context.Context
+	timer *time.Timer
 }
 
-func NewSessionHistory(ctx context.Context) *SessionHistory {
+func NewSessionHistory() *SessionHistory {
 	h := &SessionHistory{
 		cache: make(map[sessionId]time.Time, 128),
 		token: signal.NewSemaphore(1),
-		ctx:   ctx,
 	}
 	return h
 }
 
 func (h *SessionHistory) add(session sessionId) {
 	h.Lock()
-	h.cache[session] = time.Now().Add(time.Minute * 3)
-	h.Unlock()
+	defer h.Unlock()
 
+	h.cache[session] = time.Now().Add(time.Minute * 3)
 	select {
 	case <-h.token.Wait():
-		go h.run()
+		h.timer = time.AfterFunc(time.Minute*3, h.removeExpiredEntries)
 	default:
 	}
 }
@@ -66,31 +64,21 @@ func (h *SessionHistory) has(session sessionId) bool {
 	return false
 }
 
-func (h *SessionHistory) run() {
-	defer h.token.Signal()
+func (h *SessionHistory) removeExpiredEntries() {
+	now := time.Now()
 
-	for {
-		select {
-		case <-h.ctx.Done():
-			return
-		case <-time.After(time.Second * 30):
-		}
-		session2Remove := make([]sessionId, 0, 16)
-		now := time.Now()
-		h.Lock()
-		if len(h.cache) == 0 {
-			h.Unlock()
-			return
-		}
-		for session, expire := range h.cache {
-			if expire.Before(now) {
-				session2Remove = append(session2Remove, session)
-			}
-		}
-		for _, session := range session2Remove {
+	h.Lock()
+	defer h.Unlock()
+
+	for session, expire := range h.cache {
+		if expire.Before(now) {
 			delete(h.cache, session)
 		}
-		h.Unlock()
+	}
+
+	if h.timer != nil {
+		h.timer.Stop()
+		h.timer = nil
 	}
 }
 

+ 1 - 1
proxy/vmess/inbound/inbound.go

@@ -93,7 +93,7 @@ func New(ctx context.Context, config *Config) (*Handler, error) {
 		clients:               vmess.NewTimedUserValidator(ctx, protocol.DefaultIDHash),
 		detours:               config.Detour,
 		usersByEmail:          newUserByEmail(config.User, config.GetDefaultValue()),
-		sessionHistory:        encoding.NewSessionHistory(ctx),
+		sessionHistory:        encoding.NewSessionHistory(),
 	}
 
 	for _, user := range config.User {