Browse Source

Add minimal perfect hash domain matcher (#743)

* rename to HybridDomainMatcher & convert domain to lowercase

* refactor code & add open hashing for rolling hash map

* fix lint errors

* update app/dns/dns.go

* convert domain to lowercase in `strmatcher.go`

* keep the original matcher behavior

* add mph domain matcher & conver domain names to loweercase when matching

* fix lint errors

* fix lint errors
DarthVader 4 years ago
parent
commit
ac1e5cd925

+ 3 - 3
app/router/condition.go

@@ -68,8 +68,8 @@ type DomainMatcher struct {
 	matchers strmatcher.IndexMatcher
 	matchers strmatcher.IndexMatcher
 }
 }
 
 
-func NewACAutomatonDomainMatcher(domains []*Domain) (*DomainMatcher, error) {
-	g := strmatcher.NewACAutomatonMatcherGroup()
+func NewMphMatcherGroup(domains []*Domain) (*DomainMatcher, error) {
+	g := strmatcher.NewMphMatcherGroup()
 	for _, d := range domains {
 	for _, d := range domains {
 		matcherType, f := matcherTypeMap[d.Type]
 		matcherType, f := matcherTypeMap[d.Type]
 		if !f {
 		if !f {
@@ -102,7 +102,7 @@ func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) {
 }
 }
 
 
 func (m *DomainMatcher) ApplyDomain(domain string) bool {
 func (m *DomainMatcher) ApplyDomain(domain string) bool {
-	return len(m.matchers.Match(domain)) > 0
+	return len(m.matchers.Match(strings.ToLower(domain))) > 0
 }
 }
 
 
 // Apply implements Condition.
 // Apply implements Condition.

+ 3 - 3
app/router/condition_test.go

@@ -358,7 +358,7 @@ func TestChinaSites(t *testing.T) {
 
 
 	matcher, err := NewDomainMatcher(domains)
 	matcher, err := NewDomainMatcher(domains)
 	common.Must(err)
 	common.Must(err)
-	acMatcher, err := NewACAutomatonDomainMatcher(domains)
+	acMatcher, err := NewMphMatcherGroup(domains)
 	common.Must(err)
 	common.Must(err)
 
 
 	type TestCase struct {
 	type TestCase struct {
@@ -399,11 +399,11 @@ func TestChinaSites(t *testing.T) {
 	}
 	}
 }
 }
 
 
-func BenchmarkHybridDomainMatcher(b *testing.B) {
+func BenchmarkMphDomainMatcher(b *testing.B) {
 	domains, err := loadGeoSite("CN")
 	domains, err := loadGeoSite("CN")
 	common.Must(err)
 	common.Must(err)
 
 
-	matcher, err := NewACAutomatonDomainMatcher(domains)
+	matcher, err := NewMphMatcherGroup(domains)
 	common.Must(err)
 	common.Must(err)
 
 
 	type TestCase struct {
 	type TestCase struct {

+ 4 - 4
app/router/config.go

@@ -70,12 +70,12 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 
 
 	if len(rr.Domain) > 0 {
 	if len(rr.Domain) > 0 {
 		switch rr.DomainMatcher {
 		switch rr.DomainMatcher {
-		case "hybrid":
-			matcher, err := NewACAutomatonDomainMatcher(rr.Domain)
+		case "mph":
+			matcher, err := NewMphMatcherGroup(rr.Domain)
 			if err != nil {
 			if err != nil {
-				return nil, newError("failed to build domain condition with ACAutomatonDomainMatcher").Base(err)
+				return nil, newError("failed to build domain condition with MphDomainMatcher").Base(err)
 			}
 			}
-			newError("ACAutomatonDomainMatcher is enabled for ", len(rr.Domain), "domain rules(s)").AtDebug().WriteToLog()
+			newError("MphDomainMatcher is enabled for ", len(rr.Domain), "domain rules(s)").AtDebug().WriteToLog()
 			conds.Add(matcher)
 			conds.Add(matcher)
 		case "linear":
 		case "linear":
 			fallthrough
 			fallthrough

+ 297 - 0
common/strmatcher/mph_matcher.go

@@ -0,0 +1,297 @@
+package strmatcher
+
+import (
+	"math/bits"
+	"regexp"
+	"sort"
+	"strings"
+	"unsafe"
+)
+
+// PrimeRK is the prime base used in Rabin-Karp algorithm.
+const PrimeRK = 16777619
+
+// calculate the rolling murmurHash of given string
+func RollingHash(s string) uint32 {
+	h := uint32(0)
+	for i := len(s) - 1; i >= 0; i-- {
+		h = h*PrimeRK + uint32(s[i])
+	}
+	return h
+}
+
+// A MphMatcherGroup is divided into three parts:
+// 1. `full` and `domain` patterns are matched by Rabin-Karp algorithm and minimal perfect hash table;
+// 2. `substr` patterns are matched by ac automaton;
+// 3. `regex` patterns are matched with the regex library.
+type MphMatcherGroup struct {
+	ac            *ACAutomaton
+	otherMatchers []matcherEntry
+	rules         []string
+	level0        []uint32
+	level0Mask    int
+	level1        []uint32
+	level1Mask    int
+	count         uint32
+	ruleMap       *map[string]uint32
+}
+
+func (g *MphMatcherGroup) AddFullOrDomainPattern(pattern string, t Type) {
+	h := RollingHash(pattern)
+	switch t {
+	case Domain:
+		(*g.ruleMap)["."+pattern] = h*PrimeRK + uint32('.')
+		fallthrough
+	case Full:
+		(*g.ruleMap)[pattern] = h
+	default:
+	}
+}
+
+func NewMphMatcherGroup() *MphMatcherGroup {
+	return &MphMatcherGroup{
+		ac:            nil,
+		otherMatchers: nil,
+		rules:         nil,
+		level0:        nil,
+		level0Mask:    0,
+		level1:        nil,
+		level1Mask:    0,
+		count:         1,
+		ruleMap:       &map[string]uint32{},
+	}
+}
+
+// AddPattern adds a pattern to MphMatcherGroup
+func (g *MphMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) {
+	switch t {
+	case Substr:
+		if g.ac == nil {
+			g.ac = NewACAutomaton()
+		}
+		g.ac.Add(pattern, t)
+	case Full, Domain:
+		pattern = strings.ToLower(pattern)
+		g.AddFullOrDomainPattern(pattern, t)
+	case Regex:
+		r, err := regexp.Compile(pattern)
+		if err != nil {
+			return 0, err
+		}
+		g.otherMatchers = append(g.otherMatchers, matcherEntry{
+			m:  &regexMatcher{pattern: r},
+			id: g.count,
+		})
+	default:
+		panic("Unknown type")
+	}
+	return g.count, nil
+}
+
+// Build builds a minimal perfect hash table and ac automaton from insert rules
+func (g *MphMatcherGroup) Build() {
+	if g.ac != nil {
+		g.ac.Build()
+	}
+	keyLen := len(*g.ruleMap)
+	g.level0 = make([]uint32, nextPow2(keyLen/4))
+	g.level0Mask = len(g.level0) - 1
+	g.level1 = make([]uint32, nextPow2(keyLen))
+	g.level1Mask = len(g.level1) - 1
+	var sparseBuckets = make([][]int, len(g.level0))
+	var ruleIdx int
+	for rule, hash := range *g.ruleMap {
+		n := int(hash) & g.level0Mask
+		g.rules = append(g.rules, rule)
+		sparseBuckets[n] = append(sparseBuckets[n], ruleIdx)
+		ruleIdx++
+	}
+	g.ruleMap = nil
+	var buckets []indexBucket
+	for n, vals := range sparseBuckets {
+		if len(vals) > 0 {
+			buckets = append(buckets, indexBucket{n, vals})
+		}
+	}
+	sort.Sort(bySize(buckets))
+
+	occ := make([]bool, len(g.level1))
+	var tmpOcc []int
+	for _, bucket := range buckets {
+		var seed = uint32(0)
+		for {
+			findSeed := true
+			tmpOcc = tmpOcc[:0]
+			for _, i := range bucket.vals {
+				n := int(strhashFallback(unsafe.Pointer(&g.rules[i]), uintptr(seed))) & g.level1Mask
+				if occ[n] {
+					for _, n := range tmpOcc {
+						occ[n] = false
+					}
+					seed++
+					findSeed = false
+					break
+				}
+				occ[n] = true
+				tmpOcc = append(tmpOcc, n)
+				g.level1[n] = uint32(i)
+			}
+			if findSeed {
+				g.level0[bucket.n] = seed
+				break
+			}
+		}
+	}
+}
+
+func nextPow2(v int) int {
+	if v <= 1 {
+		return 1
+	}
+	const MaxUInt = ^uint(0)
+	n := (MaxUInt >> bits.LeadingZeros(uint(v))) + 1
+	return int(n)
+}
+
+// Lookup searches for s in t and returns its index and whether it was found.
+func (g *MphMatcherGroup) Lookup(h uint32, s string) bool {
+	i0 := int(h) & g.level0Mask
+	seed := g.level0[i0]
+	i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.level1Mask
+	n := g.level1[i1]
+	return s == g.rules[int(n)]
+}
+
+// Match implements IndexMatcher.Match.
+func (g *MphMatcherGroup) Match(pattern string) []uint32 {
+	result := []uint32{}
+	hash := uint32(0)
+	for i := len(pattern) - 1; i >= 0; i-- {
+		hash = hash*PrimeRK + uint32(pattern[i])
+		if pattern[i] == '.' {
+			if g.Lookup(hash, pattern[i:]) {
+				result = append(result, 1)
+				return result
+			}
+		}
+	}
+	if g.Lookup(hash, pattern) {
+		result = append(result, 1)
+		return result
+	}
+	if g.ac != nil && g.ac.Match(pattern) {
+		result = append(result, 1)
+		return result
+	}
+	for _, e := range g.otherMatchers {
+		if e.m.Match(pattern) {
+			result = append(result, e.id)
+			return result
+		}
+	}
+	return nil
+}
+
+type indexBucket struct {
+	n    int
+	vals []int
+}
+
+type bySize []indexBucket
+
+func (s bySize) Len() int           { return len(s) }
+func (s bySize) Less(i, j int) bool { return len(s[i].vals) > len(s[j].vals) }
+func (s bySize) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
+
+type stringStruct struct {
+	str unsafe.Pointer
+	len int
+}
+
+func strhashFallback(a unsafe.Pointer, h uintptr) uintptr {
+	x := (*stringStruct)(a)
+	return memhashFallback(x.str, h, uintptr(x.len))
+}
+
+const (
+	// Constants for multiplication: four random odd 64-bit numbers.
+	m1 = 16877499708836156737
+	m2 = 2820277070424839065
+	m3 = 9497967016996688599
+	m4 = 15839092249703872147
+)
+
+var hashkey = [4]uintptr{1, 1, 1, 1}
+
+func memhashFallback(p unsafe.Pointer, seed, s uintptr) uintptr {
+	h := uint64(seed + s*hashkey[0])
+tail:
+	switch {
+	case s == 0:
+	case s < 4:
+		h ^= uint64(*(*byte)(p))
+		h ^= uint64(*(*byte)(add(p, s>>1))) << 8
+		h ^= uint64(*(*byte)(add(p, s-1))) << 16
+		h = rotl31(h*m1) * m2
+	case s <= 8:
+		h ^= uint64(readUnaligned32(p))
+		h ^= uint64(readUnaligned32(add(p, s-4))) << 32
+		h = rotl31(h*m1) * m2
+	case s <= 16:
+		h ^= readUnaligned64(p)
+		h = rotl31(h*m1) * m2
+		h ^= readUnaligned64(add(p, s-8))
+		h = rotl31(h*m1) * m2
+	case s <= 32:
+		h ^= readUnaligned64(p)
+		h = rotl31(h*m1) * m2
+		h ^= readUnaligned64(add(p, 8))
+		h = rotl31(h*m1) * m2
+		h ^= readUnaligned64(add(p, s-16))
+		h = rotl31(h*m1) * m2
+		h ^= readUnaligned64(add(p, s-8))
+		h = rotl31(h*m1) * m2
+	default:
+		v1 := h
+		v2 := uint64(seed * hashkey[1])
+		v3 := uint64(seed * hashkey[2])
+		v4 := uint64(seed * hashkey[3])
+		for s >= 32 {
+			v1 ^= readUnaligned64(p)
+			v1 = rotl31(v1*m1) * m2
+			p = add(p, 8)
+			v2 ^= readUnaligned64(p)
+			v2 = rotl31(v2*m2) * m3
+			p = add(p, 8)
+			v3 ^= readUnaligned64(p)
+			v3 = rotl31(v3*m3) * m4
+			p = add(p, 8)
+			v4 ^= readUnaligned64(p)
+			v4 = rotl31(v4*m4) * m1
+			p = add(p, 8)
+			s -= 32
+		}
+		h = v1 ^ v2 ^ v3 ^ v4
+		goto tail
+	}
+
+	h ^= h >> 29
+	h *= m3
+	h ^= h >> 32
+	return uintptr(h)
+}
+func add(p unsafe.Pointer, x uintptr) unsafe.Pointer {
+	return unsafe.Pointer(uintptr(p) + x)
+}
+func readUnaligned32(p unsafe.Pointer) uint32 {
+	q := (*[4]byte)(p)
+	return uint32(q[0]) | uint32(q[1])<<8 | uint32(q[2])<<16 | uint32(q[3])<<24
+}
+
+func rotl31(x uint64) uint64 {
+	return (x << 31) | (x >> (64 - 31))
+}
+func readUnaligned64(p unsafe.Pointer) uint64 {
+	q := (*[8]byte)(p)
+	return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56
+}

+ 1 - 94
common/strmatcher/strmatcher.go

@@ -4,9 +4,6 @@ import (
 	"regexp"
 	"regexp"
 )
 )
 
 
-// PrimeRK is the prime base used in Rabin-Karp algorithm.
-const PrimeRK = 16777619
-
 // Matcher is the interface to determine a string matches a pattern.
 // Matcher is the interface to determine a string matches a pattern.
 type Matcher interface {
 type Matcher interface {
 	// Match returns true if the given string matches a predefined pattern.
 	// Match returns true if the given string matches a predefined pattern.
@@ -30,6 +27,7 @@ const (
 
 
 // New creates a new Matcher based on the given pattern.
 // New creates a new Matcher based on the given pattern.
 func (t Type) New(pattern string) (Matcher, error) {
 func (t Type) New(pattern string) (Matcher, error) {
+	// 1. regex matching is case-sensitive
 	switch t {
 	switch t {
 	case Full:
 	case Full:
 		return fullMatcher(pattern), nil
 		return fullMatcher(pattern), nil
@@ -61,97 +59,6 @@ type matcherEntry struct {
 	id uint32
 	id uint32
 }
 }
 
 
-type ACAutomatonMatcherGroup struct {
-	count         uint32
-	ac            *ACAutomaton
-	nonSubstrMap  map[uint32]string
-	otherMatchers []matcherEntry
-}
-
-func NewACAutomatonMatcherGroup() *ACAutomatonMatcherGroup {
-	var g = new(ACAutomatonMatcherGroup)
-	g.count = 1
-	g.nonSubstrMap = map[uint32]string{}
-	return g
-}
-
-// Add `full` or `domain` pattern to hashmap
-func (g *ACAutomatonMatcherGroup) AddFullOrDomainPattern(pattern string, t Type) {
-	h := uint32(0)
-	for i := len(pattern) - 1; i >= 0; i-- {
-		h = h*PrimeRK + uint32(pattern[i])
-	}
-	switch t {
-	case Full:
-		g.nonSubstrMap[h] = pattern
-	case Domain:
-		g.nonSubstrMap[h] = pattern
-		g.nonSubstrMap[h*PrimeRK+uint32('.')] = "." + pattern
-	default:
-	}
-}
-
-func (g *ACAutomatonMatcherGroup) AddPattern(pattern string, t Type) (uint32, error) {
-	switch t {
-	case Substr:
-		if g.ac == nil {
-			g.ac = NewACAutomaton()
-		}
-		g.ac.Add(pattern, t)
-	case Full, Domain:
-		g.AddFullOrDomainPattern(pattern, t)
-	case Regex:
-		g.count++
-		r, err := regexp.Compile(pattern)
-		if err != nil {
-			return 0, err
-		}
-		g.otherMatchers = append(g.otherMatchers, matcherEntry{
-			m:  &regexMatcher{pattern: r},
-			id: g.count,
-		})
-	default:
-		panic("Unknown type")
-	}
-	return g.count, nil
-}
-
-func (g *ACAutomatonMatcherGroup) Build() {
-	if g.ac != nil {
-		g.ac.Build()
-	}
-}
-
-// Match implements IndexMatcher.Match.
-func (g *ACAutomatonMatcherGroup) Match(pattern string) []uint32 {
-	result := []uint32{}
-	hash := uint32(0)
-	for i := len(pattern) - 1; i >= 0; i-- {
-		hash = hash*PrimeRK + uint32(pattern[i])
-		if pattern[i] == '.' {
-			if v, ok := g.nonSubstrMap[hash]; ok && v == pattern[i:] {
-				result = append(result, 1)
-				return result
-			}
-		}
-	}
-	if v, ok := g.nonSubstrMap[hash]; ok && v == pattern {
-		result = append(result, 1)
-		return result
-	}
-	if g.ac != nil && g.ac.Match(pattern) {
-		result = append(result, 1)
-		return result
-	}
-	for _, e := range g.otherMatchers {
-		if e.m.Match(pattern) {
-			result = append(result, e.id)
-			return result
-		}
-	}
-	return result
-}
-
 // MatcherGroup is an implementation of IndexMatcher.
 // MatcherGroup is an implementation of IndexMatcher.
 // Empty initialization works.
 // Empty initialization works.
 type MatcherGroup struct {
 type MatcherGroup struct {