dnscommon.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. // +build !confonly
  2. package dns
  3. import (
  4. "encoding/binary"
  5. "strings"
  6. "time"
  7. "golang.org/x/net/dns/dnsmessage"
  8. "github.com/v2fly/v2ray-core/v4/common"
  9. "github.com/v2fly/v2ray-core/v4/common/errors"
  10. "github.com/v2fly/v2ray-core/v4/common/net"
  11. dns_feature "github.com/v2fly/v2ray-core/v4/features/dns"
  12. )
  13. // Fqdn normalize domain make sure it ends with '.'
  14. func Fqdn(domain string) string {
  15. if len(domain) > 0 && strings.HasSuffix(domain, ".") {
  16. return domain
  17. }
  18. return domain + "."
  19. }
  20. type record struct {
  21. A *IPRecord
  22. AAAA *IPRecord
  23. }
  24. // IPRecord is a cacheable item for a resolved domain
  25. type IPRecord struct {
  26. ReqID uint16
  27. IP []net.Address
  28. Expire time.Time
  29. RCode dnsmessage.RCode
  30. }
  31. func (r *IPRecord) getIPs() ([]net.Address, error) {
  32. if r == nil || r.Expire.Before(time.Now()) {
  33. return nil, errRecordNotFound
  34. }
  35. if r.RCode != dnsmessage.RCodeSuccess {
  36. return nil, dns_feature.RCodeError(r.RCode)
  37. }
  38. return r.IP, nil
  39. }
  40. func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
  41. if newRec == nil {
  42. return false
  43. }
  44. if baseRec == nil {
  45. return true
  46. }
  47. return baseRec.Expire.Before(newRec.Expire)
  48. }
  49. var errRecordNotFound = errors.New("record not found")
  50. type dnsRequest struct {
  51. reqType dnsmessage.Type
  52. domain string
  53. start time.Time
  54. expire time.Time
  55. msg *dnsmessage.Message
  56. }
  57. func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource {
  58. if len(clientIP) == 0 {
  59. return nil
  60. }
  61. var netmask int
  62. var family uint16
  63. if len(clientIP) == 4 {
  64. family = 1
  65. netmask = 24 // 24 for IPV4, 96 for IPv6
  66. } else {
  67. family = 2
  68. netmask = 96
  69. }
  70. b := make([]byte, 4)
  71. binary.BigEndian.PutUint16(b[0:], family)
  72. b[2] = byte(netmask)
  73. b[3] = 0
  74. switch family {
  75. case 1:
  76. ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
  77. needLength := (netmask + 8 - 1) / 8 // division rounding up
  78. b = append(b, ip[:needLength]...)
  79. case 2:
  80. ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
  81. needLength := (netmask + 8 - 1) / 8 // division rounding up
  82. b = append(b, ip[:needLength]...)
  83. }
  84. const EDNS0SUBNET = 0x08
  85. opt := new(dnsmessage.Resource)
  86. common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
  87. opt.Body = &dnsmessage.OPTResource{
  88. Options: []dnsmessage.Option{
  89. {
  90. Code: EDNS0SUBNET,
  91. Data: b,
  92. },
  93. },
  94. }
  95. return opt
  96. }
  97. func buildReqMsgs(domain string, option dns_feature.IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
  98. qA := dnsmessage.Question{
  99. Name: dnsmessage.MustNewName(domain),
  100. Type: dnsmessage.TypeA,
  101. Class: dnsmessage.ClassINET,
  102. }
  103. qAAAA := dnsmessage.Question{
  104. Name: dnsmessage.MustNewName(domain),
  105. Type: dnsmessage.TypeAAAA,
  106. Class: dnsmessage.ClassINET,
  107. }
  108. var reqs []*dnsRequest
  109. now := time.Now()
  110. if option.IPv4Enable {
  111. msg := new(dnsmessage.Message)
  112. msg.Header.ID = reqIDGen()
  113. msg.Header.RecursionDesired = true
  114. msg.Questions = []dnsmessage.Question{qA}
  115. if reqOpts != nil {
  116. msg.Additionals = append(msg.Additionals, *reqOpts)
  117. }
  118. reqs = append(reqs, &dnsRequest{
  119. reqType: dnsmessage.TypeA,
  120. domain: domain,
  121. start: now,
  122. msg: msg,
  123. })
  124. }
  125. if option.IPv6Enable {
  126. msg := new(dnsmessage.Message)
  127. msg.Header.ID = reqIDGen()
  128. msg.Header.RecursionDesired = true
  129. msg.Questions = []dnsmessage.Question{qAAAA}
  130. if reqOpts != nil {
  131. msg.Additionals = append(msg.Additionals, *reqOpts)
  132. }
  133. reqs = append(reqs, &dnsRequest{
  134. reqType: dnsmessage.TypeAAAA,
  135. domain: domain,
  136. start: now,
  137. msg: msg,
  138. })
  139. }
  140. return reqs
  141. }
  142. // parseResponse parse DNS answers from the returned payload
  143. func parseResponse(payload []byte) (*IPRecord, error) {
  144. var parser dnsmessage.Parser
  145. h, err := parser.Start(payload)
  146. if err != nil {
  147. return nil, newError("failed to parse DNS response").Base(err).AtWarning()
  148. }
  149. if err := parser.SkipAllQuestions(); err != nil {
  150. return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning()
  151. }
  152. now := time.Now()
  153. ipRecord := &IPRecord{
  154. ReqID: h.ID,
  155. RCode: h.RCode,
  156. Expire: now.Add(time.Second * 600),
  157. }
  158. L:
  159. for {
  160. ah, err := parser.AnswerHeader()
  161. if err != nil {
  162. if err != dnsmessage.ErrSectionDone {
  163. newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog()
  164. }
  165. break
  166. }
  167. ttl := ah.TTL
  168. if ttl == 0 {
  169. ttl = 600
  170. }
  171. expire := now.Add(time.Duration(ttl) * time.Second)
  172. if ipRecord.Expire.After(expire) {
  173. ipRecord.Expire = expire
  174. }
  175. switch ah.Type {
  176. case dnsmessage.TypeA:
  177. ans, err := parser.AResource()
  178. if err != nil {
  179. newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
  180. break L
  181. }
  182. ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
  183. case dnsmessage.TypeAAAA:
  184. ans, err := parser.AAAAResource()
  185. if err != nil {
  186. newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
  187. break L
  188. }
  189. ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
  190. default:
  191. if err := parser.SkipAnswer(); err != nil {
  192. newError("failed to skip answer").Base(err).WriteToLog()
  193. break L
  194. }
  195. continue
  196. }
  197. }
  198. return ipRecord, nil
  199. }