Browse Source

Refactor strmatcher.MphMatcherGroup (#1364)

* Refactor strmatcher.MphMatcherGroup

* Add test for empty mph matcher group
Ye Zhihao 4 years ago
parent
commit
ed9641dad1
2 changed files with 239 additions and 185 deletions
  1. 135 185
      common/strmatcher/matchergroup_mph.go
  2. 104 0
      common/strmatcher/matchergroup_mph_test.go

+ 135 - 185
common/strmatcher/matchergroup_mph.go

@@ -10,134 +10,187 @@ import (
 // PrimeRK is the prime base used in Rabin-Karp algorithm.
 const PrimeRK = 16777619
 
-// calculate the rolling murmurHash of given string
-func RollingHash(s string) uint32 {
-	h := uint32(0)
-	for i := len(s) - 1; i >= 0; i-- {
-		h = h*PrimeRK + uint32(s[i])
+// RollingHash calculates the rolling murmurHash of given string based on a provided suffix hash.
+func RollingHash(hash uint32, input string) uint32 {
+	for i := len(input) - 1; i >= 0; i-- {
+		hash = hash*PrimeRK + uint32(input[i])
 	}
-	return h
+	return hash
+}
+
+// MemHash is the hash function used by go map, it utilizes available hardware instructions(behaves
+// as aeshash if aes instruction is available).
+// With different seed, each MemHash<seed> performs as distinct hash functions.
+func MemHash(seed uint32, input string) uint32 {
+	return uint32(strhash(unsafe.Pointer(&input), uintptr(seed))) // nosemgrep
+}
+
+const (
+	mphMatchTypeCount = 2 // Full and Domain
+)
+
+type mphRuleInfo struct {
+	rollingHash uint32
+	matchers    [mphMatchTypeCount][]uint32
 }
 
 // MphMatcherGroup is an implementation of MatcherGroup.
 // It implements Rabin-Karp algorithm and minimal perfect hash table for Full and Domain matcher.
 type MphMatcherGroup struct {
-	rules      []string
-	level0     []uint32
-	level0Mask int
-	level1     []uint32
-	level1Mask int
-	ruleMap    *map[string]uint32
+	rules      []string   // RuleIdx -> pattern string, index 0 reserved for failed lookup
+	values     [][]uint32 // RuleIdx -> registered matcher values for the pattern (Full Matcher takes precedence)
+	level0     []uint32   // RollingHash & Mask -> seed for Memhash
+	level0Mask uint32     // Mask restricting RollingHash to 0 ~ len(level0)
+	level1     []uint32   // Memhash<seed> & Mask -> stored index for rules
+	level1Mask uint32     // Mask for restricting Memhash<seed> to 0 ~ len(level1)
+	ruleInfos  *map[string]mphRuleInfo
 }
 
 func NewMphMatcherGroup() *MphMatcherGroup {
 	return &MphMatcherGroup{
-		rules:      nil,
+		rules:      []string{""},
+		values:     [][]uint32{nil},
 		level0:     nil,
 		level0Mask: 0,
 		level1:     nil,
 		level1Mask: 0,
-		ruleMap:    &map[string]uint32{},
+		ruleInfos:  &map[string]mphRuleInfo{}, // Only used for building, destroyed after build complete
 	}
 }
 
 // AddFullMatcher implements MatcherGroupForFull.
-func (g *MphMatcherGroup) AddFullMatcher(matcher FullMatcher, _ uint32) {
+func (g *MphMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) {
 	pattern := strings.ToLower(matcher.Pattern())
-	(*g.ruleMap)[pattern] = RollingHash(pattern)
+	g.addPattern(0, "", pattern, matcher.Type(), value)
 }
 
 // AddDomainMatcher implements MatcherGroupForDomain.
-func (g *MphMatcherGroup) AddDomainMatcher(matcher DomainMatcher, _ uint32) {
+func (g *MphMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) {
 	pattern := strings.ToLower(matcher.Pattern())
-	h := RollingHash(pattern)
-	(*g.ruleMap)[pattern] = h
-	(*g.ruleMap)["."+pattern] = h*PrimeRK + uint32('.')
+	hash := g.addPattern(0, "", pattern, matcher.Type(), value) // For full domain match
+	g.addPattern(hash, pattern, ".", matcher.Type(), value)     // For partial domain match
 }
 
-// Build builds a minimal perfect hash table for insert rules.
-func (g *MphMatcherGroup) Build() {
-	keyLen := len(*g.ruleMap)
-	if keyLen == 0 {
-		keyLen = 1
-		(*g.ruleMap)["empty___"] = RollingHash("empty___")
-	}
-	g.level0 = make([]uint32, nextPow2(keyLen/4))
-	g.level0Mask = len(g.level0) - 1
-	g.level1 = make([]uint32, nextPow2(keyLen))
-	g.level1Mask = len(g.level1) - 1
-	sparseBuckets := make([][]int, len(g.level0))
-	var ruleIdx int
-	for rule, hash := range *g.ruleMap {
-		n := int(hash) & g.level0Mask
-		g.rules = append(g.rules, rule)
-		sparseBuckets[n] = append(sparseBuckets[n], ruleIdx)
-		ruleIdx++
+func (g *MphMatcherGroup) addPattern(suffixHash uint32, suffixPattern string, pattern string, matcherType Type, value uint32) uint32 {
+	fullPattern := pattern + suffixPattern
+	info, found := (*g.ruleInfos)[fullPattern]
+	if !found {
+		info = mphRuleInfo{rollingHash: RollingHash(suffixHash, pattern)}
+		g.rules = append(g.rules, fullPattern)
+		g.values = append(g.values, nil)
 	}
-	g.ruleMap = nil
-	var buckets []indexBucket
-	for n, vals := range sparseBuckets {
-		if len(vals) > 0 {
-			buckets = append(buckets, indexBucket{n, vals})
-		}
+	info.matchers[matcherType] = append(info.matchers[matcherType], value)
+	(*g.ruleInfos)[fullPattern] = info
+	return info.rollingHash
+}
+
+// Build builds a minimal perfect hash table for insert rules.
+// Algorithm used: Hash, displace, and compress. See http://cmph.sourceforge.net/papers/esa09.pdf
+func (g *MphMatcherGroup) Build() error {
+	ruleCount := len(*g.ruleInfos)
+	g.level0 = make([]uint32, nextPow2(ruleCount/4))
+	g.level0Mask = uint32(len(g.level0) - 1)
+	g.level1 = make([]uint32, nextPow2(ruleCount))
+	g.level1Mask = uint32(len(g.level1) - 1)
+
+	// Create buckets based on all rule's rolling hash
+	buckets := make([][]uint32, len(g.level0))
+	for ruleIdx := 1; ruleIdx < len(g.rules); ruleIdx++ { // Traverse rules starting from index 1 (0 reserved for failed lookup)
+		ruleInfo := (*g.ruleInfos)[g.rules[ruleIdx]]
+		bucketIdx := ruleInfo.rollingHash & g.level0Mask
+		buckets[bucketIdx] = append(buckets[bucketIdx], uint32(ruleIdx))
+		g.values[ruleIdx] = append(ruleInfo.matchers[Full], ruleInfo.matchers[Domain]...) // nolint:gocritic
 	}
-	sort.Sort(bySize(buckets))
+	g.ruleInfos = nil // Set ruleInfos nil to release memory
 
-	occ := make([]bool, len(g.level1))
-	var tmpOcc []int
-	for _, bucket := range buckets {
+	// Sort buckets in descending order with respect to each bucket's size
+	bucketIdxs := make([]int, len(buckets))
+	for bucketIdx := range buckets {
+		bucketIdxs[bucketIdx] = bucketIdx
+	}
+	sort.Slice(bucketIdxs, func(i, j int) bool { return len(buckets[bucketIdxs[i]]) > len(buckets[bucketIdxs[j]]) })
+
+	// Exercise Hash, Displace, and Compress algorithm to construct minimal perfect hash table
+	occupied := make([]bool, len(g.level1)) // Whether a second-level hash has been already used
+	hashedBucket := make([]uint32, 0, 4)    // Second-level hashes for each rule in a specific bucket
+	for _, bucketIdx := range bucketIdxs {
+		bucket := buckets[bucketIdx]
+		hashedBucket = hashedBucket[:0]
 		seed := uint32(0)
-		for {
-			findSeed := true
-			tmpOcc = tmpOcc[:0]
-			for _, i := range bucket.vals {
-				n := int(strhashFallback(unsafe.Pointer(&g.rules[i]), uintptr(seed))) & g.level1Mask // nosemgrep
-				if occ[n] {
-					for _, n := range tmpOcc {
-						occ[n] = false
+		for len(hashedBucket) != len(bucket) {
+			for _, ruleIdx := range bucket {
+				memHash := MemHash(seed, g.rules[ruleIdx]) & g.level1Mask
+				if occupied[memHash] { // Collision occurred with this seed
+					for _, hash := range hashedBucket { // Revert all values in this hashed bucket
+						occupied[hash] = false
+						g.level1[hash] = 0
 					}
-					seed++
-					findSeed = false
+					hashedBucket = hashedBucket[:0]
+					seed++ // Try next seed
 					break
 				}
-				occ[n] = true
-				tmpOcc = append(tmpOcc, n)
-				g.level1[n] = uint32(i)
-			}
-			if findSeed {
-				g.level0[bucket.n] = seed
-				break
+				occupied[memHash] = true
+				g.level1[memHash] = ruleIdx // The final value in the hash table
+				hashedBucket = append(hashedBucket, memHash)
 			}
 		}
+		g.level0[bucketIdx] = seed // Displacement value for this bucket
 	}
+	return nil
 }
 
-// Lookup searches for s in t and returns its index and whether it was found.
-func (g *MphMatcherGroup) Lookup(h uint32, s string) bool {
-	i0 := int(h) & g.level0Mask
+// Lookup searches for input in minimal perfect hash table and returns its index. 0 indicates not found.
+func (g *MphMatcherGroup) Lookup(rollingHash uint32, input string) uint32 {
+	i0 := rollingHash & g.level0Mask
 	seed := g.level0[i0]
-	i1 := int(strhashFallback(unsafe.Pointer(&s), uintptr(seed))) & g.level1Mask // nosemgrep
-	n := g.level1[i1]
-	return s == g.rules[int(n)]
+	i1 := MemHash(seed, input) & g.level1Mask
+	if n := g.level1[i1]; g.rules[n] == input {
+		return n
+	}
+	return 0
 }
 
 // Match implements MatcherGroup.Match.
-func (*MphMatcherGroup) Match(_ string) []uint32 {
-	return nil
+func (g *MphMatcherGroup) Match(input string) []uint32 {
+	matches := [][]uint32{}
+	hash := uint32(0)
+	for i := len(input) - 1; i >= 0; i-- {
+		hash = hash*PrimeRK + uint32(input[i])
+		if input[i] == '.' {
+			if mphIdx := g.Lookup(hash, input[i:]); mphIdx != 0 {
+				matches = append(matches, g.values[mphIdx])
+			}
+		}
+	}
+	if mphIdx := g.Lookup(hash, input); mphIdx != 0 {
+		matches = append(matches, g.values[mphIdx])
+	}
+	switch len(matches) {
+	case 0:
+		return nil
+	case 1:
+		return matches[0]
+	default:
+		result := []uint32{}
+		for i := len(matches) - 1; i >= 0; i-- {
+			result = append(result, matches[i]...)
+		}
+		return result
+	}
 }
 
 // MatchAny implements MatcherGroup.MatchAny.
-func (g *MphMatcherGroup) MatchAny(pattern string) bool {
+func (g *MphMatcherGroup) MatchAny(input string) bool {
 	hash := uint32(0)
-	for i := len(pattern) - 1; i >= 0; i-- {
-		hash = hash*PrimeRK + uint32(pattern[i])
-		if pattern[i] == '.' {
-			if g.Lookup(hash, pattern[i:]) {
+	for i := len(input) - 1; i >= 0; i-- {
+		hash = hash*PrimeRK + uint32(input[i])
+		if input[i] == '.' {
+			if g.Lookup(hash, input[i:]) != 0 {
 				return true
 			}
 		}
 	}
-	return g.Lookup(hash, pattern)
+	return g.Lookup(hash, input) != 0
 }
 
 func nextPow2(v int) int {
@@ -149,109 +202,6 @@ func nextPow2(v int) int {
 	return int(n)
 }
 
-type indexBucket struct {
-	n    int
-	vals []int
-}
-
-type bySize []indexBucket
-
-func (s bySize) Len() int           { return len(s) }
-func (s bySize) Less(i, j int) bool { return len(s[i].vals) > len(s[j].vals) }
-func (s bySize) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }
-
-type stringStruct struct {
-	str unsafe.Pointer
-	len int
-}
-
-func strhashFallback(a unsafe.Pointer, h uintptr) uintptr {
-	x := (*stringStruct)(a)
-	return memhashFallback(x.str, h, uintptr(x.len))
-}
-
-const (
-	// Constants for multiplication: four random odd 64-bit numbers.
-	m1 = 16877499708836156737
-	m2 = 2820277070424839065
-	m3 = 9497967016996688599
-	m4 = 15839092249703872147
-)
-
-var hashkey = [4]uintptr{1, 1, 1, 1}
-
-func memhashFallback(p unsafe.Pointer, seed, s uintptr) uintptr {
-	h := uint64(seed + s*hashkey[0])
-tail:
-	switch {
-	case s == 0:
-	case s < 4:
-		h ^= uint64(*(*byte)(p))
-		h ^= uint64(*(*byte)(add(p, s>>1))) << 8
-		h ^= uint64(*(*byte)(add(p, s-1))) << 16
-		h = rotl31(h*m1) * m2
-	case s <= 8:
-		h ^= uint64(readUnaligned32(p))
-		h ^= uint64(readUnaligned32(add(p, s-4))) << 32
-		h = rotl31(h*m1) * m2
-	case s <= 16:
-		h ^= readUnaligned64(p)
-		h = rotl31(h*m1) * m2
-		h ^= readUnaligned64(add(p, s-8))
-		h = rotl31(h*m1) * m2
-	case s <= 32:
-		h ^= readUnaligned64(p)
-		h = rotl31(h*m1) * m2
-		h ^= readUnaligned64(add(p, 8))
-		h = rotl31(h*m1) * m2
-		h ^= readUnaligned64(add(p, s-16))
-		h = rotl31(h*m1) * m2
-		h ^= readUnaligned64(add(p, s-8))
-		h = rotl31(h*m1) * m2
-	default:
-		v1 := h
-		v2 := uint64(seed * hashkey[1])
-		v3 := uint64(seed * hashkey[2])
-		v4 := uint64(seed * hashkey[3])
-		for s >= 32 {
-			v1 ^= readUnaligned64(p)
-			v1 = rotl31(v1*m1) * m2
-			p = add(p, 8)
-			v2 ^= readUnaligned64(p)
-			v2 = rotl31(v2*m2) * m3
-			p = add(p, 8)
-			v3 ^= readUnaligned64(p)
-			v3 = rotl31(v3*m3) * m4
-			p = add(p, 8)
-			v4 ^= readUnaligned64(p)
-			v4 = rotl31(v4*m4) * m1
-			p = add(p, 8)
-			s -= 32
-		}
-		h = v1 ^ v2 ^ v3 ^ v4
-		goto tail
-	}
-
-	h ^= h >> 29
-	h *= m3
-	h ^= h >> 32
-	return uintptr(h)
-}
-
-func add(p unsafe.Pointer, x uintptr) unsafe.Pointer {
-	return unsafe.Pointer(uintptr(p) + x) // nosemgrep
-}
-
-func readUnaligned32(p unsafe.Pointer) uint32 {
-	q := (*[4]byte)(p)
-	return uint32(q[0]) | uint32(q[1])<<8 | uint32(q[2])<<16 | uint32(q[3])<<24
-}
-
-func rotl31(x uint64) uint64 {
-	return (x << 31) | (x >> (64 - 31))
-}
-
-func readUnaligned64(p unsafe.Pointer) uint64 {
-	q := (*[8]byte)(p)
-	return uint64(q[0]) | uint64(q[1])<<8 | uint64(q[2])<<16 | uint64(q[3])<<24 | uint64(q[4])<<32 | uint64(q[5])<<40 | uint64(q[6])<<48 | uint64(q[7])<<56
-}
+//go:noescape
+//go:linkname strhash runtime.strhash
+func strhash(p unsafe.Pointer, h uintptr) uintptr

+ 104 - 0
common/strmatcher/matchergroup_mph_test.go

@@ -1,6 +1,7 @@
 package strmatcher_test
 
 import (
+	"reflect"
 	"testing"
 
 	"github.com/v2fly/v2ray-core/v4/common"
@@ -172,3 +173,106 @@ func TestMphMatcherGroup(t *testing.T) {
 		}
 	}
 }
+
+// See https://github.com/v2fly/v2ray-core/issues/92#issuecomment-673238489
+func TestMphMatcherGroupAsIndexMatcher(t *testing.T) {
+	rules := []struct {
+		Type   Type
+		Domain string
+	}{
+		// Regex not supported by MphMatcherGroup
+		// {
+		// 	Type:   Regex,
+		// 	Domain: "apis\\.us$",
+		// },
+		// Substr not supported by MphMatcherGroup
+		// {
+		// 	Type:   Substr,
+		// 	Domain: "apis",
+		// },
+		{
+			Type:   Domain,
+			Domain: "googleapis.com",
+		},
+		{
+			Type:   Domain,
+			Domain: "com",
+		},
+		{
+			Type:   Full,
+			Domain: "www.baidu.com",
+		},
+		// Substr not supported by MphMatcherGroup, We add another matcher to preserve index
+		{
+			Type:   Domain,        // Substr,
+			Domain: "example.com", // "apis",
+		},
+		{
+			Type:   Domain,
+			Domain: "googleapis.com",
+		},
+		{
+			Type:   Full,
+			Domain: "fonts.googleapis.com",
+		},
+		{
+			Type:   Full,
+			Domain: "www.baidu.com",
+		},
+		{ // This matcher (index 10) is swapped with matcher (index 6) to test that full matcher takes high priority.
+			Type:   Full,
+			Domain: "example.com",
+		},
+		{
+			Type:   Domain,
+			Domain: "example.com",
+		},
+	}
+	cases := []struct {
+		Input  string
+		Output []uint32
+	}{
+		{
+			Input:  "www.baidu.com",
+			Output: []uint32{5, 9, 4},
+		},
+		{
+			Input:  "fonts.googleapis.com",
+			Output: []uint32{8, 3, 7, 4 /*2, 6*/},
+		},
+		{
+			Input:  "example.googleapis.com",
+			Output: []uint32{3, 7, 4 /*2, 6*/},
+		},
+		{
+			Input: "testapis.us",
+			// Output: []uint32{ /*2, 6*/ /*1,*/ },
+			Output: nil,
+		},
+		{
+			Input:  "example.com",
+			Output: []uint32{10, 6, 11, 4},
+		},
+	}
+	matcherGroup := NewMphMatcherGroup()
+	for i, rule := range rules {
+		matcher, err := rule.Type.New(rule.Domain)
+		common.Must(err)
+		common.Must(AddMatcherToGroup(matcherGroup, matcher, uint32(i+3)))
+	}
+	matcherGroup.Build()
+	for _, test := range cases {
+		if m := matcherGroup.Match(test.Input); !reflect.DeepEqual(m, test.Output) {
+			t.Error("unexpected output: ", m, " for test case ", test)
+		}
+	}
+}
+
+func TestEmptyMphMatcherGroup(t *testing.T) {
+	g := NewMphMatcherGroup()
+	g.Build()
+	r := g.Match("v2fly.org")
+	if len(r) != 0 {
+		t.Error("Expect [], but ", r)
+	}
+}