소스 검색

per inbound session history

Darien Raymond 8 년 전
부모
커밋
ec95caa946
3개의 변경된 파일28개의 추가작업 그리고 21개의 파일을 삭제
  1. 3 1
      proxy/vmess/encoding/encoding_test.go
  2. 19 16
      proxy/vmess/encoding/server.go
  3. 6 4
      proxy/vmess/inbound/inbound.go

+ 3 - 1
proxy/vmess/encoding/encoding_test.go

@@ -45,10 +45,12 @@ func TestRequestSerialization(t *testing.T) {
 	buffer2.Append(buffer.Bytes())
 
 	ctx, cancel := context.WithCancel(context.Background())
+	sessionHistory := NewSessionHistory(ctx)
+
 	userValidator := vmess.NewTimedUserValidator(ctx, protocol.DefaultIDHash)
 	userValidator.Add(user)
 
-	server := NewServerSession(userValidator)
+	server := NewServerSession(userValidator, sessionHistory)
 	actualRequest, err := server.DecodeRequestHeader(buffer)
 	assert.Error(err).IsNil()
 

+ 19 - 16
proxy/vmess/encoding/server.go

@@ -1,6 +1,7 @@
 package encoding
 
 import (
+	"context"
 	"crypto/aes"
 	"crypto/cipher"
 	"crypto/md5"
@@ -25,26 +26,26 @@ type sessionId struct {
 	nonce [16]byte
 }
 
-type sessionHistory struct {
+type SessionHistory struct {
 	sync.RWMutex
 	cache map[sessionId]time.Time
 }
 
-func newSessionHistory() *sessionHistory {
-	h := &sessionHistory{
+func NewSessionHistory(ctx context.Context) *SessionHistory {
+	h := &SessionHistory{
 		cache: make(map[sessionId]time.Time, 128),
 	}
-	go h.run()
+	go h.run(ctx)
 	return h
 }
 
-func (h *sessionHistory) Add(session sessionId) {
+func (h *SessionHistory) add(session sessionId) {
 	h.Lock()
 	h.cache[session] = time.Now().Add(time.Minute * 3)
 	h.Unlock()
 }
 
-func (h *sessionHistory) Has(session sessionId) bool {
+func (h *SessionHistory) has(session sessionId) bool {
 	h.RLock()
 	defer h.RUnlock()
 
@@ -54,9 +55,13 @@ func (h *sessionHistory) Has(session sessionId) bool {
 	return false
 }
 
-func (h *sessionHistory) run() {
+func (h *SessionHistory) run(ctx context.Context) {
 	for {
-		time.Sleep(time.Second * 30)
+		select {
+		case <-ctx.Done():
+			return
+		case <-time.After(time.Second * 30):
+		}
 		session2Remove := make([]sessionId, 0, 16)
 		now := time.Now()
 		h.Lock()
@@ -72,12 +77,9 @@ func (h *sessionHistory) run() {
 	}
 }
 
-var (
-	globalSessionHistory = newSessionHistory()
-)
-
 type ServerSession struct {
 	userValidator   protocol.UserValidator
+	sessionHistory  *SessionHistory
 	requestBodyKey  []byte
 	requestBodyIV   []byte
 	responseBodyKey []byte
@@ -88,9 +90,10 @@ type ServerSession struct {
 
 // NewServerSession creates a new ServerSession, using the given UserValidator.
 // The ServerSession instance doesn't take ownership of the validator.
-func NewServerSession(validator protocol.UserValidator) *ServerSession {
+func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionHistory) *ServerSession {
 	return &ServerSession{
-		userValidator: validator,
+		userValidator:  validator,
+		sessionHistory: sessionHistory,
 	}
 }
 
@@ -140,10 +143,10 @@ func (v *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 	copy(sid.user[:], vmessAccount.ID.Bytes())
 	copy(sid.key[:], v.requestBodyKey)
 	copy(sid.nonce[:], v.requestBodyIV)
-	if globalSessionHistory.Has(sid) {
+	if v.sessionHistory.has(sid) {
 		return nil, errors.New("VMess|Server: Duplicated session id. Possibly under reply attack.")
 	}
-	globalSessionHistory.Add(sid)
+	v.sessionHistory.add(sid)
 
 	v.responseHeader = buffer[33]                       // 1 byte
 	request.Option = protocol.RequestOption(buffer[34]) // 1 byte

+ 6 - 4
proxy/vmess/inbound/inbound.go

@@ -78,6 +78,7 @@ type VMessInboundHandler struct {
 	clients               protocol.UserValidator
 	usersByEmail          *userByEmail
 	detours               *DetourConfig
+	sessionHistory        *encoding.SessionHistory
 }
 
 func New(ctx context.Context, config *Config) (*VMessInboundHandler, error) {
@@ -92,9 +93,10 @@ func New(ctx context.Context, config *Config) (*VMessInboundHandler, error) {
 	}
 
 	handler := &VMessInboundHandler{
-		clients:      allowedClients,
-		detours:      config.Detour,
-		usersByEmail: NewUserByEmail(config.User, config.GetDefaultValue()),
+		clients:        allowedClients,
+		detours:        config.Detour,
+		usersByEmail:   NewUserByEmail(config.User, config.GetDefaultValue()),
+		sessionHistory: encoding.NewSessionHistory(ctx),
 	}
 
 	space.OnInitialize(func() error {
@@ -171,7 +173,7 @@ func (v *VMessInboundHandler) Process(ctx context.Context, network net.Network,
 	connection.SetReadDeadline(time.Now().Add(time.Second * 8))
 	reader := bufio.NewReader(connection)
 
-	session := encoding.NewServerSession(v.clients)
+	session := encoding.NewServerSession(v.clients, v.sessionHistory)
 	request, err := session.DecodeRequestHeader(reader)
 
 	if err != nil {