Browse Source

Amending domain matcher with returning array of all matches

Vigilans 5 years ago
parent
commit
c74a33f827

+ 6 - 3
app/dns/hosts.go

@@ -106,11 +106,14 @@ func filterIP(ips []net.Address, option IPOption) []net.Address {
 
 // LookupIP returns IP address for the given domain, if exists in this StaticHosts.
 func (h *StaticHosts) LookupIP(domain string, option IPOption) []net.Address {
-	id := h.matchers.Match(domain)
-	if id == 0 {
+	indices := h.matchers.Match(domain)
+	if len(indices) == 0 {
 		return nil
 	}
-	ips := h.ips[id]
+	ips := []net.Address{}
+	for _, id := range indices {
+		ips = append(ips, h.ips[id]...)
+	}
 	if len(ips) == 1 && ips[0].Family().IsDomain() {
 		return ips
 	}

+ 2 - 2
app/dns/server.go

@@ -330,8 +330,8 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 	var lastErr error
 	var matchedClient Client
 	if s.domainMatcher != nil {
-		idx := s.domainMatcher.Match(domain)
-		if idx > 0 {
+		indices := s.domainMatcher.Match(domain)
+		for _, idx := range indices {
 			matchedClient = s.clients[s.domainIndexMap[idx]]
 			ips, err := s.queryIPTimeout(s.domainIndexMap[idx], matchedClient, domain, option)
 			if len(ips) > 0 {

+ 164 - 0
app/dns/server_test.go

@@ -50,6 +50,9 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
 				rr, _ := dns.NewRR("google.com. IN A 8.8.4.4")
 				ans.Answer = append(ans.Answer, rr)
 			}
+		} else if q.Name == "api.google.com." && q.Qtype == dns.TypeA {
+			rr, _ := dns.NewRR("api.google.com. IN A 8.8.7.7")
+			ans.Answer = append(ans.Answer, rr)
 		} else if q.Name == "facebook.com." && q.Qtype == dns.TypeA {
 			rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9")
 			ans.Answer = append(ans.Answer, rr)
@@ -754,3 +757,164 @@ func TestLocalDomain(t *testing.T) {
 		t.Error("DNS query doesn't finish in 2 seconds.")
 	}
 }
+
+func TestMultiMatchPrioritizedDomain(t *testing.T) {
+	port := udp.PickPort()
+
+	dnsServer := dns.Server{
+		Addr:    "127.0.0.1:" + port.String(),
+		Net:     "udp",
+		Handler: &staticHandler{},
+		UDPSize: 1200,
+	}
+
+	go dnsServer.ListenAndServe()
+	time.Sleep(time.Second)
+
+	config := &core.Config{
+		App: []*serial.TypedMessage{
+			serial.ToTypedMessage(&Config{
+				NameServers: []*net.Endpoint{
+					{
+						Network: net.Network_UDP,
+						Address: &net.IPOrDomain{
+							Address: &net.IPOrDomain_Ip{
+								Ip: []byte{127, 0, 0, 1},
+							},
+						},
+						Port: 9999, /* unreachable */
+					},
+				},
+				NameServer: []*NameServer{
+					{
+						Address: &net.Endpoint{
+							Network: net.Network_UDP,
+							Address: &net.IPOrDomain{
+								Address: &net.IPOrDomain_Ip{
+									Ip: []byte{127, 0, 0, 1},
+								},
+							},
+							Port: uint32(port),
+						},
+						PrioritizedDomain: []*NameServer_PriorityDomain{
+							{
+								Type:   DomainMatchingType_Subdomain,
+								Domain: "google.com",
+							},
+						},
+						Geoip: []*router.GeoIP{
+							{ // Will only match 8.8.8.8 and 8.8.4.4
+								Cidr: []*router.CIDR{
+									{Ip: []byte{8, 8, 8, 8}, Prefix: 32},
+									{Ip: []byte{8, 8, 4, 4}, Prefix: 32},
+								},
+							},
+						},
+					},
+					{
+						Address: &net.Endpoint{
+							Network: net.Network_UDP,
+							Address: &net.IPOrDomain{
+								Address: &net.IPOrDomain_Ip{
+									Ip: []byte{127, 0, 0, 1},
+								},
+							},
+							Port: uint32(port),
+						},
+						PrioritizedDomain: []*NameServer_PriorityDomain{
+							{
+								Type:   DomainMatchingType_Subdomain,
+								Domain: "google.com",
+							},
+						},
+						Geoip: []*router.GeoIP{
+							{ // Will match 8.8.8.8 and 8.8.8.7, etc
+								Cidr: []*router.CIDR{
+									{Ip: []byte{8, 8, 8, 7}, Prefix: 24},
+								},
+							},
+						},
+					},
+					{
+						Address: &net.Endpoint{
+							Network: net.Network_UDP,
+							Address: &net.IPOrDomain{
+								Address: &net.IPOrDomain_Ip{
+									Ip: []byte{127, 0, 0, 1},
+								},
+							},
+							Port: uint32(port),
+						},
+						PrioritizedDomain: []*NameServer_PriorityDomain{
+							{
+								Type:   DomainMatchingType_Full,
+								Domain: "api.google.com",
+							},
+						},
+						Geoip: []*router.GeoIP{
+							{ // Will only match 8.8.7.7 (api.google.com)
+								Cidr: []*router.CIDR{
+									{Ip: []byte{8, 8, 7, 7}, Prefix: 0},
+								},
+							},
+						},
+					},
+				},
+			}),
+			serial.ToTypedMessage(&dispatcher.Config{}),
+			serial.ToTypedMessage(&proxyman.OutboundConfig{}),
+			serial.ToTypedMessage(&policy.Config{}),
+		},
+		Outbound: []*core.OutboundHandlerConfig{
+			{
+				ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
+			},
+		},
+	}
+
+	v, err := core.New(config)
+	common.Must(err)
+
+	client := v.GetFeature(feature_dns.ClientType()).(feature_dns.Client)
+
+	startTime := time.Now()
+
+	{ // Will match server 1,2 and server 1 returns expected ip
+		ips, err := client.LookupIP("google.com")
+		if err != nil {
+			t.Fatal("unexpected error: ", err)
+		}
+
+		if r := cmp.Diff(ips, []net.IP{{8, 8, 8, 8}}); r != "" {
+			t.Fatal(r)
+		}
+	}
+
+	{ // Will match server 1,2 and server 1 returns unexpected ip, then server 2 returns expected one
+		clientv4 := client.(feature_dns.IPv4Lookup)
+		ips, err := clientv4.LookupIPv4("ipv6.google.com")
+		if err != nil {
+			t.Fatal("unexpected error: ", err)
+		}
+
+		if r := cmp.Diff(ips, []net.IP{{8, 8, 8, 7}}); r != "" {
+			t.Fatal(r)
+		}
+	}
+
+	{ // Will match server 1,2,3 and server 1,2 returns unexpected ip, then server 3 returns expected one
+		ips, err := client.LookupIP("api.google.com")
+		if err != nil {
+			t.Fatal("unexpected error: ", err)
+		}
+
+		if r := cmp.Diff(ips, []net.IP{{8, 8, 7, 7}}); r != "" {
+			t.Fatal(r)
+		}
+	}
+
+	endTime := time.Now()
+	if startTime.After(endTime.Add(time.Second * 2)) {
+		t.Error("DNS query doesn't finish in 2 seconds.")
+	}
+}

+ 1 - 1
app/router/condition.go

@@ -82,7 +82,7 @@ func NewDomainMatcher(domains []*Domain) (*DomainMatcher, error) {
 }
 
 func (m *DomainMatcher) ApplyDomain(domain string) bool {
-	return m.matchers.Match(domain) > 0
+	return len(m.matchers.Match(domain)) > 0
 }
 
 func (m *DomainMatcher) Apply(ctx *Context) bool {

+ 8 - 8
common/strmatcher/domain_matcher.go

@@ -7,8 +7,8 @@ func breakDomain(domain string) []string {
 }
 
 type node struct {
-	value uint32
-	sub   map[string]*node
+	values []uint32
+	sub    map[string]*node
 }
 
 // DomainMatcherGroup is a IndexMatcher for a large set of Domain matchers.
@@ -25,7 +25,7 @@ func (g *DomainMatcherGroup) Add(domain string, value uint32) {
 	current := g.root
 	parts := breakDomain(domain)
 	for i := len(parts) - 1; i >= 0; i-- {
-		if current.value > 0 {
+		if len(current.values) > 0 {
 			// if current node is already a match, it is not necessary to match further.
 			return
 		}
@@ -42,7 +42,7 @@ func (g *DomainMatcherGroup) Add(domain string, value uint32) {
 		current = next
 	}
 
-	current.value = value
+	current.values = append(current.values, value)
 	current.sub = nil // shortcut sub nodes as current node is a match.
 }
 
@@ -50,14 +50,14 @@ func (g *DomainMatcherGroup) addMatcher(m domainMatcher, value uint32) {
 	g.Add(string(m), value)
 }
 
-func (g *DomainMatcherGroup) Match(domain string) uint32 {
+func (g *DomainMatcherGroup) Match(domain string) []uint32 {
 	if domain == "" {
-		return 0
+		return nil
 	}
 
 	current := g.root
 	if current == nil {
-		return 0
+		return nil
 	}
 
 	nextPart := func(idx int) int {
@@ -84,5 +84,5 @@ func (g *DomainMatcherGroup) Match(domain string) uint32 {
 		current = next
 		idx = nidx
 	}
-	return current.value
+	return current.values
 }

+ 19 - 12
common/strmatcher/domain_matcher_test.go

@@ -1,6 +1,7 @@
 package strmatcher_test
 
 import (
+	"reflect"
 	"testing"
 
 	. "v2ray.com/core/common/strmatcher"
@@ -13,48 +14,54 @@ func TestDomainMatcherGroup(t *testing.T) {
 	g.Add("x.a.com", 3)
 	g.Add("a.b.com", 4)
 	g.Add("c.a.b.com", 5)
+	g.Add("x.y.com", 4)
+	g.Add("x.y.com", 6)
 
 	testCases := []struct {
 		Domain string
-		Result uint32
+		Result []uint32
 	}{
 		{
 			Domain: "x.v2ray.com",
-			Result: 1,
+			Result: []uint32{1},
 		},
 		{
 			Domain: "y.com",
-			Result: 0,
+			Result: nil,
 		},
 		{
 			Domain: "a.b.com",
-			Result: 4,
+			Result: []uint32{4},
 		},
 		{
 			Domain: "c.a.b.com",
-			Result: 4,
+			Result: []uint32{4},
 		},
 		{
 			Domain: "c.a..b.com",
-			Result: 0,
+			Result: nil,
 		},
 		{
 			Domain: ".com",
-			Result: 0,
+			Result: nil,
 		},
 		{
 			Domain: "com",
-			Result: 0,
+			Result: nil,
 		},
 		{
 			Domain: "",
-			Result: 0,
+			Result: nil,
+		},
+		{
+			Domain: "x.y.com",
+			Result: []uint32{4, 6},
 		},
 	}
 
 	for _, testCase := range testCases {
 		r := g.Match(testCase.Domain)
-		if r != testCase.Result {
+		if !reflect.DeepEqual(r, testCase.Result) {
 			t.Error("Failed to match domain: ", testCase.Domain, ", expect ", testCase.Result, ", but got ", r)
 		}
 	}
@@ -63,7 +70,7 @@ func TestDomainMatcherGroup(t *testing.T) {
 func TestEmptyDomainMatcherGroup(t *testing.T) {
 	g := new(DomainMatcherGroup)
 	r := g.Match("v2ray.com")
-	if r != 0 {
-		t.Error("Expect 0, but ", r)
+	if len(r) != 0 {
+		t.Error("Expect [], but ", r)
 	}
 }

+ 5 - 5
common/strmatcher/full_matcher.go

@@ -1,24 +1,24 @@
 package strmatcher
 
 type FullMatcherGroup struct {
-	matchers map[string]uint32
+	matchers map[string][]uint32
 }
 
 func (g *FullMatcherGroup) Add(domain string, value uint32) {
 	if g.matchers == nil {
-		g.matchers = make(map[string]uint32)
+		g.matchers = make(map[string][]uint32)
 	}
 
-	g.matchers[domain] = value
+	g.matchers[domain] = append(g.matchers[domain], value)
 }
 
 func (g *FullMatcherGroup) addMatcher(m fullMatcher, value uint32) {
 	g.Add(string(m), value)
 }
 
-func (g *FullMatcherGroup) Match(str string) uint32 {
+func (g *FullMatcherGroup) Match(str string) []uint32 {
 	if g.matchers == nil {
-		return 0
+		return nil
 	}
 
 	return g.matchers[str]

+ 13 - 6
common/strmatcher/full_matcher_test.go

@@ -1,6 +1,7 @@
 package strmatcher_test
 
 import (
+	"reflect"
 	"testing"
 
 	. "v2ray.com/core/common/strmatcher"
@@ -11,24 +12,30 @@ func TestFullMatcherGroup(t *testing.T) {
 	g.Add("v2ray.com", 1)
 	g.Add("google.com", 2)
 	g.Add("x.a.com", 3)
+	g.Add("x.y.com", 4)
+	g.Add("x.y.com", 6)
 
 	testCases := []struct {
 		Domain string
-		Result uint32
+		Result []uint32
 	}{
 		{
 			Domain: "v2ray.com",
-			Result: 1,
+			Result: []uint32{1},
 		},
 		{
 			Domain: "y.com",
-			Result: 0,
+			Result: nil,
+		},
+		{
+			Domain: "x.y.com",
+			Result: []uint32{4, 6},
 		},
 	}
 
 	for _, testCase := range testCases {
 		r := g.Match(testCase.Domain)
-		if r != testCase.Result {
+		if !reflect.DeepEqual(r, testCase.Result) {
 			t.Error("Failed to match domain: ", testCase.Domain, ", expect ", testCase.Result, ", but got ", r)
 		}
 	}
@@ -37,7 +44,7 @@ func TestFullMatcherGroup(t *testing.T) {
 func TestEmptyFullMatcherGroup(t *testing.T) {
 	g := new(FullMatcherGroup)
 	r := g.Match("v2ray.com")
-	if r != 0 {
-		t.Error("Expect 0, but ", r)
+	if len(r) != 0 {
+		t.Error("Expect [], but ", r)
 	}
 }

+ 7 - 13
common/strmatcher/strmatcher.go

@@ -49,7 +49,7 @@ func (t Type) New(pattern string) (Matcher, error) {
 // IndexMatcher is the interface for matching with a group of matchers.
 type IndexMatcher interface {
 	// Match returns the the index of a matcher that matches the input. It returns 0 if no such matcher exists.
-	Match(input string) uint32
+	Match(input string) []uint32
 }
 
 type matcherEntry struct {
@@ -87,22 +87,16 @@ func (g *MatcherGroup) Add(m Matcher) uint32 {
 }
 
 // Match implements IndexMatcher.Match.
-func (g *MatcherGroup) Match(pattern string) uint32 {
-	if c := g.fullMatcher.Match(pattern); c > 0 {
-		return c
-	}
-
-	if c := g.domainMatcher.Match(pattern); c > 0 {
-		return c
-	}
-
+func (g *MatcherGroup) Match(pattern string) []uint32 {
+	result := []uint32{}
+	result = append(result, g.fullMatcher.Match(pattern)...)
+	result = append(result, g.domainMatcher.Match(pattern)...)
 	for _, e := range g.otherMatchers {
 		if e.m.Match(pattern) {
-			return e.id
+			result = append(result, e.id)
 		}
 	}
-
-	return 0
+	return result
 }
 
 // Size returns the number of matchers in the MatcherGroup.