|
|
@@ -29,10 +29,16 @@ type IPRecord struct {
|
|
|
Expire time.Time
|
|
|
}
|
|
|
|
|
|
+type pendingRequest struct {
|
|
|
+ domain string
|
|
|
+ expire time.Time
|
|
|
+}
|
|
|
+
|
|
|
type ClassicNameServer struct {
|
|
|
sync.RWMutex
|
|
|
address net.Destination
|
|
|
ips map[string][]IPRecord
|
|
|
+ requests map[uint16]pendingRequest
|
|
|
pub *pubsub.Service
|
|
|
udpServer *udp.Dispatcher
|
|
|
cleanup *task.Periodic
|
|
|
@@ -44,6 +50,7 @@ func NewClassicNameServer(address net.Destination, dispatcher core.Dispatcher, c
|
|
|
s := &ClassicNameServer{
|
|
|
address: address,
|
|
|
ips: make(map[string][]IPRecord),
|
|
|
+ requests: make(map[uint16]pendingRequest),
|
|
|
udpServer: udp.NewDispatcher(dispatcher),
|
|
|
clientIP: clientIP,
|
|
|
pub: pubsub.NewService(),
|
|
|
@@ -77,6 +84,16 @@ func (s *ClassicNameServer) Cleanup() error {
|
|
|
s.ips = make(map[string][]IPRecord)
|
|
|
}
|
|
|
|
|
|
+ for id, req := range s.requests {
|
|
|
+ if req.expire.Before(now) {
|
|
|
+ delete(s.requests, id)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(s.requests) == 0 {
|
|
|
+ s.requests = make(map[uint16]pendingRequest)
|
|
|
+ }
|
|
|
+
|
|
|
s.Unlock()
|
|
|
return nil
|
|
|
}
|
|
|
@@ -91,16 +108,24 @@ func (s *ClassicNameServer) HandleResponse(payload *buf.Buffer) {
|
|
|
return
|
|
|
}
|
|
|
|
|
|
- var domain string
|
|
|
+ id := msg.Id
|
|
|
+ s.Lock()
|
|
|
+ req, f := s.requests[id]
|
|
|
+ if f {
|
|
|
+ delete(s.requests, id)
|
|
|
+ }
|
|
|
+ s.Unlock()
|
|
|
+
|
|
|
+ if !f {
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ domain := req.domain
|
|
|
ips := make([]IPRecord, 0, 16)
|
|
|
|
|
|
now := time.Now()
|
|
|
for _, rr := range msg.Answer {
|
|
|
var ip net.IP
|
|
|
- name := rr.Header().Name
|
|
|
- if len(name) > 0 {
|
|
|
- domain = rr.Header().Name
|
|
|
- }
|
|
|
ttl := rr.Header().Ttl
|
|
|
switch rr := rr.(type) {
|
|
|
case *dns.A:
|
|
|
@@ -164,7 +189,19 @@ func (s *ClassicNameServer) getMsgOptions() *dns.OPT {
|
|
|
o.Option = append(o.Option, e)
|
|
|
|
|
|
return o
|
|
|
+}
|
|
|
+
|
|
|
+func (s *ClassicNameServer) addPendingRequest(domain string) uint16 {
|
|
|
+ id := uint16(atomic.AddUint32(&s.reqID, 1))
|
|
|
+ s.Lock()
|
|
|
+ defer s.Unlock()
|
|
|
+
|
|
|
+ s.requests[id] = pendingRequest{
|
|
|
+ domain: domain,
|
|
|
+ expire: time.Now().Add(time.Second * 8),
|
|
|
+ }
|
|
|
|
|
|
+ return id
|
|
|
}
|
|
|
|
|
|
func (s *ClassicNameServer) buildMsgs(domain string) []*dns.Msg {
|
|
|
@@ -186,7 +223,7 @@ func (s *ClassicNameServer) buildMsgs(domain string) []*dns.Msg {
|
|
|
|
|
|
{
|
|
|
msg := new(dns.Msg)
|
|
|
- msg.Id = uint16(atomic.AddUint32(&s.reqID, 1))
|
|
|
+ msg.Id = s.addPendingRequest(domain)
|
|
|
msg.RecursionDesired = true
|
|
|
msg.Question = []dns.Question{qA}
|
|
|
if allowMulti {
|
|
|
@@ -200,7 +237,7 @@ func (s *ClassicNameServer) buildMsgs(domain string) []*dns.Msg {
|
|
|
|
|
|
if !allowMulti {
|
|
|
msg := new(dns.Msg)
|
|
|
- msg.Id = uint16(atomic.AddUint32(&s.reqID, 1))
|
|
|
+ msg.Id = s.addPendingRequest(domain)
|
|
|
msg.RecursionDesired = true
|
|
|
msg.Question = []dns.Question{qAAAA}
|
|
|
if opt := s.getMsgOptions(); opt != nil {
|