Browse Source

Refactor: A faster DomainMatcher implementation (#587)

* a faster DomainMatcher implementation

* rename benchmark name

* fix linting errors
DarthVader 4 years ago
parent
commit
de618121ad

+ 18 - 0
app/router/condition.go

@@ -67,6 +67,24 @@ type DomainMatcher struct {
 	matchers strmatcher.IndexMatcher
 	matchers strmatcher.IndexMatcher
 }
 }
 
 
+func NewACAutomatonDomainMatcher(domains []*Domain) (*DomainMatcher, error) {
+	g := strmatcher.NewACAutomatonMatcherGroup()
+	for _, d := range domains {
+		matcherType, f := matcherTypeMap[d.Type]
+		if !f {
+			return nil, newError("unsupported domain type", d.Type)
+		}
+		_, err := g.AddPattern(d.Value, matcherType)
+		if err != nil {
+			return nil, err
+		}
+	}
+	g.Build()
+	return &DomainMatcher{
+		matchers: g,
+	}, nil
+}
+
 func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) {
 func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) {
 	g := new(strmatcher.MatcherGroup)
 	g := new(strmatcher.MatcherGroup)
 	for _, d := range domains {
 	for _, d := range domains {

+ 92 - 3
app/router/condition_test.go

@@ -358,6 +358,8 @@ func TestChinaSites(t *testing.T) {
 
 
 	matcher, err := NewDomainMatcher(domains)
 	matcher, err := NewDomainMatcher(domains)
 	common.Must(err)
 	common.Must(err)
+	acMatcher, err := NewACAutomatonDomainMatcher(domains)
+	common.Must(err)
 
 
 	type TestCase struct {
 	type TestCase struct {
 		Domain string
 		Domain string
@@ -387,9 +389,96 @@ func TestChinaSites(t *testing.T) {
 	}
 	}
 
 
 	for _, testCase := range testCases {
 	for _, testCase := range testCases {
-		r := matcher.ApplyDomain(testCase.Domain)
-		if r != testCase.Output {
-			t.Error("expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r)
+		r1 := matcher.ApplyDomain(testCase.Domain)
+		r2 := acMatcher.ApplyDomain(testCase.Domain)
+		if r1 != testCase.Output {
+			t.Error("DomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r1)
+		} else if r2 != testCase.Output {
+			t.Error("ACDomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r2)
+		}
+	}
+}
+
+func BenchmarkACDomainMatcher(b *testing.B) {
+	domains, err := loadGeoSite("CN")
+	common.Must(err)
+
+	matcher, err := NewACAutomatonDomainMatcher(domains)
+	common.Must(err)
+
+	type TestCase struct {
+		Domain string
+		Output bool
+	}
+	testCases := []TestCase{
+		{
+			Domain: "163.com",
+			Output: true,
+		},
+		{
+			Domain: "163.com",
+			Output: true,
+		},
+		{
+			Domain: "164.com",
+			Output: false,
+		},
+		{
+			Domain: "164.com",
+			Output: false,
+		},
+	}
+
+	for i := 0; i < 1024; i++ {
+		testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false})
+	}
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		for _, testCase := range testCases {
+			_ = matcher.ApplyDomain(testCase.Domain)
+		}
+	}
+}
+
+func BenchmarkDomainMatcher(b *testing.B) {
+	domains, err := loadGeoSite("CN")
+	common.Must(err)
+
+	matcher, err := NewDomainMatcher(domains)
+	common.Must(err)
+
+	type TestCase struct {
+		Domain string
+		Output bool
+	}
+	testCases := []TestCase{
+		{
+			Domain: "163.com",
+			Output: true,
+		},
+		{
+			Domain: "163.com",
+			Output: true,
+		},
+		{
+			Domain: "164.com",
+			Output: false,
+		},
+		{
+			Domain: "164.com",
+			Output: false,
+		},
+	}
+
+	for i := 0; i < 1024; i++ {
+		testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false})
+	}
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		for _, testCase := range testCases {
+			_ = matcher.ApplyDomain(testCase.Domain)
 		}
 		}
 	}
 	}
 }
 }

+ 1 - 1
app/router/config.go

@@ -69,7 +69,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 	conds := NewConditionChan()
 	conds := NewConditionChan()
 
 
 	if len(rr.Domain) > 0 {
 	if len(rr.Domain) > 0 {
-		matcher, err := NewDomainMatcher(rr.Domain)
+		matcher, err := NewACAutomatonDomainMatcher(rr.Domain)
 		if err != nil {
 		if err != nil {
 			return nil, newError("failed to build domain condition").Base(err)
 			return nil, newError("failed to build domain condition").Base(err)
 		}
 		}

+ 243 - 0
common/strmatcher/ac_automaton_matcher.go

@@ -0,0 +1,243 @@
+package strmatcher
+
+import (
+	"container/list"
+)
+
+const validCharCount = 53
+
+type MatchType struct {
+	matchType Type
+	exist     bool
+}
+
+const (
+	TrieEdge bool = true
+	FailEdge bool = false
+)
+
+type Edge struct {
+	edgeType bool
+	nextNode int
+}
+
+type ACAutomaton struct {
+	trie   [][validCharCount]Edge
+	fail   []int
+	exists []MatchType
+	count  int
+}
+
+func newNode() [validCharCount]Edge {
+	var s [validCharCount]Edge
+	for i := range s {
+		s[i] = Edge{
+			edgeType: FailEdge,
+			nextNode: 0,
+		}
+	}
+	return s
+}
+
+var char2Index = []int{
+	'A':  0,
+	'a':  0,
+	'B':  1,
+	'b':  1,
+	'C':  2,
+	'c':  2,
+	'D':  3,
+	'd':  3,
+	'E':  4,
+	'e':  4,
+	'F':  5,
+	'f':  5,
+	'G':  6,
+	'g':  6,
+	'H':  7,
+	'h':  7,
+	'I':  8,
+	'i':  8,
+	'J':  9,
+	'j':  9,
+	'K':  10,
+	'k':  10,
+	'L':  11,
+	'l':  11,
+	'M':  12,
+	'm':  12,
+	'N':  13,
+	'n':  13,
+	'O':  14,
+	'o':  14,
+	'P':  15,
+	'p':  15,
+	'Q':  16,
+	'q':  16,
+	'R':  17,
+	'r':  17,
+	'S':  18,
+	's':  18,
+	'T':  19,
+	't':  19,
+	'U':  20,
+	'u':  20,
+	'V':  21,
+	'v':  21,
+	'W':  22,
+	'w':  22,
+	'X':  23,
+	'x':  23,
+	'Y':  24,
+	'y':  24,
+	'Z':  25,
+	'z':  25,
+	'!':  26,
+	'$':  27,
+	'&':  28,
+	'\'': 29,
+	'(':  30,
+	')':  31,
+	'*':  32,
+	'+':  33,
+	',':  34,
+	';':  35,
+	'=':  36,
+	':':  37,
+	'%':  38,
+	'-':  39,
+	'.':  40,
+	'_':  41,
+	'~':  42,
+	'0':  43,
+	'1':  44,
+	'2':  45,
+	'3':  46,
+	'4':  47,
+	'5':  48,
+	'6':  49,
+	'7':  50,
+	'8':  51,
+	'9':  52,
+}
+
+func NewACAutomaton() *ACAutomaton {
+	var ac = new(ACAutomaton)
+	ac.trie = append(ac.trie, newNode())
+	ac.fail = append(ac.fail, 0)
+	ac.exists = append(ac.exists, MatchType{
+		matchType: Full,
+		exist:     false,
+	})
+	return ac
+}
+
+func (ac *ACAutomaton) Add(domain string, t Type) {
+	var node = 0
+	for i := len(domain) - 1; i >= 0; i-- {
+		var idx = char2Index[domain[i]]
+		if ac.trie[node][idx].nextNode == 0 {
+			ac.count++
+			if len(ac.trie) < ac.count+1 {
+				ac.trie = append(ac.trie, newNode())
+				ac.fail = append(ac.fail, 0)
+				ac.exists = append(ac.exists, MatchType{
+					matchType: Full,
+					exist:     false,
+				})
+			}
+			ac.trie[node][idx] = Edge{
+				edgeType: TrieEdge,
+				nextNode: ac.count,
+			}
+		}
+		node = ac.trie[node][idx].nextNode
+	}
+	ac.exists[node] = MatchType{
+		matchType: t,
+		exist:     true,
+	}
+	switch t {
+	case Domain:
+		ac.exists[node] = MatchType{
+			matchType: Full,
+			exist:     true,
+		}
+		var idx = char2Index['.']
+		if ac.trie[node][idx].nextNode == 0 {
+			ac.count++
+			if len(ac.trie) < ac.count+1 {
+				ac.trie = append(ac.trie, newNode())
+				ac.fail = append(ac.fail, 0)
+				ac.exists = append(ac.exists, MatchType{
+					matchType: Full,
+					exist:     false,
+				})
+			}
+			ac.trie[node][idx] = Edge{
+				edgeType: TrieEdge,
+				nextNode: ac.count,
+			}
+		}
+		node = ac.trie[node][idx].nextNode
+		ac.exists[node] = MatchType{
+			matchType: t,
+			exist:     true,
+		}
+	default:
+		break
+	}
+}
+
+func (ac *ACAutomaton) Build() {
+	var queue = list.New()
+	for i := 0; i < validCharCount; i++ {
+		if ac.trie[0][i].nextNode != 0 {
+			queue.PushBack(ac.trie[0][i])
+		}
+	}
+	for {
+		var front = queue.Front()
+		if front == nil {
+			break
+		} else {
+			var node = front.Value.(Edge).nextNode
+			queue.Remove(front)
+			for i := 0; i < validCharCount; i++ {
+				if ac.trie[node][i].nextNode != 0 {
+					ac.fail[ac.trie[node][i].nextNode] = ac.trie[ac.fail[node]][i].nextNode
+					queue.PushBack(ac.trie[node][i])
+				} else {
+					ac.trie[node][i] = Edge{
+						edgeType: FailEdge,
+						nextNode: ac.trie[ac.fail[node]][i].nextNode,
+					}
+				}
+			}
+		}
+	}
+}
+
+func (ac *ACAutomaton) Match(s string) bool {
+	var node = 0
+	var fullMatch = true
+	// 1. the match string is all through trie edge. FULL MATCH or DOMAIN
+	// 2. the match string is through a fail edge. NOT FULL MATCH
+	// 2.1 Through a fail edge, but there exists a valid node. SUBSTR
+	for i := len(s) - 1; i >= 0; i-- {
+		var idx = char2Index[s[i]]
+		fullMatch = fullMatch && ac.trie[node][idx].edgeType
+		node = ac.trie[node][idx].nextNode
+		switch ac.exists[node].matchType {
+		case Substr:
+			return true
+		case Domain:
+			if fullMatch {
+				return true
+			}
+		default:
+			break
+		}
+	}
+	return fullMatch && ac.exists[node].exist
+}

+ 13 - 0
common/strmatcher/benchmark_test.go

@@ -8,6 +8,19 @@ import (
 	. "v2ray.com/core/common/strmatcher"
 	. "v2ray.com/core/common/strmatcher"
 )
 )
 
 
+func BenchmarkACAutomaton(b *testing.B) {
+	ac := NewACAutomaton()
+	for i := 1; i <= 1024; i++ {
+		ac.Add(strconv.Itoa(i)+".v2ray.com", Domain)
+	}
+	ac.Build()
+
+	b.ResetTimer()
+	for i := 0; i < b.N; i++ {
+		_ = ac.Match("0.v2ray.com")
+	}
+}
+
 func BenchmarkDomainMatcherGroup(b *testing.B) {
 func BenchmarkDomainMatcherGroup(b *testing.B) {
 	g := new(DomainMatcherGroup)
 	g := new(DomainMatcherGroup)
 
 

+ 168 - 0
common/strmatcher/matchers_test.go

@@ -71,3 +71,171 @@ func TestMatcher(t *testing.T) {
 		}
 		}
 	}
 	}
 }
 }
+func TestACAutomaton(t *testing.T) {
+	cases1 := []struct {
+		pattern string
+		mType   Type
+		input   string
+		output  bool
+	}{
+		{
+			pattern: "v2ray.com",
+			mType:   Domain,
+			input:   "www.v2ray.com",
+			output:  true,
+		},
+		{
+			pattern: "v2ray.com",
+			mType:   Domain,
+			input:   "v2ray.com",
+			output:  true,
+		},
+		{
+			pattern: "v2ray.com",
+			mType:   Domain,
+			input:   "www.v3ray.com",
+			output:  false,
+		},
+		{
+			pattern: "v2ray.com",
+			mType:   Domain,
+			input:   "2ray.com",
+			output:  false,
+		},
+		{
+			pattern: "v2ray.com",
+			mType:   Domain,
+			input:   "xv2ray.com",
+			output:  false,
+		},
+		{
+			pattern: "v2ray.com",
+			mType:   Full,
+			input:   "v2ray.com",
+			output:  true,
+		},
+		{
+			pattern: "v2ray.com",
+			mType:   Full,
+			input:   "xv2ray.com",
+			output:  false,
+		},
+	}
+	for _, test := range cases1 {
+		var ac = NewACAutomaton()
+		ac.Add(test.pattern, test.mType)
+		ac.Build()
+		if m := ac.Match(test.input); m != test.output {
+			t.Error("unexpected output: ", m, " for test case ", test)
+		}
+	}
+	{
+		cases2Input := []struct {
+			pattern string
+			mType   Type
+		}{
+			{
+				pattern: "163.com",
+				mType:   Domain,
+			},
+			{
+				pattern: "m.126.com",
+				mType:   Full,
+			},
+			{
+				pattern: "3.com",
+				mType:   Full,
+			},
+			{
+				pattern: "google.com",
+				mType:   Substr,
+			},
+			{
+				pattern: "vgoogle.com",
+				mType:   Substr,
+			},
+		}
+		var ac = NewACAutomaton()
+		for _, test := range cases2Input {
+			ac.Add(test.pattern, test.mType)
+		}
+		ac.Build()
+		cases2Output := []struct {
+			pattern string
+			res     bool
+		}{
+			{
+				pattern: "126.com",
+				res:     false,
+			},
+			{
+				pattern: "m.163.com",
+				res:     true,
+			},
+			{
+				pattern: "mm163.com",
+				res:     false,
+			},
+			{
+				pattern: "m.126.com",
+				res:     true,
+			},
+			{
+				pattern: "163.com",
+				res:     true,
+			},
+			{
+				pattern: "63.com",
+				res:     false,
+			},
+			{
+				pattern: "oogle.com",
+				res:     false,
+			},
+			{
+				pattern: "vvgoogle.com",
+				res:     true,
+			},
+		}
+		for _, test := range cases2Output {
+			if m := ac.Match(test.pattern); m != test.res {
+				t.Error("unexpected output: ", m, " for test case ", test)
+			}
+		}
+	}
+
+	{
+		cases3Input := []struct {
+			pattern string
+			mType   Type
+		}{
+			{
+				pattern: "video.google.com",
+				mType:   Domain,
+			},
+			{
+				pattern: "gle.com",
+				mType:   Domain,
+			},
+		}
+		var ac = NewACAutomaton()
+		for _, test := range cases3Input {
+			ac.Add(test.pattern, test.mType)
+		}
+		ac.Build()
+		cases3Output := []struct {
+			pattern string
+			res     bool
+		}{
+			{
+				pattern: "google.com",
+				res:     false,
+			},
+		}
+		for _, test := range cases3Output {
+			if m := ac.Match(test.pattern); m != test.res {
+				t.Error("unexpected output: ", m, " for test case ", test)
+			}
+		}
+	}
+}

+ 51 - 0
common/strmatcher/strmatcher.go

@@ -58,6 +58,57 @@ type matcherEntry struct {
 	id uint32
 	id uint32
 }
 }
 
 
+type ACAutomatonMatcherGroup struct {
+	count         uint32
+	ac            *ACAutomaton
+	otherMatchers []matcherEntry
+}
+
+func NewACAutomatonMatcherGroup() *ACAutomatonMatcherGroup {
+	var g = new(ACAutomatonMatcherGroup)
+	g.count = 1
+	g.ac = NewACAutomaton()
+	return g
+}
+
+func (g *ACAutomatonMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) {
+	switch t {
+	case Full, Substr, Domain:
+		g.ac.Add(pattern, t)
+	case Regex:
+		g.count++
+		r, err := regexp.Compile(pattern)
+		if err != nil {
+			return 0, err
+		}
+		g.otherMatchers = append(g.otherMatchers, matcherEntry{
+			m:  &regexMatcher{pattern: r},
+			id: g.count,
+		})
+	default:
+		panic("Unknown type")
+	}
+	return g.count, nil
+}
+
+func (g *ACAutomatonMatcherGroup) Build() {
+	g.ac.Build()
+}
+
+// Match implements IndexMatcher.Match.
+func (g *ACAutomatonMatcherGroup) Match(pattern string) []uint32 {
+	result := []uint32{}
+	if g.ac.Match(pattern) {
+		result = append(result, 1)
+	}
+	for _, e := range g.otherMatchers {
+		if e.m.Match(pattern) {
+			result = append(result, e.id)
+		}
+	}
+	return result
+}
+
 // MatcherGroup is an implementation of IndexMatcher.
 // MatcherGroup is an implementation of IndexMatcher.
 // Empty initialization works.
 // Empty initialization works.
 type MatcherGroup struct {
 type MatcherGroup struct {