瀏覽代碼

fix dns query for CNAME

Darien Raymond 7 年之前
父節點
當前提交
4c18b61e6c
共有 1 個文件被更改,包括 44 次插入7 次删除
  1. 44 7
      app/dns/udpns.go

+ 44 - 7
app/dns/udpns.go

@@ -29,10 +29,16 @@ type IPRecord struct {
 	Expire time.Time
 }
 
+type pendingRequest struct {
+	domain string
+	expire time.Time
+}
+
 type ClassicNameServer struct {
 	sync.RWMutex
 	address   net.Destination
 	ips       map[string][]IPRecord
+	requests  map[uint16]pendingRequest
 	pub       *pubsub.Service
 	udpServer *udp.Dispatcher
 	cleanup   *task.Periodic
@@ -44,6 +50,7 @@ func NewClassicNameServer(address net.Destination, dispatcher core.Dispatcher, c
 	s := &ClassicNameServer{
 		address:   address,
 		ips:       make(map[string][]IPRecord),
+		requests:  make(map[uint16]pendingRequest),
 		udpServer: udp.NewDispatcher(dispatcher),
 		clientIP:  clientIP,
 		pub:       pubsub.NewService(),
@@ -77,6 +84,16 @@ func (s *ClassicNameServer) Cleanup() error {
 		s.ips = make(map[string][]IPRecord)
 	}
 
+	for id, req := range s.requests {
+		if req.expire.Before(now) {
+			delete(s.requests, id)
+		}
+	}
+
+	if len(s.requests) == 0 {
+		s.requests = make(map[uint16]pendingRequest)
+	}
+
 	s.Unlock()
 	return nil
 }
@@ -91,16 +108,24 @@ func (s *ClassicNameServer) HandleResponse(payload *buf.Buffer) {
 		return
 	}
 
-	var domain string
+	id := msg.Id
+	s.Lock()
+	req, f := s.requests[id]
+	if f {
+		delete(s.requests, id)
+	}
+	s.Unlock()
+
+	if !f {
+		return
+	}
+
+	domain := req.domain
 	ips := make([]IPRecord, 0, 16)
 
 	now := time.Now()
 	for _, rr := range msg.Answer {
 		var ip net.IP
-		name := rr.Header().Name
-		if len(name) > 0 {
-			domain = rr.Header().Name
-		}
 		ttl := rr.Header().Ttl
 		switch rr := rr.(type) {
 		case *dns.A:
@@ -164,7 +189,19 @@ func (s *ClassicNameServer) getMsgOptions() *dns.OPT {
 	o.Option = append(o.Option, e)
 
 	return o
+}
+
+func (s *ClassicNameServer) addPendingRequest(domain string) uint16 {
+	id := uint16(atomic.AddUint32(&s.reqID, 1))
+	s.Lock()
+	defer s.Unlock()
+
+	s.requests[id] = pendingRequest{
+		domain: domain,
+		expire: time.Now().Add(time.Second * 8),
+	}
 
+	return id
 }
 
 func (s *ClassicNameServer) buildMsgs(domain string) []*dns.Msg {
@@ -186,7 +223,7 @@ func (s *ClassicNameServer) buildMsgs(domain string) []*dns.Msg {
 
 	{
 		msg := new(dns.Msg)
-		msg.Id = uint16(atomic.AddUint32(&s.reqID, 1))
+		msg.Id = s.addPendingRequest(domain)
 		msg.RecursionDesired = true
 		msg.Question = []dns.Question{qA}
 		if allowMulti {
@@ -200,7 +237,7 @@ func (s *ClassicNameServer) buildMsgs(domain string) []*dns.Msg {
 
 	if !allowMulti {
 		msg := new(dns.Msg)
-		msg.Id = uint16(atomic.AddUint32(&s.reqID, 1))
+		msg.Id = s.addPendingRequest(domain)
 		msg.RecursionDesired = true
 		msg.Question = []dns.Question{qAAAA}
 		if opt := s.getMsgOptions(); opt != nil {