Browse Source

add DOH dns client

vcptr 6 years ago
parent
commit
6ef77246ab

+ 235 - 0
app/dns/dnscommon.go

@@ -0,0 +1,235 @@
+// +build !confonly
+
+package dns
+
+import (
+	"encoding/binary"
+	"time"
+
+	"golang.org/x/net/dns/dnsmessage"
+	"v2ray.com/core/common"
+	"v2ray.com/core/common/errors"
+	"v2ray.com/core/common/net"
+	dns_feature "v2ray.com/core/features/dns"
+)
+
+func Fqdn(domain string) string {
+	if len(domain) > 0 && domain[len(domain)-1] == '.' {
+		return domain
+	}
+	return domain + "."
+}
+
+type record struct {
+	A    *IPRecord
+	AAAA *IPRecord
+}
+
+type IPRecord struct {
+	ReqID  uint16
+	IP     []net.Address
+	Expire time.Time
+	RCode  dnsmessage.RCode
+}
+
+func (r *IPRecord) getIPs() ([]net.Address, error) {
+	if r == nil || r.Expire.Before(time.Now()) {
+		return nil, errRecordNotFound
+	}
+	if r.RCode != dnsmessage.RCodeSuccess {
+		return nil, dns_feature.RCodeError(r.RCode)
+	}
+	return r.IP, nil
+}
+
+func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
+	if newRec == nil {
+		return false
+	}
+	if baseRec == nil {
+		return true
+	}
+	return baseRec.Expire.Before(newRec.Expire)
+}
+
+var (
+	errRecordNotFound = errors.New("record not found")
+)
+
+type dnsRequest struct {
+	reqType dnsmessage.Type
+	domain  string
+	start   time.Time
+	expire  time.Time
+	msg     *dnsmessage.Message
+}
+
+func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource {
+	if len(clientIP) == 0 {
+		return nil
+	}
+
+	var netmask int
+	var family uint16
+
+	if len(clientIP) == 4 {
+		family = 1
+		netmask = 24 // 24 for IPV4, 96 for IPv6
+	} else {
+		family = 2
+		netmask = 96
+	}
+
+	b := make([]byte, 4)
+	binary.BigEndian.PutUint16(b[0:], family)
+	b[2] = byte(netmask)
+	b[3] = 0
+	switch family {
+	case 1:
+		ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
+		needLength := (netmask + 8 - 1) / 8 // division rounding up
+		b = append(b, ip[:needLength]...)
+	case 2:
+		ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
+		needLength := (netmask + 8 - 1) / 8 // division rounding up
+		b = append(b, ip[:needLength]...)
+	}
+
+	const EDNS0SUBNET = 0x08
+
+	opt := new(dnsmessage.Resource)
+	common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
+
+	opt.Body = &dnsmessage.OPTResource{
+		Options: []dnsmessage.Option{
+			{
+				Code: EDNS0SUBNET,
+				Data: b,
+			},
+		},
+	}
+
+	return opt
+}
+
+func buildReqMsgs(domain string, option IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
+	qA := dnsmessage.Question{
+		Name:  dnsmessage.MustNewName(domain),
+		Type:  dnsmessage.TypeA,
+		Class: dnsmessage.ClassINET,
+	}
+
+	qAAAA := dnsmessage.Question{
+		Name:  dnsmessage.MustNewName(domain),
+		Type:  dnsmessage.TypeAAAA,
+		Class: dnsmessage.ClassINET,
+	}
+
+	var reqs []*dnsRequest
+	now := time.Now()
+
+	if option.IPv4Enable {
+		msg := new(dnsmessage.Message)
+		msg.Header.ID = reqIDGen()
+		msg.Header.RecursionDesired = true
+		msg.Questions = []dnsmessage.Question{qA}
+		if reqOpts != nil {
+			msg.Additionals = append(msg.Additionals, *reqOpts)
+		}
+		reqs = append(reqs, &dnsRequest{
+			reqType: dnsmessage.TypeA,
+			domain:  domain,
+			start:   now,
+			msg:     msg,
+		})
+	}
+
+	if option.IPv6Enable {
+		msg := new(dnsmessage.Message)
+		msg.Header.ID = reqIDGen()
+		msg.Header.RecursionDesired = true
+		msg.Questions = []dnsmessage.Question{qAAAA}
+		if reqOpts != nil {
+			msg.Additionals = append(msg.Additionals, *reqOpts)
+		}
+		reqs = append(reqs, &dnsRequest{
+			reqType: dnsmessage.TypeAAAA,
+			domain:  domain,
+			start:   now,
+			msg:     msg,
+		})
+	}
+
+	return reqs
+}
+
+// parseResponse parse DNS answers from the returned payload
+func parseResponse(payload []byte) (*IPRecord, error) {
+	var parser dnsmessage.Parser
+	h, err := parser.Start(payload)
+	if err != nil {
+		return nil, newError("failed to parse DNS response").Base(err).AtWarning()
+	}
+	if err := parser.SkipAllQuestions(); err != nil {
+		return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning()
+	}
+
+	now := time.Now()
+	var ipRecExpire time.Time
+	if h.RCode != dnsmessage.RCodeSuccess {
+		// A default TTL, maybe a negtive cache
+		ipRecExpire = now.Add(time.Second * 120)
+	}
+
+	ipRecord := &IPRecord{
+		ReqID:  h.ID,
+		RCode:  h.RCode,
+		Expire: ipRecExpire,
+	}
+
+L:
+	for {
+		ah, err := parser.AnswerHeader()
+		if err != nil {
+			if err != dnsmessage.ErrSectionDone {
+				newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog()
+			}
+			break
+		}
+
+		switch ah.Type {
+		case dnsmessage.TypeA:
+			ans, err := parser.AResource()
+			if err != nil {
+				newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
+				break L
+			}
+			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
+		case dnsmessage.TypeAAAA:
+			ans, err := parser.AAAAResource()
+			if err != nil {
+				newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
+				break L
+			}
+			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
+		default:
+			if err := parser.SkipAnswer(); err != nil {
+				newError("failed to skip answer").Base(err).WriteToLog()
+				break L
+			}
+			continue
+		}
+
+		if ipRecord.Expire.IsZero() {
+			ttl := ah.TTL
+			if ttl < 600 {
+				// at least 10 mins TTL
+				ipRecord.Expire = now.Add(time.Minute * 10)
+			} else {
+				ipRecord.Expire = now.Add(time.Duration(ttl) * time.Second)
+			}
+		}
+	}
+
+	return ipRecord, nil
+}

+ 166 - 0
app/dns/dnscommon_test.go

@@ -0,0 +1,166 @@
+// +build !confonly
+
+package dns
+
+import (
+	"math/rand"
+	"testing"
+	"time"
+
+	"github.com/google/go-cmp/cmp"
+	"github.com/miekg/dns"
+	"golang.org/x/net/dns/dnsmessage"
+	"v2ray.com/core/common"
+	"v2ray.com/core/common/net"
+	v2net "v2ray.com/core/common/net"
+)
+
+func Test_parseResponse(t *testing.T) {
+	type args struct {
+		payload []byte
+	}
+
+	var p [][]byte
+
+	ans := new(dns.Msg)
+	ans.Id = 0
+	p = append(p, common.Must2(ans.Pack()).([]byte))
+
+	p = append(p, []byte{})
+
+	ans = new(dns.Msg)
+	ans.Id = 1
+	ans.Answer = append(ans.Answer,
+		common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
+		common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR),
+		common.Must2(dns.NewRR("google.com. IN A 8.8.8.8")).(dns.RR),
+		common.Must2(dns.NewRR("google.com. IN A 8.8.4.4")).(dns.RR),
+	)
+	p = append(p, common.Must2(ans.Pack()).([]byte))
+
+	ans = new(dns.Msg)
+	ans.Id = 2
+	ans.Answer = append(ans.Answer,
+		common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
+		common.Must2(dns.NewRR("google.com. IN CNAME fake.google.com")).(dns.RR),
+		common.Must2(dns.NewRR("google.com. IN CNAME m.test.google.com")).(dns.RR),
+		common.Must2(dns.NewRR("google.com. IN CNAME test.google.com")).(dns.RR),
+		common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8888")).(dns.RR),
+		common.Must2(dns.NewRR("google.com. IN AAAA 2001::123:8844")).(dns.RR),
+	)
+	p = append(p, common.Must2(ans.Pack()).([]byte))
+
+	tests := []struct {
+		name    string
+		want    *IPRecord
+		wantErr bool
+	}{
+		{"empty",
+			&IPRecord{0, []v2net.Address(nil), time.Time{}, dnsmessage.RCodeSuccess},
+			false,
+		},
+		{"error",
+			nil,
+			true,
+		},
+		{"a record",
+			&IPRecord{1, []v2net.Address{v2net.ParseAddress("8.8.8.8"), v2net.ParseAddress("8.8.4.4")},
+				time.Time{}, dnsmessage.RCodeSuccess},
+			false,
+		},
+		{"aaaa record",
+			&IPRecord{2, []v2net.Address{v2net.ParseAddress("2001::123:8888"), v2net.ParseAddress("2001::123:8844")}, time.Time{}, dnsmessage.RCodeSuccess},
+			false,
+		},
+	}
+	for i, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := parseResponse(p[i])
+			if (err != nil) != tt.wantErr {
+				t.Errorf("handleResponse() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+
+			if got != nil {
+				// reset the time
+				got.Expire = time.Time{}
+			}
+			if cmp.Diff(got, tt.want) != "" {
+				t.Errorf(cmp.Diff(got, tt.want))
+				// t.Errorf("handleResponse() = %#v, want %#v", got, tt.want)
+			}
+		})
+	}
+}
+
+func Test_buildReqMsgs(t *testing.T) {
+
+	stubID := func() uint16 {
+		return uint16(rand.Uint32())
+	}
+	type args struct {
+		domain  string
+		option  IPOption
+		reqOpts *dnsmessage.Resource
+	}
+	tests := []struct {
+		name string
+		args args
+		want int
+	}{
+		{"dual stack", args{"test.com", IPOption{true, true}, nil}, 2},
+		{"ipv4 only", args{"test.com", IPOption{true, false}, nil}, 1},
+		{"ipv6 only", args{"test.com", IPOption{false, true}, nil}, 1},
+		{"none/error", args{"test.com", IPOption{false, false}, nil}, 0},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			if got := buildReqMsgs(tt.args.domain, tt.args.option, stubID, tt.args.reqOpts); !(len(got) == tt.want) {
+				t.Errorf("buildReqMsgs() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func Test_genEDNS0Options(t *testing.T) {
+	type args struct {
+		clientIP net.IP
+	}
+	tests := []struct {
+		name string
+		args args
+		want *dnsmessage.Resource
+	}{
+		// TODO: Add test cases.
+		{"ipv4", args{net.ParseIP("4.3.2.1")}, nil},
+		{"ipv6", args{net.ParseIP("2001::4321")}, nil},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			if got := genEDNS0Options(tt.args.clientIP); got == nil {
+				t.Errorf("genEDNS0Options() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}
+
+func TestFqdn(t *testing.T) {
+	type args struct {
+		domain string
+	}
+	tests := []struct {
+		name string
+		args args
+		want string
+	}{
+		{"with fqdn", args{"www.v2ray.com."}, "www.v2ray.com."},
+		{"without fqdn", args{"www.v2ray.com"}, "www.v2ray.com."},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			if got := Fqdn(tt.args.domain); got != tt.want {
+				t.Errorf("Fqdn() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}

+ 315 - 0
app/dns/dohdns.go

@@ -0,0 +1,315 @@
+// +build !confonly
+
+package dns
+
+import (
+	"bytes"
+	"context"
+	"fmt"
+	"io/ioutil"
+	"net/http"
+	"sync"
+	"sync/atomic"
+	"time"
+
+	"golang.org/x/net/dns/dnsmessage"
+	"v2ray.com/core/common"
+	"v2ray.com/core/common/dice"
+	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/protocol/dns"
+	"v2ray.com/core/common/session"
+	"v2ray.com/core/common/signal/pubsub"
+	"v2ray.com/core/common/task"
+	"v2ray.com/core/features/routing"
+)
+
+// DoHNameServer implimented DNS over HTTPS (RFC8484) Wire Format,
+// which is compatiable with traditional dns over udp(RFC1035),
+// thus most of the DOH implimentation is copied from udpns.go
+type DoHNameServer struct {
+	sync.RWMutex
+	dispatcher routing.Dispatcher
+	dohDests   []net.Destination
+	ips        map[string]record
+	pub        *pubsub.Service
+	cleanup    *task.Periodic
+	reqID      uint32
+	clientIP   net.IP
+	httpClient *http.Client
+	dohURL     string
+	name       string
+}
+
+func NewDoHNameServer(dests []net.Destination, dohHost string, dispatcher routing.Dispatcher, clientIP net.IP) *DoHNameServer {
+
+	s := NewDoHLocalNameServer(dohHost, clientIP)
+	s.name = "DOH:" + dohHost
+	s.dispatcher = dispatcher
+	s.dohDests = dests
+
+	// Dispatched connection will be closed (interupted) after each request
+	// This makes DOH inefficient without a keeped-alive connection
+	// See: core/app/proxyman/outbound/handler.go:113
+	// Using mux (https request wrapped in a stream layer) improves the situation.
+	// Recommand to use NewDoHLocalNameServer (DOHL:) if v2ray instance is running on
+	//  a normal network eg. the server side of v2ray
+	tr := &http.Transport{
+		MaxIdleConns:        10,
+		IdleConnTimeout:     90 * time.Second,
+		TLSHandshakeTimeout: 10 * time.Second,
+		DialContext:         s.DialContext,
+	}
+
+	dispatchedClient := &http.Client{
+		Transport: tr,
+		Timeout:   16 * time.Second,
+	}
+
+	s.httpClient = dispatchedClient
+	return s
+}
+
+func NewDoHLocalNameServer(dohHost string, clientIP net.IP) *DoHNameServer {
+	s := &DoHNameServer{
+		httpClient: http.DefaultClient,
+		ips:        make(map[string]record),
+		clientIP:   clientIP,
+		pub:        pubsub.NewService(),
+		name:       "DOHL:" + dohHost,
+		dohURL:     fmt.Sprintf("https://%s/dns-query", dohHost),
+	}
+	s.cleanup = &task.Periodic{
+		Interval: time.Minute,
+		Execute:  s.Cleanup,
+	}
+	return s
+}
+
+func (s *DoHNameServer) Name() string {
+	return s.name
+}
+
+func (s *DoHNameServer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) {
+
+	dest := s.dohDests[dice.Roll(len(s.dohDests))]
+
+	link, err := s.dispatcher.Dispatch(ctx, dest)
+	if err != nil {
+		return nil, err
+	}
+	return net.NewConnection(
+		net.ConnectionInputMulti(link.Writer),
+		net.ConnectionOutputMulti(link.Reader),
+	), nil
+}
+
+func (s *DoHNameServer) Cleanup() error {
+	now := time.Now()
+	s.Lock()
+	defer s.Unlock()
+
+	if len(s.ips) == 0 {
+		return newError("nothing to do. stopping...")
+	}
+
+	for domain, record := range s.ips {
+		if record.A != nil && record.A.Expire.Before(now) {
+			record.A = nil
+		}
+		if record.AAAA != nil && record.AAAA.Expire.Before(now) {
+			record.AAAA = nil
+		}
+
+		if record.A == nil && record.AAAA == nil {
+			newError(s.name, " cleanup ", domain).AtDebug().WriteToLog()
+			delete(s.ips, domain)
+		} else {
+			s.ips[domain] = record
+		}
+	}
+
+	if len(s.ips) == 0 {
+		s.ips = make(map[string]record)
+	}
+
+	return nil
+}
+
+func (s *DoHNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) {
+	elapsed := time.Since(req.start)
+	newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
+
+	s.Lock()
+	rec := s.ips[req.domain]
+	updated := false
+
+	switch req.reqType {
+	case dnsmessage.TypeA:
+		if isNewer(rec.A, ipRec) {
+			rec.A = ipRec
+			updated = true
+		}
+	case dnsmessage.TypeAAAA:
+		if isNewer(rec.AAAA, ipRec) {
+			rec.AAAA = ipRec
+			updated = true
+		}
+	}
+
+	if updated {
+		s.ips[req.domain] = rec
+		s.pub.Publish(req.domain, nil)
+	}
+
+	s.Unlock()
+	common.Must(s.cleanup.Start())
+}
+
+func (s *DoHNameServer) newReqID() uint16 {
+	return uint16(atomic.AddUint32(&s.reqID, 1))
+}
+
+func (s *DoHNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
+	newError(s.name, " querying: ", domain).AtInfo().WriteToLog(session.ExportIDToError(ctx))
+
+	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP))
+
+	var deadline time.Time
+	if d, ok := ctx.Deadline(); ok {
+		deadline = d
+	} else {
+		deadline = time.Now().Add(time.Second * 8)
+	}
+
+	for _, req := range reqs {
+
+		go func(r *dnsRequest) {
+
+			// generate new context for each req, using same context
+			// may cause reqs all aborted if any one encounter an error
+			dnsCtx := context.Background()
+
+			// reserve internal dns server requested Inbound
+			if inbound := session.InboundFromContext(ctx); inbound != nil {
+				dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
+			}
+
+			dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{
+				Protocol: "https",
+			})
+
+			// forced to use mux for DOH
+			dnsCtx = session.ContextWithMuxPrefered(dnsCtx, true)
+
+			dnsCtx, cancel := context.WithDeadline(dnsCtx, deadline)
+			defer cancel()
+
+			b, _ := dns.PackMessage(r.msg)
+			resp, err := s.dohHTTPSContext(dnsCtx, b.Bytes())
+			if err != nil {
+				newError("failed to retrive response").Base(err).AtError().WriteToLog()
+				return
+			}
+			rec, err := parseResponse(resp)
+			if err != nil {
+				newError("failed to handle DOH response").Base(err).AtError().WriteToLog()
+				return
+			}
+			s.updateIP(r, rec)
+		}(req)
+	}
+}
+
+func (s *DoHNameServer) dohHTTPSContext(ctx context.Context, b []byte) ([]byte, error) {
+
+	body := bytes.NewBuffer(b)
+	req, err := http.NewRequest("POST", s.dohURL, body)
+	if err != nil {
+		return nil, err
+	}
+
+	req.Header.Add("Accept", "application/dns-message")
+	req.Header.Add("Content-Type", "application/dns-message")
+
+	resp, err := s.httpClient.Do(req.WithContext(ctx))
+	if err != nil {
+		return nil, err
+	}
+
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusOK {
+		err = fmt.Errorf("DOH HTTPS server returned with non-OK code %d", resp.StatusCode)
+		return nil, err
+	}
+
+	return ioutil.ReadAll(resp.Body)
+}
+
+func (s *DoHNameServer) findIPsForDomain(domain string, option IPOption) ([]net.IP, error) {
+	s.RLock()
+	record, found := s.ips[domain]
+	s.RUnlock()
+
+	if !found {
+		return nil, errRecordNotFound
+	}
+
+	var ips []net.Address
+	var lastErr error
+	if option.IPv6Enable && record.AAAA != nil && record.AAAA.RCode == dnsmessage.RCodeSuccess {
+		aaaa, err := record.AAAA.getIPs()
+		if err != nil {
+			lastErr = err
+		}
+		ips = append(ips, aaaa...)
+	}
+
+	if option.IPv4Enable && record.A != nil && record.A.RCode == dnsmessage.RCodeSuccess {
+		a, err := record.A.getIPs()
+		if err != nil {
+			lastErr = err
+		}
+		ips = append(ips, a...)
+	}
+
+	if len(ips) > 0 {
+		return toNetIP(ips), nil
+	}
+
+	if lastErr != nil {
+		return nil, lastErr
+	}
+
+	return nil, errRecordNotFound
+}
+
+// QueryIP is called from dns.Server->queryIPTimeout
+func (s *DoHNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
+
+	fqdn := Fqdn(domain)
+
+	ips, err := s.findIPsForDomain(fqdn, option)
+	if err != errRecordNotFound {
+		newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
+		return ips, err
+	}
+
+	sub := s.pub.Subscribe(fqdn)
+	defer sub.Close()
+
+	s.sendQuery(ctx, fqdn, option)
+
+	for {
+		ips, err := s.findIPsForDomain(fqdn, option)
+		if err != errRecordNotFound {
+			return ips, err
+		}
+
+		select {
+		case <-ctx.Done():
+			return nil, ctx.Err()
+		case <-sub.Wait():
+		}
+	}
+}

+ 53 - 1
app/dns/server.go

@@ -6,6 +6,8 @@ package dns
 
 
 import (
 import (
 	"context"
 	"context"
+	"fmt"
+	"strings"
 	"sync"
 	"sync"
 	"time"
 	"time"
 
 
@@ -87,6 +89,49 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 		address := endpoint.Address.AsAddress()
 		address := endpoint.Address.AsAddress()
 		if address.Family().IsDomain() && address.Domain() == "localhost" {
 		if address.Family().IsDomain() && address.Domain() == "localhost" {
 			server.clients = append(server.clients, NewLocalNameServer())
 			server.clients = append(server.clients, NewLocalNameServer())
+			newError("DNS: localhost inited").AtInfo().WriteToLog()
+		} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOHL_") {
+			dohHost := address.Domain()[5:]
+			server.clients = append(server.clients, NewDoHLocalNameServer(dohHost, server.clientIP))
+			newError("DNS: DOH - Local inited for https://", dohHost).AtInfo().WriteToLog()
+		} else if address.Family().IsDomain() && strings.HasPrefix(address.Domain(), "DOH_") {
+			// DOH_ prefix makes net.Address think it's a domain
+			// need to process the real address here.
+			dohHost := address.Domain()[4:]
+			dohAddr := net.ParseAddress(dohHost)
+			dohIP := dohHost
+			var dests []net.Destination
+
+			if dohAddr.Family().IsDomain() {
+				// resolve DOH server in advance
+				ips, err := net.LookupIP(dohAddr.Domain())
+				if err != nil || len(ips) == 0 {
+					return 0
+				}
+				for _, ip := range ips {
+					dohIP := ip.String()
+					if len(ip) == net.IPv6len {
+						dohIP = fmt.Sprintf("[%s]", dohIP)
+					}
+					dohdest, _ := net.ParseDestination(fmt.Sprintf("tcp:%s:443", dohIP))
+					dests = append(dests, dohdest)
+				}
+			} else {
+				// rfc8484, DOH service only use port 443
+				dest, err := net.ParseDestination(fmt.Sprintf("tcp:%s:443", dohIP))
+				if err != nil {
+					return 0
+				}
+				dests = []net.Destination{dest}
+			}
+
+			// need the core dispatcher, register DOHClient at callback
+			idx := len(server.clients)
+			server.clients = append(server.clients, nil)
+			common.Must(core.RequireFeatures(ctx, func(d routing.Dispatcher) {
+				server.clients[idx] = NewDoHNameServer(dests, dohHost, d, server.clientIP)
+				newError("DNS: DOH - Remote client inited for https://", dohHost).AtInfo().WriteToLog()
+			}))
 		} else {
 		} else {
 			dest := endpoint.AsDestination()
 			dest := endpoint.AsDestination()
 			if dest.Network == net.Network_Unknown {
 			if dest.Network == net.Network_Unknown {
@@ -100,6 +145,7 @@ func New(ctx context.Context, config *Config) (*Server, error) {
 					server.clients[idx] = NewClassicNameServer(dest, d, server.clientIP)
 					server.clients[idx] = NewClassicNameServer(dest, d, server.clientIP)
 				}))
 				}))
 			}
 			}
+			newError("DNS: UDP client inited for ", dest.NetAddr()).AtInfo().WriteToLog()
 		}
 		}
 		return len(server.clients) - 1
 		return len(server.clients) - 1
 	}
 	}
@@ -272,10 +318,16 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 		return nil, newError("empty domain name")
 		return nil, newError("empty domain name")
 	}
 	}
 
 
+	// normalize the FQDN form query
 	if domain[len(domain)-1] == '.' {
 	if domain[len(domain)-1] == '.' {
 		domain = domain[:len(domain)-1]
 		domain = domain[:len(domain)-1]
 	}
 	}
 
 
+	// skip domain without any dot
+	if strings.Index(domain, ".") == -1 {
+		return nil, newError("invalid domain name")
+	}
+
 	ips := s.lookupStatic(domain, option, 0)
 	ips := s.lookupStatic(domain, option, 0)
 	if ips != nil && ips[0].Family().IsIP() {
 	if ips != nil && ips[0].Family().IsIP() {
 		newError("returning ", len(ips), " IPs for domain ", domain).WriteToLog()
 		newError("returning ", len(ips), " IPs for domain ", domain).WriteToLog()
@@ -331,7 +383,7 @@ func (s *Server) lookupIPInternal(domain string, option IPOption) ([]net.IP, err
 		}
 		}
 	}
 	}
 
 
-	return nil, newError("returning nil for domain ", domain).Base(lastErr)
+	return nil, dns.ErrEmptyResponse.Base(lastErr)
 }
 }
 
 
 func init() {
 func init() {

+ 37 - 232
app/dns/udpns.go

@@ -4,14 +4,13 @@ package dns
 
 
 import (
 import (
 	"context"
 	"context"
-	"encoding/binary"
+	"strings"
 	"sync"
 	"sync"
 	"sync/atomic"
 	"sync/atomic"
 	"time"
 	"time"
 
 
 	"golang.org/x/net/dns/dnsmessage"
 	"golang.org/x/net/dns/dnsmessage"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
-	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol/dns"
 	"v2ray.com/core/common/protocol/dns"
 	udp_proto "v2ray.com/core/common/protocol/udp"
 	udp_proto "v2ray.com/core/common/protocol/udp"
@@ -23,42 +22,12 @@ import (
 	"v2ray.com/core/transport/internet/udp"
 	"v2ray.com/core/transport/internet/udp"
 )
 )
 
 
-type record struct {
-	A    *IPRecord
-	AAAA *IPRecord
-}
-
-type IPRecord struct {
-	IP     []net.Address
-	Expire time.Time
-	RCode  dnsmessage.RCode
-}
-
-func (r *IPRecord) getIPs() ([]net.Address, error) {
-	if r == nil || r.Expire.Before(time.Now()) {
-		return nil, errRecordNotFound
-	}
-	if r.RCode != dnsmessage.RCodeSuccess {
-		return nil, dns_feature.RCodeError(r.RCode)
-	}
-	return r.IP, nil
-}
-
-type pendingRequest struct {
-	domain  string
-	expire  time.Time
-	recType dnsmessage.Type
-}
-
-var (
-	errRecordNotFound = errors.New("record not found")
-)
-
 type ClassicNameServer struct {
 type ClassicNameServer struct {
 	sync.RWMutex
 	sync.RWMutex
+	name      string
 	address   net.Destination
 	address   net.Destination
 	ips       map[string]record
 	ips       map[string]record
-	requests  map[uint16]pendingRequest
+	requests  map[uint16]dnsRequest
 	pub       *pubsub.Service
 	pub       *pubsub.Service
 	udpServer *udp.Dispatcher
 	udpServer *udp.Dispatcher
 	cleanup   *task.Periodic
 	cleanup   *task.Periodic
@@ -70,9 +39,10 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
 	s := &ClassicNameServer{
 	s := &ClassicNameServer{
 		address:  address,
 		address:  address,
 		ips:      make(map[string]record),
 		ips:      make(map[string]record),
-		requests: make(map[uint16]pendingRequest),
+		requests: make(map[uint16]dnsRequest),
 		clientIP: clientIP,
 		clientIP: clientIP,
 		pub:      pubsub.NewService(),
 		pub:      pubsub.NewService(),
+		name:     strings.ToUpper(address.String()),
 	}
 	}
 	s.cleanup = &task.Periodic{
 	s.cleanup = &task.Periodic{
 		Interval: time.Minute,
 		Interval: time.Minute,
@@ -83,7 +53,7 @@ func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher
 }
 }
 
 
 func (s *ClassicNameServer) Name() string {
 func (s *ClassicNameServer) Name() string {
-	return s.address.String()
+	return s.name
 }
 }
 
 
 func (s *ClassicNameServer) Cleanup() error {
 func (s *ClassicNameServer) Cleanup() error {
@@ -92,7 +62,7 @@ func (s *ClassicNameServer) Cleanup() error {
 	defer s.Unlock()
 	defer s.Unlock()
 
 
 	if len(s.ips) == 0 && len(s.requests) == 0 {
 	if len(s.ips) == 0 && len(s.requests) == 0 {
-		return newError("nothing to do. stopping...")
+		return newError(s.name, " nothing to do. stopping...")
 	}
 	}
 
 
 	for domain, record := range s.ips {
 	for domain, record := range s.ips {
@@ -121,123 +91,52 @@ func (s *ClassicNameServer) Cleanup() error {
 	}
 	}
 
 
 	if len(s.requests) == 0 {
 	if len(s.requests) == 0 {
-		s.requests = make(map[uint16]pendingRequest)
+		s.requests = make(map[uint16]dnsRequest)
 	}
 	}
 
 
 	return nil
 	return nil
 }
 }
 
 
 func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
 func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) {
-	payload := packet.Payload
 
 
-	var parser dnsmessage.Parser
-	header, err := parser.Start(payload.Bytes())
+	ipRec, err := parseResponse(packet.Payload.Bytes())
 	if err != nil {
 	if err != nil {
-		newError("failed to parse DNS response").Base(err).AtWarning().WriteToLog()
-		return
-	}
-	if err := parser.SkipAllQuestions(); err != nil {
-		newError("failed to skip questions in DNS response").Base(err).AtWarning().WriteToLog()
+		newError(s.name, " fail to parse responsed DNS udp").AtError().WriteToLog()
 		return
 		return
 	}
 	}
 
 
-	id := header.ID
 	s.Lock()
 	s.Lock()
-	req, f := s.requests[id]
-	if f {
+	id := ipRec.ReqID
+	req, ok := s.requests[id]
+	if ok {
+		// remove the pending request
 		delete(s.requests, id)
 		delete(s.requests, id)
 	}
 	}
 	s.Unlock()
 	s.Unlock()
-
-	if !f {
+	if !ok {
+		newError(s.name, " cannot find the pending request").AtError().WriteToLog()
 		return
 		return
 	}
 	}
 
 
-	domain := req.domain
-	recType := req.recType
-
-	now := time.Now()
-	ipRecord := &IPRecord{
-		RCode:  header.RCode,
-		Expire: now.Add(time.Second * 600),
-	}
-
-L:
-	for {
-		header, err := parser.AnswerHeader()
-		if err != nil {
-			if err != dnsmessage.ErrSectionDone {
-				newError("failed to parse answer section for domain: ", domain).Base(err).WriteToLog()
-			}
-			break
-		}
-		ttl := header.TTL
-		if ttl == 0 {
-			ttl = 600
-		}
-		expire := now.Add(time.Duration(ttl) * time.Second)
-		if ipRecord.Expire.After(expire) {
-			ipRecord.Expire = expire
-		}
-
-		if header.Type != recType {
-			if err := parser.SkipAnswer(); err != nil {
-				newError("failed to skip answer").Base(err).WriteToLog()
-				break L
-			}
-			continue
-		}
-
-		switch header.Type {
-		case dnsmessage.TypeA:
-			ans, err := parser.AResource()
-			if err != nil {
-				newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
-				break L
-			}
-			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
-		case dnsmessage.TypeAAAA:
-			ans, err := parser.AAAAResource()
-			if err != nil {
-				newError("failed to parse A record for domain: ", domain).Base(err).WriteToLog()
-				break L
-			}
-			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
-		default:
-			if err := parser.SkipAnswer(); err != nil {
-				newError("failed to skip answer").Base(err).WriteToLog()
-				break L
-			}
-		}
-	}
-
 	var rec record
 	var rec record
-	switch recType {
+	switch req.reqType {
 	case dnsmessage.TypeA:
 	case dnsmessage.TypeA:
-		rec.A = ipRecord
+		rec.A = ipRec
 	case dnsmessage.TypeAAAA:
 	case dnsmessage.TypeAAAA:
-		rec.AAAA = ipRecord
-	}
-
-	if len(domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
-		s.updateIP(domain, rec)
+		rec.AAAA = ipRec
 	}
 	}
-}
 
 
-func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
-	if newRec == nil {
-		return false
+	elapsed := time.Since(req.start)
+	newError(s.name, " got answere: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog()
+	if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) {
+		s.updateIP(req.domain, rec)
 	}
 	}
-	if baseRec == nil {
-		return true
-	}
-	return baseRec.Expire.Before(newRec.Expire)
 }
 }
 
 
 func (s *ClassicNameServer) updateIP(domain string, newRec record) {
 func (s *ClassicNameServer) updateIP(domain string, newRec record) {
 	s.Lock()
 	s.Lock()
 
 
-	newError("updating IP records for domain:", domain).AtDebug().WriteToLog()
+	newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog()
 	rec := s.ips[domain]
 	rec := s.ips[domain]
 
 
 	updated := false
 	updated := false
@@ -259,116 +158,27 @@ func (s *ClassicNameServer) updateIP(domain string, newRec record) {
 	common.Must(s.cleanup.Start())
 	common.Must(s.cleanup.Start())
 }
 }
 
 
-func (s *ClassicNameServer) getMsgOptions() *dnsmessage.Resource {
-	if len(s.clientIP) == 0 {
-		return nil
-	}
-
-	var netmask int
-	var family uint16
-
-	if len(s.clientIP) == 4 {
-		family = 1
-		netmask = 24 // 24 for IPV4, 96 for IPv6
-	} else {
-		family = 2
-		netmask = 96
-	}
-
-	b := make([]byte, 4)
-	binary.BigEndian.PutUint16(b[0:], family)
-	b[2] = byte(netmask)
-	b[3] = 0
-	switch family {
-	case 1:
-		ip := s.clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
-		needLength := (netmask + 8 - 1) / 8 // division rounding up
-		b = append(b, ip[:needLength]...)
-	case 2:
-		ip := s.clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
-		needLength := (netmask + 8 - 1) / 8 // division rounding up
-		b = append(b, ip[:needLength]...)
-	}
-
-	const EDNS0SUBNET = 0x08
-
-	opt := new(dnsmessage.Resource)
-	common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
-
-	opt.Body = &dnsmessage.OPTResource{
-		Options: []dnsmessage.Option{
-			{
-				Code: EDNS0SUBNET,
-				Data: b,
-			},
-		},
-	}
-
-	return opt
+func (s *ClassicNameServer) newReqID() uint16 {
+	return uint16(atomic.AddUint32(&s.reqID, 1))
 }
 }
 
 
-func (s *ClassicNameServer) addPendingRequest(domain string, recType dnsmessage.Type) uint16 {
-	id := uint16(atomic.AddUint32(&s.reqID, 1))
+func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) {
 	s.Lock()
 	s.Lock()
 	defer s.Unlock()
 	defer s.Unlock()
 
 
-	s.requests[id] = pendingRequest{
-		domain:  domain,
-		expire:  time.Now().Add(time.Second * 8),
-		recType: recType,
-	}
-
-	return id
-}
-
-func (s *ClassicNameServer) buildMsgs(domain string, option IPOption) []*dnsmessage.Message {
-	qA := dnsmessage.Question{
-		Name:  dnsmessage.MustNewName(domain),
-		Type:  dnsmessage.TypeA,
-		Class: dnsmessage.ClassINET,
-	}
-
-	qAAAA := dnsmessage.Question{
-		Name:  dnsmessage.MustNewName(domain),
-		Type:  dnsmessage.TypeAAAA,
-		Class: dnsmessage.ClassINET,
-	}
-
-	var msgs []*dnsmessage.Message
-
-	if option.IPv4Enable {
-		msg := new(dnsmessage.Message)
-		msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeA)
-		msg.Header.RecursionDesired = true
-		msg.Questions = []dnsmessage.Question{qA}
-		if opt := s.getMsgOptions(); opt != nil {
-			msg.Additionals = append(msg.Additionals, *opt)
-		}
-		msgs = append(msgs, msg)
-	}
-
-	if option.IPv6Enable {
-		msg := new(dnsmessage.Message)
-		msg.Header.ID = s.addPendingRequest(domain, dnsmessage.TypeAAAA)
-		msg.Header.RecursionDesired = true
-		msg.Questions = []dnsmessage.Question{qAAAA}
-		if opt := s.getMsgOptions(); opt != nil {
-			msg.Additionals = append(msg.Additionals, *opt)
-		}
-		msgs = append(msgs, msg)
-	}
-
-	return msgs
+	id := req.msg.ID
+	req.expire = time.Now().Add(time.Second * 8)
+	s.requests[id] = *req
 }
 }
 
 
 func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
 func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option IPOption) {
-	newError("querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
-
-	msgs := s.buildMsgs(domain, option)
+	newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx))
 
 
-	for _, msg := range msgs {
-		b, _ := dns.PackMessage(msg)
+	reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP))
 
 
+	for _, req := range reqs {
+		s.addPendingRequest(req)
+		b, _ := dns.PackMessage(req.msg)
 		udpCtx := context.Background()
 		udpCtx := context.Background()
 		if inbound := session.InboundFromContext(ctx); inbound != nil {
 		if inbound := session.InboundFromContext(ctx); inbound != nil {
 			udpCtx = session.ContextWithInbound(udpCtx, inbound)
 			udpCtx = session.ContextWithInbound(udpCtx, inbound)
@@ -418,18 +228,13 @@ func (s *ClassicNameServer) findIPsForDomain(domain string, option IPOption) ([]
 	return nil, dns_feature.ErrEmptyResponse
 	return nil, dns_feature.ErrEmptyResponse
 }
 }
 
 
-func Fqdn(domain string) string {
-	if len(domain) > 0 && domain[len(domain)-1] == '.' {
-		return domain
-	}
-	return domain + "."
-}
-
 func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
 func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option IPOption) ([]net.IP, error) {
+
 	fqdn := Fqdn(domain)
 	fqdn := Fqdn(domain)
 
 
 	ips, err := s.findIPsForDomain(fqdn, option)
 	ips, err := s.findIPsForDomain(fqdn, option)
 	if err != errRecordNotFound {
 	if err != errRecordNotFound {
+		newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog()
 		return ips, err
 		return ips, err
 	}
 	}
 
 

+ 3 - 2
app/proxyman/outbound/handler.go

@@ -68,12 +68,13 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou
 		return nil, newError("not an outbound handler")
 		return nil, newError("not an outbound handler")
 	}
 	}
 
 
-	if h.senderSettings != nil && h.senderSettings.MultiplexSettings != nil && h.senderSettings.MultiplexSettings.Enabled {
+	if h.senderSettings != nil && h.senderSettings.MultiplexSettings != nil {
 		config := h.senderSettings.MultiplexSettings
 		config := h.senderSettings.MultiplexSettings
 		if config.Concurrency < 1 || config.Concurrency > 1024 {
 		if config.Concurrency < 1 || config.Concurrency > 1024 {
 			return nil, newError("invalid mux concurrency: ", config.Concurrency).AtWarning()
 			return nil, newError("invalid mux concurrency: ", config.Concurrency).AtWarning()
 		}
 		}
 		h.mux = &mux.ClientManager{
 		h.mux = &mux.ClientManager{
+			Enabled: h.senderSettings.MultiplexSettings.Enabled,
 			Picker: &mux.IncrementalWorkerPicker{
 			Picker: &mux.IncrementalWorkerPicker{
 				Factory: &mux.DialingWorkerFactory{
 				Factory: &mux.DialingWorkerFactory{
 					Proxy:  proxyHandler,
 					Proxy:  proxyHandler,
@@ -98,7 +99,7 @@ func (h *Handler) Tag() string {
 
 
 // Dispatch implements proxy.Outbound.Dispatch.
 // Dispatch implements proxy.Outbound.Dispatch.
 func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
 func (h *Handler) Dispatch(ctx context.Context, link *transport.Link) {
-	if h.mux != nil {
+	if h.mux != nil && (h.mux.Enabled || session.MuxPreferedFromContext(ctx)) {
 		if err := h.mux.Dispatch(ctx, link); err != nil {
 		if err := h.mux.Dispatch(ctx, link); err != nil {
 			newError("failed to process mux outbound traffic").Base(err).WriteToLog(session.ExportIDToError(ctx))
 			newError("failed to process mux outbound traffic").Base(err).WriteToLog(session.ExportIDToError(ctx))
 			common.Interrupt(link.Writer)
 			common.Interrupt(link.Writer)

+ 2 - 1
common/mux/client.go

@@ -21,7 +21,8 @@ import (
 )
 )
 
 
 type ClientManager struct {
 type ClientManager struct {
-	Picker WorkerPicker
+	Enabled bool // wheather mux is enabled from user config
+	Picker  WorkerPicker
 }
 }
 
 
 func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error {
 func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error {

+ 14 - 0
common/session/context.go

@@ -9,6 +9,7 @@ const (
 	inboundSessionKey
 	inboundSessionKey
 	outboundSessionKey
 	outboundSessionKey
 	contentSessionKey
 	contentSessionKey
+	MuxPreferedSessionKey
 )
 )
 
 
 // ContextWithID returns a new context with the given ID.
 // ContextWithID returns a new context with the given ID.
@@ -56,3 +57,16 @@ func ContentFromContext(ctx context.Context) *Content {
 	}
 	}
 	return nil
 	return nil
 }
 }
+
+// ContextWithMuxPrefered returns a new context with the given bool
+func ContextWithMuxPrefered(ctx context.Context, forced bool) context.Context {
+	return context.WithValue(ctx, MuxPreferedSessionKey, forced)
+}
+
+// MuxPreferedFromContext returns value in this context, or false if not contained.
+func MuxPreferedFromContext(ctx context.Context) bool {
+	if val, ok := ctx.Value(MuxPreferedSessionKey).(bool); ok {
+		return val
+	}
+	return false
+}

+ 17 - 11
infra/conf/v2ray.go

@@ -75,15 +75,24 @@ func (c *SniffingConfig) Build() (*proxyman.SniffingConfig, error) {
 }
 }
 
 
 type MuxConfig struct {
 type MuxConfig struct {
-	Enabled     bool   `json:"enabled"`
-	Concurrency uint16 `json:"concurrency"`
+	Enabled     bool  `json:"enabled"`
+	Concurrency int16 `json:"concurrency"`
 }
 }
 
 
-func (c *MuxConfig) GetConcurrency() uint16 {
-	if c.Concurrency == 0 {
-		return 8
+func (m *MuxConfig) Build() *proxyman.MultiplexingConfig {
+	if m.Concurrency < 0 {
+		return nil
+	}
+
+	var con uint32 = 8
+	if m.Concurrency > 0 {
+		con = uint32(m.Concurrency)
+	}
+
+	return &proxyman.MultiplexingConfig{
+		Enabled:     m.Enabled,
+		Concurrency: con,
 	}
 	}
-	return c.Concurrency
 }
 }
 
 
 type InboundDetourAllocationConfig struct {
 type InboundDetourAllocationConfig struct {
@@ -246,11 +255,8 @@ func (c *OutboundDetourConfig) Build() (*core.OutboundHandlerConfig, error) {
 		senderSettings.ProxySettings = ps
 		senderSettings.ProxySettings = ps
 	}
 	}
 
 
-	if c.MuxSettings != nil && c.MuxSettings.Enabled {
-		senderSettings.MultiplexSettings = &proxyman.MultiplexingConfig{
-			Enabled:     true,
-			Concurrency: uint32(c.MuxSettings.GetConcurrency()),
-		}
+	if c.MuxSettings != nil {
+		senderSettings.MultiplexSettings = c.MuxSettings.Build()
 	}
 	}
 
 
 	settings := []byte("{}")
 	settings := []byte("{}")

+ 33 - 1
infra/conf/v2ray_test.go

@@ -2,15 +2,16 @@ package conf_test
 
 
 import (
 import (
 	"encoding/json"
 	"encoding/json"
+	"reflect"
 	"testing"
 	"testing"
 
 
 	"github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/proto"
-
 	"v2ray.com/core"
 	"v2ray.com/core"
 	"v2ray.com/core/app/dispatcher"
 	"v2ray.com/core/app/dispatcher"
 	"v2ray.com/core/app/log"
 	"v2ray.com/core/app/log"
 	"v2ray.com/core/app/proxyman"
 	"v2ray.com/core/app/proxyman"
 	"v2ray.com/core/app/router"
 	"v2ray.com/core/app/router"
+	"v2ray.com/core/common"
 	clog "v2ray.com/core/common/log"
 	clog "v2ray.com/core/common/log"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/protocol"
 	"v2ray.com/core/common/protocol"
@@ -337,3 +338,34 @@ func TestV2RayConfig(t *testing.T) {
 		},
 		},
 	})
 	})
 }
 }
+
+func TestMuxConfig_Build(t *testing.T) {
+	tests := []struct {
+		name   string
+		fields string
+		want   *proxyman.MultiplexingConfig
+	}{
+		{"default", `{"enabled": true, "concurrency": 16}`, &proxyman.MultiplexingConfig{
+			Enabled:     true,
+			Concurrency: 16,
+		}},
+		{"empty def", `{}`, &proxyman.MultiplexingConfig{
+			Enabled:     false,
+			Concurrency: 8,
+		}},
+		{"not enable", `{"enabled": false, "concurrency": 4}`, &proxyman.MultiplexingConfig{
+			Enabled:     false,
+			Concurrency: 4,
+		}},
+		{"forbidden", `{"enabled": false, "concurrency": -1}`, nil},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			m := &MuxConfig{}
+			common.Must(json.Unmarshal([]byte(tt.fields), m))
+			if got := m.Build(); !reflect.DeepEqual(got, tt.want) {
+				t.Errorf("MuxConfig.Build() = %v, want %v", got, tt.want)
+			}
+		})
+	}
+}