udpns.go 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. // +build !confonly
  2. package dns
  3. import (
  4. "context"
  5. "encoding/binary"
  6. fmt "fmt"
  7. "sync"
  8. "sync/atomic"
  9. "time"
  10. "golang.org/x/net/dns/dnsmessage"
  11. "v2ray.com/core/common"
  12. "v2ray.com/core/common/errors"
  13. "v2ray.com/core/common/net"
  14. "v2ray.com/core/common/protocol/dns"
  15. udp_proto "v2ray.com/core/common/protocol/udp"
  16. "v2ray.com/core/common/session"
  17. "v2ray.com/core/common/signal/pubsub"
  18. "v2ray.com/core/common/task"
  19. dns_feature "v2ray.com/core/features/dns"
  20. "v2ray.com/core/features/routing"
  21. "v2ray.com/core/transport/internet/udp"
  22. )
  23. type record struct {
  24. A *IPRecord
  25. AAAA *IPRecord
  26. }
  27. type IPRecord struct {
  28. IP []net.Address
  29. Expire time.Time
  30. RCode dnsmessage.RCode
  31. }
  32. func (r *IPRecord) getIPs() ([]net.Address, error) {
  33. if r == nil || r.Expire.Before(time.Now()) {
  34. return nil, errRecordNotFound
  35. }
  36. if r.RCode != dnsmessage.RCodeSuccess {
  37. return nil, dns_feature.RCodeError(r.RCode)
  38. }
  39. return r.IP, nil
  40. }
  41. type pendingRequest struct {
  42. domain string
  43. expire time.Time
  44. recType dnsmessage.Type
  45. }
  46. var (
  47. errRecordNotFound = errors.New("record not found")
  48. )
  49. type ClassicNameServer struct {
  50. sync.RWMutex
  51. address net.Destination
  52. ips map[string]record
  53. requests map[uint16]pendingRequest
  54. pub *pubsub.Service
  55. udpServer *udp.Dispatcher
  56. cleanup *task.Periodic
  57. reqID uint32
  58. clientIP net.IP
  59. }
  60. func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer {
  61. s := &ClassicNameServer{
  62. address: address,
  63. ips: make(map[string]record),
  64. requests: make(map[uint16]pendingRequest),
  65. clientIP: clientIP,
  66. pub: pubsub.NewService(),
  67. }
  68. s.cleanup = &task.Periodic{
  69. Interval: time.Minute,
  70. Execute: s.Cleanup,
  71. }
  72. s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse)
  73. return s
  74. }
  75. func (s *ClassicNameServer) Name() string {
  76. return s.address.String()
  77. }
  78. func (s *ClassicNameServer) Cleanup() error {
  79. now := time.Now()
  80. s.Lock()
  81. defer s.Unlock()
  82. if len(s.ips) == 0 && len(s.requests) == 0 {
  83. return newError("nothing to do. stopping...")
  84. }
  85. for domain, record := range s.ips {
  86. if record.A != nil && record.A.Expire.Before(now) {
  87. record.A = nil
  88. }
  89. if record.AAAA != nil && record.AAAA.Expire.Before(now) {
  90. record.AAAA = nil
  91. }
  92. if record.A == nil && record.AAAA == nil {
  93. delete(s.ips, domain)
  94. } else {
  95. s.ips[domain] = record
  96. }
  97. }
  98. if len(s.ips) == 0 {
  99. s.ips = make(map[string]record)
  100. }
  101. for id, req := range s.requests {
  102. if req.expire.Before(now) {
  103. delete(s.requests, id)
  104. }
  105. }
  106. if len(s.requests) == 0 {
  107. s.requests = make(map[uint16]pendingRequest)
  108. }
  109. return nil
  110. }
  111. func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
  112. payload := packet.Payload
  113. var parser dnsmessage.Parser
  114. header, err := parser.Start(payload.Bytes())
  115. if err != nil {
  116. newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog()
  117. return
  118. }
  119. if err := parser.SkipAllQuestions(); err != nil {
  120. newError("failed to skip questions in DNS response").Base(err).AtWarning().WriteToLog()
  121. return
  122. }
  123. id := header.ID
  124. s.Lock()
  125. req, f := s.requests[id]
  126. if f {
  127. delete(s.requests, id)
  128. }
  129. s.Unlock()
  130. if !f {
  131. return
  132. }
  133. domain := req.domain
  134. recType := req.recType
  135. now := time.Now()
  136. ipRecord := &IPRecord{
  137. RCode: header.RCode,
  138. Expire: now.Add(time.Second * 600),
  139. }
  140. for {
  141. header, err := parser.AnswerHeader()
  142. if err != nil {
  143. if err != dnsmessage.ErrSectionDone {
  144. newError("failed to parse answer section for domain: ", domain).Base(err).WriteToLog()
  145. }
  146. break
  147. }
  148. ttl := header.TTL
  149. if ttl == 0 {
  150. ttl = 600
  151. }
  152. expire := now.Add(time.Duration(ttl) * time.Second)
  153. if ipRecord.Expire.After(expire) {
  154. ipRecord.Expire = expire
  155. }
  156. if header.Type != recType {
  157. continue
  158. }
  159. switch header.Type {
  160. case dnsmessage.TypeA:
  161. ans, err := parser.AResource()
  162. if err != nil {
  163. newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
  164. break
  165. }
  166. ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
  167. case dnsmessage.TypeAAAA:
  168. ans, err := parser.AAAAResource()
  169. if err != nil {
  170. newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
  171. break
  172. }
  173. ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
  174. default:
  175. if err := parser.SkipAnswer(); err != nil {
  176. newError("failed to skip answer").Base(err).WriteToLog()
  177. }
  178. }
  179. }
  180. var rec record
  181. switch recType {
  182. case dnsmessage.TypeA:
  183. rec.A = ipRecord
  184. case dnsmessage.TypeAAAA:
  185. rec.AAAA = ipRecord
  186. }
  187. if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
  188. s.updateIP(domain, rec)
  189. }
  190. }
  191. func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
  192. if newRec == nil {
  193. return false
  194. }
  195. if baseRec == nil {
  196. return true
  197. }
  198. return baseRec.Expire.Before(newRec.Expire)
  199. }
  200. func (s *ClassicNameServer) updateIP(domain string, newRec record) {
  201. s.Lock()
  202. newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
  203. rec := s.ips[domain]
  204. updated := false
  205. if isNewer(rec.A, newRec.A) {
  206. rec.A = newRec.A
  207. updated = true
  208. }
  209. if isNewer(rec.AAAA, newRec.AAAA) {
  210. rec.AAAA = newRec.AAAA
  211. updated = true
  212. }
  213. if updated {
  214. s.ips[domain] = rec
  215. s.pub.Publish(domain, nil)
  216. }
  217. s.Unlock()
  218. common.Must(s.cleanup.Start())
  219. }
  220. func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource {
  221. if len(s.clientIP) == 0 {
  222. return nil
  223. }
  224. var netmask int
  225. var family uint16
  226. if len(s.clientIP) == 4 {
  227. family = 1
  228. netmask = 24 // 24 for IPV4, 96 for IPv6
  229. } else {
  230. family = 2
  231. netmask = 96
  232. }
  233. b := make([]byte, 4)
  234. binary.BigEndian.PutUint16(b[0:], family)
  235. b[2] = byte(netmask)
  236. b[3] = 0
  237. switch family {
  238. case 1:
  239. ip := s.clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
  240. needLength := (netmask + 8 - 1) / 8 // division rounding up
  241. b = append(b, ip[:needLength]...)
  242. case 2:
  243. ip := s.clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
  244. needLength := (netmask + 8 - 1) / 8 // division rounding up
  245. b = append(b, ip[:needLength]...)
  246. }
  247. const EDNS0SUBNET = 0x08
  248. opt := new(dnsmessage.Resource)
  249. common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
  250. opt.Body = &dnsmessage.OPTResource{
  251. Options: []dnsmessage.Option{
  252. {
  253. Code: EDNS0SUBNET,
  254. Data: b,
  255. },
  256. },
  257. }
  258. return opt
  259. }
  260. func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 {
  261. id := uint16(atomic.AddUint32(&s.reqID, 1))
  262. s.Lock()
  263. defer s.Unlock()
  264. s.requests[id] = pendingRequest{
  265. domain: domain,
  266. expire: time.Now().Add(time.Second * 8),
  267. recType: recType,
  268. }
  269. return id
  270. }
  271. func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmessage.Message {
  272. qA := dnsmessage.Question{
  273. Name: dnsmessage.MustNewName(domain),
  274. Type: dnsmessage.TypeA,
  275. Class: dnsmessage.ClassINET,
  276. }
  277. qAAAA := dnsmessage.Question{
  278. Name: dnsmessage.MustNewName(domain),
  279. Type: dnsmessage.TypeAAAA,
  280. Class: dnsmessage.ClassINET,
  281. }
  282. var msgs []*dnsmessage.Message
  283. if option.IPv4Enable {
  284. msg := new(dnsmessage.Message)
  285. msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA)
  286. msg.Header.RecursionDesired = true
  287. msg.Questions = []dnsmessage.Question{qA}
  288. if opt := s.getMsgOptions(); opt != nil {
  289. msg.Additionals = append(msg.Additionals, *opt)
  290. }
  291. msgs = append(msgs, msg)
  292. }
  293. if option.IPv6Enable {
  294. msg := new(dnsmessage.Message)
  295. msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA)
  296. msg.Header.RecursionDesired = true
  297. msg.Questions = []dnsmessage.Question{qAAAA}
  298. if opt := s.getMsgOptions(); opt != nil {
  299. msg.Additionals = append(msg.Additionals, *opt)
  300. }
  301. msgs = append(msgs, msg)
  302. }
  303. return msgs
  304. }
  305. func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
  306. newError("querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
  307. msgs := s.buildMsgs(domain, option)
  308. for _, msg := range msgs {
  309. b, err := dns.PackMessage(msg)
  310. common.Must(err)
  311. udpCtx := context.Background()
  312. if inbound := session.InboundFromContext(ctx); inbound != nil {
  313. udpCtx = session.ContextWithInbound(udpCtx, inbound)
  314. }
  315. s.udpServer.Dispatch(udpCtx, s.address, b)
  316. }
  317. }
  318. func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) {
  319. s.RLock()
  320. record, found := s.ips[domain]
  321. s.RUnlock()
  322. if !found {
  323. return nil, errRecordNotFound
  324. }
  325. var ips []net.Address
  326. var lastErr error
  327. if option.IPv4Enable {
  328. a, err := record.A.getIPs()
  329. if err != nil {
  330. lastErr = err
  331. }
  332. ips = append(ips, a...)
  333. }
  334. if option.IPv6Enable {
  335. aaaa, err := record.AAAA.getIPs()
  336. if err != nil {
  337. lastErr = err
  338. }
  339. ips = append(ips, aaaa...)
  340. }
  341. fmt.Println("IPs for ", domain, ": ", ips)
  342. if len(ips) > 0 {
  343. return toNetIP(ips), nil
  344. }
  345. if lastErr != nil {
  346. return nil, lastErr
  347. }
  348. return nil, dns_feature.ErrEmptyResponse
  349. }
  350. func Fqdn(domain string) string {
  351. if len(domain) > 0 && domain[len(domain)-1] == '.' {
  352. return domain
  353. }
  354. return domain + "."
  355. }
  356. func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
  357. fqdn := Fqdn(domain)
  358. ips, err := s.findIPsForDomain(fqdn, option)
  359. if err != errRecordNotFound {
  360. return ips, err
  361. }
  362. sub := s.pub.Subscribe(fqdn)
  363. defer sub.Close()
  364. s.sendQuery(ctx, fqdn, option)
  365. for {
  366. ips, err := s.findIPsForDomain(fqdn, option)
  367. if err != errRecordNotFound {
  368. return ips, err
  369. }
  370. select {
  371. case <-ctx.Done():
  372. return nil, ctx.Err()
  373. case <-sub.Wait():
  374. }
  375. }
  376. }