server.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. package dns
  2. //go:generate errorgen
  3. import (
  4. "context"
  5. "sync"
  6. "time"
  7. "v2ray.com/core"
  8. "v2ray.com/core/common"
  9. "v2ray.com/core/common/net"
  10. "v2ray.com/core/common/session"
  11. "v2ray.com/core/common/strmatcher"
  12. "v2ray.com/core/features"
  13. "v2ray.com/core/features/dns"
  14. "v2ray.com/core/features/routing"
  15. )
  16. // Server is a DNS rely server.
  17. type Server struct {
  18. sync.Mutex
  19. hosts *StaticHosts
  20. clients []Client
  21. clientIP net.IP
  22. domainMatcher strmatcher.IndexMatcher
  23. domainIndexMap map[uint32]uint32
  24. tag string
  25. }
  26. // New creates a new DNS server with given configuration.
  27. func New(ctx context.Context, config *Config) (*Server, error) {
  28. server := &Server{
  29. clients: make([]Client, 0, len(config.NameServers)+len(config.NameServer)),
  30. tag: config.Tag,
  31. }
  32. if len(config.ClientIp) > 0 {
  33. if len(config.ClientIp) != 4 && len(config.ClientIp) != 16 {
  34. return nil, newError("unexpected IP length", len(config.ClientIp))
  35. }
  36. server.clientIP = net.IP(config.ClientIp)
  37. }
  38. hosts, err := NewStaticHosts(config.StaticHosts, config.Hosts)
  39. if err != nil {
  40. return nil, newError("failed to create hosts").Base(err)
  41. }
  42. server.hosts = hosts
  43. addNameServer := func(endpoint *net.Endpoint) int {
  44. address := endpoint.Address.AsAddress()
  45. if address.Family().IsDomain() && address.Domain() == "localhost" {
  46. server.clients = append(server.clients, NewLocalNameServer())
  47. } else {
  48. dest := endpoint.AsDestination()
  49. if dest.Network == net.Network_Unknown {
  50. dest.Network = net.Network_UDP
  51. }
  52. if dest.Network == net.Network_UDP {
  53. idx := len(server.clients)
  54. server.clients = append(server.clients, nil)
  55. common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) {
  56. server.clients[idx] = NewClassicNameServer(dest, d, server.clientIP)
  57. }))
  58. }
  59. }
  60. return len(server.clients) - 1
  61. }
  62. if len(config.NameServers) > 0 {
  63. features.PrintDeprecatedFeatureWarning("simple DNS server")
  64. }
  65. for _, destPB := range config.NameServers {
  66. addNameServer(destPB)
  67. }
  68. if len(config.NameServer) > 0 {
  69. domainMatcher := &strmatcher.MatcherGroup{}
  70. domainIndexMap := make(map[uint32]uint32)
  71. for _, ns := range config.NameServer {
  72. idx := addNameServer(ns.Address)
  73. for _, domain := range ns.PrioritizedDomain {
  74. matcher, err := toStrMatcher(domain.Type, domain.Domain)
  75. if err != nil {
  76. return nil, newError("failed to create prioritized domain").Base(err).AtWarning()
  77. }
  78. midx := domainMatcher.Add(matcher)
  79. domainIndexMap[midx] = uint32(idx)
  80. }
  81. }
  82. server.domainMatcher = domainMatcher
  83. server.domainIndexMap = domainIndexMap
  84. }
  85. if len(server.clients) == 0 {
  86. server.clients = append(server.clients, NewLocalNameServer())
  87. }
  88. return server, nil
  89. }
  90. // Type implements common.HasType.
  91. func (*Server) Type() interface{} {
  92. return dns.ClientType()
  93. }
  94. // Start implements common.Runnable.
  95. func (s *Server) Start() error {
  96. return nil
  97. }
  98. // Close implements common.Closable.
  99. func (s *Server) Close() error {
  100. return nil
  101. }
  102. func (s *Server) queryIPTimeout(client Client, domain string, option IPOption) ([]net.IP, error) {
  103. ctx, cancel := context.WithTimeout(context.Background(), time.Second*4)
  104. if len(s.tag) > 0 {
  105. ctx = session.ContextWithInbound(ctx, &session.Inbound{
  106. Tag: s.tag,
  107. })
  108. }
  109. ips, err := client.QueryIP(ctx, domain, option)
  110. cancel()
  111. return ips, err
  112. }
  113. // LookupIP implements dns.Client.
  114. func (s *Server) LookupIP(domain string) ([]net.IP, error) {
  115. return s.lookupIPInternal(domain, IPOption{
  116. IPv4Enable: true,
  117. IPv6Enable: true,
  118. })
  119. }
  120. // LookupIPv4 implements dns.IPv4Lookup.
  121. func (s *Server) LookupIPv4(domain string) ([]net.IP, error) {
  122. return s.lookupIPInternal(domain, IPOption{
  123. IPv4Enable: true,
  124. IPv6Enable: false,
  125. })
  126. }
  127. // LookupIPv6 implements dns.IPv6Lookup.
  128. func (s *Server) LookupIPv6(domain string) ([]net.IP, error) {
  129. return s.lookupIPInternal(domain, IPOption{
  130. IPv4Enable: false,
  131. IPv6Enable: true,
  132. })
  133. }
  134. func (s *Server) lookupStatic(domain string, option IPOption, depth int32) []net.Address {
  135. ips := s.hosts.LookupIP(domain, option)
  136. if ips == nil {
  137. return nil
  138. }
  139. if ips[0].Family().IsDomain() && depth < 5 {
  140. if newIPs := s.lookupStatic(ips[0].Domain(), option, depth+1); newIPs != nil {
  141. return newIPs
  142. }
  143. }
  144. return ips
  145. }
  146. func toNetIP(ips []net.Address) []net.IP {
  147. if len(ips) == 0 {
  148. return nil
  149. }
  150. netips := make([]net.IP, 0, len(ips))
  151. for _, ip := range ips {
  152. netips = append(netips, ip.IP())
  153. }
  154. return netips
  155. }
  156. func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, error) {
  157. ips := s.lookupStatic(domain, option, 0)
  158. if ips != nil && ips[0].Family().IsIP() {
  159. return toNetIP(ips), nil
  160. }
  161. if ips != nil && ips[0].Family().IsDomain() {
  162. newdomain := ips[0].Domain()
  163. newError("domain replaced: ", domain, " -> ", newdomain).WriteToLog()
  164. domain = newdomain
  165. }
  166. var lastErr error
  167. if s.domainMatcher != nil {
  168. idx := s.domainMatcher.Match(domain)
  169. if idx > 0 {
  170. ns := s.clients[s.domainIndexMap[idx]]
  171. ips, err := s.queryIPTimeout(ns, domain, option)
  172. if len(ips) > 0 {
  173. return ips, nil
  174. }
  175. if err != nil {
  176. newError("failed to lookup ip for domain ", domain, " at server ", ns.Name()).Base(err).WriteToLog()
  177. lastErr = err
  178. }
  179. }
  180. }
  181. for _, client := range s.clients {
  182. ips, err := s.queryIPTimeout(client, domain, option)
  183. if len(ips) > 0 {
  184. return ips, nil
  185. }
  186. if err != nil {
  187. newError("failed to lookup ip for domain ", domain, " at server ", client.Name()).Base(err).WriteToLog()
  188. lastErr = err
  189. }
  190. }
  191. return nil, newError("returning nil for domain ", domain).Base(lastErr)
  192. }