소스 검색

only add to ipIndexMap if GeoIP is configured

clearer logging for expectIPs

refactor dns init code

optimal to default port logic

clear message exit if doh met error
vcptr 6 년 전
부모
커밋
ceb77ac8f5
5개의 변경된 파일93개의 추가작업 그리고 66개의 파일을 삭제
  1. 60 10
      app/dns/dohdns.go
  2. 1 0
      app/dns/nameserver.go
  3. 25 55
      app/dns/server.go
  4. 7 0
      app/dns/udpns.go
  5. 0 1
      infra/conf/dns.go

+ 60 - 10
app/dns/dohdns.go

@@ -40,10 +40,41 @@ type DoHNameServer struct {
 	name       string
 }
 
-func NewDoHNameServer(dests []net.Destination, dohHost string, dispatcher routing.Dispatcher, clientIP net.IP) *DoHNameServer {
+func NewDoHNameServer(dohHost string, dohPort uint32, dispatcher routing.Dispatcher, clientIP net.IP) (*DoHNameServer, error) {
 
-	s := NewDoHLocalNameServer(dohHost, clientIP)
-	s.name = "DOH:" + dohHost
+	dohAddr := net.ParseAddress(dohHost)
+	var dests []net.Destination
+
+	if dohPort == 0 {
+		dohPort = 443
+	}
+
+	parseIPDest := func(ip net.IP, port uint32) net.Destination {
+		strIP := ip.String()
+		if len(ip) == net.IPv6len {
+			strIP = fmt.Sprintf("[%s]", strIP)
+		}
+		dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:%d", strIP, port))
+		common.Must(err)
+		return dest
+	}
+
+	if dohAddr.Family().IsDomain() {
+		// resolve DOH server in advance
+		ips, err := net.LookupIP(dohAddr.Domain())
+		if err != nil || len(ips) == 0 {
+			return nil, err
+		}
+		for _, ip := range ips {
+			dests = append(dests, parseIPDest(ip, dohPort))
+		}
+	} else {
+		ip := dohAddr.IP()
+		dests = append(dests, parseIPDest(ip, dohPort))
+	}
+
+	newError("DNS: created remote DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog()
+	s := baseDOHNameServer(dohHost, dohPort, "DOH", clientIP)
 	s.dispatcher = dispatcher
 	s.dohDests = dests
 
@@ -66,22 +97,41 @@ func NewDoHNameServer(dests []net.Destination, dohHost string, dispatcher routin
 	}
 
 	s.httpClient = dispatchedClient
+	return s, nil
+}
+
+func NewDoHLocalNameServer(dohHost string, dohPort uint32, clientIP net.IP) *DoHNameServer {
+
+	if dohPort == 0 {
+		dohPort = 443
+	}
+
+	s := baseDOHNameServer(dohHost, dohPort, "DOHL", clientIP)
+	s.httpClient = &http.Client{
+		Timeout: time.Second * 180,
+	}
+	newError("DNS: created local DOH client for https://", dohHost, ":", dohPort).AtInfo().WriteToLog()
 	return s
 }
 
-func NewDoHLocalNameServer(dohHost string, clientIP net.IP) *DoHNameServer {
+func baseDOHNameServer(dohHost string, dohPort uint32, prefix string, clientIP net.IP) *DoHNameServer {
+
+	if dohPort == 0 {
+		dohPort = 443
+	}
+
 	s := &DoHNameServer{
-		httpClient: http.DefaultClient,
-		ips:        make(map[string]record),
-		clientIP:   clientIP,
-		pub:        pubsub.NewService(),
-		name:       "DOHL:" + dohHost,
-		dohURL:     fmt.Sprintf("https://%s/dns-query", dohHost),
+		ips:      make(map[string]record),
+		clientIP: clientIP,
+		pub:      pubsub.NewService(),
+		name:     fmt.Sprintf("%s:%s:%d", prefix, dohHost, dohPort),
+		dohURL:   fmt.Sprintf("https://%s:%d/dns-query", dohHost, dohPort),
 	}
 	s.cleanup = &task.Periodic{
 		Interval: time.Minute,
 		Execute:  s.Cleanup,
 	}
+
 	return s
 }
 

+ 1 - 0
app/dns/nameserver.go

@@ -49,6 +49,7 @@ func (s *localNameServer) Name() string {
 }
 
 func NewLocalNameServer() *localNameServer {
+	newError("DNS: created localhost client").AtInfo().WriteToLog()
 	return &localNameServer{
 		client: localdns.New(),
 	}

+ 25 - 55
app/dns/server.go

@@ -6,7 +6,7 @@ package dns
 
 import (
 	"context"
-	"fmt"
+	"log"
 	"strings"
 	"sync"
 	"time"
@@ -41,7 +41,7 @@ type MultiGeoIPMatcher struct {
 	matchers []*router.GeoIPMatcher
 }
 
-var errExpectedIPNonMatch = errors.New("expected ip not match")
+var errExpectedIPNonMatch = errors.New("expectIPs not match")
 
 // Match check ip match
 func (c *MultiGeoIPMatcher) Match(ip net.IP) bool {
@@ -73,7 +73,7 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 		server.tag = generateRandomTag()
 	}
 	if len(config.ClientIp) > 0 {
-		if len(config.ClientIp) != 4 && len(config.ClientIp) != 16 {
+		if len(config.ClientIp) != net.IPv4len && len(config.ClientIp) != net.IPv6len {
 			return nil, newError("unexpected IP length", len(config.ClientIp))
 		}
 		server.clientIP = net.IP(config.ClientIp)
@@ -89,48 +89,22 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 		address := endpoint.Address.AsAddress()
 		if address.Family().IsDomain() && address.Domain() == "localhost" {
 			server.clients = append(server.clients, NewLocalNameServer())
-			newError("DNS: localhost inited").AtInfo().WriteToLog()
 		} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOHL_") {
 			dohHost := address.Domain()[5:]
-			server.clients = append(server.clients, NewDoHLocalNameServer(dohHost, server.clientIP))
-			newError("DNS: DOH - Local inited for https://", dohHost).AtInfo().WriteToLog()
+			server.clients = append(server.clients, NewDoHLocalNameServer(dohHost, endpoint.Port, server.clientIP))
 		} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOH_") {
 			// DOH_ prefix makes net.Address think it's a domain
-			// need to process the real address here.
 			dohHost := address.Domain()[4:]
-			dohAddr := net.ParseAddress(dohHost)
-			dohIP := dohHost
-			var dests []net.Destination
-
-			if dohAddr.Family().IsDomain() {
-				// resolve DOH server in advance
-				ips, err := net.LookupIP(dohAddr.Domain())
-				if err != nil || len(ips) == 0 {
-					return 0
-				}
-				for _, ip := range ips {
-					dohIP := ip.String()
-					if len(ip) == net.IPv6len {
-						dohIP = fmt.Sprintf("[%s]", dohIP)
-					}
-					dohdest, _ := net.ParseDestination(fmt.Sprintf("tcp:%s:443", dohIP))
-					dests = append(dests, dohdest)
-				}
-			} else {
-				// rfc8484, DOH service only use port 443
-				dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:443", dohIP))
-				if err != nil {
-					return 0
-				}
-				dests = []net.Destination{dest}
-			}
-
-			// need the core dispatcher, register DOHClient at callback
 			idx := len(server.clients)
 			server.clients = append(server.clients, nil)
+
+			// need the core dispatcher, register DOHClient at callback
 			common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) {
-				server.clients[idx] = NewDoHNameServer(dests, dohHost, d, server.clientIP)
-				newError("DNS: DOH - Remote client inited for https://", dohHost).AtInfo().WriteToLog()
+				c, err := NewDoHNameServer(dohHost, endpoint.Port, d, server.clientIP)
+				if err != nil {
+					log.Fatalln(newError("DNS config error").Base(err))
+				}
+				server.clients[idx] = c
 			}))
 		} else {
 			dest := endpoint.AsDestination()
@@ -145,7 +119,6 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 					server.clients[idx] = NewClassicNameServer(dest, d, server.clientIP)
 				}))
 			}
-			newError("DNS: UDP client inited for ", dest.NetAddr()).AtInfo().WriteToLog()
 		}
 		return len(server.clients) - 1
 	}
@@ -175,16 +148,19 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 				domainIndexMap[midx] = uint32(idx)
 			}
 
-			var matchers []*router.GeoIPMatcher
-			for _, geoip := range ns.Geoip {
-				matcher, err := geoIPMatcherContainer.Add(geoip)
-				if err != nil {
-					return nil, newError("failed to create ip matcher").Base(err).AtWarning()
+			// only add to ipIndexMap if GeoIP is configured
+			if len(ns.Geoip) > 0 {
+				var matchers []*router.GeoIPMatcher
+				for _, geoip := range ns.Geoip {
+					matcher, err := geoIPMatcherContainer.Add(geoip)
+					if err != nil {
+						return nil, newError("failed to create ip matcher").Base(err).AtWarning()
+					}
+					matchers = append(matchers, matcher)
 				}
-				matchers = append(matchers, matcher)
+				matcher := &MultiGeoIPMatcher{matchers: matchers}
+				ipIndexMap[uint32(idx)] = matcher
 			}
-			matcher := &MultiGeoIPMatcher{matchers: matchers}
-			ipIndexMap[uint32(idx)] = matcher
 		}
 
 		server.domainMatcher = domainMatcher
@@ -223,12 +199,11 @@ func (s *Server) IsOwnLink(ctx context.Context) bool {
 func (s *Server) Match(idx uint32, client Client, domain string, ips []net.IP) ([]net.IP, error) {
 	matcher, exist := s.ipIndexMap[idx]
 	if !exist {
-		newError("domain ", domain, " server not in ipIndexMap: ", client.Name(), " idx:", idx, " just return").AtDebug().WriteToLog()
 		return ips, nil
 	}
 
 	if !matcher.HasMatcher() {
-		newError("domain ", domain, " server has not valid matcher: ", client.Name(), " idx:", idx, " just return").AtDebug().WriteToLog()
+		newError("domain ", domain, " server has no valid matcher: ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 		return ips, nil
 	}
 
@@ -236,14 +211,12 @@ func (s *Server) Match(idx uint32, client Client, domain string, ips []net.IP) (
 	for _, ip := range ips {
 		if matcher.Match(ip) {
 			newIps = append(newIps, ip)
-			newError("domain ", domain, " ip ", ip, " is match at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
-		} else {
-			newError("domain ", domain, " ip ", ip, " is not match at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 		}
 	}
 	if len(newIps) == 0 {
 		return nil, errExpectedIPNonMatch
 	}
+	newError("domain ", domain, " expectIPs ", newIps, " matched at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 	return newIps, nil
 }
 
@@ -325,7 +298,7 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 
 	// skip domain without any dot
 	if strings.Index(domain, ".") == -1 {
-		return nil, newError("invalid domain name")
+		return nil, newError("invalid domain name").AtWarning()
 	}
 
 	ips := s.lookupStatic(domain, option, 0)
@@ -346,7 +319,6 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 		idx := s.domainMatcher.Match(domain)
 		if idx > 0 {
 			matchedClient = s.clients[s.domainIndexMap[idx]]
-			newError("domain matched, direct lookup ip for domain ", domain, " at ", matchedClient.Name()).WriteToLog()
 			ips, err := s.queryIPTimeout(s.domainIndexMap[idx], matchedClient, domain, option)
 			if len(ips) > 0 {
 				return ips, nil
@@ -367,10 +339,8 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 			continue
 		}
 
-		newError("try to lookup ip for domain ", domain, " at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 		ips, err := s.queryIPTimeout(uint32(idx), client, domain, option)
 		if len(ips) > 0 {
-			newError("lookup ip for domain ", domain, " success: ", ips, " at server ", client.Name(), " idx:", idx).AtDebug().WriteToLog()
 			return ips, nil
 		}
 

+ 7 - 0
app/dns/udpns.go

@@ -36,6 +36,12 @@ type ClassicNameServer struct {
 }
 
 func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
+
+	// default to 53 if unspecific
+	if address.Port == 0 {
+		address.Port = net.Port(53)
+	}
+
 	s := &ClassicNameServer{
 		address:  address,
 		ips:      make(map[string]record),
@@ -49,6 +55,7 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
 		Execute:  s.Cleanup,
 	}
 	s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
+	newError("DNS: created udp client inited for ", address.NetAddr()).AtInfo().WriteToLog()
 	return s
 }
 

+ 0 - 1
infra/conf/dns.go

@@ -21,7 +21,6 @@ func (c *NameServerConfig) UnmarshalJSON(data []byte) error {
 	var address Address
 	if err := json.Unmarshal(data, &address); err == nil {
 		c.Address = &address
-		c.Port = 53
 		return nil
 	}