|
@@ -4,14 +4,13 @@ package dns
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
"context"
|
|
"context"
|
|
|
- "encoding/binary"
|
|
|
|
|
|
|
+ "strings"
|
|
|
"sync"
|
|
"sync"
|
|
|
"sync/atomic"
|
|
"sync/atomic"
|
|
|
"time"
|
|
"time"
|
|
|
|
|
|
|
|
"golang.org/x/net/dns/dnsmessage"
|
|
"golang.org/x/net/dns/dnsmessage"
|
|
|
"v2ray.com/core/common"
|
|
"v2ray.com/core/common"
|
|
|
- "v2ray.com/core/common/errors"
|
|
|
|
|
"v2ray.com/core/common/net"
|
|
"v2ray.com/core/common/net"
|
|
|
"v2ray.com/core/common/protocol/dns"
|
|
"v2ray.com/core/common/protocol/dns"
|
|
|
udp_proto "v2ray.com/core/common/protocol/udp"
|
|
udp_proto "v2ray.com/core/common/protocol/udp"
|
|
@@ -23,42 +22,12 @@ import (
|
|
|
"v2ray.com/core/transport/internet/udp"
|
|
"v2ray.com/core/transport/internet/udp"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
-type record struct {
|
|
|
|
|
- A *IPRecord
|
|
|
|
|
- AAAA *IPRecord
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-type IPRecord struct {
|
|
|
|
|
- IP []net.Address
|
|
|
|
|
- Expire time.Time
|
|
|
|
|
- RCode dnsmessage.RCode
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-func (r *IPRecord) getIPs() ([]net.Address, error) {
|
|
|
|
|
- if r == nil || r.Expire.Before(time.Now()) {
|
|
|
|
|
- return nil, errRecordNotFound
|
|
|
|
|
- }
|
|
|
|
|
- if r.RCode != dnsmessage.RCodeSuccess {
|
|
|
|
|
- return nil, dns_feature.RCodeError(r.RCode)
|
|
|
|
|
- }
|
|
|
|
|
- return r.IP, nil
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-type pendingRequest struct {
|
|
|
|
|
- domain string
|
|
|
|
|
- expire time.Time
|
|
|
|
|
- recType dnsmessage.Type
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-var (
|
|
|
|
|
- errRecordNotFound = errors.New("record not found")
|
|
|
|
|
-)
|
|
|
|
|
-
|
|
|
|
|
type ClassicNameServer struct {
|
|
type ClassicNameServer struct {
|
|
|
sync.RWMutex
|
|
sync.RWMutex
|
|
|
|
|
+ name string
|
|
|
address net.Destination
|
|
address net.Destination
|
|
|
ips map[string]record
|
|
ips map[string]record
|
|
|
- requests map[uint16]pendingRequest
|
|
|
|
|
|
|
+ requests map[uint16]dnsRequest
|
|
|
pub *pubsub.Service
|
|
pub *pubsub.Service
|
|
|
udpServer *udp.Dispatcher
|
|
udpServer *udp.Dispatcher
|
|
|
cleanup *task.Periodic
|
|
cleanup *task.Periodic
|
|
@@ -70,9 +39,10 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
|
|
|
s := &ClassicNameServer{
|
|
s := &ClassicNameServer{
|
|
|
address: address,
|
|
address: address,
|
|
|
ips: make(map[string]record),
|
|
ips: make(map[string]record),
|
|
|
- requests: make(map[uint16]pendingRequest),
|
|
|
|
|
|
|
+ requests: make(map[uint16]dnsRequest),
|
|
|
clientIP: clientIP,
|
|
clientIP: clientIP,
|
|
|
pub: pubsub.NewService(),
|
|
pub: pubsub.NewService(),
|
|
|
|
|
+ name: strings.ToUpper(address.String()),
|
|
|
}
|
|
}
|
|
|
s.cleanup = &task.Periodic{
|
|
s.cleanup = &task.Periodic{
|
|
|
Interval: time.Minute,
|
|
Interval: time.Minute,
|
|
@@ -83,7 +53,7 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (s *ClassicNameServer) Name() string {
|
|
func (s *ClassicNameServer) Name() string {
|
|
|
- return s.address.String()
|
|
|
|
|
|
|
+ return s.name
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (s *ClassicNameServer) Cleanup() error {
|
|
func (s *ClassicNameServer) Cleanup() error {
|
|
@@ -92,7 +62,7 @@ func (s *ClassicNameServer) Cleanup() error {
|
|
|
defer s.Unlock()
|
|
defer s.Unlock()
|
|
|
|
|
|
|
|
if len(s.ips) == 0 && len(s.requests) == 0 {
|
|
if len(s.ips) == 0 && len(s.requests) == 0 {
|
|
|
- return newError("nothing to do. stopping...")
|
|
|
|
|
|
|
+ return newError(s.name, " nothing to do. stopping...")
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
for domain, record := range s.ips {
|
|
for domain, record := range s.ips {
|
|
@@ -121,123 +91,52 @@ func (s *ClassicNameServer) Cleanup() error {
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if len(s.requests) == 0 {
|
|
if len(s.requests) == 0 {
|
|
|
- s.requests = make(map[uint16]pendingRequest)
|
|
|
|
|
|
|
+ s.requests = make(map[uint16]dnsRequest)
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
|
|
func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
|
|
|
- payload := packet.Payload
|
|
|
|
|
|
|
|
|
|
- var parser dnsmessage.Parser
|
|
|
|
|
- header, err := parser.Start(payload.Bytes())
|
|
|
|
|
|
|
+ ipRec, err := parseResponse(packet.Payload.Bytes())
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
- newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog()
|
|
|
|
|
- return
|
|
|
|
|
- }
|
|
|
|
|
- if err := parser.SkipAllQuestions(); err != nil {
|
|
|
|
|
- newError("failed to skip questions in DNS response").Base(err).AtWarning().WriteToLog()
|
|
|
|
|
|
|
+ newError(s.name, " fail to parse responsed DNS udp").AtError().WriteToLog()
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- id := header.ID
|
|
|
|
|
s.Lock()
|
|
s.Lock()
|
|
|
- req, f := s.requests[id]
|
|
|
|
|
- if f {
|
|
|
|
|
|
|
+ id := ipRec.ReqID
|
|
|
|
|
+ req, ok := s.requests[id]
|
|
|
|
|
+ if ok {
|
|
|
|
|
+ // remove the pending request
|
|
|
delete(s.requests, id)
|
|
delete(s.requests, id)
|
|
|
}
|
|
}
|
|
|
s.Unlock()
|
|
s.Unlock()
|
|
|
-
|
|
|
|
|
- if !f {
|
|
|
|
|
|
|
+ if !ok {
|
|
|
|
|
+ newError(s.name, " cannot find the pending request").AtError().WriteToLog()
|
|
|
return
|
|
return
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- domain := req.domain
|
|
|
|
|
- recType := req.recType
|
|
|
|
|
-
|
|
|
|
|
- now := time.Now()
|
|
|
|
|
- ipRecord := &IPRecord{
|
|
|
|
|
- RCode: header.RCode,
|
|
|
|
|
- Expire: now.Add(time.Second * 600),
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
-L:
|
|
|
|
|
- for {
|
|
|
|
|
- header, err := parser.AnswerHeader()
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- if err != dnsmessage.ErrSectionDone {
|
|
|
|
|
- newError("failed to parse answer section for domain: ", domain).Base(err).WriteToLog()
|
|
|
|
|
- }
|
|
|
|
|
- break
|
|
|
|
|
- }
|
|
|
|
|
- ttl := header.TTL
|
|
|
|
|
- if ttl == 0 {
|
|
|
|
|
- ttl = 600
|
|
|
|
|
- }
|
|
|
|
|
- expire := now.Add(time.Duration(ttl) * time.Second)
|
|
|
|
|
- if ipRecord.Expire.After(expire) {
|
|
|
|
|
- ipRecord.Expire = expire
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if header.Type != recType {
|
|
|
|
|
- if err := parser.SkipAnswer(); err != nil {
|
|
|
|
|
- newError("failed to skip answer").Base(err).WriteToLog()
|
|
|
|
|
- break L
|
|
|
|
|
- }
|
|
|
|
|
- continue
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- switch header.Type {
|
|
|
|
|
- case dnsmessage.TypeA:
|
|
|
|
|
- ans, err := parser.AResource()
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
|
|
|
|
|
- break L
|
|
|
|
|
- }
|
|
|
|
|
- ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
|
|
|
|
|
- case dnsmessage.TypeAAAA:
|
|
|
|
|
- ans, err := parser.AAAAResource()
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
|
|
|
|
|
- break L
|
|
|
|
|
- }
|
|
|
|
|
- ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
|
|
|
|
|
- default:
|
|
|
|
|
- if err := parser.SkipAnswer(); err != nil {
|
|
|
|
|
- newError("failed to skip answer").Base(err).WriteToLog()
|
|
|
|
|
- break L
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
var rec record
|
|
var rec record
|
|
|
- switch recType {
|
|
|
|
|
|
|
+ switch req.reqType {
|
|
|
case dnsmessage.TypeA:
|
|
case dnsmessage.TypeA:
|
|
|
- rec.A = ipRecord
|
|
|
|
|
|
|
+ rec.A = ipRec
|
|
|
case dnsmessage.TypeAAAA:
|
|
case dnsmessage.TypeAAAA:
|
|
|
- rec.AAAA = ipRecord
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
|
|
|
|
|
- s.updateIP(domain, rec)
|
|
|
|
|
|
|
+ rec.AAAA = ipRec
|
|
|
}
|
|
}
|
|
|
-}
|
|
|
|
|
|
|
|
|
|
-func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
|
|
|
|
|
- if newRec == nil {
|
|
|
|
|
- return false
|
|
|
|
|
|
|
+ elapsed := time.Since(req.start)
|
|
|
|
|
+ newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
|
|
|
|
|
+ if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
|
|
|
|
|
+ s.updateIP(req.domain, rec)
|
|
|
}
|
|
}
|
|
|
- if baseRec == nil {
|
|
|
|
|
- return true
|
|
|
|
|
- }
|
|
|
|
|
- return baseRec.Expire.Before(newRec.Expire)
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (s *ClassicNameServer) updateIP(domain string, newRec record) {
|
|
func (s *ClassicNameServer) updateIP(domain string, newRec record) {
|
|
|
s.Lock()
|
|
s.Lock()
|
|
|
|
|
|
|
|
- newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
|
|
|
|
|
|
|
+ newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog()
|
|
|
rec := s.ips[domain]
|
|
rec := s.ips[domain]
|
|
|
|
|
|
|
|
updated := false
|
|
updated := false
|
|
@@ -259,116 +158,27 @@ func (s *ClassicNameServer) updateIP(domain string, newRec record) {
|
|
|
common.Must(s.cleanup.Start())
|
|
common.Must(s.cleanup.Start())
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource {
|
|
|
|
|
- if len(s.clientIP) == 0 {
|
|
|
|
|
- return nil
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- var netmask int
|
|
|
|
|
- var family uint16
|
|
|
|
|
-
|
|
|
|
|
- if len(s.clientIP) == 4 {
|
|
|
|
|
- family = 1
|
|
|
|
|
- netmask = 24 // 24 for IPV4, 96 for IPv6
|
|
|
|
|
- } else {
|
|
|
|
|
- family = 2
|
|
|
|
|
- netmask = 96
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- b := make([]byte, 4)
|
|
|
|
|
- binary.BigEndian.PutUint16(b[0:], family)
|
|
|
|
|
- b[2] = byte(netmask)
|
|
|
|
|
- b[3] = 0
|
|
|
|
|
- switch family {
|
|
|
|
|
- case 1:
|
|
|
|
|
- ip := s.clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
|
|
|
|
|
- needLength := (netmask + 8 - 1) / 8 // division rounding up
|
|
|
|
|
- b = append(b, ip[:needLength]...)
|
|
|
|
|
- case 2:
|
|
|
|
|
- ip := s.clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
|
|
|
|
|
- needLength := (netmask + 8 - 1) / 8 // division rounding up
|
|
|
|
|
- b = append(b, ip[:needLength]...)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- const EDNS0SUBNET = 0x08
|
|
|
|
|
-
|
|
|
|
|
- opt := new(dnsmessage.Resource)
|
|
|
|
|
- common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
|
|
|
|
|
-
|
|
|
|
|
- opt.Body = &dnsmessage.OPTResource{
|
|
|
|
|
- Options: []dnsmessage.Option{
|
|
|
|
|
- {
|
|
|
|
|
- Code: EDNS0SUBNET,
|
|
|
|
|
- Data: b,
|
|
|
|
|
- },
|
|
|
|
|
- },
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- return opt
|
|
|
|
|
|
|
+func (s *ClassicNameServer) newReqID() uint16 {
|
|
|
|
|
+ return uint16(atomic.AddUint32(&s.reqID, 1))
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 {
|
|
|
|
|
- id := uint16(atomic.AddUint32(&s.reqID, 1))
|
|
|
|
|
|
|
+func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) {
|
|
|
s.Lock()
|
|
s.Lock()
|
|
|
defer s.Unlock()
|
|
defer s.Unlock()
|
|
|
|
|
|
|
|
- s.requests[id] = pendingRequest{
|
|
|
|
|
- domain: domain,
|
|
|
|
|
- expire: time.Now().Add(time.Second * 8),
|
|
|
|
|
- recType: recType,
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- return id
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmessage.Message {
|
|
|
|
|
- qA := dnsmessage.Question{
|
|
|
|
|
- Name: dnsmessage.MustNewName(domain),
|
|
|
|
|
- Type: dnsmessage.TypeA,
|
|
|
|
|
- Class: dnsmessage.ClassINET,
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- qAAAA := dnsmessage.Question{
|
|
|
|
|
- Name: dnsmessage.MustNewName(domain),
|
|
|
|
|
- Type: dnsmessage.TypeAAAA,
|
|
|
|
|
- Class: dnsmessage.ClassINET,
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- var msgs []*dnsmessage.Message
|
|
|
|
|
-
|
|
|
|
|
- if option.IPv4Enable {
|
|
|
|
|
- msg := new(dnsmessage.Message)
|
|
|
|
|
- msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA)
|
|
|
|
|
- msg.Header.RecursionDesired = true
|
|
|
|
|
- msg.Questions = []dnsmessage.Question{qA}
|
|
|
|
|
- if opt := s.getMsgOptions(); opt != nil {
|
|
|
|
|
- msg.Additionals = append(msg.Additionals, *opt)
|
|
|
|
|
- }
|
|
|
|
|
- msgs = append(msgs, msg)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- if option.IPv6Enable {
|
|
|
|
|
- msg := new(dnsmessage.Message)
|
|
|
|
|
- msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA)
|
|
|
|
|
- msg.Header.RecursionDesired = true
|
|
|
|
|
- msg.Questions = []dnsmessage.Question{qAAAA}
|
|
|
|
|
- if opt := s.getMsgOptions(); opt != nil {
|
|
|
|
|
- msg.Additionals = append(msg.Additionals, *opt)
|
|
|
|
|
- }
|
|
|
|
|
- msgs = append(msgs, msg)
|
|
|
|
|
- }
|
|
|
|
|
-
|
|
|
|
|
- return msgs
|
|
|
|
|
|
|
+ id := req.msg.ID
|
|
|
|
|
+ req.expire = time.Now().Add(time.Second * 8)
|
|
|
|
|
+ s.requests[id] = *req
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
|
|
func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
|
|
|
- newError("querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
|
|
|
|
|
-
|
|
|
|
|
- msgs := s.buildMsgs(domain, option)
|
|
|
|
|
|
|
+ newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
|
|
|
|
|
|
|
|
- for _, msg := range msgs {
|
|
|
|
|
- b, _ := dns.PackMessage(msg)
|
|
|
|
|
|
|
+ reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP))
|
|
|
|
|
|
|
|
|
|
+ for _, req := range reqs {
|
|
|
|
|
+ s.addPendingRequest(req)
|
|
|
|
|
+ b, _ := dns.PackMessage(req.msg)
|
|
|
udpCtx := context.Background()
|
|
udpCtx := context.Background()
|
|
|
if inbound := session.InboundFromContext(ctx); inbound != nil {
|
|
if inbound := session.InboundFromContext(ctx); inbound != nil {
|
|
|
udpCtx = session.ContextWithInbound(udpCtx, inbound)
|
|
udpCtx = session.ContextWithInbound(udpCtx, inbound)
|
|
@@ -418,18 +228,13 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]
|
|
|
return nil, dns_feature.ErrEmptyResponse
|
|
return nil, dns_feature.ErrEmptyResponse
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func Fqdn(domain string) string {
|
|
|
|
|
- if len(domain) > 0 && domain[len(domain)-1] == '.' {
|
|
|
|
|
- return domain
|
|
|
|
|
- }
|
|
|
|
|
- return domain + "."
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
|
|
func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
|
|
|
|
|
+
|
|
|
fqdn := Fqdn(domain)
|
|
fqdn := Fqdn(domain)
|
|
|
|
|
|
|
|
ips, err := s.findIPsForDomain(fqdn, option)
|
|
ips, err := s.findIPsForDomain(fqdn, option)
|
|
|
if err != errRecordNotFound {
|
|
if err != errRecordNotFound {
|
|
|
|
|
+ newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
|
|
|
return ips, err
|
|
return ips, err
|
|
|
}
|
|
}
|
|
|
|
|
|