Explorar o código

simplify classic dns server

Darien Raymond %!s(int64=7) %!d(string=hai) anos
pai
achega
9cfb2bfd51
Modificáronse 5 ficheiros con 282 adicións e 284 borrados
  1. 3 202
      app/dns/nameserver.go
  2. 11 74
      app/dns/server.go
  3. 229 0
      app/dns/udpns.go
  4. 16 8
      common/signal/notifier.go
  5. 23 0
      common/signal/notifier_test.go

+ 3 - 202
app/dns/nameserver.go

@@ -2,17 +2,9 @@ package dns
 
 import (
 	"context"
-	"sync"
 	"time"
 
-	"github.com/miekg/dns"
-	"v2ray.com/core"
-	"v2ray.com/core/common"
-	"v2ray.com/core/common/buf"
-	"v2ray.com/core/common/dice"
 	"v2ray.com/core/common/net"
-	"v2ray.com/core/common/task"
-	"v2ray.com/core/transport/internet/udp"
 )
 
 var (
@@ -29,203 +21,12 @@ type ARecord struct {
 }
 
 type NameServer interface {
-	QueryA(domain string) <-chan *ARecord
-}
-
-type PendingRequest struct {
-	expire   time.Time
-	response chan<- *ARecord
-}
-
-type UDPNameServer struct {
-	sync.Mutex
-	address   net.Destination
-	requests  map[uint16]*PendingRequest
-	udpServer *udp.Dispatcher
-	cleanup   *task.Periodic
-}
-
-func NewUDPNameServer(address net.Destination, dispatcher core.Dispatcher) *UDPNameServer {
-	s := &UDPNameServer{
-		address:   address,
-		requests:  make(map[uint16]*PendingRequest),
-		udpServer: udp.NewDispatcher(dispatcher),
-	}
-	s.cleanup = &task.Periodic{
-		Interval: time.Minute,
-		Execute:  s.Cleanup,
-	}
-	common.Must(s.cleanup.Start())
-	return s
-}
-
-func (s *UDPNameServer) Cleanup() error {
-	now := time.Now()
-	s.Lock()
-	for id, r := range s.requests {
-		if r.expire.Before(now) {
-			close(r.response)
-			delete(s.requests, id)
-		}
-	}
-	s.Unlock()
-	return nil
-}
-
-func (s *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
-	var id uint16
-	s.Lock()
-
-	for {
-		id = dice.RollUint16()
-		if _, found := s.requests[id]; found {
-			time.Sleep(time.Millisecond * 500)
-			continue
-		}
-		newError("add pending request id ", id).AtDebug().WriteToLog()
-		s.requests[id] = &PendingRequest{
-			expire:   time.Now().Add(time.Second * 8),
-			response: response,
-		}
-		break
-	}
-	s.Unlock()
-	return id
-}
-
-func (s *UDPNameServer) HandleResponse(payload *buf.Buffer) {
-	msg := new(dns.Msg)
-	err := msg.Unpack(payload.Bytes())
-	if err == dns.ErrTruncated {
-		newError("truncated message received. DNS server should still work. If you see anything abnormal, please submit an issue to v2ray-core.").AtWarning().WriteToLog()
-	} else if err != nil {
-		newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog()
-		return
-	}
-	record := &ARecord{
-		IPs: make([]net.IP, 0, 16),
-	}
-	id := msg.Id
-	ttl := uint32(3600) // an hour
-	newError("handling response for id ", id, " content: ", msg).AtDebug().WriteToLog()
-
-	s.Lock()
-	request, found := s.requests[id]
-	if !found {
-		s.Unlock()
-		return
-	}
-	delete(s.requests, id)
-	s.Unlock()
-
-	for _, rr := range msg.Answer {
-		switch rr := rr.(type) {
-		case *dns.A:
-			record.IPs = append(record.IPs, rr.A)
-			if rr.Hdr.Ttl < ttl {
-				ttl = rr.Hdr.Ttl
-			}
-		case *dns.AAAA:
-			record.IPs = append(record.IPs, rr.AAAA)
-			if rr.Hdr.Ttl < ttl {
-				ttl = rr.Hdr.Ttl
-			}
-		}
-	}
-	record.Expire = time.Now().Add(time.Second * time.Duration(ttl))
-
-	request.response <- record
-	close(request.response)
-}
-
-func (s *UDPNameServer) buildAMsg(domain string, id uint16) *dns.Msg {
-	msg := new(dns.Msg)
-	msg.Id = id
-	msg.RecursionDesired = true
-	msg.Question = []dns.Question{
-		{
-			Name:   dns.Fqdn(domain),
-			Qtype:  dns.TypeA,
-			Qclass: dns.ClassINET,
-		}}
-	if multiQuestionDNS[s.address.Address] {
-		msg.Question = append(msg.Question, dns.Question{
-			Name:   dns.Fqdn(domain),
-			Qtype:  dns.TypeAAAA,
-			Qclass: dns.ClassINET,
-		})
-	}
-
-	return msg
-}
-
-func msgToBuffer(msg *dns.Msg) (*buf.Buffer, error) {
-	buffer := buf.New()
-	if err := buffer.Reset(func(b []byte) (int, error) {
-		writtenBuffer, err := msg.PackBuffer(b)
-		return len(writtenBuffer), err
-	}); err != nil {
-		return nil, err
-	}
-	return buffer, nil
-}
-
-func (s *UDPNameServer) QueryA(domain string) <-chan *ARecord {
-	response := make(chan *ARecord, 1)
-	id := s.AssignUnusedID(response)
-
-	msg := s.buildAMsg(domain, id)
-	b, err := msgToBuffer(msg)
-	if err != nil {
-		newError("failed to build A query for domain ", domain).Base(err).WriteToLog()
-		s.Lock()
-		delete(s.requests, id)
-		s.Unlock()
-		close(response)
-		return response
-	}
-
-	ctx, cancel := context.WithCancel(context.Background())
-	s.udpServer.Dispatch(ctx, s.address, b, s.HandleResponse)
-
-	go func() {
-		for i := 0; i < 2; i++ {
-			time.Sleep(time.Second)
-			s.Lock()
-			_, found := s.requests[id]
-			s.Unlock()
-			if !found {
-				break
-			}
-			b, _ := msgToBuffer(msg)
-			s.udpServer.Dispatch(ctx, s.address, b, s.HandleResponse)
-		}
-		cancel()
-	}()
-
-	return response
+	QueryIP(ctx context.Context, domain string) ([]net.IP, error)
 }
 
 type LocalNameServer struct {
 }
 
-func (*LocalNameServer) QueryA(domain string) <-chan *ARecord {
-	response := make(chan *ARecord, 1)
-
-	go func() {
-		defer close(response)
-
-		ips, err := net.LookupIP(domain)
-		if err != nil {
-			newError("failed to lookup IPs for domain ", domain).Base(err).AtWarning().WriteToLog()
-			return
-		}
-
-		response <- &ARecord{
-			IPs:    ips,
-			Expire: time.Now().Add(time.Hour),
-		}
-	}()
-
-	return response
+func (*LocalNameServer) QueryIP(ctx context.Context, domain string) ([]net.IP, error) {
+	return net.LookupIP(domain)
 }

+ 11 - 74
app/dns/server.go

@@ -7,48 +7,24 @@ import (
 	"sync"
 	"time"
 
-	dnsmsg "github.com/miekg/dns"
 	"v2ray.com/core"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/task"
 )
 
-const (
-	QueryTimeout = time.Second * 8
-)
-
-type DomainRecord struct {
-	IP         []net.IP
-	Expire     time.Time
-	LastAccess time.Time
-}
-
-func (r *DomainRecord) Expired() bool {
-	return r.Expire.Before(time.Now())
-}
-
 type Server struct {
 	sync.Mutex
 	hosts   map[string]net.IP
-	records map[string]*DomainRecord
 	servers []NameServer
 	task    *task.Periodic
 }
 
 func New(ctx context.Context, config *Config) (*Server, error) {
 	server := &Server{
-		records: make(map[string]*DomainRecord),
 		servers: make([]NameServer, len(config.NameServers)),
 		hosts:   config.GetInternalHosts(),
 	}
-	server.task = &task.Periodic{
-		Interval: time.Minute * 10,
-		Execute: func() error {
-			server.cleanup()
-			return nil
-		},
-	}
 	v := core.MustFromContext(ctx)
 	if err := v.RegisterFeature((*core.DNSClient)(nil), server); err != nil {
 		return nil, newError("unable to register DNSClient.").Base(err)
@@ -64,7 +40,7 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 				dest.Network = net.Network_UDP
 			}
 			if dest.Network == net.Network_UDP {
-				server.servers[idx] = NewUDPNameServer(dest, v.Dispatcher())
+				server.servers[idx] = NewClassicNameServer(dest, v.Dispatcher())
 			}
 		}
 	}
@@ -85,64 +61,25 @@ func (s *Server) Close() error {
 	return s.task.Close()
 }
 
-func (s *Server) GetCached(domain string) []net.IP {
-	s.Lock()
-	defer s.Unlock()
-
-	if record, found := s.records[domain]; found && !record.Expired() {
-		record.LastAccess = time.Now()
-		return record.IP
-	}
-	return nil
-}
-
-func (s *Server) cleanup() {
-	s.Lock()
-	defer s.Unlock()
-
-	for d, r := range s.records {
-		if r.Expired() {
-			delete(s.records, d)
-		}
-	}
-
-	if len(s.records) == 0 {
-		s.records = make(map[string]*DomainRecord)
-	}
-}
-
 func (s *Server) LookupIP(domain string) ([]net.IP, error) {
 	if ip, found := s.hosts[domain]; found {
 		return []net.IP{ip}, nil
 	}
 
-	domain = dnsmsg.Fqdn(domain)
-	ips := s.GetCached(domain)
-	if ips != nil {
-		return ips, nil
-	}
-
+	var lastErr error
 	for _, server := range s.servers {
-		response := server.QueryA(domain)
-		select {
-		case a, open := <-response:
-			if !open || a == nil {
-				continue
-			}
-			s.Lock()
-			s.records[domain] = &DomainRecord{
-				IP:         a.IPs,
-				Expire:     a.Expire,
-				LastAccess: time.Now(),
-			}
-			s.Unlock()
-			newError("returning ", len(a.IPs), " IPs for domain ", domain).AtDebug().WriteToLog()
-			return a.IPs, nil
-		case <-time.After(QueryTimeout):
+		ctx, cancel := context.WithTimeout(context.Background(), time.Second*4)
+		ips, err := server.QueryIP(ctx, domain)
+		cancel()
+		if err != nil {
+			lastErr = err
+		}
+		if len(ips) > 0 {
+			return ips, nil
 		}
 	}
 
-	return nil, newError("returning nil for domain ", domain)
+	return nil, newError("returning nil for domain ", domain).Base(lastErr)
 }
 
 func init() {

+ 229 - 0
app/dns/udpns.go

@@ -0,0 +1,229 @@
+package dns
+
+import (
+	"context"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"github.com/miekg/dns"
+	"v2ray.com/core"
+	"v2ray.com/core/common"
+	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/signal"
+	"v2ray.com/core/common/task"
+	"v2ray.com/core/transport/internet/udp"
+)
+
+type IPRecord struct {
+	IP     net.IP
+	Expire time.Time
+}
+
+type ClassicNameServer struct {
+	sync.RWMutex
+	address   net.Destination
+	ips       map[string][]IPRecord
+	updated   signal.Notifier
+	udpServer *udp.Dispatcher
+	cleanup   *task.Periodic
+	reqID     uint32
+}
+
+func NewClassicNameServer(address net.Destination, dispatcher core.Dispatcher) *ClassicNameServer {
+	s := &ClassicNameServer{
+		address:   address,
+		ips:       make(map[string][]IPRecord),
+		udpServer: udp.NewDispatcher(dispatcher),
+	}
+	s.cleanup = &task.Periodic{
+		Interval: time.Minute,
+		Execute:  s.Cleanup,
+	}
+	common.Must(s.cleanup.Start())
+	return s
+}
+
+func (s *ClassicNameServer) Cleanup() error {
+	now := time.Now()
+	s.Lock()
+	for domain, ips := range s.ips {
+		newIPs := make([]IPRecord, 0, len(ips))
+		for _, ip := range ips {
+			if ip.Expire.After(now) {
+				newIPs = append(newIPs, ip)
+			}
+		}
+		if len(newIPs) == 0 {
+			delete(s.ips, domain)
+		} else if len(newIPs) < len(ips) {
+			s.ips[domain] = newIPs
+		}
+	}
+
+	if len(s.ips) == 0 {
+		s.ips = make(map[string][]IPRecord)
+	}
+
+	s.Unlock()
+	return nil
+}
+
+func (s *ClassicNameServer) HandleResponse(payload *buf.Buffer) {
+	msg := new(dns.Msg)
+	err := msg.Unpack(payload.Bytes())
+	if err == dns.ErrTruncated {
+		newError("truncated message received. DNS server should still work. If you see anything abnormal, please submit an issue to v2ray-core.").AtWarning().WriteToLog()
+	} else if err != nil {
+		newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog()
+		return
+	}
+
+	var domain string
+	ips := make([]IPRecord, 0, 16)
+
+	now := time.Now()
+	for _, rr := range msg.Answer {
+		var ip net.IP
+		domain = rr.Header().Name
+		ttl := rr.Header().Ttl
+		switch rr := rr.(type) {
+		case *dns.A:
+			ip = rr.A
+		case *dns.AAAA:
+			ip = rr.AAAA
+		}
+		if ttl == 0 {
+			ttl = 300
+		}
+		if len(ip) > 0 {
+			ips = append(ips, IPRecord{
+				IP:     ip,
+				Expire: now.Add(time.Second * time.Duration(ttl)),
+			})
+		}
+	}
+
+	if len(domain) > 0 && len(ips) > 0 {
+		s.updateIP(domain, ips)
+	}
+}
+
+func (s *ClassicNameServer) updateIP(domain string, ips []IPRecord) {
+	s.Lock()
+	defer s.Unlock()
+
+	newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
+	now := time.Now()
+	eips := s.ips[domain]
+	for _, ip := range eips {
+		if ip.Expire.After(now) {
+			ips = append(ips, ip)
+		}
+	}
+	s.ips[domain] = ips
+	s.updated.Signal()
+}
+
+func (s *ClassicNameServer) buildMsgs(domain string) []*dns.Msg {
+	allowMulti := multiQuestionDNS[s.address.Address]
+
+	var msgs []*dns.Msg
+
+	{
+		msg := new(dns.Msg)
+		msg.Id = uint16(atomic.AddUint32(&s.reqID, 1))
+		msg.RecursionDesired = true
+		msg.Question = []dns.Question{
+			{
+				Name:   domain,
+				Qtype:  dns.TypeA,
+				Qclass: dns.ClassINET,
+			}}
+		if allowMulti {
+			msg.Question = append(msg.Question, dns.Question{
+				Name:   domain,
+				Qtype:  dns.TypeAAAA,
+				Qclass: dns.ClassINET,
+			})
+		}
+		msgs = append(msgs, msg)
+	}
+
+	if !allowMulti {
+		msg := new(dns.Msg)
+		msg.Id = uint16(atomic.AddUint32(&s.reqID, 1))
+		msg.RecursionDesired = true
+		msg.Question = []dns.Question{
+			{
+				Name:   domain,
+				Qtype:  dns.TypeAAAA,
+				Qclass: dns.ClassINET,
+			},
+		}
+		msgs = append(msgs, msg)
+	}
+
+	return msgs
+}
+
+func msgToBuffer(msg *dns.Msg) (*buf.Buffer, error) {
+	buffer := buf.New()
+	if err := buffer.Reset(func(b []byte) (int, error) {
+		writtenBuffer, err := msg.PackBuffer(b)
+		return len(writtenBuffer), err
+	}); err != nil {
+		return nil, err
+	}
+	return buffer, nil
+}
+
+func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string) {
+	msgs := s.buildMsgs(domain)
+
+	for _, msg := range msgs {
+		b, err := msgToBuffer(msg)
+		common.Must(err)
+		s.udpServer.Dispatch(ctx, s.address, b, s.HandleResponse)
+	}
+}
+
+func (s *ClassicNameServer) findIPsForDomain(domain string) []net.IP {
+	records, found := s.ips[domain]
+	if found && len(records) > 0 {
+		var ips []net.IP
+		now := time.Now()
+		for _, rec := range records {
+			if rec.Expire.After(now) {
+				ips = append(ips, rec.IP)
+			}
+		}
+		return ips
+	}
+	return nil
+}
+
+func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string) ([]net.IP, error) {
+	fqdn := dns.Fqdn(domain)
+
+	ips := s.findIPsForDomain(fqdn)
+	if len(ips) > 0 {
+		return ips, nil
+	}
+
+	s.sendQuery(ctx, fqdn)
+
+	for {
+		ips := s.findIPsForDomain(fqdn)
+		if len(ips) > 0 {
+			return ips, nil
+		}
+
+		select {
+		case <-ctx.Done():
+			return nil, ctx.Err()
+		case <-s.updated.Wait():
+		}
+	}
+}

+ 16 - 8
common/signal/notifier.go

@@ -1,26 +1,34 @@
 package signal
 
+import "sync"
+
 // Notifier is a utility for notifying changes. The change producer may notify changes multiple time, and the consumer may get notified asynchronously.
 type Notifier struct {
-	c chan struct{}
+	sync.Mutex
+	waiters []chan struct{}
 }
 
 // NewNotifier creates a new Notifier.
 func NewNotifier() *Notifier {
-	return &Notifier{
-		c: make(chan struct{}, 1),
-	}
+	return &Notifier{}
 }
 
 // Signal signals a change, usually by producer. This method never blocks.
 func (n *Notifier) Signal() {
-	select {
-	case n.c <- struct{}{}:
-	default:
+	n.Lock()
+	for _, w := range n.waiters {
+		close(w)
 	}
+	n.waiters = make([]chan struct{}, 0, 8)
+	n.Unlock()
 }
 
 // Wait returns a channel for waiting for changes. The returned channel never gets closed.
 func (n *Notifier) Wait() <-chan struct{} {
-	return n.c
+	n.Lock()
+	defer n.Unlock()
+
+	w := make(chan struct{})
+	n.waiters = append(n.waiters, w)
+	return w
 }

+ 23 - 0
common/signal/notifier_test.go

@@ -0,0 +1,23 @@
+package signal_test
+
+import (
+	"testing"
+
+	. "v2ray.com/core/common/signal"
+	//. "v2ray.com/ext/assert"
+)
+
+func TestNotifierSignal(t *testing.T) {
+	//assert := With(t)
+
+	var n Notifier
+
+	w := n.Wait()
+	n.Signal()
+
+	select {
+	case <-w:
+	default:
+		t.Fail()
+	}
+}