matchergroup_mph.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. package strmatcher
  2. import (
  3. "math/bits"
  4. "sort"
  5. "strings"
  6. "unsafe"
  7. )
  8. // PrimeRK is the prime base used in Rabin-Karp algorithm.
  9. const PrimeRK = 16777619
  10. // RollingHash calculates the rolling murmurHash of given string based on a provided suffix hash.
  11. func RollingHash(hash uint32, input string) uint32 {
  12. for i := len(input) - 1; i >= 0; i-- {
  13. hash = hash*PrimeRK + uint32(input[i])
  14. }
  15. return hash
  16. }
  17. // MemHash is the hash function used by go map, it utilizes available hardware instructions(behaves
  18. // as aeshash if aes instruction is available).
  19. // With different seed, each MemHash<seed> performs as distinct hash functions.
  20. func MemHash(seed uint32, input string) uint32 {
  21. return uint32(strhash(unsafe.Pointer(&input), uintptr(seed))) // nosemgrep
  22. }
  23. const (
  24. mphMatchTypeCount = 2 // Full and Domain
  25. )
  26. type mphRuleInfo struct {
  27. rollingHash uint32
  28. matchers [mphMatchTypeCount][]uint32
  29. }
  30. // MphMatcherGroup is an implementation of MatcherGroup.
  31. // It implements Rabin-Karp algorithm and minimal perfect hash table for Full and Domain matcher.
  32. type MphMatcherGroup struct {
  33. rules []string // RuleIdx -> pattern string, index 0 reserved for failed lookup
  34. values [][]uint32 // RuleIdx -> registered matcher values for the pattern (Full Matcher takes precedence)
  35. level0 []uint32 // RollingHash & Mask -> seed for Memhash
  36. level0Mask uint32 // Mask restricting RollingHash to 0 ~ len(level0)
  37. level1 []uint32 // Memhash<seed> & Mask -> stored index for rules
  38. level1Mask uint32 // Mask for restricting Memhash<seed> to 0 ~ len(level1)
  39. ruleInfos *map[string]mphRuleInfo
  40. }
  41. func NewMphMatcherGroup() *MphMatcherGroup {
  42. return &MphMatcherGroup{
  43. rules: []string{""},
  44. values: [][]uint32{nil},
  45. level0: nil,
  46. level0Mask: 0,
  47. level1: nil,
  48. level1Mask: 0,
  49. ruleInfos: &map[string]mphRuleInfo{}, // Only used for building, destroyed after build complete
  50. }
  51. }
  52. // AddFullMatcher implements MatcherGroupForFull.
  53. func (g *MphMatcherGroup) AddFullMatcher(matcher FullMatcher, value uint32) {
  54. pattern := strings.ToLower(matcher.Pattern())
  55. g.addPattern(0, "", pattern, matcher.Type(), value)
  56. }
  57. // AddDomainMatcher implements MatcherGroupForDomain.
  58. func (g *MphMatcherGroup) AddDomainMatcher(matcher DomainMatcher, value uint32) {
  59. pattern := strings.ToLower(matcher.Pattern())
  60. hash := g.addPattern(0, "", pattern, matcher.Type(), value) // For full domain match
  61. g.addPattern(hash, pattern, ".", matcher.Type(), value) // For partial domain match
  62. }
  63. func (g *MphMatcherGroup) addPattern(suffixHash uint32, suffixPattern string, pattern string, matcherType Type, value uint32) uint32 {
  64. fullPattern := pattern + suffixPattern
  65. info, found := (*g.ruleInfos)[fullPattern]
  66. if !found {
  67. info = mphRuleInfo{rollingHash: RollingHash(suffixHash, pattern)}
  68. g.rules = append(g.rules, fullPattern)
  69. g.values = append(g.values, nil)
  70. }
  71. info.matchers[matcherType] = append(info.matchers[matcherType], value)
  72. (*g.ruleInfos)[fullPattern] = info
  73. return info.rollingHash
  74. }
  75. // Build builds a minimal perfect hash table for insert rules.
  76. // Algorithm used: Hash, displace, and compress. See http://cmph.sourceforge.net/papers/esa09.pdf
  77. func (g *MphMatcherGroup) Build() error {
  78. ruleCount := len(*g.ruleInfos)
  79. g.level0 = make([]uint32, nextPow2(ruleCount/4))
  80. g.level0Mask = uint32(len(g.level0) - 1)
  81. g.level1 = make([]uint32, nextPow2(ruleCount))
  82. g.level1Mask = uint32(len(g.level1) - 1)
  83. // Create buckets based on all rule's rolling hash
  84. buckets := make([][]uint32, len(g.level0))
  85. for ruleIdx := 1; ruleIdx < len(g.rules); ruleIdx++ { // Traverse rules starting from index 1 (0 reserved for failed lookup)
  86. ruleInfo := (*g.ruleInfos)[g.rules[ruleIdx]]
  87. bucketIdx := ruleInfo.rollingHash & g.level0Mask
  88. buckets[bucketIdx] = append(buckets[bucketIdx], uint32(ruleIdx))
  89. g.values[ruleIdx] = append(ruleInfo.matchers[Full], ruleInfo.matchers[Domain]...) // nolint:gocritic
  90. }
  91. g.ruleInfos = nil // Set ruleInfos nil to release memory
  92. // Sort buckets in descending order with respect to each bucket's size
  93. bucketIdxs := make([]int, len(buckets))
  94. for bucketIdx := range buckets {
  95. bucketIdxs[bucketIdx] = bucketIdx
  96. }
  97. sort.Slice(bucketIdxs, func(i, j int) bool { return len(buckets[bucketIdxs[i]]) > len(buckets[bucketIdxs[j]]) })
  98. // Exercise Hash, Displace, and Compress algorithm to construct minimal perfect hash table
  99. occupied := make([]bool, len(g.level1)) // Whether a second-level hash has been already used
  100. hashedBucket := make([]uint32, 0, 4) // Second-level hashes for each rule in a specific bucket
  101. for _, bucketIdx := range bucketIdxs {
  102. bucket := buckets[bucketIdx]
  103. hashedBucket = hashedBucket[:0]
  104. seed := uint32(0)
  105. for len(hashedBucket) != len(bucket) {
  106. for _, ruleIdx := range bucket {
  107. memHash := MemHash(seed, g.rules[ruleIdx]) & g.level1Mask
  108. if occupied[memHash] { // Collision occurred with this seed
  109. for _, hash := range hashedBucket { // Revert all values in this hashed bucket
  110. occupied[hash] = false
  111. g.level1[hash] = 0
  112. }
  113. hashedBucket = hashedBucket[:0]
  114. seed++ // Try next seed
  115. break
  116. }
  117. occupied[memHash] = true
  118. g.level1[memHash] = ruleIdx // The final value in the hash table
  119. hashedBucket = append(hashedBucket, memHash)
  120. }
  121. }
  122. g.level0[bucketIdx] = seed // Displacement value for this bucket
  123. }
  124. return nil
  125. }
  126. // Lookup searches for input in minimal perfect hash table and returns its index. 0 indicates not found.
  127. func (g *MphMatcherGroup) Lookup(rollingHash uint32, input string) uint32 {
  128. i0 := rollingHash & g.level0Mask
  129. seed := g.level0[i0]
  130. i1 := MemHash(seed, input) & g.level1Mask
  131. if n := g.level1[i1]; g.rules[n] == input {
  132. return n
  133. }
  134. return 0
  135. }
  136. // Match implements MatcherGroup.Match.
  137. func (g *MphMatcherGroup) Match(input string) []uint32 {
  138. matches := [][]uint32{}
  139. hash := uint32(0)
  140. for i := len(input) - 1; i >= 0; i-- {
  141. hash = hash*PrimeRK + uint32(input[i])
  142. if input[i] == '.' {
  143. if mphIdx := g.Lookup(hash, input[i:]); mphIdx != 0 {
  144. matches = append(matches, g.values[mphIdx])
  145. }
  146. }
  147. }
  148. if mphIdx := g.Lookup(hash, input); mphIdx != 0 {
  149. matches = append(matches, g.values[mphIdx])
  150. }
  151. switch len(matches) {
  152. case 0:
  153. return nil
  154. case 1:
  155. return matches[0]
  156. default:
  157. result := []uint32{}
  158. for i := len(matches) - 1; i >= 0; i-- {
  159. result = append(result, matches[i]...)
  160. }
  161. return result
  162. }
  163. }
  164. // MatchAny implements MatcherGroup.MatchAny.
  165. func (g *MphMatcherGroup) MatchAny(input string) bool {
  166. hash := uint32(0)
  167. for i := len(input) - 1; i >= 0; i-- {
  168. hash = hash*PrimeRK + uint32(input[i])
  169. if input[i] == '.' {
  170. if g.Lookup(hash, input[i:]) != 0 {
  171. return true
  172. }
  173. }
  174. }
  175. return g.Lookup(hash, input) != 0
  176. }
  177. func nextPow2(v int) int {
  178. if v <= 1 {
  179. return 1
  180. }
  181. const MaxUInt = ^uint(0)
  182. n := (MaxUInt >> bits.LeadingZeros(uint(v))) + 1
  183. return int(n)
  184. }
  185. //go:noescape
  186. //go:linkname strhash runtime.strhash
  187. func strhash(p unsafe.Pointer, h uintptr) uintptr