Переглянути джерело

improve performance of domain matcher

Darien Raymond 7 роки тому
батько
коміт
edcf564dd7

+ 25 - 87
app/router/condition.go

@@ -3,8 +3,6 @@ package router
 import (
 	"context"
 	"strings"
-	"sync"
-	"time"
 
 	"v2ray.com/core/app/dispatcher"
 	"v2ray.com/core/common/net"
@@ -67,116 +65,56 @@ func (v *AnyCondition) Len() int {
 	return len(*v)
 }
 
-type timedResult struct {
-	timestamp time.Time
-	result    bool
-}
-
-type CachableDomainMatcher struct {
-	sync.Mutex
-	matchers *strmatcher.MatcherGroup
-	cache    map[string]timedResult
-	lastScan time.Time
-}
-
-func NewCachableDomainMatcher() *CachableDomainMatcher {
-	return &CachableDomainMatcher{
-		matchers: strmatcher.NewMatcherGroup(),
-		cache:    make(map[string]timedResult, 512),
-	}
-}
-
 var matcherTypeMap = map[Domain_Type]strmatcher.Type{
 	Domain_Plain:  strmatcher.Substr,
 	Domain_Regex:  strmatcher.Regex,
 	Domain_Domain: strmatcher.Domain,
 }
 
-func (m *CachableDomainMatcher) Add(domain *Domain) error {
+func domainToMatcher(domain *Domain) (strmatcher.Matcher, error) {
 	matcherType, f := matcherTypeMap[domain.Type]
 	if !f {
-		return newError("unsupported domain type", domain.Type)
+		return nil, newError("unsupported domain type", domain.Type)
 	}
 
 	matcher, err := matcherType.New(domain.Value)
 	if err != nil {
-		return newError("failed to create domain matcher").Base(err)
+		return nil, newError("failed to create domain matcher").Base(err)
 	}
 
-	m.matchers.Add(matcher)
-	return nil
+	return matcher, nil
 }
 
-func (m *CachableDomainMatcher) applyInternal(domain string) bool {
-	return m.matchers.Match(domain) > 0
-}
-
-type cacheResult int
-
-const (
-	cacheMiss cacheResult = iota
-	cacheHitTrue
-	cacheHitFalse
-)
-
-func (m *CachableDomainMatcher) findInCache(domain string) cacheResult {
-	m.Lock()
-	defer m.Unlock()
-
-	r, f := m.cache[domain]
-	if !f {
-		return cacheMiss
-	}
-	r.timestamp = time.Now()
-	m.cache[domain] = r
-
-	if r.result {
-		return cacheHitTrue
-	}
-	return cacheHitFalse
+type DomainMatcher struct {
+	matchers strmatcher.IndexMatcher
 }
 
-func (m *CachableDomainMatcher) ApplyDomain(domain string) bool {
-	if m.matchers.Size() < 64 {
-		return m.applyInternal(domain)
-	}
-
-	cr := m.findInCache(domain)
-
-	if cr == cacheHitTrue {
-		return true
-	}
-
-	if cr == cacheHitFalse {
-		return false
+func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) {
+	g := strmatcher.NewMatcherGroup()
+	for _, d := range domains {
+		m, err := domainToMatcher(d)
+		if err != nil {
+			return nil, err
+		}
+		g.Add(m)
 	}
 
-	r := m.applyInternal(domain)
-	m.Lock()
-	defer m.Unlock()
-
-	m.cache[domain] = timedResult{
-		result:    r,
-		timestamp: time.Now(),
+	if len(domains) < 64 {
+		return &DomainMatcher{
+			matchers: g,
+		}, nil
 	}
 
-	now := time.Now()
-	if len(m.cache) > 256 && now.Sub(m.lastScan)/time.Second > 5 {
-		now := time.Now()
-
-		for k, v := range m.cache {
-			if now.Sub(v.timestamp)/time.Second > 60 {
-				delete(m.cache, k)
-			}
-		}
-
-		m.lastScan = now
-	}
+	return &DomainMatcher{
+		matchers: strmatcher.NewCachedMatcherGroup(g),
+	}, nil
+}
 
-	return r
+func (m *DomainMatcher) ApplyDomain(domain string) bool {
+	return m.matchers.Match(domain) > 0
 }
 
-func (m *CachableDomainMatcher) Apply(ctx context.Context) bool {
+func (m *DomainMatcher) Apply(ctx context.Context) bool {
 	dest, ok := proxy.TargetFromContext(ctx)
 	if !ok {
 		return false

+ 2 - 4
app/router/condition_test.go

@@ -189,10 +189,8 @@ func TestChinaSites(t *testing.T) {
 	domains, err := loadGeoSite("CN")
 	assert(err, IsNil)
 
-	matcher := NewCachableDomainMatcher()
-	for _, d := range domains {
-		assert(matcher.Add(d), IsNil)
-	}
+	matcher, err := NewCachableDomainMatcher(domains)
+	common.Must(err)
 
 	assert(matcher.ApplyDomain("163.com"), IsTrue)
 	assert(matcher.ApplyDomain("163.com"), IsTrue)

+ 3 - 5
app/router/config.go

@@ -52,11 +52,9 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 	conds := NewConditionChan()
 
 	if len(rr.Domain) > 0 {
-		matcher := NewCachableDomainMatcher()
-		for _, domain := range rr.Domain {
-			if err := matcher.Add(domain); err != nil {
-				return nil, newError("failed to build domain condition").Base(err)
-			}
+		matcher, err := NewCachableDomainMatcher(rr.Domain)
+		if err != nil {
+			return nil, newError("failed to build domain condition").Base(err)
 		}
 		conds.Add(matcher)
 	}

+ 36 - 0
common/strmatcher/benchmark_test.go

@@ -0,0 +1,36 @@
+package strmatcher_test
+
+import (
+	"strconv"
+	"testing"
+
+	"v2ray.com/core/common"
+	. "v2ray.com/core/common/strmatcher"
+)
+
+func BenchmarkDomainMatcherGroup(b *testing.B) {
+	g := new(DomainMatcherGroup)
+
+	for i := 1; i <= 1024; i++ {
+		g.Add(strconv.Itoa(i)+".v2ray.com", uint32(i))
+	}
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		_ = g.Match("0.v2ray.com")
+	}
+}
+
+func BenchmarkMarchGroup(b *testing.B) {
+	g := NewMatcherGroup()
+	for i := 1; i <= 1024; i++ {
+		m, err := Domain.New(strconv.Itoa(i) + ".v2ray.com")
+		common.Must(err)
+		g.Add(m)
+	}
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		_ = g.Match("0.v2ray.com")
+	}
+}

+ 52 - 0
common/strmatcher/domain_matcher.go

@@ -0,0 +1,52 @@
+package strmatcher
+
+import "strings"
+
+func breakDomain(domain string) []string {
+	return strings.Split(domain, ".")
+}
+
+type node struct {
+	value uint32
+	sub   map[string]*node
+}
+
+type DomainMatcherGroup struct {
+	root *node
+}
+
+func (g *DomainMatcherGroup) Add(domain string, value uint32) {
+	if g.root == nil {
+		g.root = &node{
+			sub: make(map[string]*node),
+		}
+	}
+
+	current := g.root
+	parts := breakDomain(domain)
+	for i := len(parts) - 1; i >= 0; i-- {
+		part := parts[i]
+		next := current.sub[part]
+		if next == nil {
+			next = &node{sub: make(map[string]*node)}
+			current.sub[part] = next
+		}
+		current = next
+	}
+
+	current.value = value
+}
+
+func (g *DomainMatcherGroup) Match(domain string) uint32 {
+	current := g.root
+	parts := breakDomain(domain)
+	for i := len(parts) - 1; i >= 0; i-- {
+		part := parts[i]
+		next := current.sub[part]
+		if next == nil {
+			break
+		}
+		current = next
+	}
+	return current.value
+}

+ 35 - 0
common/strmatcher/domain_matcher_test.go

@@ -0,0 +1,35 @@
+package strmatcher_test
+
+import (
+	"testing"
+
+	. "v2ray.com/core/common/strmatcher"
+)
+
+func TestDomainMatcherGroup(t *testing.T) {
+	g := new(DomainMatcherGroup)
+	g.Add("v2ray.com", 1)
+	g.Add("google.com", 2)
+	g.Add("x.a.com", 3)
+
+	testCases := []struct {
+		Domain string
+		Result uint32
+	}{
+		{
+			Domain: "x.v2ray.com",
+			Result: 1,
+		},
+		{
+			Domain: "y.com",
+			Result: 0,
+		},
+	}
+
+	for _, testCase := range testCases {
+		r := g.Match(testCase.Domain)
+		if r != testCase.Result {
+			t.Error("Failed to match domain: ", testCase.Domain, ", expect ", testCase.Result, ", but got ", r)
+		}
+	}
+}

+ 75 - 4
common/strmatcher/strmatcher.go

@@ -1,6 +1,12 @@
 package strmatcher
 
-import "regexp"
+import (
+	"regexp"
+	"sync"
+	"time"
+
+	"v2ray.com/core/common/task"
+)
 
 type Matcher interface {
 	Match(string) bool
@@ -36,6 +42,10 @@ func (t Type) New(pattern string) (Matcher, error) {
 	}
 }
 
+type IndexMatcher interface {
+	Match(pattern string) uint32
+}
+
 type matcherEntry struct {
 	m  Matcher
 	id uint32
@@ -44,6 +54,7 @@ type matcherEntry struct {
 type MatcherGroup struct {
 	count         uint32
 	fullMatchers  map[string]uint32
+	domainMatcher DomainMatcherGroup
 	otherMatchers []matcherEntry
 }
 
@@ -58,9 +69,12 @@ func (g *MatcherGroup) Add(m Matcher) uint32 {
 	c := g.count
 	g.count++
 
-	if fm, ok := m.(fullMatcher); ok {
-		g.fullMatchers[string(fm)] = c
-	} else {
+	switch tm := m.(type) {
+	case fullMatcher:
+		g.fullMatchers[string(tm)] = c
+	case domainMatcher:
+		g.domainMatcher.Add(string(tm), c)
+	default:
 		g.otherMatchers = append(g.otherMatchers, matcherEntry{
 			m:  m,
 			id: c,
@@ -87,3 +101,60 @@ func (g *MatcherGroup) Match(pattern string) uint32 {
 func (g *MatcherGroup) Size() uint32 {
 	return g.count
 }
+
+type cacheEntry struct {
+	timestamp time.Time
+	result    uint32
+}
+
+type CachedMatcherGroup struct {
+	sync.Mutex
+	group   *MatcherGroup
+	cache   map[string]cacheEntry
+	cleanup *task.Periodic
+}
+
+func NewCachedMatcherGroup(g *MatcherGroup) *CachedMatcherGroup {
+	r := &CachedMatcherGroup{
+		group: g,
+		cache: make(map[string]cacheEntry),
+	}
+	r.cleanup = &task.Periodic{
+		Interval: time.Second * 30,
+		Execute: func() error {
+			r.Lock()
+			defer r.Unlock()
+
+			expire := time.Now().Add(-1 * time.Second * 60)
+			for p, e := range r.cache {
+				if e.timestamp.Before(expire) {
+					delete(r.cache, p)
+				}
+			}
+
+			return nil
+		},
+	}
+	return r
+}
+
+func (g *CachedMatcherGroup) Match(pattern string) uint32 {
+	g.Lock()
+	defer g.Unlock()
+
+	r, f := g.cache[pattern]
+	if f {
+		r.timestamp = time.Now()
+		g.cache[pattern] = r
+		return r.result
+	}
+
+	mr := g.group.Match(pattern)
+
+	g.cache[pattern] = cacheEntry{
+		result:    mr,
+		timestamp: time.Now(),
+	}
+
+	return mr
+}