Browse Source

Feat: DNS transport over TCP (#983)

* feat: DNS over TCP

* fix: DNS over TCP misbehaving

* fix: add a blank line after +build tag

* style: rename NewTCPLNameServer to NewTCPLocalNameServer

* style: add some comments

* style: format

Co-authored-by: Shelikhoo <xiaokangwang@outlook.com>
秋のかえで 4 years ago
parent
commit
f84a401704
3 changed files with 424 additions and 0 deletions
  1. 4 0
      app/dns/nameserver.go
  2. 360 0
      app/dns/nameserver_tcp.go
  3. 60 0
      app/dns/nameserver_tcp_test.go

+ 4 - 0
app/dns/nameserver.go

@@ -52,6 +52,10 @@ func NewServer(dest net.Destination, dispatcher routing.Dispatcher) (Server, err
 			return NewDoHLocalNameServer(u), nil
 		case strings.EqualFold(u.Scheme, "quic+local"): // DNS-over-QUIC Local mode
 			return NewQUICNameServer(u)
+		case strings.EqualFold(u.Scheme, "tcp"): // DNS-over-TCP Remote mode
+			return NewTCPNameServer(u, dispatcher)
+		case strings.EqualFold(u.Scheme, "tcp+local"): // DNS-over-TCP Local mode
+			return NewTCPLocalNameServer(u)
 		case strings.EqualFold(u.String(), "fakedns"):
 			return NewFakeDNSServer(), nil
 		}

+ 360 - 0
app/dns/nameserver_tcp.go

@@ -0,0 +1,360 @@
+// +build !confonly
+
+package dns
+
+import (
+	"bytes"
+	"context"
+	"encoding/binary"
+	"net/url"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"golang.org/x/net/dns/dnsmessage"
+
+	"github.com/v2fly/v2ray-core/v4/common"
+	"github.com/v2fly/v2ray-core/v4/common/buf"
+	"github.com/v2fly/v2ray-core/v4/common/net"
+	"github.com/v2fly/v2ray-core/v4/common/protocol/dns"
+	"github.com/v2fly/v2ray-core/v4/common/session"
+	"github.com/v2fly/v2ray-core/v4/common/signal/pubsub"
+	"github.com/v2fly/v2ray-core/v4/common/task"
+	dns_feature "github.com/v2fly/v2ray-core/v4/features/dns"
+	"github.com/v2fly/v2ray-core/v4/features/routing"
+	"github.com/v2fly/v2ray-core/v4/transport/internet"
+)
+
+// TCPNameServer implemented DNS over TCP (RFC7766).
+type TCPNameServer struct {
+	sync.RWMutex
+	name        string
+	destination net.Destination
+	ips         map[string]record
+	pub         *pubsub.Service
+	cleanup     *task.Periodic
+	reqID       uint32
+	dial        func(context.Context) (net.Conn, error)
+}
+
+// NewTCPNameServer creates DNS over TCP server object for remote resolving.
+func NewTCPNameServer(url *url.URL, dispatcher routing.Dispatcher) (*TCPNameServer, error) {
+	s, err := baseTCPNameServer(url, "TCP")
+	if err != nil {
+		return nil, err
+	}
+
+	s.dial = func(ctx context.Context) (net.Conn, error) {
+		link, err := dispatcher.Dispatch(ctx, s.destination)
+		if err != nil {
+			return nil, err
+		}
+
+		return net.NewConnection(
+			net.ConnectionInputMulti(link.Writer),
+			net.ConnectionOutputMulti(link.Reader),
+		), nil
+	}
+
+	return s, nil
+}
+
+// NewTCPLocalNameServer creates DNS over TCP client object for local resolving
+func NewTCPLocalNameServer(url *url.URL) (*TCPNameServer, error) {
+	s, err := baseTCPNameServer(url, "TCPL")
+	if err != nil {
+		return nil, err
+	}
+
+	s.dial = func(ctx context.Context) (net.Conn, error) {
+		return internet.DialSystem(ctx, s.destination, nil)
+	}
+
+	return s, nil
+}
+
+func baseTCPNameServer(url *url.URL, prefix string) (*TCPNameServer, error) {
+	var err error
+	port := net.Port(53)
+	if url.Port() != "" {
+		port, err = net.PortFromString(url.Port())
+		if err != nil {
+			return nil, err
+		}
+	}
+	dest := net.TCPDestination(net.DomainAddress(url.Hostname()), port)
+
+	s := &TCPNameServer{
+		destination: dest,
+		ips:         make(map[string]record),
+		pub:         pubsub.NewService(),
+		name:        prefix + "//" + dest.NetAddr(),
+	}
+	s.cleanup = &task.Periodic{
+		Interval: time.Minute,
+		Execute:  s.Cleanup,
+	}
+
+	return s, nil
+}
+
+// Name implements Server.
+func (s *TCPNameServer) Name() string {
+	return s.name
+}
+
+// Cleanup clears expired items from cache
+func (s *TCPNameServer) Cleanup() error {
+	now := time.Now()
+	s.Lock()
+	defer s.Unlock()
+
+	if len(s.ips) == 0 {
+		return newError("nothing to do. stopping...")
+	}
+
+	for domain, record := range s.ips {
+		if record.A != nil && record.A.Expire.Before(now) {
+			record.A = nil
+		}
+		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
+			record.AAAA = nil
+		}
+
+		if record.A == nil && record.AAAA == nil {
+			newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
+			delete(s.ips, domain)
+		} else {
+			s.ips[domain] = record
+		}
+	}
+
+	if len(s.ips) == 0 {
+		s.ips = make(map[string]record)
+	}
+
+	return nil
+}
+
+func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
+	elapsed := time.Since(req.start)
+
+	s.Lock()
+	rec := s.ips[req.domain]
+	updated := false
+
+	switch req.reqType {
+	case dnsmessage.TypeA:
+		if isNewer(rec.A, ipRec) {
+			rec.A = ipRec
+			updated = true
+		}
+	case dnsmessage.TypeAAAA:
+		addr := make([]net.Address, 0)
+		for _, ip := range ipRec.IP {
+			if len(ip.IP()) == net.IPv6len {
+				addr = append(addr, ip)
+			}
+		}
+		ipRec.IP = addr
+		if isNewer(rec.AAAA, ipRec) {
+			rec.AAAA = ipRec
+			updated = true
+		}
+	}
+	newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
+
+	if updated {
+		s.ips[req.domain] = rec
+	}
+	switch req.reqType {
+	case dnsmessage.TypeA:
+		s.pub.Publish(req.domain+"4", nil)
+	case dnsmessage.TypeAAAA:
+		s.pub.Publish(req.domain+"6", nil)
+	}
+	s.Unlock()
+	common.Must(s.cleanup.Start())
+}
+
+func (s *TCPNameServer) newReqID() uint16 {
+	return uint16(atomic.AddUint32(&s.reqID, 1))
+}
+
+func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) {
+	newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
+
+	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP))
+
+	var deadline time.Time
+	if d, ok := ctx.Deadline(); ok {
+		deadline = d
+	} else {
+		deadline = time.Now().Add(time.Second * 5)
+	}
+
+	for _, req := range reqs {
+		go func(r *dnsRequest) {
+			dnsCtx := ctx
+
+			if inbound := session.InboundFromContext(ctx); inbound != nil {
+				dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
+			}
+
+			dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{
+				Protocol:       "dns",
+				SkipDNSResolve: true,
+			})
+
+			var cancel context.CancelFunc
+			dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline)
+			defer cancel()
+
+			b, err := dns.PackMessage(r.msg)
+			if err != nil {
+				newError("failed to pack dns query").Base(err).AtError().WriteToLog()
+				return
+			}
+
+			conn, err := s.dial(dnsCtx)
+			if err != nil {
+				newError("failed to dial namesever").Base(err).AtError().WriteToLog()
+				return
+			}
+			defer conn.Close()
+			dnsReqBuf := buf.New()
+			binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len()))
+			dnsReqBuf.Write(b.Bytes())
+			b.Release()
+
+			_, err = conn.Write(dnsReqBuf.Bytes())
+			if err != nil {
+				newError("failed to send query").Base(err).AtError().WriteToLog()
+				return
+			}
+			dnsReqBuf.Release()
+
+			respBuf := buf.New()
+			defer respBuf.Release()
+			n, err := respBuf.ReadFullFrom(conn, 2)
+			if err != nil && n == 0 {
+				newError("failed to read response length").Base(err).AtError().WriteToLog()
+				return
+			}
+			var length int16
+			err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length)
+			if err != nil {
+				newError("failed to parse response length").Base(err).AtError().WriteToLog()
+				return
+			}
+			respBuf.Clear()
+			n, err = respBuf.ReadFullFrom(conn, int32(length))
+			if err != nil && n == 0 {
+				newError("failed to read response length").Base(err).AtError().WriteToLog()
+				return
+			}
+
+			rec, err := parseResponse(respBuf.Bytes())
+			if err != nil {
+				newError("failed to parse DNS over TCP response").Base(err).AtError().WriteToLog()
+				return
+			}
+
+			s.updateIP(r, rec)
+		}(req)
+	}
+}
+
+func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) {
+	s.RLock()
+	record, found := s.ips[domain]
+	s.RUnlock()
+
+	if !found {
+		return nil, errRecordNotFound
+	}
+
+	var ips []net.Address
+	var lastErr error
+	if option.IPv4Enable {
+		a, err := record.A.getIPs()
+		if err != nil {
+			lastErr = err
+		}
+		ips = append(ips, a...)
+	}
+
+	if option.IPv6Enable {
+		aaaa, err := record.AAAA.getIPs()
+		if err != nil {
+			lastErr = err
+		}
+		ips = append(ips, aaaa...)
+	}
+
+	if len(ips) > 0 {
+		return toNetIP(ips)
+	}
+
+	if lastErr != nil {
+		return nil, lastErr
+	}
+
+	return nil, dns_feature.ErrEmptyResponse
+}
+
+// QueryIP implements Server.
+func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) {
+	fqdn := Fqdn(domain)
+
+	if disableCache {
+		newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog()
+	} else {
+		ips, err := s.findIPsForDomain(fqdn, option)
+		if err != errRecordNotFound {
+			newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
+			return ips, err
+		}
+	}
+
+	// ipv4 and ipv6 belong to different subscription groups
+	var sub4, sub6 *pubsub.Subscriber
+	if option.IPv4Enable {
+		sub4 = s.pub.Subscribe(fqdn + "4")
+		defer sub4.Close()
+	}
+	if option.IPv6Enable {
+		sub6 = s.pub.Subscribe(fqdn + "6")
+		defer sub6.Close()
+	}
+	done := make(chan interface{})
+	go func() {
+		if sub4 != nil {
+			select {
+			case <-sub4.Wait():
+			case <-ctx.Done():
+			}
+		}
+		if sub6 != nil {
+			select {
+			case <-sub6.Wait():
+			case <-ctx.Done():
+			}
+		}
+		close(done)
+	}()
+	s.sendQuery(ctx, fqdn, clientIP, option)
+
+	for {
+		ips, err := s.findIPsForDomain(fqdn, option)
+		if err != errRecordNotFound {
+			return ips, err
+		}
+
+		select {
+		case <-ctx.Done():
+			return nil, ctx.Err()
+		case <-done:
+		}
+	}
+}

+ 60 - 0
app/dns/nameserver_tcp_test.go

@@ -0,0 +1,60 @@
+package dns_test
+
+import (
+	"context"
+	"net/url"
+	"testing"
+	"time"
+
+	"github.com/google/go-cmp/cmp"
+
+	. "github.com/v2fly/v2ray-core/v4/app/dns"
+	"github.com/v2fly/v2ray-core/v4/common"
+	"github.com/v2fly/v2ray-core/v4/common/net"
+	dns_feature "github.com/v2fly/v2ray-core/v4/features/dns"
+)
+
+func TestTCPLocalNameServer(t *testing.T) {
+	url, err := url.Parse("tcp+local://8.8.8.8")
+	common.Must(err)
+	s, err := NewTCPLocalNameServer(url)
+	common.Must(err)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+	ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
+		IPv4Enable: true,
+		IPv6Enable: true,
+	}, false)
+	cancel()
+	common.Must(err)
+	if len(ips) == 0 {
+		t.Error("expect some ips, but got 0")
+	}
+}
+
+func TestTCPLocalNameServerWithCache(t *testing.T) {
+	url, err := url.Parse("tcp+local://8.8.8.8")
+	common.Must(err)
+	s, err := NewTCPLocalNameServer(url)
+	common.Must(err)
+	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
+	ips, err := s.QueryIP(ctx, "google.com", net.IP(nil), dns_feature.IPOption{
+		IPv4Enable: true,
+		IPv6Enable: true,
+	}, false)
+	cancel()
+	common.Must(err)
+	if len(ips) == 0 {
+		t.Error("expect some ips, but got 0")
+	}
+
+	ctx2, cancel := context.WithTimeout(context.Background(), time.Second*5)
+	ips2, err := s.QueryIP(ctx2, "google.com", net.IP(nil), dns_feature.IPOption{
+		IPv4Enable: true,
+		IPv6Enable: true,
+	}, true)
+	cancel()
+	common.Must(err)
+	if r := cmp.Diff(ips2, ips); r != "" {
+		t.Fatal(r)
+	}
+}