Browse Source

correctly propagate dns errors all the way through.
the internal dns system can correctly handle the cases where:
1) domain has no A or AAAA records
2) domain doesn't exist
fixes #1565

Darien Raymond 6 years ago
parent
commit
9957c64b4a
7 changed files with 252 additions and 78 deletions
  1. 6 0
      app/dns/server.go
  2. 23 0
      app/dns/server_test.go
  3. 138 57
      app/dns/udpns.go
  4. 22 0
      features/dns/client.go
  5. 9 0
      features/dns/localdns/client.go
  6. 3 6
      proxy/dns/dns.go
  7. 51 15
      proxy/dns/dns_test.go

+ 6 - 0
app/dns/server.go

@@ -226,6 +226,9 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 			if len(ips) > 0 {
 			if len(ips) > 0 {
 				return ips, nil
 				return ips, nil
 			}
 			}
+			if err == dns.ErrEmptyResponse {
+				return nil, err
+			}
 			if err != nil {
 			if err != nil {
 				newError("failed to lookup ip for domain ", domain, " at server ", ns.Name()).Base(err).WriteToLog()
 				newError("failed to lookup ip for domain ", domain, " at server ", ns.Name()).Base(err).WriteToLog()
 				lastErr = err
 				lastErr = err
@@ -238,6 +241,9 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 		if len(ips) > 0 {
 		if len(ips) > 0 {
 			return ips, nil
 			return ips, nil
 		}
 		}
+		if err == dns.ErrEmptyResponse {
+			return nil, err
+		}
 		if err != nil {
 		if err != nil {
 			newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog()
 			newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog()
 			lastErr = err
 			lastErr = err

+ 23 - 0
app/dns/server_test.go

@@ -60,6 +60,8 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
 			rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888")
 			rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888")
 			common.Must(err)
 			common.Must(err)
 			ans.Answer = append(ans.Answer, rr)
 			ans.Answer = append(ans.Answer, rr)
+		} else if q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA {
+			ans.MsgHdr.Rcode = dns.RcodeNameError
 		}
 		}
 	}
 	}
 	w.WriteMsg(ans)
 	w.WriteMsg(ans)
@@ -186,6 +188,27 @@ func TestUDPServer(t *testing.T) {
 		}
 		}
 	}
 	}
 
 
+	{
+		_, err := client.LookupIP("notexist.google.com")
+		if err == nil {
+			t.Fatal("nil error")
+		}
+		if r := feature_dns.RCodeFromError(err); r != uint16(dns.RcodeNameError) {
+			t.Fatal("expected NameError, but got ", r)
+		}
+	}
+
+	{
+		clientv6 := client.(feature_dns.IPv6Lookup)
+		ips, err := clientv6.LookupIPv6("ipv4only.google.com")
+		if err != feature_dns.ErrEmptyResponse {
+			t.Fatal("error: ", err)
+		}
+		if len(ips) != 0 {
+			t.Fatal("ips: ", ips)
+		}
+	}
+
 	dnsServer.Shutdown()
 	dnsServer.Shutdown()
 
 
 	{
 	{

+ 138 - 57
app/dns/udpns.go

@@ -5,36 +5,60 @@ package dns
 import (
 import (
 	"context"
 	"context"
 	"encoding/binary"
 	"encoding/binary"
+	fmt "fmt"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
 
 
 	"golang.org/x/net/dns/dnsmessage"
 	"golang.org/x/net/dns/dnsmessage"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
+	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol/dns"
 	"v2ray.com/core/common/protocol/dns"
 	udp_proto "v2ray.com/core/common/protocol/udp"
 	udp_proto "v2ray.com/core/common/protocol/udp"
 	"v2ray.com/core/common/session"
 	"v2ray.com/core/common/session"
 	"v2ray.com/core/common/signal/pubsub"
 	"v2ray.com/core/common/signal/pubsub"
 	"v2ray.com/core/common/task"
 	"v2ray.com/core/common/task"
+	dns_feature "v2ray.com/core/features/dns"
 	"v2ray.com/core/features/routing"
 	"v2ray.com/core/features/routing"
 	"v2ray.com/core/transport/internet/udp"
 	"v2ray.com/core/transport/internet/udp"
 )
 )
 
 
+type record struct {
+	A    *IPRecord
+	AAAA *IPRecord
+}
+
 type IPRecord struct {
 type IPRecord struct {
-	IP     net.Address
+	IP     []net.Address
 	Expire time.Time
 	Expire time.Time
+	RCode  dnsmessage.RCode
+}
+
+func (r *IPRecord) getIPs() ([]net.Address, error) {
+	if r == nil || r.Expire.Before(time.Now()) {
+		return nil, errRecordNotFound
+	}
+	if r.RCode != dnsmessage.RCodeSuccess {
+		return nil, dns_feature.RCodeError(r.RCode)
+	}
+	return r.IP, nil
 }
 }
 
 
 type pendingRequest struct {
 type pendingRequest struct {
-	domain string
-	expire time.Time
+	domain  string
+	expire  time.Time
+	recType dnsmessage.Type
 }
 }
 
 
+var (
+	errRecordNotFound = errors.New("record not found")
+)
+
 type ClassicNameServer struct {
 type ClassicNameServer struct {
 	sync.RWMutex
 	sync.RWMutex
 	address   net.Destination
 	address   net.Destination
-	ips       map[string][]IPRecord
+	ips       map[string]record
 	requests  map[uint16]pendingRequest
 	requests  map[uint16]pendingRequest
 	pub       *pubsub.Service
 	pub       *pubsub.Service
 	udpServer *udp.Dispatcher
 	udpServer *udp.Dispatcher
@@ -46,7 +70,7 @@ type ClassicNameServer struct {
 func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
 func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
 	s := &ClassicNameServer{
 	s := &ClassicNameServer{
 		address:  address,
 		address:  address,
-		ips:      make(map[string][]IPRecord),
+		ips:      make(map[string]record),
 		requests: make(map[uint16]pendingRequest),
 		requests: make(map[uint16]pendingRequest),
 		clientIP: clientIP,
 		clientIP: clientIP,
 		pub:      pubsub.NewService(),
 		pub:      pubsub.NewService(),
@@ -72,22 +96,23 @@ func (s *ClassicNameServer) Cleanup() error {
 		return newError("nothing to do. stopping...")
 		return newError("nothing to do. stopping...")
 	}
 	}
 
 
-	for domain, ips := range s.ips {
-		newIPs := make([]IPRecord, 0, len(ips))
-		for _, ip := range ips {
-			if ip.Expire.After(now) {
-				newIPs = append(newIPs, ip)
-			}
+	for domain, record := range s.ips {
+		if record.A != nil && record.A.Expire.Before(now) {
+			record.A = nil
 		}
 		}
-		if len(newIPs) == 0 {
+		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
+			record.AAAA = nil
+		}
+
+		if record.A == nil && record.AAAA == nil {
 			delete(s.ips, domain)
 			delete(s.ips, domain)
-		} else if len(newIPs) < len(ips) {
-			s.ips[domain] = newIPs
+		} else {
+			s.ips[domain] = record
 		}
 		}
 	}
 	}
 
 
 	if len(s.ips) == 0 {
 	if len(s.ips) == 0 {
-		s.ips = make(map[string][]IPRecord)
+		s.ips = make(map[string]record)
 	}
 	}
 
 
 	for id, req := range s.requests {
 	for id, req := range s.requests {
@@ -130,9 +155,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
 	}
 	}
 
 
 	domain := req.domain
 	domain := req.domain
-	ips := make([]IPRecord, 0, 16)
+	recType := req.recType
 
 
 	now := time.Now()
 	now := time.Now()
+	ipRecord := &IPRecord{
+		RCode:  header.RCode,
+		Expire: now.Add(time.Second * 600),
+	}
+
 	for {
 	for {
 		header, err := parser.AnswerHeader()
 		header, err := parser.AnswerHeader()
 		if err != nil {
 		if err != nil {
@@ -145,6 +175,15 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
 		if ttl == 0 {
 		if ttl == 0 {
 			ttl = 600
 			ttl = 600
 		}
 		}
+		expire := now.Add(time.Duration(ttl) * time.Second)
+		if ipRecord.Expire.After(expire) {
+			ipRecord.Expire = expire
+		}
+
+		if header.Type != recType {
+			continue
+		}
+
 		switch header.Type {
 		switch header.Type {
 		case dnsmessage.TypeA:
 		case dnsmessage.TypeA:
 			ans, err := parser.AResource()
 			ans, err := parser.AResource()
@@ -152,20 +191,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
 				newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
 				newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
 				break
 				break
 			}
 			}
-			ips = append(ips, IPRecord{
-				IP:     net.IPAddress(ans.A[:]),
-				Expire: now.Add(time.Duration(ttl) * time.Second),
-			})
+			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
 		case dnsmessage.TypeAAAA:
 		case dnsmessage.TypeAAAA:
 			ans, err := parser.AAAAResource()
 			ans, err := parser.AAAAResource()
 			if err != nil {
 			if err != nil {
 				newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
 				newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
 				break
 				break
 			}
 			}
-			ips = append(ips, IPRecord{
-				IP:     net.IPAddress(ans.AAAA[:]),
-				Expire: now.Add(time.Duration(ttl) * time.Second),
-			})
+			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
 		default:
 		default:
 			if err := parser.SkipAnswer(); err != nil {
 			if err := parser.SkipAnswer(); err != nil {
 				newError("failed to skip answer").Base(err).WriteToLog()
 				newError("failed to skip answer").Base(err).WriteToLog()
@@ -173,24 +206,49 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
 		}
 		}
 	}
 	}
 
 
-	if len(domain) > 0 && len(ips) > 0 {
-		s.updateIP(domain, ips)
+	var rec record
+	switch recType {
+	case dnsmessage.TypeA:
+		rec.A = ipRecord
+	case dnsmessage.TypeAAAA:
+		rec.AAAA = ipRecord
+	}
+
+	if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
+		s.updateIP(domain, rec)
+	}
+}
+
+func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
+	if newRec == nil {
+		return false
+	}
+	if baseRec == nil {
+		return true
 	}
 	}
+	return baseRec.Expire.Before(newRec.Expire)
 }
 }
 
 
-func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) {
+func (s *ClassicNameServer) updateIP(domain string, newRec record) {
 	s.Lock()
 	s.Lock()
 
 
 	newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
 	newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
-	now := time.Now()
-	eips := s.ips[domain]
-	for _, ip := range eips {
-		if ip.Expire.After(now) {
-			ips = append(ips, ip)
-		}
+	rec := s.ips[domain]
+
+	updated := false
+	if isNewer(rec.A, newRec.A) {
+		rec.A = newRec.A
+		updated = true
+	}
+	if isNewer(rec.AAAA, newRec.AAAA) {
+		rec.AAAA = newRec.AAAA
+		updated = true
+	}
+
+	if updated {
+		s.ips[domain] = rec
+		s.pub.Publish(domain, nil)
 	}
 	}
-	s.ips[domain] = ips
-	s.pub.Publish(domain, nil)
 
 
 	s.Unlock()
 	s.Unlock()
 	common.Must(s.cleanup.Start())
 	common.Must(s.cleanup.Start())
@@ -244,14 +302,15 @@ func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource {
 	return opt
 	return opt
 }
 }
 
 
-func (s *ClassicNameServer) addPendingRequest(domain string) uint16 {
+func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 {
 	id := uint16(atomic.AddUint32(&s.reqID, 1))
 	id := uint16(atomic.AddUint32(&s.reqID, 1))
 	s.Lock()
 	s.Lock()
 	defer s.Unlock()
 	defer s.Unlock()
 
 
 	s.requests[id] = pendingRequest{
 	s.requests[id] = pendingRequest{
-		domain: domain,
-		expire: time.Now().Add(time.Second * 8),
+		domain:  domain,
+		expire:  time.Now().Add(time.Second * 8),
+		recType: recType,
 	}
 	}
 
 
 	return id
 	return id
@@ -274,7 +333,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
 
 
 	if option.IPv4Enable {
 	if option.IPv4Enable {
 		msg := new(dnsmessage.Message)
 		msg := new(dnsmessage.Message)
-		msg.Header.ID = s.addPendingRequest(domain)
+		msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA)
 		msg.Header.RecursionDesired = true
 		msg.Header.RecursionDesired = true
 		msg.Questions = []dnsmessage.Question{qA}
 		msg.Questions = []dnsmessage.Question{qA}
 		if opt := s.getMsgOptions(); opt != nil {
 		if opt := s.getMsgOptions(); opt != nil {
@@ -285,7 +344,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
 
 
 	if option.IPv6Enable {
 	if option.IPv6Enable {
 		msg := new(dnsmessage.Message)
 		msg := new(dnsmessage.Message)
-		msg.Header.ID = s.addPendingRequest(domain)
+		msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA)
 		msg.Header.RecursionDesired = true
 		msg.Header.RecursionDesired = true
 		msg.Questions = []dnsmessage.Question{qAAAA}
 		msg.Questions = []dnsmessage.Question{qAAAA}
 		if opt := s.getMsgOptions(); opt != nil {
 		if opt := s.getMsgOptions(); opt != nil {
@@ -313,22 +372,44 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option
 	}
 	}
 }
 }
 
 
-func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) []net.IP {
+func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) {
 	s.RLock()
 	s.RLock()
-	records, found := s.ips[domain]
+	record, found := s.ips[domain]
 	s.RUnlock()
 	s.RUnlock()
 
 
-	if found && len(records) > 0 {
-		var ips []net.Address
-		now := time.Now()
-		for _, rec := range records {
-			if rec.Expire.After(now) {
-				ips = append(ips, rec.IP)
-			}
+	if !found {
+		return nil, errRecordNotFound
+	}
+
+	var ips []net.Address
+	var lastErr error
+	if option.IPv4Enable {
+		a, err := record.A.getIPs()
+		if err != nil {
+			lastErr = err
 		}
 		}
-		return toNetIP(filterIP(ips, option))
+		ips = append(ips, a...)
 	}
 	}
-	return nil
+
+	if option.IPv6Enable {
+		aaaa, err := record.AAAA.getIPs()
+		if err != nil {
+			lastErr = err
+		}
+		ips = append(ips, aaaa...)
+	}
+
+	fmt.Println("IPs for ", domain, ": ", ips)
+
+	if len(ips) > 0 {
+		return toNetIP(ips), nil
+	}
+
+	if lastErr != nil {
+		return nil, lastErr
+	}
+
+	return nil, dns_feature.ErrEmptyResponse
 }
 }
 
 
 func Fqdn(domain string) string {
 func Fqdn(domain string) string {
@@ -341,9 +422,9 @@ func Fqdn(domain string) string {
 func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
 func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
 	fqdn := Fqdn(domain)
 	fqdn := Fqdn(domain)
 
 
-	ips := s.findIPsForDomain(fqdn, option)
-	if len(ips) > 0 {
-		return ips, nil
+	ips, err := s.findIPsForDomain(fqdn, option)
+	if err != errRecordNotFound {
+		return ips, err
 	}
 	}
 
 
 	sub := s.pub.Subscribe(fqdn)
 	sub := s.pub.Subscribe(fqdn)
@@ -352,9 +433,9 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option I
 	s.sendQuery(ctx, fqdn, option)
 	s.sendQuery(ctx, fqdn, option)
 
 
 	for {
 	for {
-		ips := s.findIPsForDomain(fqdn, option)
-		if len(ips) > 0 {
-			return ips, nil
+		ips, err := s.findIPsForDomain(fqdn, option)
+		if err != errRecordNotFound {
+			return ips, err
 		}
 		}
 
 
 		select {
 		select {

+ 22 - 0
features/dns/client.go

@@ -1,7 +1,9 @@
 package dns
 package dns
 
 
 import (
 import (
+	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/serial"
 	"v2ray.com/core/features"
 	"v2ray.com/core/features"
 )
 )
 
 
@@ -35,3 +37,23 @@ type IPv6Lookup interface {
 func ClientType() interface{} {
 func ClientType() interface{} {
 	return (*Client)(nil)
 	return (*Client)(nil)
 }
 }
+
+// ErrEmptyResponse indicates that DNS query succeeded but no answer was returned.
+var ErrEmptyResponse = errors.New("empty response")
+
+type RCodeError uint16
+
+func (e RCodeError) Error() string {
+	return serial.Concat("rcode: ", uint16(e))
+}
+
+func RCodeFromError(err error) uint16 {
+	if err == nil {
+		return 0
+	}
+	cause := errors.Cause(err)
+	if r, ok := cause.(RCodeError); ok {
+		return uint16(r)
+	}
+	return 0
+}

+ 9 - 0
features/dns/localdns/client.go

@@ -32,6 +32,9 @@ func (*Client) LookupIP(host string) ([]net.IP, error) {
 			parsedIPs = append(parsedIPs, parsed.IP())
 			parsedIPs = append(parsedIPs, parsed.IP())
 		}
 		}
 	}
 	}
+	if len(parsedIPs) == 0 {
+		return nil, dns.ErrEmptyResponse
+	}
 	return parsedIPs, nil
 	return parsedIPs, nil
 }
 }
 
 
@@ -47,6 +50,9 @@ func (c *Client) LookupIPv4(host string) ([]net.IP, error) {
 			ipv4 = append(ipv4, ip)
 			ipv4 = append(ipv4, ip)
 		}
 		}
 	}
 	}
+	if len(ipv4) == 0 {
+		return nil, dns.ErrEmptyResponse
+	}
 	return ipv4, nil
 	return ipv4, nil
 }
 }
 
 
@@ -62,6 +68,9 @@ func (c *Client) LookupIPv6(host string) ([]net.IP, error) {
 			ipv6 = append(ipv6, ip)
 			ipv6 = append(ipv6, ip)
 		}
 		}
 	}
 	}
+	if len(ipv6) == 0 {
+		return nil, dns.ErrEmptyResponse
+	}
 	return ipv6, nil
 	return ipv6, nil
 }
 }
 
 

+ 3 - 6
proxy/dns/dns.go

@@ -218,20 +218,17 @@ func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string,
 		ips, err = h.ipv6Lookup.LookupIPv6(domain)
 		ips, err = h.ipv6Lookup.LookupIPv6(domain)
 	}
 	}
 
 
-	if err != nil {
+	rcode := dns.RCodeFromError(err)
+	if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse {
 		newError("ip query").Base(err).WriteToLog()
 		newError("ip query").Base(err).WriteToLog()
 		return
 		return
 	}
 	}
 
 
-	if len(ips) == 0 {
-		return
-	}
-
 	b := buf.New()
 	b := buf.New()
 	rawBytes := b.Extend(buf.Size)
 	rawBytes := b.Extend(buf.Size)
 	builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
 	builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{
 		ID:                 id,
 		ID:                 id,
-		RCode:              dnsmessage.RCodeSuccess,
+		RCode:              dnsmessage.RCode(rcode),
 		RecursionAvailable: true,
 		RecursionAvailable: true,
 		RecursionDesired:   true,
 		RecursionDesired:   true,
 		Response:           true,
 		Response:           true,

+ 51 - 15
proxy/dns/dns_test.go

@@ -63,6 +63,8 @@ func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
 			rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888")
 			rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888")
 			common.Must(err)
 			common.Must(err)
 			ans.Answer = append(ans.Answer, rr)
 			ans.Answer = append(ans.Answer, rr)
+		} else if q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA {
+			ans.MsgHdr.Rcode = dns.RcodeNameError
 		}
 		}
 	}
 	}
 	w.WriteMsg(ans)
 	w.WriteMsg(ans)
@@ -128,26 +130,60 @@ func TestUDPDNSTunnel(t *testing.T) {
 	common.Must(v.Start())
 	common.Must(v.Start())
 	defer v.Close()
 	defer v.Close()
 
 
-	m1 := new(dns.Msg)
-	m1.Id = dns.Id()
-	m1.RecursionDesired = true
-	m1.Question = make([]dns.Question, 1)
-	m1.Question[0] = dns.Question{"google.com.", dns.TypeA, dns.ClassINET}
+	{
+		m1 := new(dns.Msg)
+		m1.Id = dns.Id()
+		m1.RecursionDesired = true
+		m1.Question = make([]dns.Question, 1)
+		m1.Question[0] = dns.Question{"google.com.", dns.TypeA, dns.ClassINET}
 
 
-	c := new(dns.Client)
-	in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
-	common.Must(err)
+		c := new(dns.Client)
+		in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
+		common.Must(err)
 
 
-	if len(in.Answer) != 1 {
-		t.Fatal("len(answer): ", len(in.Answer))
+		if len(in.Answer) != 1 {
+			t.Fatal("len(answer): ", len(in.Answer))
+		}
+
+		rr, ok := in.Answer[0].(*dns.A)
+		if !ok {
+			t.Fatal("not A record")
+		}
+		if r := cmp.Diff(rr.A[:], net.IP{8, 8, 8, 8}); r != "" {
+			t.Error(r)
+		}
 	}
 	}
 
 
-	rr, ok := in.Answer[0].(*dns.A)
-	if !ok {
-		t.Fatal("not A record")
+	{
+		m1 := new(dns.Msg)
+		m1.Id = dns.Id()
+		m1.RecursionDesired = true
+		m1.Question = make([]dns.Question, 1)
+		m1.Question[0] = dns.Question{"ipv4only.google.com.", dns.TypeAAAA, dns.ClassINET}
+
+		c := new(dns.Client)
+		in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
+		common.Must(err)
+
+		if len(in.Answer) != 0 {
+			t.Fatal("len(answer): ", len(in.Answer))
+		}
 	}
 	}
-	if r := cmp.Diff(rr.A[:], net.IP{8, 8, 8, 8}); r != "" {
-		t.Error(r)
+
+	{
+		m1 := new(dns.Msg)
+		m1.Id = dns.Id()
+		m1.RecursionDesired = true
+		m1.Question = make([]dns.Question, 1)
+		m1.Question[0] = dns.Question{"notexist.google.com.", dns.TypeAAAA, dns.ClassINET}
+
+		c := new(dns.Client)
+		in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
+		common.Must(err)
+
+		if in.Rcode != dns.RcodeNameError {
+			t.Error("expected NameError, but got ", in.Rcode)
+		}
 	}
 	}
 }
 }