Browse Source

DNS: fix typo & refine code (#1183)

Co-authored-by: loyalsoldier <10487845+Loyalsoldier@users.noreply.github.com>
rurirei 4 years ago
parent
commit
73470e8dd8
5 changed files with 111 additions and 106 deletions
  1. 1 1
      app/dns/dnscommon.go
  2. 23 24
      app/dns/nameserver_doh.go
  3. 24 25
      app/dns/nameserver_quic.go
  4. 28 25
      app/dns/nameserver_tcp.go
  5. 35 31
      app/dns/nameserver_udp.go

+ 1 - 1
app/dns/dnscommon.go

@@ -213,7 +213,7 @@ L:
 		case dnsmessage.TypeAAAA:
 		case dnsmessage.TypeAAAA:
 			ans, err := parser.AAAAResource()
 			ans, err := parser.AAAAResource()
 			if err != nil {
 			if err != nil {
-				newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
+				newError("failed to parse AAAA record for domain: ", ah.Name).Base(err).WriteToLog()
 				break L
 				break L
 			}
 			}
 			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
 			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))

+ 23 - 24
app/dns/nameserver_doh.go

@@ -32,7 +32,7 @@ import (
 // thus most of the DOH implementation is copied from udpns.go
 // thus most of the DOH implementation is copied from udpns.go
 type DoHNameServer struct {
 type DoHNameServer struct {
 	sync.RWMutex
 	sync.RWMutex
-	ips        map[string]record
+	ips        map[string]*record
 	pub        *pubsub.Service
 	pub        *pubsub.Service
 	cleanup    *task.Periodic
 	cleanup    *task.Periodic
 	reqID      uint32
 	reqID      uint32
@@ -112,7 +112,7 @@ func NewDoHLocalNameServer(url *url.URL) *DoHNameServer {
 
 
 func baseDOHNameServer(url *url.URL, prefix string) *DoHNameServer {
 func baseDOHNameServer(url *url.URL, prefix string) *DoHNameServer {
 	s := &DoHNameServer{
 	s := &DoHNameServer{
-		ips:    make(map[string]record),
+		ips:    make(map[string]*record),
 		pub:    pubsub.NewService(),
 		pub:    pubsub.NewService(),
 		name:   prefix + "//" + url.Host,
 		name:   prefix + "//" + url.Host,
 		dohURL: url.String(),
 		dohURL: url.String(),
@@ -156,7 +156,7 @@ func (s *DoHNameServer) Cleanup() error {
 	}
 	}
 
 
 	if len(s.ips) == 0 {
 	if len(s.ips) == 0 {
-		s.ips = make(map[string]record)
+		s.ips = make(map[string]*record)
 	}
 	}
 
 
 	return nil
 	return nil
@@ -166,7 +166,10 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
 	elapsed := time.Since(req.start)
 	elapsed := time.Since(req.start)
 
 
 	s.Lock()
 	s.Lock()
-	rec := s.ips[req.domain]
+	rec, found := s.ips[req.domain]
+	if !found {
+		rec = &record{}
+	}
 	updated := false
 	updated := false
 
 
 	switch req.reqType {
 	switch req.reqType {
@@ -176,7 +179,7 @@ func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
 			updated = true
 			updated = true
 		}
 		}
 	case dnsmessage.TypeAAAA:
 	case dnsmessage.TypeAAAA:
-		addr := make([]net.Address, 0)
+		addr := make([]net.Address, 0, len(ipRec.IP))
 		for _, ip := range ipRec.IP {
 		for _, ip := range ipRec.IP {
 			if len(ip.IP()) == net.IPv6len {
 			if len(ip.IP()) == net.IPv6len {
 				addr = append(addr, ip)
 				addr = append(addr, ip)
@@ -295,34 +298,30 @@ func (s *DoHNameServer) findIPsForDomain(domain string, option dns_feature.IPOpt
 		return nil, errRecordNotFound
 		return nil, errRecordNotFound
 	}
 	}
 
 
+	var err4 error
+	var err6 error
 	var ips []net.Address
 	var ips []net.Address
-	var lastErr error
-	if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess {
-		aaaa, err := record.AAAA.getIPs()
-		if err != nil {
-			lastErr = err
-		}
-		ips = append(ips, aaaa...)
-	}
-
-	if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess {
-		a, err := record.A.getIPs()
-		if err != nil {
-			lastErr = err
-		}
-		ips = append(ips, a...)
+	var ip6 []net.Address
+
+	switch {
+	case option.IPv4Enable:
+		ips, err4 = record.A.getIPs()
+		fallthrough
+	case option.IPv6Enable:
+		ip6, err6 = record.AAAA.getIPs()
+		ips = append(ips, ip6...)
 	}
 	}
 
 
 	if len(ips) > 0 {
 	if len(ips) > 0 {
 		return toNetIP(ips)
 		return toNetIP(ips)
 	}
 	}
 
 
-	if lastErr != nil {
-		return nil, lastErr
+	if err4 != nil {
+		return nil, err4
 	}
 	}
 
 
-	if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
-		return nil, dns_feature.ErrEmptyResponse
+	if err6 != nil {
+		return nil, err6
 	}
 	}
 
 
 	return nil, errRecordNotFound
 	return nil, errRecordNotFound

+ 24 - 25
app/dns/nameserver_quic.go

@@ -33,12 +33,12 @@ const handshakeIdleTimeout = time.Second * 8
 // QUICNameServer implemented DNS over QUIC
 // QUICNameServer implemented DNS over QUIC
 type QUICNameServer struct {
 type QUICNameServer struct {
 	sync.RWMutex
 	sync.RWMutex
-	ips         map[string]record
+	ips         map[string]*record
 	pub         *pubsub.Service
 	pub         *pubsub.Service
 	cleanup     *task.Periodic
 	cleanup     *task.Periodic
 	reqID       uint32
 	reqID       uint32
 	name        string
 	name        string
-	destination net.Destination
+	destination *net.Destination
 	session     quic.Session
 	session     quic.Session
 }
 }
 
 
@@ -57,10 +57,10 @@ func NewQUICNameServer(url *url.URL) (*QUICNameServer, error) {
 	dest := net.UDPDestination(net.DomainAddress(url.Hostname()), port)
 	dest := net.UDPDestination(net.DomainAddress(url.Hostname()), port)
 
 
 	s := &QUICNameServer{
 	s := &QUICNameServer{
-		ips:         make(map[string]record),
+		ips:         make(map[string]*record),
 		pub:         pubsub.NewService(),
 		pub:         pubsub.NewService(),
 		name:        url.String(),
 		name:        url.String(),
-		destination: dest,
+		destination: &dest,
 	}
 	}
 	s.cleanup = &task.Periodic{
 	s.cleanup = &task.Periodic{
 		Interval: time.Minute,
 		Interval: time.Minute,
@@ -102,7 +102,7 @@ func (s *QUICNameServer) Cleanup() error {
 	}
 	}
 
 
 	if len(s.ips) == 0 {
 	if len(s.ips) == 0 {
-		s.ips = make(map[string]record)
+		s.ips = make(map[string]*record)
 	}
 	}
 
 
 	return nil
 	return nil
@@ -112,7 +112,10 @@ func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
 	elapsed := time.Since(req.start)
 	elapsed := time.Since(req.start)
 
 
 	s.Lock()
 	s.Lock()
-	rec := s.ips[req.domain]
+	rec, found := s.ips[req.domain]
+	if !found {
+		rec = &record{}
+	}
 	updated := false
 	updated := false
 
 
 	switch req.reqType {
 	switch req.reqType {
@@ -232,34 +235,30 @@ func (s *QUICNameServer) findIPsForDomain(domain string, option dns_feature.IPOp
 		return nil, errRecordNotFound
 		return nil, errRecordNotFound
 	}
 	}
 
 
+	var err4 error
+	var err6 error
 	var ips []net.Address
 	var ips []net.Address
-	var lastErr error
-	if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess {
-		aaaa, err := record.AAAA.getIPs()
-		if err != nil {
-			lastErr = err
-		}
-		ips = append(ips, aaaa...)
-	}
-
-	if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess {
-		a, err := record.A.getIPs()
-		if err != nil {
-			lastErr = err
-		}
-		ips = append(ips, a...)
+	var ip6 []net.Address
+
+	switch {
+	case option.IPv4Enable:
+		ips, err4 = record.A.getIPs()
+		fallthrough
+	case option.IPv6Enable:
+		ip6, err6 = record.AAAA.getIPs()
+		ips = append(ips, ip6...)
 	}
 	}
 
 
 	if len(ips) > 0 {
 	if len(ips) > 0 {
 		return toNetIP(ips)
 		return toNetIP(ips)
 	}
 	}
 
 
-	if lastErr != nil {
-		return nil, lastErr
+	if err4 != nil {
+		return nil, err4
 	}
 	}
 
 
-	if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
-		return nil, dns_feature.ErrEmptyResponse
+	if err6 != nil {
+		return nil, err6
 	}
 	}
 
 
 	return nil, errRecordNotFound
 	return nil, errRecordNotFound

+ 28 - 25
app/dns/nameserver_tcp.go

@@ -29,8 +29,8 @@ import (
 type TCPNameServer struct {
 type TCPNameServer struct {
 	sync.RWMutex
 	sync.RWMutex
 	name        string
 	name        string
-	destination net.Destination
-	ips         map[string]record
+	destination *net.Destination
+	ips         map[string]*record
 	pub         *pubsub.Service
 	pub         *pubsub.Service
 	cleanup     *task.Periodic
 	cleanup     *task.Periodic
 	reqID       uint32
 	reqID       uint32
@@ -45,7 +45,7 @@ func NewTCPNameServer(url *url.URL, dispatcher routing.Dispatcher) (*TCPNameServ
 	}
 	}
 
 
 	s.dial = func(ctx context.Context) (net.Conn, error) {
 	s.dial = func(ctx context.Context) (net.Conn, error) {
-		link, err := dispatcher.Dispatch(ctx, s.destination)
+		link, err := dispatcher.Dispatch(ctx, *s.destination)
 		if err != nil {
 		if err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -67,7 +67,7 @@ func NewTCPLocalNameServer(url *url.URL) (*TCPNameServer, error) {
 	}
 	}
 
 
 	s.dial = func(ctx context.Context) (net.Conn, error) {
 	s.dial = func(ctx context.Context) (net.Conn, error) {
-		return internet.DialSystem(ctx, s.destination, nil)
+		return internet.DialSystem(ctx, *s.destination, nil)
 	}
 	}
 
 
 	return s, nil
 	return s, nil
@@ -85,8 +85,8 @@ func baseTCPNameServer(url *url.URL, prefix string) (*TCPNameServer, error) {
 	dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port)
 	dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port)
 
 
 	s := &TCPNameServer{
 	s := &TCPNameServer{
-		destination: dest,
-		ips:         make(map[string]record),
+		destination: &dest,
+		ips:         make(map[string]*record),
 		pub:         pubsub.NewService(),
 		pub:         pubsub.NewService(),
 		name:        prefix + "//" + dest.NetAddr(),
 		name:        prefix + "//" + dest.NetAddr(),
 	}
 	}
@@ -130,7 +130,7 @@ func (s *TCPNameServer) Cleanup() error {
 	}
 	}
 
 
 	if len(s.ips) == 0 {
 	if len(s.ips) == 0 {
-		s.ips = make(map[string]record)
+		s.ips = make(map[string]*record)
 	}
 	}
 
 
 	return nil
 	return nil
@@ -140,7 +140,10 @@ func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
 	elapsed := time.Since(req.start)
 	elapsed := time.Since(req.start)
 
 
 	s.Lock()
 	s.Lock()
-	rec := s.ips[req.domain]
+	rec, found := s.ips[req.domain]
+	if !found {
+		rec = &record{}
+	}
 	updated := false
 	updated := false
 
 
 	switch req.reqType {
 	switch req.reqType {
@@ -274,30 +277,30 @@ func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOpt
 		return nil, errRecordNotFound
 		return nil, errRecordNotFound
 	}
 	}
 
 
+	var err4 error
+	var err6 error
 	var ips []net.Address
 	var ips []net.Address
-	var lastErr error
-	if option.IPv4Enable {
-		a, err := record.A.getIPs()
-		if err != nil {
-			lastErr = err
-		}
-		ips = append(ips, a...)
-	}
-
-	if option.IPv6Enable {
-		aaaa, err := record.AAAA.getIPs()
-		if err != nil {
-			lastErr = err
-		}
-		ips = append(ips, aaaa...)
+	var ip6 []net.Address
+
+	switch {
+	case option.IPv4Enable:
+		ips, err4 = record.A.getIPs()
+		fallthrough
+	case option.IPv6Enable:
+		ip6, err6 = record.AAAA.getIPs()
+		ips = append(ips, ip6...)
 	}
 	}
 
 
 	if len(ips) > 0 {
 	if len(ips) > 0 {
 		return toNetIP(ips)
 		return toNetIP(ips)
 	}
 	}
 
 
-	if lastErr != nil {
-		return nil, lastErr
+	if err4 != nil {
+		return nil, err4
+	}
+
+	if err6 != nil {
+		return nil, err6
 	}
 	}
 
 
 	return nil, dns_feature.ErrEmptyResponse
 	return nil, dns_feature.ErrEmptyResponse

+ 35 - 31
app/dns/nameserver_udp.go

@@ -28,9 +28,9 @@ import (
 type ClassicNameServer struct {
 type ClassicNameServer struct {
 	sync.RWMutex
 	sync.RWMutex
 	name      string
 	name      string
-	address   net.Destination
-	ips       map[string]record
-	requests  map[uint16]dnsRequest
+	address   *net.Destination
+	ips       map[string]*record
+	requests  map[uint16]*dnsRequest
 	pub       *pubsub.Service
 	pub       *pubsub.Service
 	udpServer *udp.Dispatcher
 	udpServer *udp.Dispatcher
 	cleanup   *task.Periodic
 	cleanup   *task.Periodic
@@ -45,9 +45,9 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
 	}
 	}
 
 
 	s := &ClassicNameServer{
 	s := &ClassicNameServer{
-		address:  address,
-		ips:      make(map[string]record),
-		requests: make(map[uint16]dnsRequest),
+		address:  &address,
+		ips:      make(map[string]*record),
+		requests: make(map[uint16]*dnsRequest),
 		pub:      pubsub.NewService(),
 		pub:      pubsub.NewService(),
 		name:     strings.ToUpper(address.String()),
 		name:     strings.ToUpper(address.String()),
 	}
 	}
@@ -84,6 +84,7 @@ func (s *ClassicNameServer) Cleanup() error {
 		}
 		}
 
 
 		if record.A == nil && record.AAAA == nil {
 		if record.A == nil && record.AAAA == nil {
+			newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
 			delete(s.ips, domain)
 			delete(s.ips, domain)
 		} else {
 		} else {
 			s.ips[domain] = record
 			s.ips[domain] = record
@@ -91,7 +92,7 @@ func (s *ClassicNameServer) Cleanup() error {
 	}
 	}
 
 
 	if len(s.ips) == 0 {
 	if len(s.ips) == 0 {
-		s.ips = make(map[string]record)
+		s.ips = make(map[string]*record)
 	}
 	}
 
 
 	for id, req := range s.requests {
 	for id, req := range s.requests {
@@ -101,7 +102,7 @@ func (s *ClassicNameServer) Cleanup() error {
 	}
 	}
 
 
 	if len(s.requests) == 0 {
 	if len(s.requests) == 0 {
-		s.requests = make(map[uint16]dnsRequest)
+		s.requests = make(map[uint16]*dnsRequest)
 	}
 	}
 
 
 	return nil
 	return nil
@@ -139,15 +140,17 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
 	elapsed := time.Since(req.start)
 	elapsed := time.Since(req.start)
 	newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
 	newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
 	if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
 	if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
-		s.updateIP(req.domain, rec)
+		s.updateIP(req.domain, &rec)
 	}
 	}
 }
 }
 
 
-func (s *ClassicNameServer) updateIP(domain string, newRec record) {
+func (s *ClassicNameServer) updateIP(domain string, newRec *record) {
 	s.Lock()
 	s.Lock()
 
 
-	newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog()
-	rec := s.ips[domain]
+	rec, found := s.ips[domain]
+	if !found {
+		rec = &record{}
+	}
 
 
 	updated := false
 	updated := false
 	if isNewer(rec.A, newRec.A) {
 	if isNewer(rec.A, newRec.A) {
@@ -160,6 +163,7 @@ func (s *ClassicNameServer) updateIP(domain string, newRec record) {
 	}
 	}
 
 
 	if updated {
 	if updated {
+		newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog()
 		s.ips[domain] = rec
 		s.ips[domain] = rec
 	}
 	}
 	if newRec.A != nil {
 	if newRec.A != nil {
@@ -182,7 +186,7 @@ func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) {
 
 
 	id := req.msg.ID
 	id := req.msg.ID
 	req.expire = time.Now().Add(time.Second * 8)
 	req.expire = time.Now().Add(time.Second * 8)
-	s.requests[id] = *req
+	s.requests[id] = req
 }
 }
 
 
 func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
 func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
@@ -200,7 +204,7 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, client
 		udpCtx = session.ContextWithContent(udpCtx, &session.Content{
 		udpCtx = session.ContextWithContent(udpCtx, &session.Content{
 			Protocol: "dns",
 			Protocol: "dns",
 		})
 		})
-		s.udpServer.Dispatch(udpCtx, s.address, b)
+		s.udpServer.Dispatch(udpCtx, *s.address, b)
 	}
 	}
 }
 }
 
 
@@ -213,30 +217,30 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option dns_feature.I
 		return nil, errRecordNotFound
 		return nil, errRecordNotFound
 	}
 	}
 
 
+	var err4 error
+	var err6 error
 	var ips []net.Address
 	var ips []net.Address
-	var lastErr error
-	if option.IPv4Enable {
-		a, err := record.A.getIPs()
-		if err != nil {
-			lastErr = err
-		}
-		ips = append(ips, a...)
-	}
-
-	if option.IPv6Enable {
-		aaaa, err := record.AAAA.getIPs()
-		if err != nil {
-			lastErr = err
-		}
-		ips = append(ips, aaaa...)
+	var ip6 []net.Address
+
+	switch {
+	case option.IPv4Enable:
+		ips, err4 = record.A.getIPs()
+		fallthrough
+	case option.IPv6Enable:
+		ip6, err6 = record.AAAA.getIPs()
+		ips = append(ips, ip6...)
 	}
 	}
 
 
 	if len(ips) > 0 {
 	if len(ips) > 0 {
 		return toNetIP(ips)
 		return toNetIP(ips)
 	}
 	}
 
 
-	if lastErr != nil {
-		return nil, lastErr
+	if err4 != nil {
+		return nil, err4
+	}
+
+	if err6 != nil {
+		return nil, err6
 	}
 	}
 
 
 	return nil, dns_feature.ErrEmptyResponse
 	return nil, dns_feature.ErrEmptyResponse