Browse Source

Merge pull request #95 from Vigilans/vigilans/dns-subdomain-multimatch

Amend domain matcher with returning values of all matched subdomains
Kslr 5 years ago
parent
commit
aa800355c3

+ 41 - 3
app/dns/server_test.go

@@ -53,6 +53,9 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
 		} else if q.Name == "api.google.com." && q.Qtype == dns.TypeA {
 			rr, _ := dns.NewRR("api.google.com. IN A 8.8.7.7")
 			ans.Answer = append(ans.Answer, rr)
+		} else if q.Name == "v2.api.google.com." && q.Qtype == dns.TypeA {
+			rr, _ := dns.NewRR("v2.api.google.com. IN A 8.8.7.8")
+			ans.Answer = append(ans.Answer, rr)
 		} else if q.Name == "facebook.com." && q.Qtype == dns.TypeA {
 			rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9")
 			ans.Answer = append(ans.Answer, rr)
@@ -847,14 +850,38 @@ func TestMultiMatchPrioritizedDomain(t *testing.T) {
 						},
 						PrioritizedDomain: []*NameServer_PriorityDomain{
 							{
-								Type:   DomainMatchingType_Full,
+								Type:   DomainMatchingType_Subdomain,
 								Domain: "api.google.com",
 							},
 						},
 						Geoip: []*router.GeoIP{
 							{ // Will only match 8.8.7.7 (api.google.com)
 								Cidr: []*router.CIDR{
-									{Ip: []byte{8, 8, 7, 7}, Prefix: 0},
+									{Ip: []byte{8, 8, 7, 7}, Prefix: 32},
+								},
+							},
+						},
+					},
+					{
+						Address: &net.Endpoint{
+							Network: net.Network_UDP,
+							Address: &net.IPOrDomain{
+								Address: &net.IPOrDomain_Ip{
+									Ip: []byte{127, 0, 0, 1},
+								},
+							},
+							Port: uint32(port),
+						},
+						PrioritizedDomain: []*NameServer_PriorityDomain{
+							{
+								Type:   DomainMatchingType_Full,
+								Domain: "v2.api.google.com",
+							},
+						},
+						Geoip: []*router.GeoIP{
+							{ // Will only match 8.8.7.8 (v2.api.google.com)
+								Cidr: []*router.CIDR{
+									{Ip: []byte{8, 8, 7, 8}, Prefix: 32},
 								},
 							},
 						},
@@ -902,7 +929,7 @@ func TestMultiMatchPrioritizedDomain(t *testing.T) {
 		}
 	}
 
-	{ // Will match server 1,2,3 and server 1,2 returns unexpected ip, then server 3 returns expected one
+	{ // Will match server 3,1,2 and server 3 returns expected one
 		ips, err := client.LookupIP("api.google.com")
 		if err != nil {
 			t.Fatal("unexpected error: ", err)
@@ -913,6 +940,17 @@ func TestMultiMatchPrioritizedDomain(t *testing.T) {
 		}
 	}
 
+	{ // Will match server 4,3,1,2 and server 4 returns expected one
+		ips, err := client.LookupIP("v2.api.google.com")
+		if err != nil {
+			t.Fatal("unexpected error: ", err)
+		}
+
+		if r := cmp.Diff(ips, []net.IP{{8, 8, 7, 8}}); r != "" {
+			t.Fatal(r)
+		}
+	}
+
 	endTime := time.Now()
 	if startTime.After(endTime.Add(time.Second * 2)) {
 		t.Error("DNS query doesn't finish in 2 seconds.")

+ 17 - 7
common/strmatcher/domain_matcher.go

@@ -25,11 +25,6 @@ func (g *DomainMatcherGroup) Add(domain string, value uint32) {
 	current := g.root
 	parts := breakDomain(domain)
 	for i := len(parts) - 1; i >= 0; i-- {
-		if len(current.values) > 0 {
-			// if current node is already a match, it is not necessary to match further.
-			return
-		}
-
 		part := parts[i]
 		if current.sub == nil {
 			current.sub = make(map[string]*node)
@@ -43,7 +38,6 @@ func (g *DomainMatcherGroup) Add(domain string, value uint32) {
 	}
 
 	current.values = append(current.values, value)
-	current.sub = nil // shortcut sub nodes as current node is a match.
 }
 
 func (g *DomainMatcherGroup) addMatcher(m domainMatcher, value uint32) {
@@ -69,6 +63,7 @@ func (g *DomainMatcherGroup) Match(domain string) []uint32 {
 		return -1
 	}
 
+	matches := [][]uint32{}
 	idx := len(domain)
 	for {
 		if idx == -1 || current.sub == nil {
@@ -83,6 +78,21 @@ func (g *DomainMatcherGroup) Match(domain string) []uint32 {
 		}
 		current = next
 		idx = nidx
+		if len(current.values) > 0 {
+			matches = append(matches, current.values)
+		}
+	}
+	switch len(matches) {
+	case 0:
+		return nil
+	case 1:
+		return matches[0]
+	default:
+		result := []uint32{}
+		for idx := range matches {
+			// Insert reversely, the subdomain that matches further ranks higher
+			result = append(result, matches[len(matches)-1-idx]...)
+		}
+		return result
 	}
-	return current.values
 }

+ 2 - 2
common/strmatcher/domain_matcher_test.go

@@ -33,9 +33,9 @@ func TestDomainMatcherGroup(t *testing.T) {
 			Domain: "a.b.com",
 			Result: []uint32{4},
 		},
-		{
+		{ // Matches [c.a.b.com, a.b.com]
 			Domain: "c.a.b.com",
-			Result: []uint32{4},
+			Result: []uint32{5, 4},
 		},
 		{
 			Domain: "c.a..b.com",