|
|
@@ -241,7 +241,7 @@ func (s *ClassicNameServer) addPendingRequest(domain string) uint16 {
|
|
|
return id
|
|
|
}
|
|
|
|
|
|
-func (s *ClassicNameServer) buildMsgs(domain string) []*dnsmessage.Message {
|
|
|
+func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmessage.Message {
|
|
|
qA := dnsmessage.Question{
|
|
|
Name: dnsmessage.MustNewName(domain),
|
|
|
Type: dnsmessage.TypeA,
|
|
|
@@ -256,7 +256,7 @@ func (s *ClassicNameServer) buildMsgs(domain string) []*dnsmessage.Message {
|
|
|
|
|
|
var msgs []*dnsmessage.Message
|
|
|
|
|
|
- {
|
|
|
+ if option.IPv4Enable {
|
|
|
msg := new(dnsmessage.Message)
|
|
|
msg.Header.ID = s.addPendingRequest(domain)
|
|
|
msg.Header.RecursionDesired = true
|
|
|
@@ -267,7 +267,7 @@ func (s *ClassicNameServer) buildMsgs(domain string) []*dnsmessage.Message {
|
|
|
msgs = append(msgs, msg)
|
|
|
}
|
|
|
|
|
|
- {
|
|
|
+ if option.IPv6Enable {
|
|
|
msg := new(dnsmessage.Message)
|
|
|
msg.Header.ID = s.addPendingRequest(domain)
|
|
|
msg.Header.RecursionDesired = true
|
|
|
@@ -281,7 +281,7 @@ func (s *ClassicNameServer) buildMsgs(domain string) []*dnsmessage.Message {
|
|
|
return msgs
|
|
|
}
|
|
|
|
|
|
-func msgToBuffer(msg *dnsmessage.Message) (*buf.Buffer, error) {
|
|
|
+func msgToBuffer2(msg *dnsmessage.Message) (*buf.Buffer, error) {
|
|
|
buffer := buf.New()
|
|
|
rawBytes := buffer.Extend(buf.Size)
|
|
|
packed, err := msg.AppendPack(rawBytes[:0])
|
|
|
@@ -293,19 +293,19 @@ func msgToBuffer(msg *dnsmessage.Message) (*buf.Buffer, error) {
|
|
|
return buffer, nil
|
|
|
}
|
|
|
|
|
|
-func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string) {
|
|
|
+func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
|
|
|
newError("querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
|
|
|
|
|
|
- msgs := s.buildMsgs(domain)
|
|
|
+ msgs := s.buildMsgs(domain, option)
|
|
|
|
|
|
for _, msg := range msgs {
|
|
|
- b, err := msgToBuffer(msg)
|
|
|
+ b, err := msgToBuffer2(msg)
|
|
|
common.Must(err)
|
|
|
s.udpServer.Dispatch(context.Background(), s.address, b)
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (s *ClassicNameServer) findIPsForDomain(domain string) []net.IP {
|
|
|
+func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) []net.IP {
|
|
|
s.RLock()
|
|
|
records, found := s.ips[domain]
|
|
|
s.RUnlock()
|
|
|
@@ -318,7 +318,7 @@ func (s *ClassicNameServer) findIPsForDomain(domain string) []net.IP {
|
|
|
ips = append(ips, rec.IP)
|
|
|
}
|
|
|
}
|
|
|
- return ips
|
|
|
+ return filterIP(ips, option)
|
|
|
}
|
|
|
return nil
|
|
|
}
|
|
|
@@ -330,10 +330,10 @@ func Fqdn(domain string) string {
|
|
|
return domain + "."
|
|
|
}
|
|
|
|
|
|
-func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string) ([]net.IP, error) {
|
|
|
+func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
|
|
|
fqdn := Fqdn(domain)
|
|
|
|
|
|
- ips := s.findIPsForDomain(fqdn)
|
|
|
+ ips := s.findIPsForDomain(fqdn, option)
|
|
|
if len(ips) > 0 {
|
|
|
return ips, nil
|
|
|
}
|
|
|
@@ -341,10 +341,10 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string) ([]net.I
|
|
|
sub := s.pub.Subscribe(fqdn)
|
|
|
defer sub.Close()
|
|
|
|
|
|
- s.sendQuery(ctx, fqdn)
|
|
|
+ s.sendQuery(ctx, fqdn, option)
|
|
|
|
|
|
for {
|
|
|
- ips := s.findIPsForDomain(fqdn)
|
|
|
+ ips := s.findIPsForDomain(fqdn, option)
|
|
|
if len(ips) > 0 {
|
|
|
return ips, nil
|
|
|
}
|