|  | @@ -1,6 +1,7 @@
 | 
	
		
			
				|  |  |  package encoding
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import (
 | 
	
		
			
				|  |  | +	"context"
 | 
	
		
			
				|  |  |  	"crypto/aes"
 | 
	
		
			
				|  |  |  	"crypto/cipher"
 | 
	
		
			
				|  |  |  	"crypto/md5"
 | 
	
	
		
			
				|  | @@ -25,26 +26,26 @@ type sessionId struct {
 | 
	
		
			
				|  |  |  	nonce [16]byte
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -type sessionHistory struct {
 | 
	
		
			
				|  |  | +type SessionHistory struct {
 | 
	
		
			
				|  |  |  	sync.RWMutex
 | 
	
		
			
				|  |  |  	cache map[sessionId]time.Time
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -func newSessionHistory() *sessionHistory {
 | 
	
		
			
				|  |  | -	h := &sessionHistory{
 | 
	
		
			
				|  |  | +func NewSessionHistory(ctx context.Context) *SessionHistory {
 | 
	
		
			
				|  |  | +	h := &SessionHistory{
 | 
	
		
			
				|  |  |  		cache: make(map[sessionId]time.Time, 128),
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	go h.run()
 | 
	
		
			
				|  |  | +	go h.run(ctx)
 | 
	
		
			
				|  |  |  	return h
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -func (h *sessionHistory) Add(session sessionId) {
 | 
	
		
			
				|  |  | +func (h *SessionHistory) add(session sessionId) {
 | 
	
		
			
				|  |  |  	h.Lock()
 | 
	
		
			
				|  |  |  	h.cache[session] = time.Now().Add(time.Minute * 3)
 | 
	
		
			
				|  |  |  	h.Unlock()
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -func (h *sessionHistory) Has(session sessionId) bool {
 | 
	
		
			
				|  |  | +func (h *SessionHistory) has(session sessionId) bool {
 | 
	
		
			
				|  |  |  	h.RLock()
 | 
	
		
			
				|  |  |  	defer h.RUnlock()
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -54,9 +55,13 @@ func (h *sessionHistory) Has(session sessionId) bool {
 | 
	
		
			
				|  |  |  	return false
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -func (h *sessionHistory) run() {
 | 
	
		
			
				|  |  | +func (h *SessionHistory) run(ctx context.Context) {
 | 
	
		
			
				|  |  |  	for {
 | 
	
		
			
				|  |  | -		time.Sleep(time.Second * 30)
 | 
	
		
			
				|  |  | +		select {
 | 
	
		
			
				|  |  | +		case <-ctx.Done():
 | 
	
		
			
				|  |  | +			return
 | 
	
		
			
				|  |  | +		case <-time.After(time.Second * 30):
 | 
	
		
			
				|  |  | +		}
 | 
	
		
			
				|  |  |  		session2Remove := make([]sessionId, 0, 16)
 | 
	
		
			
				|  |  |  		now := time.Now()
 | 
	
		
			
				|  |  |  		h.Lock()
 | 
	
	
		
			
				|  | @@ -72,12 +77,9 @@ func (h *sessionHistory) run() {
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -var (
 | 
	
		
			
				|  |  | -	globalSessionHistory = newSessionHistory()
 | 
	
		
			
				|  |  | -)
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |  type ServerSession struct {
 | 
	
		
			
				|  |  |  	userValidator   protocol.UserValidator
 | 
	
		
			
				|  |  | +	sessionHistory  *SessionHistory
 | 
	
		
			
				|  |  |  	requestBodyKey  []byte
 | 
	
		
			
				|  |  |  	requestBodyIV   []byte
 | 
	
		
			
				|  |  |  	responseBodyKey []byte
 | 
	
	
		
			
				|  | @@ -88,9 +90,10 @@ type ServerSession struct {
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  // NewServerSession creates a new ServerSession, using the given UserValidator.
 | 
	
		
			
				|  |  |  // The ServerSession instance doesn't take ownership of the validator.
 | 
	
		
			
				|  |  | -func NewServerSession(validator protocol.UserValidator) *ServerSession {
 | 
	
		
			
				|  |  | +func NewServerSession(validator protocol.UserValidator, sessionHistory *SessionHistory) *ServerSession {
 | 
	
		
			
				|  |  |  	return &ServerSession{
 | 
	
		
			
				|  |  | -		userValidator: validator,
 | 
	
		
			
				|  |  | +		userValidator:  validator,
 | 
	
		
			
				|  |  | +		sessionHistory: sessionHistory,
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
	
		
			
				|  | @@ -140,10 +143,10 @@ func (v *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 | 
	
		
			
				|  |  |  	copy(sid.user[:], vmessAccount.ID.Bytes())
 | 
	
		
			
				|  |  |  	copy(sid.key[:], v.requestBodyKey)
 | 
	
		
			
				|  |  |  	copy(sid.nonce[:], v.requestBodyIV)
 | 
	
		
			
				|  |  | -	if globalSessionHistory.Has(sid) {
 | 
	
		
			
				|  |  | +	if v.sessionHistory.has(sid) {
 | 
	
		
			
				|  |  |  		return nil, errors.New("VMess|Server: Duplicated session id. Possibly under reply attack.")
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	globalSessionHistory.Add(sid)
 | 
	
		
			
				|  |  | +	v.sessionHistory.add(sid)
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	v.responseHeader = buffer[33]                       // 1 byte
 | 
	
		
			
				|  |  |  	request.Option = protocol.RequestOption(buffer[34]) // 1 byte
 |