Browse Source

fix dns parsing for unknown header types

Darien Raymond 6 years ago
parent
commit
bb8465e1d6
1 changed files with 8 additions and 5 deletions
  1. 8 5
      app/dns/udpns.go

+ 8 - 5
app/dns/udpns.go

@@ -5,7 +5,6 @@ package dns
 import (
 import (
 	"context"
 	"context"
 	"encoding/binary"
 	"encoding/binary"
-	fmt "fmt"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
@@ -163,6 +162,7 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
 		Expire: now.Add(time.Second * 600),
 		Expire: now.Add(time.Second * 600),
 	}
 	}
 
 
+L:
 	for {
 	for {
 		header, err := parser.AnswerHeader()
 		header, err := parser.AnswerHeader()
 		if err != nil {
 		if err != nil {
@@ -181,6 +181,10 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
 		}
 		}
 
 
 		if header.Type != recType {
 		if header.Type != recType {
+			if err := parser.SkipAnswer(); err != nil {
+				newError("failed to skip answer").Base(err).WriteToLog()
+				break L
+			}
 			continue
 			continue
 		}
 		}
 
 
@@ -189,19 +193,20 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
 			ans, err := parser.AResource()
 			ans, err := parser.AResource()
 			if err != nil {
 			if err != nil {
 				newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
 				newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
-				break
+				break L
 			}
 			}
 			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
 			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
 		case dnsmessage.TypeAAAA:
 		case dnsmessage.TypeAAAA:
 			ans, err := parser.AAAAResource()
 			ans, err := parser.AAAAResource()
 			if err != nil {
 			if err != nil {
 				newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
 				newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
-				break
+				break L
 			}
 			}
 			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
 			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
 		default:
 		default:
 			if err := parser.SkipAnswer(); err != nil {
 			if err := parser.SkipAnswer(); err != nil {
 				newError("failed to skip answer").Base(err).WriteToLog()
 				newError("failed to skip answer").Base(err).WriteToLog()
+				break L
 			}
 			}
 		}
 		}
 	}
 	}
@@ -399,8 +404,6 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]
 		ips = append(ips, aaaa...)
 		ips = append(ips, aaaa...)
 	}
 	}
 
 
-	fmt.Println("IPs for ", domain, ": ", ips)
-
 	if len(ips) > 0 {
 	if len(ips) > 0 {
 		return toNetIP(ips), nil
 		return toNetIP(ips), nil
 	}
 	}