Browse Source

Refactor common/antireplay, unexport unnecessary public fields. (#422)

* rename AuthIDDecoderHolder private fields

* ignore unused return value more clear

* change PoolSwap to private fields

* refactor Unlock to defer

* use const capacity, reorder code layout

* replace mismatch field name poolX with type Filter

* change AntiReplayTime to private fileds, protect to modify

* rename lastSwapTime to lastSwap

* merge duplicate time.Now.
Since the current unit is in seconds, there is no need to repeat the call

* refine negate expression

* rename antiReplayTime to interval

* add docs

* fix lint stutter issue, rename antireplay.AntiReplayWindow to antireplay.ReplayFilter

* rename fileds m,n to poolA,poolB

* rename antireplay.go to replayfilter.go

* fix build issue

Co-authored-by: Chinsyo <chinsyo@sina.cn>
Chinsyo 5 years ago
parent
commit
dc78733196
3 changed files with 67 additions and 62 deletions
  1. 0 51
      common/antireplay/antireplay.go
  2. 58 0
      common/antireplay/replayfilter.go
  3. 9 11
      proxy/vmess/aead/authid.go

+ 0 - 51
common/antireplay/antireplay.go

@@ -1,51 +0,0 @@
-package antireplay
-
-import (
-	"sync"
-	"time"
-
-	cuckoo "github.com/seiflotfy/cuckoofilter"
-)
-
-func NewAntiReplayWindow(antiReplayTime int64) *AntiReplayWindow {
-	arw := &AntiReplayWindow{}
-	arw.AntiReplayTime = antiReplayTime
-	return arw
-}
-
-type AntiReplayWindow struct {
-	lock           sync.Mutex
-	poolA          *cuckoo.Filter
-	poolB          *cuckoo.Filter
-	lastSwapTime   int64
-	PoolSwap       bool
-	AntiReplayTime int64
-}
-
-func (aw *AntiReplayWindow) Check(sum []byte) bool {
-	aw.lock.Lock()
-
-	if aw.lastSwapTime == 0 {
-		aw.lastSwapTime = time.Now().Unix()
-		aw.poolA = cuckoo.NewFilter(100000)
-		aw.poolB = cuckoo.NewFilter(100000)
-	}
-
-	tnow := time.Now().Unix()
-	timediff := tnow - aw.lastSwapTime
-
-	if timediff >= aw.AntiReplayTime {
-		if aw.PoolSwap {
-			aw.PoolSwap = false
-			aw.poolA.Reset()
-		} else {
-			aw.PoolSwap = true
-			aw.poolB.Reset()
-		}
-		aw.lastSwapTime = tnow
-	}
-
-	ret := aw.poolA.InsertUnique(sum) && aw.poolB.InsertUnique(sum)
-	aw.lock.Unlock()
-	return ret
-}

+ 58 - 0
common/antireplay/replayfilter.go

@@ -0,0 +1,58 @@
+package antireplay
+
+import (
+	"sync"
+	"time"
+
+	cuckoo "github.com/seiflotfy/cuckoofilter"
+)
+
+const replayFilterCapacity = 100000
+
+// ReplayFilter check for replay attacks.
+type ReplayFilter struct {
+	lock     sync.Mutex
+	poolA    *cuckoo.Filter
+	poolB    *cuckoo.Filter
+	poolSwap bool
+	lastSwap int64
+	interval int64
+}
+
+// NewReplayFilter create a new filter with specifying the expiration time interval in seconds.
+func NewReplayFilter(interval int64) *ReplayFilter {
+	filter := &ReplayFilter{}
+	filter.interval = interval
+	return filter
+}
+
+// Interval in second for expiration time for duplicate records.
+func (filter *ReplayFilter) Interval() int64 {
+	return filter.interval
+}
+
+// Check determine if there are duplicate records.
+func (filter *ReplayFilter) Check(sum []byte) bool {
+	filter.lock.Lock()
+	defer filter.lock.Unlock()
+
+	now := time.Now().Unix()
+	if filter.lastSwap == 0 {
+		filter.lastSwap = now
+		filter.poolA = cuckoo.NewFilter(replayFilterCapacity)
+		filter.poolB = cuckoo.NewFilter(replayFilterCapacity)
+	}
+
+	elapsed := now - filter.lastSwap
+	if elapsed >= filter.Interval() {
+		if filter.poolSwap {
+			filter.poolA.Reset()
+		} else {
+			filter.poolB.Reset()
+		}
+		filter.poolSwap = !filter.poolSwap
+		filter.lastSwap = now
+	}
+
+	return filter.poolA.InsertUnique(sum) && filter.poolB.InsertUnique(sum)
+}

+ 9 - 11
proxy/vmess/aead/authid.go

@@ -13,7 +13,7 @@ import (
 	"time"
 	"time"
 
 
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
-	antiReplayWindow "v2ray.com/core/common/antireplay"
+	"v2ray.com/core/common/antireplay"
 )
 )
 
 
 var (
 var (
@@ -66,12 +66,12 @@ func (aidd *AuthIDDecoder) Decode(data [16]byte) (int64, uint32, int32, []byte)
 }
 }
 
 
 func NewAuthIDDecoderHolder() *AuthIDDecoderHolder {
 func NewAuthIDDecoderHolder() *AuthIDDecoderHolder {
-	return &AuthIDDecoderHolder{make(map[string]*AuthIDDecoderItem), antiReplayWindow.NewAntiReplayWindow(120)}
+	return &AuthIDDecoderHolder{make(map[string]*AuthIDDecoderItem), antireplay.NewReplayFilter(120)}
 }
 }
 
 
 type AuthIDDecoderHolder struct {
 type AuthIDDecoderHolder struct {
-	aidhi map[string]*AuthIDDecoderItem
-	apw   *antiReplayWindow.AntiReplayWindow
+	decoders map[string]*AuthIDDecoderItem
+	filter   *antireplay.ReplayFilter
 }
 }
 
 
 type AuthIDDecoderItem struct {
 type AuthIDDecoderItem struct {
@@ -87,16 +87,16 @@ func NewAuthIDDecoderItem(key [16]byte, ticket interface{}) *AuthIDDecoderItem {
 }
 }
 
 
 func (a *AuthIDDecoderHolder) AddUser(key [16]byte, ticket interface{}) {
 func (a *AuthIDDecoderHolder) AddUser(key [16]byte, ticket interface{}) {
-	a.aidhi[string(key[:])] = NewAuthIDDecoderItem(key, ticket)
+	a.decoders[string(key[:])] = NewAuthIDDecoderItem(key, ticket)
 }
 }
 
 
 func (a *AuthIDDecoderHolder) RemoveUser(key [16]byte) {
 func (a *AuthIDDecoderHolder) RemoveUser(key [16]byte) {
-	delete(a.aidhi, string(key[:]))
+	delete(a.decoders, string(key[:]))
 }
 }
 
 
 func (a *AuthIDDecoderHolder) Match(authID [16]byte) (interface{}, error) {
 func (a *AuthIDDecoderHolder) Match(authID [16]byte) (interface{}, error) {
-	for _, v := range a.aidhi {
-		t, z, r, d := v.dec.Decode(authID)
+	for _, v := range a.decoders {
+		t, z, _, d := v.dec.Decode(authID)
 		if z != crc32.ChecksumIEEE(d[:12]) {
 		if z != crc32.ChecksumIEEE(d[:12]) {
 			continue
 			continue
 		}
 		}
@@ -109,12 +109,10 @@ func (a *AuthIDDecoderHolder) Match(authID [16]byte) (interface{}, error) {
 			continue
 			continue
 		}
 		}
 
 
-		if !a.apw.Check(authID[:]) {
+		if !a.filter.Check(authID[:]) {
 			return nil, ErrReplay
 			return nil, ErrReplay
 		}
 		}
 
 
-		_ = r
-
 		return v.ticket, nil
 		return v.ticket, nil
 	}
 	}
 	return nil, ErrNotFound
 	return nil, ErrNotFound