Jelajahi Sumber

dns client implementation

v2ray 9 tahun lalu
induk
melakukan
3b545abe02

+ 11 - 13
app/dns/dns.go

@@ -11,33 +11,31 @@ const (
 )
 
 // A DnsCache is an internal cache of DNS resolutions.
-type DnsCache interface {
-	Get(domain string) net.IP
-	Add(domain string, ip net.IP)
+type Server interface {
+	Get(domain string) []net.IP
 }
 
-type dnsCacheWithContext interface {
-	Get(context app.Context, domain string) net.IP
-	Add(contaxt app.Context, domain string, ip net.IP)
+type dnsServerWithContext interface {
+	Get(context app.Context, domain string) []net.IP
 }
 
-type contextedDnsCache struct {
+type contextedDnsServer struct {
 	context  app.Context
-	dnsCache dnsCacheWithContext
+	dnsCache dnsServerWithContext
 }
 
-func (this *contextedDnsCache) Get(domain string) net.IP {
+func (this *contextedDnsServer) Get(domain string) []net.IP {
 	return this.dnsCache.Get(this.context, domain)
 }
 
-func (this *contextedDnsCache) Add(domain string, ip net.IP) {
-	this.dnsCache.Add(this.context, domain, ip)
+func CreateDNSServer(rawConfig interface{}) (Server, error) {
+	return nil, nil
 }
 
 func init() {
 	app.Register(APP_ID, func(context app.Context, obj interface{}) interface{} {
-		dcContext := obj.(dnsCacheWithContext)
-		return &contextedDnsCache{
+		dcContext := obj.(dnsServerWithContext)
+		return &contextedDnsServer{
 			context:  context,
 			dnsCache: dcContext,
 		}

+ 3 - 8
app/dns/internal/config.go

@@ -1,14 +1,9 @@
 package internal
 
 import (
-	"github.com/v2ray/v2ray-core/common/serial"
+	v2net "github.com/v2ray/v2ray-core/common/net"
 )
 
-type CacheConfig struct {
-	TrustedTags map[serial.StringLiteral]bool
-}
-
-func (this *CacheConfig) IsTrustedSource(tag serial.StringLiteral) bool {
-	_, found := this.TrustedTags[tag]
-	return found
+type Config struct {
+	NameServers []v2net.Destination
 }

+ 11 - 9
app/dns/internal/config_json.go

@@ -5,19 +5,21 @@ package internal
 import (
 	"encoding/json"
 
-	"github.com/v2ray/v2ray-core/common/serial"
+	v2net "github.com/v2ray/v2ray-core/common/net"
 )
 
-func (this *CacheConfig) UnmarshalJSON(data []byte) error {
-	var strlist serial.StringLiteralList
-	if err := json.Unmarshal(data, strlist); err != nil {
-		return err
+func (this *Config) UnmarshalJSON(data []byte) error {
+	type JsonConfig struct {
+		Servers []v2net.Address `json:"servers"`
 	}
-	config := &CacheConfig{
-		TrustedTags: make(map[serial.StringLiteral]bool, strlist.Len()),
+	jsonConfig := new(JsonConfig)
+	if err := json.Unmarshal(data, jsonConfig); err != nil {
+		return err
 	}
-	for _, str := range strlist {
-		config.TrustedTags[str.TrimSpace()] = true
+	this.NameServers = make([]v2net.Destination, len(jsonConfig.Servers))
+	for idx, server := range jsonConfig.Servers {
+		this.NameServers[idx] = v2net.UDPDestination(server, v2net.Port(53))
 	}
+
 	return nil
 }

+ 48 - 37
app/dns/internal/dns.go

@@ -2,61 +2,72 @@ package internal
 
 import (
 	"net"
+	"sync"
 	"time"
 
 	"github.com/v2ray/v2ray-core/app"
-	"github.com/v2ray/v2ray-core/common/collect"
-	"github.com/v2ray/v2ray-core/common/serial"
+	"github.com/v2ray/v2ray-core/app/dispatcher"
+
+	"github.com/miekg/dns"
 )
 
-type entry struct {
-	domain     string
-	ip         net.IP
-	validUntil time.Time
-}
+const (
+	QueryTimeout = time.Second * 2
+)
 
-func newEntry(domain string, ip net.IP) *entry {
-	this := &entry{
-		domain: domain,
-		ip:     ip,
-	}
-	this.Extend()
-	return this
+type DomainRecord struct {
+	A *ARecord
 }
 
-func (this *entry) IsValid() bool {
-	return this.validUntil.After(time.Now())
+type Server struct {
+	sync.RWMutex
+	records map[string]*DomainRecord
+	servers []NameServer
 }
 
-func (this *entry) Extend() {
-	this.validUntil = time.Now().Add(time.Hour)
+func NewServer(space app.Space, config *Config) *Server {
+	server := &Server{
+		records: make(map[string]*DomainRecord),
+		servers: make([]NameServer, len(config.NameServers)),
+	}
+	dispatcher := space.GetApp(dispatcher.APP_ID).(dispatcher.PacketDispatcher)
+	for idx, ns := range config.NameServers {
+		server.servers[idx] = NewUDPNameServer(ns, dispatcher)
+	}
+	return server
 }
 
-type DnsCache struct {
-	cache  *collect.ValidityMap
-	config *CacheConfig
-}
+//@Private
+func (this *Server) GetCached(domain string) []net.IP {
+	this.RLock()
+	defer this.RUnlock()
 
-func NewCache(config *CacheConfig) *DnsCache {
-	cache := &DnsCache{
-		cache:  collect.NewValidityMap(3600),
-		config: config,
+	if record, found := this.records[domain]; found && record.A.Expire.After(time.Now()) {
+		return record.A.IPs
 	}
-	return cache
+	return nil
 }
 
-func (this *DnsCache) Add(context app.Context, domain string, ip net.IP) {
-	callerTag := context.CallerTag()
-	if !this.config.IsTrustedSource(serial.StringLiteral(callerTag)) {
-		return
+func (this *Server) Get(context app.Context, domain string) []net.IP {
+	domain = dns.Fqdn(domain)
+	ips := this.GetCached(domain)
+	if ips != nil {
+		return ips
 	}
 
-	this.cache.Set(serial.StringLiteral(domain), newEntry(domain, ip))
-}
-
-func (this *DnsCache) Get(context app.Context, domain string) net.IP {
-	if value := this.cache.Get(serial.StringLiteral(domain)); value != nil {
-		return value.(*entry).ip
+	for _, server := range this.servers {
+		response := server.QueryA(domain)
+		select {
+		case a := <-response:
+			this.Lock()
+			this.records[domain] = &DomainRecord{
+				A: a,
+			}
+			this.Unlock()
+			return a.IPs
+		case <-time.Tick(QueryTimeout):
+		}
 	}
+
 	return nil
 }

+ 40 - 14
app/dns/internal/dns_test.go

@@ -4,30 +4,56 @@ import (
 	"net"
 	"testing"
 
+	"github.com/v2ray/v2ray-core/app"
+	"github.com/v2ray/v2ray-core/app/dispatcher"
 	. "github.com/v2ray/v2ray-core/app/dns/internal"
 	apptesting "github.com/v2ray/v2ray-core/app/testing"
+	v2net "github.com/v2ray/v2ray-core/common/net"
 	netassert "github.com/v2ray/v2ray-core/common/net/testing/assert"
-	"github.com/v2ray/v2ray-core/common/serial"
+	"github.com/v2ray/v2ray-core/proxy/freedom"
 	v2testing "github.com/v2ray/v2ray-core/testing"
+	"github.com/v2ray/v2ray-core/testing/assert"
+	"github.com/v2ray/v2ray-core/transport/ray"
 )
 
+type TestDispatcher struct {
+	freedom *freedom.FreedomConnection
+}
+
+func (this *TestDispatcher) DispatchToOutbound(context app.Context, dest v2net.Destination) ray.InboundRay {
+	direct := ray.NewRay()
+
+	go func() {
+		payload, err := direct.OutboundInput().Read()
+		if err != nil {
+			direct.OutboundInput().Release()
+			direct.OutboundOutput().Release()
+			return
+		}
+		this.freedom.Dispatch(dest, payload, direct)
+	}()
+	return direct
+}
+
 func TestDnsAdd(t *testing.T) {
 	v2testing.Current(t)
 
+	d := &TestDispatcher{
+		freedom: &freedom.FreedomConnection{},
+	}
+	spaceController := app.NewController()
+	spaceController.Bind(dispatcher.APP_ID, d)
+	space := spaceController.ForContext("test")
+
 	domain := "v2ray.com"
-	cache := NewCache(&CacheConfig{
-		TrustedTags: map[serial.StringLiteral]bool{
-			serial.StringLiteral("testtag"): true,
+	server := NewServer(space, &Config{
+		NameServers: []v2net.Destination{
+			v2net.UDPDestination(v2net.IPAddress([]byte{8, 8, 8, 8}), v2net.Port(53)),
 		},
 	})
-	ip := cache.Get(&apptesting.Context{}, domain)
-	netassert.IP(ip).IsNil()
-
-	cache.Add(&apptesting.Context{CallerTagValue: "notvalidtag"}, domain, []byte{1, 2, 3, 4})
-	ip = cache.Get(&apptesting.Context{}, domain)
-	netassert.IP(ip).IsNil()
-
-	cache.Add(&apptesting.Context{CallerTagValue: "testtag"}, domain, []byte{1, 2, 3, 4})
-	ip = cache.Get(&apptesting.Context{}, domain)
-	netassert.IP(ip).Equals(net.IP([]byte{1, 2, 3, 4}))
+	ips := server.Get(&apptesting.Context{
+		CallerTagValue: "a",
+	}, domain)
+	assert.Int(len(ips)).Equals(2)
+	netassert.IP(ips[0].To4()).Equals(net.IP([]byte{104, 27, 154, 107}))
 }

+ 158 - 0
app/dns/internal/nameserver.go

@@ -0,0 +1,158 @@
+package internal
+
+import (
+	"math/rand"
+	"net"
+	"sync"
+	"time"
+
+	"github.com/v2ray/v2ray-core/app/dispatcher"
+	"github.com/v2ray/v2ray-core/common/alloc"
+	"github.com/v2ray/v2ray-core/common/log"
+	v2net "github.com/v2ray/v2ray-core/common/net"
+	"github.com/v2ray/v2ray-core/transport/hub"
+
+	"github.com/miekg/dns"
+)
+
+const (
+	DefaultTTL = uint32(3600)
+)
+
+type ARecord struct {
+	IPs    []net.IP
+	Expire time.Time
+}
+
+type NameServer interface {
+	QueryA(domain string) <-chan *ARecord
+}
+
+type PendingRequest struct {
+	expire   time.Time
+	response chan<- *ARecord
+}
+
+type UDPNameServer struct {
+	sync.Mutex
+	address   v2net.Destination
+	requests  map[uint16]*PendingRequest
+	udpServer *hub.UDPServer
+}
+
+func NewUDPNameServer(address v2net.Destination, dispatcher dispatcher.PacketDispatcher) *UDPNameServer {
+	s := &UDPNameServer{
+		address:   address,
+		requests:  make(map[uint16]*PendingRequest),
+		udpServer: hub.NewUDPServer(dispatcher),
+	}
+	go s.Cleanup()
+	return s
+}
+
+// @Private
+func (this *UDPNameServer) Cleanup() {
+	for {
+		time.Sleep(time.Second * 60)
+		expiredRequests := make([]uint16, 0, 16)
+		now := time.Now()
+		this.Lock()
+		for id, r := range this.requests {
+			if r.expire.Before(now) {
+				expiredRequests = append(expiredRequests, id)
+				close(r.response)
+			}
+		}
+		for _, id := range expiredRequests {
+			delete(this.requests, id)
+		}
+		this.Unlock()
+		expiredRequests = nil
+	}
+}
+
+// @Private
+func (this *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
+	var id uint16
+	this.Lock()
+	for {
+		id = uint16(rand.Intn(65536))
+		if _, found := this.requests[id]; found {
+			continue
+		}
+		log.Debug("DNS: Add pending request id ", id)
+		this.requests[id] = &PendingRequest{
+			expire:   time.Now().Add(time.Second * 16),
+			response: response,
+		}
+		break
+	}
+	this.Unlock()
+	return id
+}
+
+// @Private
+func (this *UDPNameServer) HandleResponse(dest v2net.Destination, payload *alloc.Buffer) {
+	msg := new(dns.Msg)
+	err := msg.Unpack(payload.Value)
+	if err != nil {
+		log.Warning("DNS: Failed to parse DNS response: ", err)
+		return
+	}
+	record := &ARecord{
+		IPs: make([]net.IP, 0, 16),
+	}
+	id := msg.Id
+	ttl := DefaultTTL
+
+	this.Lock()
+	request, found := this.requests[id]
+	if !found {
+		this.Unlock()
+		return
+	}
+	delete(this.requests, id)
+	this.Unlock()
+
+	for _, rr := range msg.Answer {
+		if a, ok := rr.(*dns.A); ok {
+			record.IPs = append(record.IPs, a.A)
+			if a.Hdr.Ttl < ttl {
+				ttl = a.Hdr.Ttl
+			}
+		}
+	}
+	record.Expire = time.Now().Add(time.Second * time.Duration(ttl))
+
+	request.response <- record
+	close(request.response)
+}
+
+func (this *UDPNameServer) QueryA(domain string) <-chan *ARecord {
+	response := make(chan *ARecord)
+
+	buffer := alloc.NewBuffer()
+	msg := new(dns.Msg)
+	msg.Id = this.AssignUnusedID(response)
+	msg.RecursionDesired = true
+	msg.Question = []dns.Question{
+		dns.Question{
+			Name:   dns.Fqdn(domain),
+			Qtype:  dns.TypeA,
+			Qclass: dns.ClassINET,
+		},
+		dns.Question{
+			Name:   dns.Fqdn(domain),
+			Qtype:  dns.TypeAAAA,
+			Qclass: dns.ClassINET,
+		},
+	}
+
+	writtenBuffer, _ := msg.PackBuffer(buffer.Value)
+	buffer.Slice(0, len(writtenBuffer))
+
+	fakeDestination := v2net.UDPDestination(v2net.LocalHostIP, v2net.Port(53))
+	this.udpServer.Dispatch(fakeDestination, this.address, buffer, this.HandleResponse)
+
+	return response
+}

+ 4 - 0
common/net/address.go

@@ -7,6 +7,10 @@ import (
 	"github.com/v2ray/v2ray-core/common/serial"
 )
 
+var (
+	LocalHostIP = IPAddress([]byte{127, 0, 0, 1})
+)
+
 // Address represents a network address to be communicated with. It may be an IP address or domain
 // address, not both. This interface doesn't resolve IP address for a given domain.
 type Address interface {

+ 11 - 9
proxy/dokodemo/dokodemo.go

@@ -95,15 +95,17 @@ func (this *DokodemoDoor) ListenUDP(port v2net.Port) error {
 }
 
 func (this *DokodemoDoor) handleUDPPackets(payload *alloc.Buffer, dest v2net.Destination) {
-	this.udpServer.Dispatch(dest, v2net.UDPDestination(this.address, this.port), payload, func(destination v2net.Destination, payload *alloc.Buffer) {
-		defer payload.Release()
-		this.udpMutex.RLock()
-		defer this.udpMutex.RUnlock()
-		if !this.accepting {
-			return
-		}
-		this.udpHub.WriteTo(payload.Value, destination)
-	})
+	this.udpServer.Dispatch(dest, v2net.UDPDestination(this.address, this.port), payload, this.handleUDPResponse)
+}
+
+func (this *DokodemoDoor) handleUDPResponse(dest v2net.Destination, payload *alloc.Buffer) {
+	defer payload.Release()
+	this.udpMutex.RLock()
+	defer this.udpMutex.RUnlock()
+	if !this.accepting {
+		return
+	}
+	this.udpHub.WriteTo(payload.Value, dest)
 }
 
 func (this *DokodemoDoor) ListenTCP(port v2net.Port) error {