Browse Source

only add to ipIndexMap if GeoIP is configured

clearer logging for expectIPs

refactor dns init code

optimal to default port logic

clear message exit if doh met error
vcptr 6 years ago
parent
commit
ceb77ac8f5
5 changed files with 93 additions and 66 deletions
  1. 60 10
      app/dns/dohdns.go
  2. 1 0
      app/dns/nameserver.go
  3. 25 55
      app/dns/server.go
  4. 7 0
      app/dns/udpns.go
  5. 0 1
      infra/conf/dns.go

+ 60 - 10
app/dns/dohdns.go

@@ -40,10 +40,41 @@ type DoHNameServer struct {
 	name       string
 }
 
-func NewDoHNameServer(dests []net.Destination, dohHost string, dispatcher routing.Dispatcher, clientIP net.IP) *DoHNameServer {
+func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) {
 
-	s := NewDoHLocalNameServer(dohHost, clientIP)
-	s.name = "DOH:" + dohHost
+	dohAddr := net.ParseAddress(dohHost)
+	var dests []net.Destination
+
+	if dohPort == 0 {
+		dohPort = 443
+	}
+
+	parseIPDest := func(ip net.IP, port uint32) net.Destination {
+		strIP := ip.String()
+		if len(ip) == net.IPv6len {
+			strIP = fmt.Sprintf("[%s]", strIP)
+		}
+		dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:%d", strIP, port))
+		common.Must(err)
+		return dest
+	}
+
+	if dohAddr.Family().IsDomain() {
+		// resolve DOH server in advance
+		ips, err := net.LookupIP(dohAddr.Domain())
+		if err != nil || len(ips) == 0 {
+			return nil, err
+		}
+		for _, ip := range ips {
+			dests = append(dests, parseIPDest(ip, dohPort))
+		}
+	} else {
+		ip := dohAddr.IP()
+		dests = append(dests, parseIPDest(ip, dohPort))
+	}
+
+	newError("DNS: created remote DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog()
+	s := baseDOHNameServer(dohHost, dohPort, "DOH", clientIP)
 	s.dispatcher = dispatcher
 	s.dohDests = dests
 
@@ -66,22 +97,41 @@ func NewDoHNameServer(dests []net.Destination, dohHost string, dispatcher routin
 	}
 
 	s.httpClient = dispatchedClient
+	return s, nil
+}
+
+func NewDoHLocalNameServer(dohHost string, dohPort uint32, clientIP net.IP) *DoHNameServer {
+
+	if dohPort == 0 {
+		dohPort = 443
+	}
+
+	s := baseDOHNameServer(dohHost, dohPort, "DOHL", clientIP)
+	s.httpClient = &http.Client{
+		Timeout: time.Second * 180,
+	}
+	newError("DNS: created local DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog()
 	return s
 }
 
-func NewDoHLocalNameServer(dohHost string, clientIP net.IP) *DoHNameServer {
+func baseDOHNameServer(dohHost string, dohPort uint32, prefix string, clientIP net.IP) *DoHNameServer {
+
+	if dohPort == 0 {
+		dohPort = 443
+	}
+
 	s := &DoHNameServer{
-		httpClient: http.DefaultClient,
-		ips:        make(map[string]record),
-		clientIP:   clientIP,
-		pub:        pubsub.NewService(),
-		name:       "DOHL:" + dohHost,
-		dohURL:     fmt.Sprintf("https://%s/dns-query", dohHost),
+		ips:      make(map[string]record),
+		clientIP: clientIP,
+		pub:      pubsub.NewService(),
+		name:     fmt.Sprintf("%s:%s:%d", prefix, dohHost, dohPort),
+		dohURL:   fmt.Sprintf("https://%s:%d/dns-query", dohHost, dohPort),
 	}
 	s.cleanup = &task.Periodic{
 		Interval: time.Minute,
 		Execute:  s.Cleanup,
 	}
+
 	return s
 }
 

+ 1 - 0
app/dns/nameserver.go

@@ -49,6 +49,7 @@ func (s *localNameServer) Name() string {
 }
 
 func NewLocalNameServer() *localNameServer {
+	newError("DNS: created localhost client").AtInfo().WriteToLog()
 	return &localNameServer{
 		client: localdns.New(),
 	}

+ 25 - 55
app/dns/server.go

@@ -6,7 +6,7 @@ package dns
 
 import (
 	"context"
-	"fmt"
+	"log"
 	"strings"
 	"sync"
 	"time"
@@ -41,7 +41,7 @@ type MultiGeoIPMatcher struct {
 	matchers []*router.GeoIPMatcher
 }
 
-var errExpectedIPNonMatch = errors.New("expected ip not match")
+var errExpectedIPNonMatch = errors.New("expectIPs not match")
 
 // Match check ip match
 func (c *MultiGeoIPMatcher) Match(ip net.IP) bool {
@@ -73,7 +73,7 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 		server.tag = generateRandomTag()
 	}
 	if len(config.ClientIp) > 0 {
-		if len(config.ClientIp) != 4 && len(config.ClientIp) != 16 {
+		if len(config.ClientIp) != net.IPv4len && len(config.ClientIp) != net.IPv6len {
 			return nil, newError("unexpected IP length", len(config.ClientIp))
 		}
 		server.clientIP = net.IP(config.ClientIp)
@@ -89,48 +89,22 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 		address := endpoint.Address.AsAddress()
 		if address.Family().IsDomain() && address.Domain() == "localhost" {
 			server.clients = append(server.clients, NewLocalNameServer())
-			newError("DNS: localhost inited").AtInfo().WriteToLog()
 		} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOHL_") {
 			dohHost := address.Domain()[5:]
-			server.clients = append(server.clients, NewDoHLocalNameServer(dohHost, server.clientIP))
-			newError("DNS: DOH - Local inited for https://", dohHost).AtInfo().WriteToLog()
+			server.clients = append(server.clients, NewDoHLocalNameServer(dohHost, endpoint.Port, server.clientIP))
 		} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOH_") {
 			// DOH_ prefix makes net.Address think it's a domain
-			// need to process the real address here.
 			dohHost := address.Domain()[4:]
-			dohAddr := net.ParseAddress(dohHost)
-			dohIP := dohHost
-			var dests []net.Destination
-
-			if dohAddr.Family().IsDomain() {
-				// resolve DOH server in advance
-				ips, err := net.LookupIP(dohAddr.Domain())
-				if err != nil || len(ips) == 0 {
-					return 0
-				}
-				for _, ip := range ips {
-					dohIP := ip.String()
-					if len(ip) == net.IPv6len {
-						dohIP = fmt.Sprintf("[%s]", dohIP)
-					}
-					dohdest, _ := net.ParseDestination(fmt.Sprintf("tcp:%s:443", dohIP))
-					dests = append(dests, dohdest)
-				}
-			} else {
-				// rfc8484, DOH service only use port 443
-				dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:443", dohIP))
-				if err != nil {
-					return 0
-				}
-				dests = []net.Destination{dest}
-			}
-
-			// need the core dispatcher, register DOHClient at callback
 			idx := len(server.clients)
 			server.clients = append(server.clients, nil)
+
+			// need the core dispatcher, register DOHClient at callback
 			common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) {
-				server.clients[idx] = NewDoHNameServer(dests, dohHost, d, server.clientIP)
-				newError("DNS: DOH - Remote client inited for https://", dohHost).AtInfo().WriteToLog()
+				c, err := NewDoHNameServer(dohHost, endpoint.Port, d, server.clientIP)
+				if err != nil {
+					log.Fatalln(newError("DNS config error").Base(err))
+				}
+				server.clients[idx] = c
 			}))
 		} else {
 			dest := endpoint.AsDestination()
@@ -145,7 +119,6 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 					server.clients[idx] = NewClassicNameServer(dest, d, server.clientIP)
 				}))
 			}
-			newError("DNS: UDP client inited for ", dest.NetAddr()).AtInfo().WriteToLog()
 		}
 		return len(server.clients) - 1
 	}
@@ -175,16 +148,19 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 				domainIndexMap[midx] = uint32(idx)
 			}
 
-			var matchers []*router.GeoIPMatcher
-			for _, geoip := range ns.Geoip {
-				matcher, err := geoIPMatcherContainer.Add(geoip)
-				if err != nil {
-					return nil, newError("failed to create ip matcher").Base(err).AtWarning()
+			// only add to ipIndexMap if GeoIP is configured
+			if len(ns.Geoip) > 0 {
+				var matchers []*router.GeoIPMatcher
+				for _, geoip := range ns.Geoip {
+					matcher, err := geoIPMatcherContainer.Add(geoip)
+					if err != nil {
+						return nil, newError("failed to create ip matcher").Base(err).AtWarning()
+					}
+					matchers = append(matchers, matcher)
 				}
-				matchers = append(matchers, matcher)
+				matcher := &MultiGeoIPMatcher{matchers: matchers}
+				ipIndexMap[uint32(idx)] = matcher
 			}
-			matcher := &MultiGeoIPMatcher{matchers: matchers}
-			ipIndexMap[uint32(idx)] = matcher
 		}
 
 		server.domainMatcher = domainMatcher
@@ -223,12 +199,11 @@ func (s *Server) IsOwnLink(ctx context.Context) bool {
 func (s *Server) Match(idx uint32, client Client, domain string, ips []net.IP) ([]net.IP, error) {
 	matcher, exist := s.ipIndexMap[idx]
 	if !exist {
-		newError("domain ", domain, " server not in ipIndexMap: ", client.Name(), " idx:", idx, " just return").AtDebug().WriteToLog()
 		return ips, nil
 	}
 
 	if !matcher.HasMatcher() {
-		newError("domain ", domain, " server has not valid matcher: ", client.Name(), " idx:", idx, " just return").AtDebug().WriteToLog()
+		newError("domain ", domain, " server has no valid matcher: ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 		return ips, nil
 	}
 
@@ -236,14 +211,12 @@ func (s *Server) Match(idx uint32, client Client, domain string, ips []net.IP) (
 	for _, ip := range ips {
 		if matcher.Match(ip) {
 			newIps = append(newIps, ip)
-			newError("domain ", domain, " ip ", ip, " is match at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
-		} else {
-			newError("domain ", domain, " ip ", ip, " is not match at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 		}
 	}
 	if len(newIps) == 0 {
 		return nil, errExpectedIPNonMatch
 	}
+	newError("domain ", domain, " expectIPs ", newIps, " matched at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 	return newIps, nil
 }
 
@@ -325,7 +298,7 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 
 	// skip domain without any dot
 	if strings.Index(domain, ".") == -1 {
-		return nil, newError("invalid domain name")
+		return nil, newError("invalid domain name").AtWarning()
 	}
 
 	ips := s.lookupStatic(domain, option, 0)
@@ -346,7 +319,6 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 		idx := s.domainMatcher.Match(domain)
 		if idx > 0 {
 			matchedClient = s.clients[s.domainIndexMap[idx]]
-			newError("domain matched, direct lookup ip for domain ", domain, " at ", matchedClient.Name()).WriteToLog()
 			ips, err := s.queryIPTimeout(s.domainIndexMap[idx], matchedClient, domain, option)
 			if len(ips) > 0 {
 				return ips, nil
@@ -367,10 +339,8 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 			continue
 		}
 
-		newError("try to lookup ip for domain ", domain, " at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 		ips, err := s.queryIPTimeout(uint32(idx), client, domain, option)
 		if len(ips) > 0 {
-			newError("lookup ip for domain ", domain, " success: ", ips, " at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 			return ips, nil
 		}
 

+ 7 - 0
app/dns/udpns.go

@@ -36,6 +36,12 @@ type ClassicNameServer struct {
 }
 
 func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
+
+	// default to 53 if unspecific
+	if address.Port == 0 {
+		address.Port = net.Port(53)
+	}
+
 	s := &ClassicNameServer{
 		address:  address,
 		ips:      make(map[string]record),
@@ -49,6 +55,7 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
 		Execute:  s.Cleanup,
 	}
 	s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
+	newError("DNS: created udp client inited for ", address.NetAddr()).AtInfo().WriteToLog()
 	return s
 }
 

+ 0 - 1
infra/conf/dns.go

@@ -21,7 +21,6 @@ func (c *NameServerConfig) UnmarshalJSON(data []byte) error {
 	var address Address
 	if err := json.Unmarshal(data, &address); err == nil {
 		c.Address = &address
-		c.Port = 53
 		return nil
 	}