dnscommon.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. // +build !confonly
  2. package dns
  3. import (
  4. "encoding/binary"
  5. "time"
  6. "golang.org/x/net/dns/dnsmessage"
  7. "v2ray.com/core/common"
  8. "v2ray.com/core/common/errors"
  9. "v2ray.com/core/common/net"
  10. dns_feature "v2ray.com/core/features/dns"
  11. )
  12. func Fqdn(domain string) string {
  13. if len(domain) > 0 && domain[len(domain)-1] == '.' {
  14. return domain
  15. }
  16. return domain + "."
  17. }
  18. type record struct {
  19. A *IPRecord
  20. AAAA *IPRecord
  21. }
  22. type IPRecord struct {
  23. ReqID uint16
  24. IP []net.Address
  25. Expire time.Time
  26. RCode dnsmessage.RCode
  27. }
  28. func (r *IPRecord) getIPs() ([]net.Address, error) {
  29. if r == nil || r.Expire.Before(time.Now()) {
  30. return nil, errRecordNotFound
  31. }
  32. if r.RCode != dnsmessage.RCodeSuccess {
  33. return nil, dns_feature.RCodeError(r.RCode)
  34. }
  35. return r.IP, nil
  36. }
  37. func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
  38. if newRec == nil {
  39. return false
  40. }
  41. if baseRec == nil {
  42. return true
  43. }
  44. return baseRec.Expire.Before(newRec.Expire)
  45. }
  46. var (
  47. errRecordNotFound = errors.New("record not found")
  48. )
  49. type dnsRequest struct {
  50. reqType dnsmessage.Type
  51. domain string
  52. start time.Time
  53. expire time.Time
  54. msg *dnsmessage.Message
  55. }
  56. func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource {
  57. if len(clientIP) == 0 {
  58. return nil
  59. }
  60. var netmask int
  61. var family uint16
  62. if len(clientIP) == 4 {
  63. family = 1
  64. netmask = 24 // 24 for IPV4, 96 for IPv6
  65. } else {
  66. family = 2
  67. netmask = 96
  68. }
  69. b := make([]byte, 4)
  70. binary.BigEndian.PutUint16(b[0:], family)
  71. b[2] = byte(netmask)
  72. b[3] = 0
  73. switch family {
  74. case 1:
  75. ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
  76. needLength := (netmask + 8 - 1) / 8 // division rounding up
  77. b = append(b, ip[:needLength]...)
  78. case 2:
  79. ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
  80. needLength := (netmask + 8 - 1) / 8 // division rounding up
  81. b = append(b, ip[:needLength]...)
  82. }
  83. const EDNS0SUBNET = 0x08
  84. opt := new(dnsmessage.Resource)
  85. common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
  86. opt.Body = &dnsmessage.OPTResource{
  87. Options: []dnsmessage.Option{
  88. {
  89. Code: EDNS0SUBNET,
  90. Data: b,
  91. },
  92. },
  93. }
  94. return opt
  95. }
  96. func buildReqMsgs(domain string, option IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
  97. qA := dnsmessage.Question{
  98. Name: dnsmessage.MustNewName(domain),
  99. Type: dnsmessage.TypeA,
  100. Class: dnsmessage.ClassINET,
  101. }
  102. qAAAA := dnsmessage.Question{
  103. Name: dnsmessage.MustNewName(domain),
  104. Type: dnsmessage.TypeAAAA,
  105. Class: dnsmessage.ClassINET,
  106. }
  107. var reqs []*dnsRequest
  108. now := time.Now()
  109. if option.IPv4Enable {
  110. msg := new(dnsmessage.Message)
  111. msg.Header.ID = reqIDGen()
  112. msg.Header.RecursionDesired = true
  113. msg.Questions = []dnsmessage.Question{qA}
  114. if reqOpts != nil {
  115. msg.Additionals = append(msg.Additionals, *reqOpts)
  116. }
  117. reqs = append(reqs, &dnsRequest{
  118. reqType: dnsmessage.TypeA,
  119. domain: domain,
  120. start: now,
  121. msg: msg,
  122. })
  123. }
  124. if option.IPv6Enable {
  125. msg := new(dnsmessage.Message)
  126. msg.Header.ID = reqIDGen()
  127. msg.Header.RecursionDesired = true
  128. msg.Questions = []dnsmessage.Question{qAAAA}
  129. if reqOpts != nil {
  130. msg.Additionals = append(msg.Additionals, *reqOpts)
  131. }
  132. reqs = append(reqs, &dnsRequest{
  133. reqType: dnsmessage.TypeAAAA,
  134. domain: domain,
  135. start: now,
  136. msg: msg,
  137. })
  138. }
  139. return reqs
  140. }
  141. // parseResponse parse DNS answers from the returned payload
  142. func parseResponse(payload []byte) (*IPRecord, error) {
  143. var parser dnsmessage.Parser
  144. h, err := parser.Start(payload)
  145. if err != nil {
  146. return nil, newError("failed to parse DNS response").Base(err).AtWarning()
  147. }
  148. if err := parser.SkipAllQuestions(); err != nil {
  149. return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning()
  150. }
  151. now := time.Now()
  152. var ipRecExpire time.Time
  153. if h.RCode != dnsmessage.RCodeSuccess {
  154. // A default TTL, maybe a negtive cache
  155. ipRecExpire = now.Add(time.Second * 120)
  156. }
  157. ipRecord := &IPRecord{
  158. ReqID: h.ID,
  159. RCode: h.RCode,
  160. Expire: ipRecExpire,
  161. }
  162. L:
  163. for {
  164. ah, err := parser.AnswerHeader()
  165. if err != nil {
  166. if err != dnsmessage.ErrSectionDone {
  167. newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog()
  168. }
  169. break
  170. }
  171. switch ah.Type {
  172. case dnsmessage.TypeA:
  173. ans, err := parser.AResource()
  174. if err != nil {
  175. newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
  176. break L
  177. }
  178. ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
  179. case dnsmessage.TypeAAAA:
  180. ans, err := parser.AAAAResource()
  181. if err != nil {
  182. newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
  183. break L
  184. }
  185. ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
  186. default:
  187. if err := parser.SkipAnswer(); err != nil {
  188. newError("failed to skip answer").Base(err).WriteToLog()
  189. break L
  190. }
  191. continue
  192. }
  193. if ipRecord.Expire.IsZero() {
  194. ttl := ah.TTL
  195. if ttl < 600 {
  196. // at least 10 mins TTL
  197. ipRecord.Expire = now.Add(time.Minute * 10)
  198. } else {
  199. ipRecord.Expire = now.Add(time.Duration(ttl) * time.Second)
  200. }
  201. }
  202. }
  203. return ipRecord, nil
  204. }