Quellcode durchsuchen

fix data race when caching session id

Darien Raymond vor 7 Jahren
Ursprung
Commit
6a3abf3147
1 geänderte Dateien mit 7 neuen und 13 gelöschten Zeilen
  1. 7 13
      proxy/vmess/encoding/server.go

+ 7 - 13
proxy/vmess/encoding/server.go

@@ -55,21 +55,16 @@ func (h *SessionHistory) Close() error {
 	return h.task.Close()
 }
 
-func (h *SessionHistory) add(session sessionId) {
+func (h *SessionHistory) addIfNotExits(session sessionId) bool {
 	h.Lock()
 	defer h.Unlock()
 
-	h.cache[session] = time.Now().Add(time.Minute * 3)
-}
-
-func (h *SessionHistory) has(session sessionId) bool {
-	h.RLock()
-	defer h.RUnlock()
-
-	if expire, found := h.cache[session]; found {
-		return expire.After(time.Now())
+	if expire, found := h.cache[session]; found && expire.After(time.Now()) {
+		return false
 	}
-	return false
+
+	h.cache[session] = time.Now().Add(time.Minute * 3)
+	return true
 }
 
 func (h *SessionHistory) removeExpiredEntries() {
@@ -152,10 +147,9 @@ func (s *ServerSession) DecodeRequestHeader(reader io.Reader) (*protocol.Request
 	copy(sid.user[:], vmessAccount.ID.Bytes())
 	copy(sid.key[:], s.requestBodyKey)
 	copy(sid.nonce[:], s.requestBodyIV)
-	if s.sessionHistory.has(sid) {
+	if !s.sessionHistory.addIfNotExits(sid) {
 		return nil, newError("duplicated session id, possibly under replay attack")
 	}
-	s.sessionHistory.add(sid)
 
 	s.responseHeader = buffer.Byte(33)             // 1 byte
 	request.Option = bitmask.Byte(buffer.Byte(34)) // 1 byte