dnscommon.go 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. // +build !confonly
  2. package dns
  3. import (
  4. "encoding/binary"
  5. "time"
  6. "golang.org/x/net/dns/dnsmessage"
  7. "v2ray.com/core/common"
  8. "v2ray.com/core/common/errors"
  9. "v2ray.com/core/common/net"
  10. dns_feature "v2ray.com/core/features/dns"
  11. )
  12. // Fqdn normalize domain make sure it ends with '.'
  13. func Fqdn(domain string) string {
  14. if len(domain) > 0 && domain[len(domain)-1] == '.' {
  15. return domain
  16. }
  17. return domain + "."
  18. }
  19. type record struct {
  20. A *IPRecord
  21. AAAA *IPRecord
  22. }
  23. // IPRecord is a cacheable item for a resolved domain
  24. type IPRecord struct {
  25. ReqID uint16
  26. IP []net.Address
  27. Expire time.Time
  28. RCode dnsmessage.RCode
  29. }
  30. func (r *IPRecord) getIPs() ([]net.Address, error) {
  31. if r == nil || r.Expire.Before(time.Now()) {
  32. return nil, errRecordNotFound
  33. }
  34. if r.RCode != dnsmessage.RCodeSuccess {
  35. return nil, dns_feature.RCodeError(r.RCode)
  36. }
  37. return r.IP, nil
  38. }
  39. func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
  40. if newRec == nil {
  41. return false
  42. }
  43. if baseRec == nil {
  44. return true
  45. }
  46. return baseRec.Expire.Before(newRec.Expire)
  47. }
  48. var (
  49. errRecordNotFound = errors.New("record not found")
  50. )
  51. type dnsRequest struct {
  52. reqType dnsmessage.Type
  53. domain string
  54. start time.Time
  55. expire time.Time
  56. msg *dnsmessage.Message
  57. }
  58. func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource {
  59. if len(clientIP) == 0 {
  60. return nil
  61. }
  62. var netmask int
  63. var family uint16
  64. if len(clientIP) == 4 {
  65. family = 1
  66. netmask = 24 // 24 for IPV4, 96 for IPv6
  67. } else {
  68. family = 2
  69. netmask = 96
  70. }
  71. b := make([]byte, 4)
  72. binary.BigEndian.PutUint16(b[0:], family)
  73. b[2] = byte(netmask)
  74. b[3] = 0
  75. switch family {
  76. case 1:
  77. ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
  78. needLength := (netmask + 8 - 1) / 8 // division rounding up
  79. b = append(b, ip[:needLength]...)
  80. case 2:
  81. ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
  82. needLength := (netmask + 8 - 1) / 8 // division rounding up
  83. b = append(b, ip[:needLength]...)
  84. }
  85. const EDNS0SUBNET = 0x08
  86. opt := new(dnsmessage.Resource)
  87. common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
  88. opt.Body = &dnsmessage.OPTResource{
  89. Options: []dnsmessage.Option{
  90. {
  91. Code: EDNS0SUBNET,
  92. Data: b,
  93. },
  94. },
  95. }
  96. return opt
  97. }
  98. func buildReqMsgs(domain string, option IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
  99. qA := dnsmessage.Question{
  100. Name: dnsmessage.MustNewName(domain),
  101. Type: dnsmessage.TypeA,
  102. Class: dnsmessage.ClassINET,
  103. }
  104. qAAAA := dnsmessage.Question{
  105. Name: dnsmessage.MustNewName(domain),
  106. Type: dnsmessage.TypeAAAA,
  107. Class: dnsmessage.ClassINET,
  108. }
  109. var reqs []*dnsRequest
  110. now := time.Now()
  111. if option.IPv4Enable {
  112. msg := new(dnsmessage.Message)
  113. msg.Header.ID = reqIDGen()
  114. msg.Header.RecursionDesired = true
  115. msg.Questions = []dnsmessage.Question{qA}
  116. if reqOpts != nil {
  117. msg.Additionals = append(msg.Additionals, *reqOpts)
  118. }
  119. reqs = append(reqs, &dnsRequest{
  120. reqType: dnsmessage.TypeA,
  121. domain: domain,
  122. start: now,
  123. msg: msg,
  124. })
  125. }
  126. if option.IPv6Enable {
  127. msg := new(dnsmessage.Message)
  128. msg.Header.ID = reqIDGen()
  129. msg.Header.RecursionDesired = true
  130. msg.Questions = []dnsmessage.Question{qAAAA}
  131. if reqOpts != nil {
  132. msg.Additionals = append(msg.Additionals, *reqOpts)
  133. }
  134. reqs = append(reqs, &dnsRequest{
  135. reqType: dnsmessage.TypeAAAA,
  136. domain: domain,
  137. start: now,
  138. msg: msg,
  139. })
  140. }
  141. return reqs
  142. }
  143. // parseResponse parse DNS answers from the returned payload
  144. func parseResponse(payload []byte) (*IPRecord, error) {
  145. var parser dnsmessage.Parser
  146. h, err := parser.Start(payload)
  147. if err != nil {
  148. return nil, newError("failed to parse DNS response").Base(err).AtWarning()
  149. }
  150. if err := parser.SkipAllQuestions(); err != nil {
  151. return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning()
  152. }
  153. now := time.Now()
  154. var ipRecExpire time.Time
  155. if h.RCode != dnsmessage.RCodeSuccess {
  156. // A default TTL, maybe a negtive cache
  157. ipRecExpire = now.Add(time.Second * 120)
  158. }
  159. ipRecord := &IPRecord{
  160. ReqID: h.ID,
  161. RCode: h.RCode,
  162. Expire: ipRecExpire,
  163. }
  164. L:
  165. for {
  166. ah, err := parser.AnswerHeader()
  167. if err != nil {
  168. if err != dnsmessage.ErrSectionDone {
  169. newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog()
  170. }
  171. break
  172. }
  173. switch ah.Type {
  174. case dnsmessage.TypeA:
  175. ans, err := parser.AResource()
  176. if err != nil {
  177. newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
  178. break L
  179. }
  180. ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
  181. case dnsmessage.TypeAAAA:
  182. ans, err := parser.AAAAResource()
  183. if err != nil {
  184. newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
  185. break L
  186. }
  187. ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
  188. default:
  189. if err := parser.SkipAnswer(); err != nil {
  190. newError("failed to skip answer").Base(err).WriteToLog()
  191. break L
  192. }
  193. continue
  194. }
  195. if ipRecord.Expire.IsZero() {
  196. ttl := ah.TTL
  197. if ttl < 600 {
  198. // at least 10 mins TTL
  199. ipRecord.Expire = now.Add(time.Minute * 10)
  200. } else {
  201. ipRecord.Expire = now.Add(time.Duration(ttl) * time.Second)
  202. }
  203. }
  204. }
  205. return ipRecord, nil
  206. }