|
|
@@ -5,36 +5,60 @@ package dns
|
|
|
import (
|
|
|
"context"
|
|
|
"encoding/binary"
|
|
|
+ fmt "fmt"
|
|
|
"sync"
|
|
|
"sync/atomic"
|
|
|
"time"
|
|
|
|
|
|
"golang.org/x/net/dns/dnsmessage"
|
|
|
"v2ray.com/core/common"
|
|
|
+ "v2ray.com/core/common/errors"
|
|
|
"v2ray.com/core/common/net"
|
|
|
"v2ray.com/core/common/protocol/dns"
|
|
|
udp_proto "v2ray.com/core/common/protocol/udp"
|
|
|
"v2ray.com/core/common/session"
|
|
|
"v2ray.com/core/common/signal/pubsub"
|
|
|
"v2ray.com/core/common/task"
|
|
|
+ dns_feature "v2ray.com/core/features/dns"
|
|
|
"v2ray.com/core/features/routing"
|
|
|
"v2ray.com/core/transport/internet/udp"
|
|
|
)
|
|
|
|
|
|
+type record struct {
|
|
|
+ A *IPRecord
|
|
|
+ AAAA *IPRecord
|
|
|
+}
|
|
|
+
|
|
|
type IPRecord struct {
|
|
|
- IP net.Address
|
|
|
+ IP []net.Address
|
|
|
Expire time.Time
|
|
|
+ RCode dnsmessage.RCode
|
|
|
+}
|
|
|
+
|
|
|
+func (r *IPRecord) getIPs() ([]net.Address, error) {
|
|
|
+ if r == nil || r.Expire.Before(time.Now()) {
|
|
|
+ return nil, errRecordNotFound
|
|
|
+ }
|
|
|
+ if r.RCode != dnsmessage.RCodeSuccess {
|
|
|
+ return nil, dns_feature.RCodeError(r.RCode)
|
|
|
+ }
|
|
|
+ return r.IP, nil
|
|
|
}
|
|
|
|
|
|
type pendingRequest struct {
|
|
|
- domain string
|
|
|
- expire time.Time
|
|
|
+ domain string
|
|
|
+ expire time.Time
|
|
|
+ recType dnsmessage.Type
|
|
|
}
|
|
|
|
|
|
+var (
|
|
|
+ errRecordNotFound = errors.New("record not found")
|
|
|
+)
|
|
|
+
|
|
|
type ClassicNameServer struct {
|
|
|
sync.RWMutex
|
|
|
address net.Destination
|
|
|
- ips map[string][]IPRecord
|
|
|
+ ips map[string]record
|
|
|
requests map[uint16]pendingRequest
|
|
|
pub *pubsub.Service
|
|
|
udpServer *udp.Dispatcher
|
|
|
@@ -46,7 +70,7 @@ type ClassicNameServer struct {
|
|
|
func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
|
|
|
s := &ClassicNameServer{
|
|
|
address: address,
|
|
|
- ips: make(map[string][]IPRecord),
|
|
|
+ ips: make(map[string]record),
|
|
|
requests: make(map[uint16]pendingRequest),
|
|
|
clientIP: clientIP,
|
|
|
pub: pubsub.NewService(),
|
|
|
@@ -72,22 +96,23 @@ func (s *ClassicNameServer) Cleanup() error {
|
|
|
return newError("nothing to do. stopping...")
|
|
|
}
|
|
|
|
|
|
- for domain, ips := range s.ips {
|
|
|
- newIPs := make([]IPRecord, 0, len(ips))
|
|
|
- for _, ip := range ips {
|
|
|
- if ip.Expire.After(now) {
|
|
|
- newIPs = append(newIPs, ip)
|
|
|
- }
|
|
|
+ for domain, record := range s.ips {
|
|
|
+ if record.A != nil && record.A.Expire.Before(now) {
|
|
|
+ record.A = nil
|
|
|
}
|
|
|
- if len(newIPs) == 0 {
|
|
|
+ if record.AAAA != nil && record.AAAA.Expire.Before(now) {
|
|
|
+ record.AAAA = nil
|
|
|
+ }
|
|
|
+
|
|
|
+ if record.A == nil && record.AAAA == nil {
|
|
|
delete(s.ips, domain)
|
|
|
- } else if len(newIPs) < len(ips) {
|
|
|
- s.ips[domain] = newIPs
|
|
|
+ } else {
|
|
|
+ s.ips[domain] = record
|
|
|
}
|
|
|
}
|
|
|
|
|
|
if len(s.ips) == 0 {
|
|
|
- s.ips = make(map[string][]IPRecord)
|
|
|
+ s.ips = make(map[string]record)
|
|
|
}
|
|
|
|
|
|
for id, req := range s.requests {
|
|
|
@@ -130,9 +155,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
|
|
}
|
|
|
|
|
|
domain := req.domain
|
|
|
- ips := make([]IPRecord, 0, 16)
|
|
|
+ recType := req.recType
|
|
|
|
|
|
now := time.Now()
|
|
|
+ ipRecord := &IPRecord{
|
|
|
+ RCode: header.RCode,
|
|
|
+ Expire: now.Add(time.Second * 600),
|
|
|
+ }
|
|
|
+
|
|
|
for {
|
|
|
header, err := parser.AnswerHeader()
|
|
|
if err != nil {
|
|
|
@@ -145,6 +175,15 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
|
|
if ttl == 0 {
|
|
|
ttl = 600
|
|
|
}
|
|
|
+ expire := now.Add(time.Duration(ttl) * time.Second)
|
|
|
+ if ipRecord.Expire.After(expire) {
|
|
|
+ ipRecord.Expire = expire
|
|
|
+ }
|
|
|
+
|
|
|
+ if header.Type != recType {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
switch header.Type {
|
|
|
case dnsmessage.TypeA:
|
|
|
ans, err := parser.AResource()
|
|
|
@@ -152,20 +191,14 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
|
|
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
|
|
|
break
|
|
|
}
|
|
|
- ips = append(ips, IPRecord{
|
|
|
- IP: net.IPAddress(ans.A[:]),
|
|
|
- Expire: now.Add(time.Duration(ttl) * time.Second),
|
|
|
- })
|
|
|
+ ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
|
|
|
case dnsmessage.TypeAAAA:
|
|
|
ans, err := parser.AAAAResource()
|
|
|
if err != nil {
|
|
|
newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
|
|
|
break
|
|
|
}
|
|
|
- ips = append(ips, IPRecord{
|
|
|
- IP: net.IPAddress(ans.AAAA[:]),
|
|
|
- Expire: now.Add(time.Duration(ttl) * time.Second),
|
|
|
- })
|
|
|
+ ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
|
|
|
default:
|
|
|
if err := parser.SkipAnswer(); err != nil {
|
|
|
newError("failed to skip answer").Base(err).WriteToLog()
|
|
|
@@ -173,24 +206,49 @@ func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_prot
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if len(domain) > 0 && len(ips) > 0 {
|
|
|
- s.updateIP(domain, ips)
|
|
|
+ var rec record
|
|
|
+ switch recType {
|
|
|
+ case dnsmessage.TypeA:
|
|
|
+ rec.A = ipRecord
|
|
|
+ case dnsmessage.TypeAAAA:
|
|
|
+ rec.AAAA = ipRecord
|
|
|
+ }
|
|
|
+
|
|
|
+ if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
|
|
|
+ s.updateIP(domain, rec)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
|
|
|
+ if newRec == nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+ if baseRec == nil {
|
|
|
+ return true
|
|
|
}
|
|
|
+ return baseRec.Expire.Before(newRec.Expire)
|
|
|
}
|
|
|
|
|
|
-func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) {
|
|
|
+func (s *ClassicNameServer) updateIP(domain string, newRec record) {
|
|
|
s.Lock()
|
|
|
|
|
|
newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
|
|
|
- now := time.Now()
|
|
|
- eips := s.ips[domain]
|
|
|
- for _, ip := range eips {
|
|
|
- if ip.Expire.After(now) {
|
|
|
- ips = append(ips, ip)
|
|
|
- }
|
|
|
+ rec := s.ips[domain]
|
|
|
+
|
|
|
+ updated := false
|
|
|
+ if isNewer(rec.A, newRec.A) {
|
|
|
+ rec.A = newRec.A
|
|
|
+ updated = true
|
|
|
+ }
|
|
|
+ if isNewer(rec.AAAA, newRec.AAAA) {
|
|
|
+ rec.AAAA = newRec.AAAA
|
|
|
+ updated = true
|
|
|
+ }
|
|
|
+
|
|
|
+ if updated {
|
|
|
+ s.ips[domain] = rec
|
|
|
+ s.pub.Publish(domain, nil)
|
|
|
}
|
|
|
- s.ips[domain] = ips
|
|
|
- s.pub.Publish(domain, nil)
|
|
|
|
|
|
s.Unlock()
|
|
|
common.Must(s.cleanup.Start())
|
|
|
@@ -244,14 +302,15 @@ func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource {
|
|
|
return opt
|
|
|
}
|
|
|
|
|
|
-func (s *ClassicNameServer) addPendingRequest(domain string) uint16 {
|
|
|
+func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 {
|
|
|
id := uint16(atomic.AddUint32(&s.reqID, 1))
|
|
|
s.Lock()
|
|
|
defer s.Unlock()
|
|
|
|
|
|
s.requests[id] = pendingRequest{
|
|
|
- domain: domain,
|
|
|
- expire: time.Now().Add(time.Second * 8),
|
|
|
+ domain: domain,
|
|
|
+ expire: time.Now().Add(time.Second * 8),
|
|
|
+ recType: recType,
|
|
|
}
|
|
|
|
|
|
return id
|
|
|
@@ -274,7 +333,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
|
|
|
|
|
|
if option.IPv4Enable {
|
|
|
msg := new(dnsmessage.Message)
|
|
|
- msg.Header.ID = s.addPendingRequest(domain)
|
|
|
+ msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA)
|
|
|
msg.Header.RecursionDesired = true
|
|
|
msg.Questions = []dnsmessage.Question{qA}
|
|
|
if opt := s.getMsgOptions(); opt != nil {
|
|
|
@@ -285,7 +344,7 @@ func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmess
|
|
|
|
|
|
if option.IPv6Enable {
|
|
|
msg := new(dnsmessage.Message)
|
|
|
- msg.Header.ID = s.addPendingRequest(domain)
|
|
|
+ msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA)
|
|
|
msg.Header.RecursionDesired = true
|
|
|
msg.Questions = []dnsmessage.Question{qAAAA}
|
|
|
if opt := s.getMsgOptions(); opt != nil {
|
|
|
@@ -313,22 +372,44 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) []net.IP {
|
|
|
+func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) {
|
|
|
s.RLock()
|
|
|
- records, found := s.ips[domain]
|
|
|
+ record, found := s.ips[domain]
|
|
|
s.RUnlock()
|
|
|
|
|
|
- if found && len(records) > 0 {
|
|
|
- var ips []net.Address
|
|
|
- now := time.Now()
|
|
|
- for _, rec := range records {
|
|
|
- if rec.Expire.After(now) {
|
|
|
- ips = append(ips, rec.IP)
|
|
|
- }
|
|
|
+ if !found {
|
|
|
+ return nil, errRecordNotFound
|
|
|
+ }
|
|
|
+
|
|
|
+ var ips []net.Address
|
|
|
+ var lastErr error
|
|
|
+ if option.IPv4Enable {
|
|
|
+ a, err := record.A.getIPs()
|
|
|
+ if err != nil {
|
|
|
+ lastErr = err
|
|
|
}
|
|
|
- return toNetIP(filterIP(ips, option))
|
|
|
+ ips = append(ips, a...)
|
|
|
}
|
|
|
- return nil
|
|
|
+
|
|
|
+ if option.IPv6Enable {
|
|
|
+ aaaa, err := record.AAAA.getIPs()
|
|
|
+ if err != nil {
|
|
|
+ lastErr = err
|
|
|
+ }
|
|
|
+ ips = append(ips, aaaa...)
|
|
|
+ }
|
|
|
+
|
|
|
+ fmt.Println("IPs for ", domain, ": ", ips)
|
|
|
+
|
|
|
+ if len(ips) > 0 {
|
|
|
+ return toNetIP(ips), nil
|
|
|
+ }
|
|
|
+
|
|
|
+ if lastErr != nil {
|
|
|
+ return nil, lastErr
|
|
|
+ }
|
|
|
+
|
|
|
+ return nil, dns_feature.ErrEmptyResponse
|
|
|
}
|
|
|
|
|
|
func Fqdn(domain string) string {
|
|
|
@@ -341,9 +422,9 @@ func Fqdn(domain string) string {
|
|
|
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
|
|
|
fqdn := Fqdn(domain)
|
|
|
|
|
|
- ips := s.findIPsForDomain(fqdn, option)
|
|
|
- if len(ips) > 0 {
|
|
|
- return ips, nil
|
|
|
+ ips, err := s.findIPsForDomain(fqdn, option)
|
|
|
+ if err != errRecordNotFound {
|
|
|
+ return ips, err
|
|
|
}
|
|
|
|
|
|
sub := s.pub.Subscribe(fqdn)
|
|
|
@@ -352,9 +433,9 @@ func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option I
|
|
|
s.sendQuery(ctx, fqdn, option)
|
|
|
|
|
|
for {
|
|
|
- ips := s.findIPsForDomain(fqdn, option)
|
|
|
- if len(ips) > 0 {
|
|
|
- return ips, nil
|
|
|
+ ips, err := s.findIPsForDomain(fqdn, option)
|
|
|
+ if err != errRecordNotFound {
|
|
|
+ return ips, err
|
|
|
}
|
|
|
|
|
|
select {
|