Browse Source

DNS: add DNS over QUIC support (#534)

maskedeken 4 years ago
parent
commit
8c5b392f27
3 changed files with 415 additions and 0 deletions
  1. 2 0
      app/dns/nameserver.go
  2. 384 0
      app/dns/nameserver_quic.go
  3. 29 0
      app/dns/nameserver_quic_test.go

+ 2 - 0
app/dns/nameserver.go

@@ -53,6 +53,8 @@ func NewServer(dest net.Destination, dispatcher routing.Dispatcher) (Server, err
 			return NewDoHNameServer(u, dispatcher)
 			return NewDoHNameServer(u, dispatcher)
 		case u.Scheme == "https+local": // DOH Local mode
 		case u.Scheme == "https+local": // DOH Local mode
 			return NewDoHLocalNameServer(u), nil
 			return NewDoHLocalNameServer(u), nil
+		case u.Scheme == "quic+local": // DNS-over-QUIC Local mode
+			return NewQUICNameServer(u)
 		}
 		}
 	}
 	}
 	if dest.Network == net.Network_Unknown {
 	if dest.Network == net.Network_Unknown {

+ 384 - 0
app/dns/nameserver_quic.go

@@ -0,0 +1,384 @@
+// +build !confonly
+
+package dns
+
+import (
+	"context"
+	"net/url"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"github.com/lucas-clemente/quic-go"
+	"golang.org/x/net/dns/dnsmessage"
+	"golang.org/x/net/http2"
+	"v2ray.com/core/common"
+	"v2ray.com/core/common/buf"
+	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/protocol/dns"
+	"v2ray.com/core/common/session"
+	"v2ray.com/core/common/signal/pubsub"
+	"v2ray.com/core/common/task"
+	dns_feature "v2ray.com/core/features/dns"
+	"v2ray.com/core/transport/internet/tls"
+)
+
+// NextProtoDQ - During connection establishment, DNS/QUIC support is indicated
+// by selecting the ALPN token "dq" in the crypto handshake.
+const NextProtoDQ = "doq-i00"
+
+const handshakeTimeout = time.Second * 8
+
+// QUICNameServer implemented DNS over QUIC
+type QUICNameServer struct {
+	sync.RWMutex
+	ips         map[string]record
+	pub         *pubsub.Service
+	cleanup     *task.Periodic
+	reqID       uint32
+	name        string
+	destination net.Destination
+	session     quic.Session
+}
+
+// NewQUICNameServer creates DNS-over-QUIC client object for local resolving
+func NewQUICNameServer(url *url.URL) (*QUICNameServer, error) {
+	newError("DNS: created Local DNS-over-QUIC client for ", url.String()).AtInfo().WriteToLog()
+
+	var err error
+	port := net.Port(784)
+	if url.Port() != "" {
+		port, err = net.PortFromString(url.Port())
+		if err != nil {
+			return nil, err
+		}
+	}
+	dest := net.UDPDestination(net.DomainAddress(url.Hostname()), port)
+
+	s := &QUICNameServer{
+		ips:         make(map[string]record),
+		pub:         pubsub.NewService(),
+		name:        url.String(),
+		destination: dest,
+	}
+	s.cleanup = &task.Periodic{
+		Interval: time.Minute,
+		Execute:  s.Cleanup,
+	}
+
+	return s, nil
+}
+
+// Name returns client name
+func (s *QUICNameServer) Name() string {
+	return s.name
+}
+
+// Cleanup clears expired items from cache
+func (s *QUICNameServer) Cleanup() error {
+	now := time.Now()
+	s.Lock()
+	defer s.Unlock()
+
+	if len(s.ips) == 0 {
+		return newError("nothing to do. stopping...")
+	}
+
+	for domain, record := range s.ips {
+		if record.A != nil && record.A.Expire.Before(now) {
+			record.A = nil
+		}
+		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
+			record.AAAA = nil
+		}
+
+		if record.A == nil && record.AAAA == nil {
+			newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
+			delete(s.ips, domain)
+		} else {
+			s.ips[domain] = record
+		}
+	}
+
+	if len(s.ips) == 0 {
+		s.ips = make(map[string]record)
+	}
+
+	return nil
+}
+
+func (s *QUICNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
+	elapsed := time.Since(req.start)
+
+	s.Lock()
+	rec := s.ips[req.domain]
+	updated := false
+
+	switch req.reqType {
+	case dnsmessage.TypeA:
+		if isNewer(rec.A, ipRec) {
+			rec.A = ipRec
+			updated = true
+		}
+	case dnsmessage.TypeAAAA:
+		addr := make([]net.Address, 0)
+		for _, ip := range ipRec.IP {
+			if len(ip.IP()) == net.IPv6len {
+				addr = append(addr, ip)
+			}
+		}
+		ipRec.IP = addr
+		if isNewer(rec.AAAA, ipRec) {
+			rec.AAAA = ipRec
+			updated = true
+		}
+	}
+	newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
+
+	if updated {
+		s.ips[req.domain] = rec
+	}
+	switch req.reqType {
+	case dnsmessage.TypeA:
+		s.pub.Publish(req.domain+"4", nil)
+	case dnsmessage.TypeAAAA:
+		s.pub.Publish(req.domain+"6", nil)
+	}
+	s.Unlock()
+	common.Must(s.cleanup.Start())
+}
+
+func (s *QUICNameServer) newReqID() uint16 {
+	return uint16(atomic.AddUint32(&s.reqID, 1))
+}
+
+func (s *QUICNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option IPOption) {
+	newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx))
+
+	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP))
+
+	var deadline time.Time
+	if d, ok := ctx.Deadline(); ok {
+		deadline = d
+	} else {
+		deadline = time.Now().Add(time.Second * 5)
+	}
+
+	for _, req := range reqs {
+		go func(r *dnsRequest) {
+			// generate new context for each req, using same context
+			// may cause reqs all aborted if any one encounter an error
+			dnsCtx := context.Background()
+
+			// reserve internal dns server requested Inbound
+			if inbound := session.InboundFromContext(ctx); inbound != nil {
+				dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
+			}
+
+			dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{
+				Protocol:      "quic",
+				SkipRoutePick: true,
+			})
+
+			var cancel context.CancelFunc
+			dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline)
+			defer cancel()
+
+			b, err := dns.PackMessage(r.msg)
+			if err != nil {
+				newError("failed to pack dns query").Base(err).AtError().WriteToLog()
+				return
+			}
+
+			conn, err := s.openStream(dnsCtx)
+			if err != nil {
+				newError("failed to open quic session").Base(err).AtError().WriteToLog()
+				return
+			}
+
+			_, err = conn.Write(b.Bytes())
+			if err != nil {
+				newError("failed to send query").Base(err).AtError().WriteToLog()
+				return
+			}
+
+			_ = conn.Close()
+
+			respBuf := buf.New()
+			defer respBuf.Release()
+			n, err := respBuf.ReadFrom(conn)
+			if err != nil && n == 0 {
+				newError("failed to read response").Base(err).AtError().WriteToLog()
+				return
+			}
+
+			rec, err := parseResponse(respBuf.Bytes())
+			if err != nil {
+				newError("failed to handle response").Base(err).AtError().WriteToLog()
+				return
+			}
+			s.updateIP(r, rec)
+		}(req)
+	}
+}
+
+func (s *QUICNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) {
+	s.RLock()
+	record, found := s.ips[domain]
+	s.RUnlock()
+
+	if !found {
+		return nil, errRecordNotFound
+	}
+
+	var ips []net.Address
+	var lastErr error
+	if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess {
+		aaaa, err := record.AAAA.getIPs()
+		if err != nil {
+			lastErr = err
+		}
+		ips = append(ips, aaaa...)
+	}
+
+	if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess {
+		a, err := record.A.getIPs()
+		if err != nil {
+			lastErr = err
+		}
+		ips = append(ips, a...)
+	}
+
+	if len(ips) > 0 {
+		return toNetIP(ips)
+	}
+
+	if lastErr != nil {
+		return nil, lastErr
+	}
+
+	if (option.IPv4Enable && record.A != nil) || (option.IPv6Enable && record.AAAA != nil) {
+		return nil, dns_feature.ErrEmptyResponse
+	}
+
+	return nil, errRecordNotFound
+}
+
+// QueryIP is called from dns.Server->queryIPTimeout
+func (s *QUICNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option IPOption) ([]net.IP, error) {
+	fqdn := Fqdn(domain)
+
+	ips, err := s.findIPsForDomain(fqdn, option)
+	if err != errRecordNotFound {
+		newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
+		return ips, err
+	}
+
+	// ipv4 and ipv6 belong to different subscription groups
+	var sub4, sub6 *pubsub.Subscriber
+	if option.IPv4Enable {
+		sub4 = s.pub.Subscribe(fqdn + "4")
+		defer sub4.Close()
+	}
+	if option.IPv6Enable {
+		sub6 = s.pub.Subscribe(fqdn + "6")
+		defer sub6.Close()
+	}
+	done := make(chan interface{})
+	go func() {
+		if sub4 != nil {
+			select {
+			case <-sub4.Wait():
+			case <-ctx.Done():
+			}
+		}
+		if sub6 != nil {
+			select {
+			case <-sub6.Wait():
+			case <-ctx.Done():
+			}
+		}
+		close(done)
+	}()
+	s.sendQuery(ctx, fqdn, clientIP, option)
+
+	for {
+		ips, err := s.findIPsForDomain(fqdn, option)
+		if err != errRecordNotFound {
+			return ips, err
+		}
+
+		select {
+		case <-ctx.Done():
+			return nil, ctx.Err()
+		case <-done:
+		}
+	}
+}
+
+func isActive(s quic.Session) bool {
+	select {
+	case <-s.Context().Done():
+		return false
+	default:
+		return true
+	}
+}
+
+func (s *QUICNameServer) getSession() (quic.Session, error) {
+	var session quic.Session
+	s.RLock()
+	session = s.session
+	if session != nil && isActive(session) {
+		s.RUnlock()
+		return session, nil
+	}
+	if session != nil {
+		// we're recreating the session, let's create a new one
+		_ = session.CloseWithError(0, "")
+	}
+	s.RUnlock()
+
+	s.Lock()
+	defer s.Unlock()
+
+	var err error
+	session, err = s.openSession()
+	if err != nil {
+		// This does not look too nice, but QUIC (or maybe quic-go)
+		// doesn't seem stable enough.
+		// Maybe retransmissions aren't fully implemented in quic-go?
+		// Anyways, the simple solution is to make a second try when
+		// it fails to open the QUIC session.
+		session, err = s.openSession()
+		if err != nil {
+			return nil, err
+		}
+	}
+	s.session = session
+	return session, nil
+}
+
+func (s *QUICNameServer) openSession() (quic.Session, error) {
+	tlsConfig := tls.Config{}
+	quicConfig := &quic.Config{
+		HandshakeTimeout: handshakeTimeout,
+	}
+
+	session, err := quic.DialAddrContext(context.Background(), s.destination.NetAddr(), tlsConfig.GetTLSConfig(tls.WithNextProto("http/1.1", http2.NextProtoTLS, NextProtoDQ)), quicConfig)
+	if err != nil {
+		return nil, err
+	}
+
+	return session, nil
+}
+
+func (s *QUICNameServer) openStream(ctx context.Context) (quic.Stream, error) {
+	session, err := s.getSession()
+	if err != nil {
+		return nil, err
+	}
+
+	// open a new stream
+	return session.OpenStreamSync(ctx)
+}

+ 29 - 0
app/dns/nameserver_quic_test.go

@@ -0,0 +1,29 @@
+package dns_test
+
+import (
+	"context"
+	"net/url"
+	"testing"
+	"time"
+
+	. "v2ray.com/core/app/dns"
+	"v2ray.com/core/common"
+	"v2ray.com/core/common/net"
+)
+
+func TestQUICNameServer(t *testing.T) {
+	url, err := url.Parse("quic://dns.adguard.com")
+	common.Must(err)
+	s, err := NewQUICNameServer(url)
+	common.Must(err)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*2)
+	ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), IPOption{
+		IPv4Enable: true,
+		IPv6Enable: true,
+	})
+	cancel()
+	common.Must(err)
+	if len(ips) == 0 {
+		t.Error("expect some ips, but got 0")
+	}
+}