Browse Source

only add to ipIndexMap if GeoIP is configured

clearer logging for expectIPs

refactor dns init code

optimal to default port logic

clear message exit if doh met error
vcptr 6 years ago
parent
commit
ceb77ac8f5
5 changed files with 93 additions and 66 deletions
  1. 60 10
      app/dns/dohdns.go
  2. 1 0
      app/dns/nameserver.go
  3. 25 55
      app/dns/server.go
  4. 7 0
      app/dns/udpns.go
  5. 0 1
      infra/conf/dns.go

+ 60 - 10
app/dns/dohdns.go

@@ -40,10 +40,41 @@ type DoHNameServer struct {
 	name       string
 	name       string
 }
 }
 
 
-func NewDoHNameServer(dests []net.Destination, dohHost string, dispatcher routing.Dispatcher, clientIP net.IP) *DoHNameServer {
+func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) {
 
 
-	s := NewDoHLocalNameServer(dohHost, clientIP)
-	s.name = "DOH:" + dohHost
+	dohAddr := net.ParseAddress(dohHost)
+	var dests []net.Destination
+
+	if dohPort == 0 {
+		dohPort = 443
+	}
+
+	parseIPDest := func(ip net.IP, port uint32) net.Destination {
+		strIP := ip.String()
+		if len(ip) == net.IPv6len {
+			strIP = fmt.Sprintf("[%s]", strIP)
+		}
+		dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:%d", strIP, port))
+		common.Must(err)
+		return dest
+	}
+
+	if dohAddr.Family().IsDomain() {
+		// resolve DOH server in advance
+		ips, err := net.LookupIP(dohAddr.Domain())
+		if err != nil || len(ips) == 0 {
+			return nil, err
+		}
+		for _, ip := range ips {
+			dests = append(dests, parseIPDest(ip, dohPort))
+		}
+	} else {
+		ip := dohAddr.IP()
+		dests = append(dests, parseIPDest(ip, dohPort))
+	}
+
+	newError("DNS: created remote DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog()
+	s := baseDOHNameServer(dohHost, dohPort, "DOH", clientIP)
 	s.dispatcher = dispatcher
 	s.dispatcher = dispatcher
 	s.dohDests = dests
 	s.dohDests = dests
 
 
@@ -66,22 +97,41 @@ func NewDoHNameServer(dests []net.Destination, dohHost string, dispatcher routin
 	}
 	}
 
 
 	s.httpClient = dispatchedClient
 	s.httpClient = dispatchedClient
+	return s, nil
+}
+
+func NewDoHLocalNameServer(dohHost string, dohPort uint32, clientIP net.IP) *DoHNameServer {
+
+	if dohPort == 0 {
+		dohPort = 443
+	}
+
+	s := baseDOHNameServer(dohHost, dohPort, "DOHL", clientIP)
+	s.httpClient = &http.Client{
+		Timeout: time.Second * 180,
+	}
+	newError("DNS: created local DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog()
 	return s
 	return s
 }
 }
 
 
-func NewDoHLocalNameServer(dohHost string, clientIP net.IP) *DoHNameServer {
+func baseDOHNameServer(dohHost string, dohPort uint32, prefix string, clientIP net.IP) *DoHNameServer {
+
+	if dohPort == 0 {
+		dohPort = 443
+	}
+
 	s := &DoHNameServer{
 	s := &DoHNameServer{
-		httpClient: http.DefaultClient,
-		ips:        make(map[string]record),
-		clientIP:   clientIP,
-		pub:        pubsub.NewService(),
-		name:       "DOHL:" + dohHost,
-		dohURL:     fmt.Sprintf("https://%s/dns-query", dohHost),
+		ips:      make(map[string]record),
+		clientIP: clientIP,
+		pub:      pubsub.NewService(),
+		name:     fmt.Sprintf("%s:%s:%d", prefix, dohHost, dohPort),
+		dohURL:   fmt.Sprintf("https://%s:%d/dns-query", dohHost, dohPort),
 	}
 	}
 	s.cleanup = &task.Periodic{
 	s.cleanup = &task.Periodic{
 		Interval: time.Minute,
 		Interval: time.Minute,
 		Execute:  s.Cleanup,
 		Execute:  s.Cleanup,
 	}
 	}
+
 	return s
 	return s
 }
 }
 
 

+ 1 - 0
app/dns/nameserver.go

@@ -49,6 +49,7 @@ func (s *localNameServer) Name() string {
 }
 }
 
 
 func NewLocalNameServer() *localNameServer {
 func NewLocalNameServer() *localNameServer {
+	newError("DNS: created localhost client").AtInfo().WriteToLog()
 	return &localNameServer{
 	return &localNameServer{
 		client: localdns.New(),
 		client: localdns.New(),
 	}
 	}

+ 25 - 55
app/dns/server.go

@@ -6,7 +6,7 @@ package dns
 
 
 import (
 import (
 	"context"
 	"context"
-	"fmt"
+	"log"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
@@ -41,7 +41,7 @@ type MultiGeoIPMatcher struct {
 	matchers []*router.GeoIPMatcher
 	matchers []*router.GeoIPMatcher
 }
 }
 
 
-var errExpectedIPNonMatch = errors.New("expected ip not match")
+var errExpectedIPNonMatch = errors.New("expectIPs not match")
 
 
 // Match check ip match
 // Match check ip match
 func (c *MultiGeoIPMatcher) Match(ip net.IP) bool {
 func (c *MultiGeoIPMatcher) Match(ip net.IP) bool {
@@ -73,7 +73,7 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 		server.tag = generateRandomTag()
 		server.tag = generateRandomTag()
 	}
 	}
 	if len(config.ClientIp) > 0 {
 	if len(config.ClientIp) > 0 {
-		if len(config.ClientIp) != 4 && len(config.ClientIp) != 16 {
+		if len(config.ClientIp) != net.IPv4len && len(config.ClientIp) != net.IPv6len {
 			return nil, newError("unexpected IP length", len(config.ClientIp))
 			return nil, newError("unexpected IP length", len(config.ClientIp))
 		}
 		}
 		server.clientIP = net.IP(config.ClientIp)
 		server.clientIP = net.IP(config.ClientIp)
@@ -89,48 +89,22 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 		address := endpoint.Address.AsAddress()
 		address := endpoint.Address.AsAddress()
 		if address.Family().IsDomain() && address.Domain() == "localhost" {
 		if address.Family().IsDomain() && address.Domain() == "localhost" {
 			server.clients = append(server.clients, NewLocalNameServer())
 			server.clients = append(server.clients, NewLocalNameServer())
-			newError("DNS: localhost inited").AtInfo().WriteToLog()
 		} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOHL_") {
 		} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOHL_") {
 			dohHost := address.Domain()[5:]
 			dohHost := address.Domain()[5:]
-			server.clients = append(server.clients, NewDoHLocalNameServer(dohHost, server.clientIP))
-			newError("DNS: DOH - Local inited for https://", dohHost).AtInfo().WriteToLog()
+			server.clients = append(server.clients, NewDoHLocalNameServer(dohHost, endpoint.Port, server.clientIP))
 		} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOH_") {
 		} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOH_") {
 			// DOH_ prefix makes net.Address think it's a domain
 			// DOH_ prefix makes net.Address think it's a domain
-			// need to process the real address here.
 			dohHost := address.Domain()[4:]
 			dohHost := address.Domain()[4:]
-			dohAddr := net.ParseAddress(dohHost)
-			dohIP := dohHost
-			var dests []net.Destination
-
-			if dohAddr.Family().IsDomain() {
-				// resolve DOH server in advance
-				ips, err := net.LookupIP(dohAddr.Domain())
-				if err != nil || len(ips) == 0 {
-					return 0
-				}
-				for _, ip := range ips {
-					dohIP := ip.String()
-					if len(ip) == net.IPv6len {
-						dohIP = fmt.Sprintf("[%s]", dohIP)
-					}
-					dohdest, _ := net.ParseDestination(fmt.Sprintf("tcp:%s:443", dohIP))
-					dests = append(dests, dohdest)
-				}
-			} else {
-				// rfc8484, DOH service only use port 443
-				dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:443", dohIP))
-				if err != nil {
-					return 0
-				}
-				dests = []net.Destination{dest}
-			}
-
-			// need the core dispatcher, register DOHClient at callback
 			idx := len(server.clients)
 			idx := len(server.clients)
 			server.clients = append(server.clients, nil)
 			server.clients = append(server.clients, nil)
+
+			// need the core dispatcher, register DOHClient at callback
 			common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) {
 			common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) {
-				server.clients[idx] = NewDoHNameServer(dests, dohHost, d, server.clientIP)
-				newError("DNS: DOH - Remote client inited for https://", dohHost).AtInfo().WriteToLog()
+				c, err := NewDoHNameServer(dohHost, endpoint.Port, d, server.clientIP)
+				if err != nil {
+					log.Fatalln(newError("DNS config error").Base(err))
+				}
+				server.clients[idx] = c
 			}))
 			}))
 		} else {
 		} else {
 			dest := endpoint.AsDestination()
 			dest := endpoint.AsDestination()
@@ -145,7 +119,6 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 					server.clients[idx] = NewClassicNameServer(dest, d, server.clientIP)
 					server.clients[idx] = NewClassicNameServer(dest, d, server.clientIP)
 				}))
 				}))
 			}
 			}
-			newError("DNS: UDP client inited for ", dest.NetAddr()).AtInfo().WriteToLog()
 		}
 		}
 		return len(server.clients) - 1
 		return len(server.clients) - 1
 	}
 	}
@@ -175,16 +148,19 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 				domainIndexMap[midx] = uint32(idx)
 				domainIndexMap[midx] = uint32(idx)
 			}
 			}
 
 
-			var matchers []*router.GeoIPMatcher
-			for _, geoip := range ns.Geoip {
-				matcher, err := geoIPMatcherContainer.Add(geoip)
-				if err != nil {
-					return nil, newError("failed to create ip matcher").Base(err).AtWarning()
+			// only add to ipIndexMap if GeoIP is configured
+			if len(ns.Geoip) > 0 {
+				var matchers []*router.GeoIPMatcher
+				for _, geoip := range ns.Geoip {
+					matcher, err := geoIPMatcherContainer.Add(geoip)
+					if err != nil {
+						return nil, newError("failed to create ip matcher").Base(err).AtWarning()
+					}
+					matchers = append(matchers, matcher)
 				}
 				}
-				matchers = append(matchers, matcher)
+				matcher := &MultiGeoIPMatcher{matchers: matchers}
+				ipIndexMap[uint32(idx)] = matcher
 			}
 			}
-			matcher := &MultiGeoIPMatcher{matchers: matchers}
-			ipIndexMap[uint32(idx)] = matcher
 		}
 		}
 
 
 		server.domainMatcher = domainMatcher
 		server.domainMatcher = domainMatcher
@@ -223,12 +199,11 @@ func (s *Server) IsOwnLink(ctx context.Context) bool {
 func (s *Server) Match(idx uint32, client Client, domain string, ips []net.IP) ([]net.IP, error) {
 func (s *Server) Match(idx uint32, client Client, domain string, ips []net.IP) ([]net.IP, error) {
 	matcher, exist := s.ipIndexMap[idx]
 	matcher, exist := s.ipIndexMap[idx]
 	if !exist {
 	if !exist {
-		newError("domain ", domain, " server not in ipIndexMap: ", client.Name(), " idx:", idx, " just return").AtDebug().WriteToLog()
 		return ips, nil
 		return ips, nil
 	}
 	}
 
 
 	if !matcher.HasMatcher() {
 	if !matcher.HasMatcher() {
-		newError("domain ", domain, " server has not valid matcher: ", client.Name(), " idx:", idx, " just return").AtDebug().WriteToLog()
+		newError("domain ", domain, " server has no valid matcher: ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 		return ips, nil
 		return ips, nil
 	}
 	}
 
 
@@ -236,14 +211,12 @@ func (s *Server) Match(idx uint32, client Client, domain string, ips []net.IP) (
 	for _, ip := range ips {
 	for _, ip := range ips {
 		if matcher.Match(ip) {
 		if matcher.Match(ip) {
 			newIps = append(newIps, ip)
 			newIps = append(newIps, ip)
-			newError("domain ", domain, " ip ", ip, " is match at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
-		} else {
-			newError("domain ", domain, " ip ", ip, " is not match at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 		}
 		}
 	}
 	}
 	if len(newIps) == 0 {
 	if len(newIps) == 0 {
 		return nil, errExpectedIPNonMatch
 		return nil, errExpectedIPNonMatch
 	}
 	}
+	newError("domain ", domain, " expectIPs ", newIps, " matched at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 	return newIps, nil
 	return newIps, nil
 }
 }
 
 
@@ -325,7 +298,7 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 
 
 	// skip domain without any dot
 	// skip domain without any dot
 	if strings.Index(domain, ".") == -1 {
 	if strings.Index(domain, ".") == -1 {
-		return nil, newError("invalid domain name")
+		return nil, newError("invalid domain name").AtWarning()
 	}
 	}
 
 
 	ips := s.lookupStatic(domain, option, 0)
 	ips := s.lookupStatic(domain, option, 0)
@@ -346,7 +319,6 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 		idx := s.domainMatcher.Match(domain)
 		idx := s.domainMatcher.Match(domain)
 		if idx > 0 {
 		if idx > 0 {
 			matchedClient = s.clients[s.domainIndexMap[idx]]
 			matchedClient = s.clients[s.domainIndexMap[idx]]
-			newError("domain matched, direct lookup ip for domain ", domain, " at ", matchedClient.Name()).WriteToLog()
 			ips, err := s.queryIPTimeout(s.domainIndexMap[idx], matchedClient, domain, option)
 			ips, err := s.queryIPTimeout(s.domainIndexMap[idx], matchedClient, domain, option)
 			if len(ips) > 0 {
 			if len(ips) > 0 {
 				return ips, nil
 				return ips, nil
@@ -367,10 +339,8 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 			continue
 			continue
 		}
 		}
 
 
-		newError("try to lookup ip for domain ", domain, " at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 		ips, err := s.queryIPTimeout(uint32(idx), client, domain, option)
 		ips, err := s.queryIPTimeout(uint32(idx), client, domain, option)
 		if len(ips) > 0 {
 		if len(ips) > 0 {
-			newError("lookup ip for domain ", domain, " success: ", ips, " at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 			return ips, nil
 			return ips, nil
 		}
 		}
 
 

+ 7 - 0
app/dns/udpns.go

@@ -36,6 +36,12 @@ type ClassicNameServer struct {
 }
 }
 
 
 func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
 func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
+
+	// default to 53 if unspecific
+	if address.Port == 0 {
+		address.Port = net.Port(53)
+	}
+
 	s := &ClassicNameServer{
 	s := &ClassicNameServer{
 		address:  address,
 		address:  address,
 		ips:      make(map[string]record),
 		ips:      make(map[string]record),
@@ -49,6 +55,7 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
 		Execute:  s.Cleanup,
 		Execute:  s.Cleanup,
 	}
 	}
 	s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
 	s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
+	newError("DNS: created udp client inited for ", address.NetAddr()).AtInfo().WriteToLog()
 	return s
 	return s
 }
 }
 
 

+ 0 - 1
infra/conf/dns.go

@@ -21,7 +21,6 @@ func (c *NameServerConfig) UnmarshalJSON(data []byte) error {
 	var address Address
 	var address Address
 	if err := json.Unmarshal(data, &address); err == nil {
 	if err := json.Unmarshal(data, &address); err == nil {
 		c.Address = &address
 		c.Address = &address
-		c.Port = 53
 		return nil
 		return nil
 	}
 	}