Ver código fonte

clean up dns package

Darien Raymond 8 anos atrás
pai
commit
0dbfb66126

+ 16 - 14
app/dns/server/nameserver.go

@@ -2,12 +2,14 @@ package server
 
 import (
 	"context"
+	"fmt"
 	"sync"
 	"time"
 
 	"github.com/miekg/dns"
 	"v2ray.com/core/app/dispatcher"
 	"v2ray.com/core/app/log"
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/dice"
 	"v2ray.com/core/common/net"
@@ -15,7 +17,6 @@ import (
 )
 
 const (
-	DefaultTTL       = uint32(3600)
 	CleanupInterval  = time.Second * 120
 	CleanupThreshold = 512
 )
@@ -55,7 +56,6 @@ func NewUDPNameServer(address net.Destination, dispatcher dispatcher.Interface)
 	return s
 }
 
-// Private: Visible for testing.
 func (v *UDPNameServer) Cleanup() {
 	expiredRequests := make([]uint16, 0, 16)
 	now := time.Now()
@@ -70,10 +70,8 @@ func (v *UDPNameServer) Cleanup() {
 		delete(v.requests, id)
 	}
 	v.Unlock()
-	expiredRequests = nil
 }
 
-// Private: Visible for testing.
 func (v *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
 	var id uint16
 	v.Lock()
@@ -98,7 +96,6 @@ func (v *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
 	return id
 }
 
-// Private: Visible for testing.
 func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
 	msg := new(dns.Msg)
 	err := msg.Unpack(payload.Bytes())
@@ -110,8 +107,8 @@ func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
 		IPs: make([]net.IP, 0, 16),
 	}
 	id := msg.Id
-	ttl := DefaultTTL
-	log.Trace(newError("handling response for id ", id, " content: ", msg.String()).AtDebug())
+	ttl := uint32(3600) // an hour
+	log.Trace(newError("handling response for id ", id, " content: ", msg).AtDebug())
 
 	v.Lock()
 	request, found := v.requests[id]
@@ -126,6 +123,7 @@ func (v *UDPNameServer) HandleResponse(payload *buf.Buffer) {
 		switch rr := rr.(type) {
 		case *dns.A:
 			record.IPs = append(record.IPs, rr.A)
+			fmt.Println("Adding ans:", rr.A)
 			if rr.Hdr.Ttl < ttl {
 				ttl = rr.Hdr.Ttl
 			}
@@ -152,13 +150,18 @@ func (v *UDPNameServer) BuildQueryA(domain string, id uint16) *buf.Buffer {
 			Name:   dns.Fqdn(domain),
 			Qtype:  dns.TypeA,
 			Qclass: dns.ClassINET,
+		},
+		{
+			Name:   dns.Fqdn(domain),
+			Qtype:  dns.TypeAAAA,
+			Qclass: dns.ClassINET,
 		}}
 
 	buffer := buf.New()
-	buffer.AppendSupplier(func(b []byte) (int, error) {
+	common.Must(buffer.Reset(func(b []byte) (int, error) {
 		writtenBuffer, err := msg.PackBuffer(b)
 		return len(writtenBuffer), err
-	})
+	}))
 
 	return buffer
 }
@@ -167,7 +170,7 @@ func (v *UDPNameServer) QueryA(domain string) <-chan *ARecord {
 	response := make(chan *ARecord, 1)
 	id := v.AssignUnusedID(response)
 
-	ctx, cancel := context.WithTimeout(context.Background(), time.Second*8)
+	ctx, cancel := context.WithCancel(context.Background())
 	v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
 
 	go func() {
@@ -176,11 +179,10 @@ func (v *UDPNameServer) QueryA(domain string) <-chan *ARecord {
 			v.Lock()
 			_, found := v.requests[id]
 			v.Unlock()
-			if found {
-				v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
-			} else {
+			if !found {
 				break
 			}
+			v.udpServer.Dispatch(ctx, v.address, v.BuildQueryA(domain, id), v.HandleResponse)
 		}
 		cancel()
 	}()
@@ -205,7 +207,7 @@ func (v *LocalNameServer) QueryA(domain string) <-chan *ARecord {
 
 		response <- &ARecord{
 			IPs:    ips,
-			Expire: time.Now().Add(time.Second * time.Duration(DefaultTTL)),
+			Expire: time.Now().Add(time.Hour),
 		}
 	}()
 

+ 0 - 20
app/dns/server/querier.go

@@ -1,20 +0,0 @@
-package server
-
-import (
-	"time"
-
-	"v2ray.com/core/common/net"
-)
-
-type IPResult struct {
-	IP  []net.IP
-	TTL time.Duration
-}
-
-type Querier interface {
-	QueryDomain(domain string) <-chan *IPResult
-}
-
-type UDPQuerier struct {
-	server net.Destination
-}

+ 34 - 12
app/dns/server/server.go

@@ -21,22 +21,22 @@ const (
 )
 
 type DomainRecord struct {
-	A *ARecord
-}
-
-type Record struct {
 	IP         []net.IP
 	Expire     time.Time
 	LastAccess time.Time
 }
 
-func (r *Record) Expired() bool {
+func (r *DomainRecord) Expired() bool {
+	return r.Expire.Before(time.Now())
+}
+
+func (r *DomainRecord) Inactive() bool {
 	now := time.Now()
-	return r.Expire.Before(now) || r.LastAccess.Add(time.Hour).Before(now)
+	return r.Expire.Before(now) || r.LastAccess.Add(time.Minute*5).Before(now)
 }
 
 type CacheServer struct {
-	sync.RWMutex
+	sync.Mutex
 	hosts   map[string]net.IP
 	records map[string]*DomainRecord
 	servers []NameServer
@@ -90,15 +90,33 @@ func (*CacheServer) Start() error {
 func (*CacheServer) Close() {}
 
 func (s *CacheServer) GetCached(domain string) []net.IP {
-	s.RLock()
-	defer s.RUnlock()
+	s.Lock()
+	defer s.Unlock()
 
-	if record, found := s.records[domain]; found && record.A.Expire.After(time.Now()) {
-		return record.A.IPs
+	if record, found := s.records[domain]; found && !record.Expired() {
+		record.LastAccess = time.Now()
+		return record.IP
 	}
 	return nil
 }
 
+func (s *CacheServer) tryCleanup() {
+	s.Lock()
+	defer s.Unlock()
+
+	if len(s.records) > 256 {
+		domains := make([]string, 0, 256)
+		for d, r := range s.records {
+			if r.Expired() {
+				domains = append(domains, d)
+			}
+		}
+		for _, d := range domains {
+			delete(s.records, d)
+		}
+	}
+}
+
 func (s *CacheServer) Get(domain string) []net.IP {
 	if ip, found := s.hosts[domain]; found {
 		return []net.IP{ip}
@@ -110,6 +128,8 @@ func (s *CacheServer) Get(domain string) []net.IP {
 		return ips
 	}
 
+	s.tryCleanup()
+
 	for _, server := range s.servers {
 		response := server.QueryA(domain)
 		select {
@@ -119,7 +139,9 @@ func (s *CacheServer) Get(domain string) []net.IP {
 			}
 			s.Lock()
 			s.records[domain] = &DomainRecord{
-				A: a,
+				IP:         a.IPs,
+				Expire:     a.Expire,
+				LastAccess: time.Now(),
 			}
 			s.Unlock()
 			log.Trace(newError("returning ", len(a.IPs), " IPs for domain ", domain).AtDebug())

+ 102 - 0
app/dns/server/server_test.go

@@ -0,0 +1,102 @@
+package server_test
+
+import (
+	"context"
+	"testing"
+
+	"v2ray.com/core/app"
+	"v2ray.com/core/app/dispatcher"
+	_ "v2ray.com/core/app/dispatcher/impl"
+	. "v2ray.com/core/app/dns"
+	_ "v2ray.com/core/app/dns/server"
+	"v2ray.com/core/app/proxyman"
+	_ "v2ray.com/core/app/proxyman/outbound"
+	"v2ray.com/core/common"
+	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/serial"
+	"v2ray.com/core/proxy/freedom"
+	"v2ray.com/core/testing/servers/udp"
+	. "v2ray.com/ext/assert"
+
+	"github.com/miekg/dns"
+)
+
+type staticHandler struct {
+}
+
+func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
+	ans := new(dns.Msg)
+	ans.Id = r.Id
+	for _, q := range r.Question {
+		if q.Name == "google.com." && q.Qtype == dns.TypeA {
+			rr, _ := dns.NewRR("google.com. IN A 8.8.8.8")
+			ans.Answer = append(ans.Answer, rr)
+		} else if q.Name == "facebook.com." && q.Qtype == dns.TypeA {
+			rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9")
+			ans.Answer = append(ans.Answer, rr)
+		}
+	}
+	w.WriteMsg(ans)
+}
+
+func TestUDPServer(t *testing.T) {
+	assert := With(t)
+
+	port := udp.PickPort()
+
+	dnsServer := dns.Server{
+		Addr:    "127.0.0.1:" + port.String(),
+		Net:     "udp",
+		Handler: &staticHandler{},
+		UDPSize: 1200,
+	}
+
+	go dnsServer.ListenAndServe()
+
+	config := &Config{
+		NameServers: []*net.Endpoint{
+			{
+				Network: net.Network_UDP,
+				Address: &net.IPOrDomain{
+					Address: &net.IPOrDomain_Ip{
+						Ip: []byte{127, 0, 0, 1},
+					},
+				},
+				Port: uint32(port),
+			},
+		},
+	}
+
+	ctx := context.Background()
+	space := app.NewSpace()
+
+	ctx = app.ContextWithSpace(ctx, space)
+	common.Must(app.AddApplicationToSpace(ctx, config))
+	common.Must(app.AddApplicationToSpace(ctx, &dispatcher.Config{}))
+	common.Must(app.AddApplicationToSpace(ctx, &proxyman.OutboundConfig{}))
+
+	om := proxyman.OutboundHandlerManagerFromSpace(space)
+	om.AddHandler(ctx, &proxyman.OutboundHandlerConfig{
+		ProxySettings: serial.ToTypedMessage(&freedom.Config{}),
+	})
+
+	common.Must(space.Initialize())
+	common.Must(space.Start())
+
+	server := FromSpace(space)
+	assert(server, IsNotNil)
+
+	ips := server.Get("google.com")
+	assert(len(ips), Equals, 1)
+	assert([]byte(ips[0]), Equals, []byte{8, 8, 8, 8})
+
+	ips = server.Get("facebook.com")
+	assert(len(ips), Equals, 1)
+	assert([]byte(ips[0]), Equals, []byte{9, 9, 9, 9})
+
+	dnsServer.Shutdown()
+
+	ips = server.Get("google.com")
+	assert(len(ips), Equals, 1)
+	assert([]byte(ips[0]), Equals, []byte{8, 8, 8, 8})
+}

+ 2 - 3
common/signal/timer.go

@@ -53,8 +53,7 @@ func (t *ActivityTimer) run() {
 	}
 }
 
-func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.Context, *ActivityTimer) {
-	ctx, cancel := context.WithCancel(ctx)
+func CancelAfterInactivity(ctx context.Context, cancel context.CancelFunc, timeout time.Duration) *ActivityTimer {
 	timer := &ActivityTimer{
 		ctx:     ctx,
 		cancel:  cancel,
@@ -63,5 +62,5 @@ func CancelAfterInactivity(ctx context.Context, timeout time.Duration) (context.
 	}
 	timer.timeout <- timeout
 	go timer.run()
-	return ctx, timer
+	return timer
 }

+ 4 - 2
common/signal/timer_test.go

@@ -13,7 +13,8 @@ import (
 func TestActivityTimer(t *testing.T) {
 	assert := With(t)
 
-	ctx, timer := CancelAfterInactivity(context.Background(), time.Second*5)
+	ctx, cancel := context.WithCancel(context.Background())
+	timer := CancelAfterInactivity(ctx, cancel, time.Second*5)
 	time.Sleep(time.Second * 6)
 	assert(ctx.Err(), IsNotNil)
 	runtime.KeepAlive(timer)
@@ -22,7 +23,8 @@ func TestActivityTimer(t *testing.T) {
 func TestActivityTimerUpdate(t *testing.T) {
 	assert := With(t)
 
-	ctx, timer := CancelAfterInactivity(context.Background(), time.Second*10)
+	ctx, cancel := context.WithCancel(context.Background())
+	timer := CancelAfterInactivity(ctx, cancel, time.Second*10)
 	time.Sleep(time.Second * 3)
 	assert(ctx.Err(), IsNil)
 	timer.SetTimeout(time.Second * 1)

+ 3 - 1
proxy/dokodemo/dokodemo.go

@@ -64,7 +64,9 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 	if timeout == 0 {
 		timeout = time.Minute * 5
 	}
-	ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
+
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, timeout)
 
 	inboundRay, err := dispatcher.Dispatch(ctx, dest)
 	if err != nil {

+ 2 - 1
proxy/freedom/freedom.go

@@ -107,7 +107,8 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
 	if timeout == 0 {
 		timeout = time.Minute * 5
 	}
-	ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, timeout)
 
 	requestDone := signal.ExecuteAsync(func() error {
 		var writer buf.Writer

+ 2 - 1
proxy/http/server.go

@@ -153,7 +153,8 @@ func (s *Server) handleConnect(ctx context.Context, request *http.Request, reade
 	if timeout == 0 {
 		timeout = time.Minute * 5
 	}
-	ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, timeout)
 	ray, err := dispatcher.Dispatch(ctx, dest)
 	if err != nil {
 		return err

+ 2 - 1
proxy/shadowsocks/client.go

@@ -90,7 +90,8 @@ func (v *Client) Process(ctx context.Context, outboundRay ray.OutboundRay, diale
 		request.Option |= RequestOptionOneTimeAuth
 	}
 
-	ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5)
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*5)
 
 	if request.Command == protocol.RequestCommandTCP {
 		bufferedWriter := buf.NewBufferedWriter(buf.NewWriter(conn))

+ 2 - 1
proxy/shadowsocks/server.go

@@ -146,7 +146,8 @@ func (s *Server) handleConnection(ctx context.Context, conn internet.Connection,
 	ctx = protocol.ContextWithUser(ctx, request.User)
 
 	userSettings := s.user.GetSettings()
-	ctx, timer := signal.CancelAfterInactivity(ctx, userSettings.PayloadTimeout)
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, userSettings.PayloadTimeout)
 	ray, err := dispatcher.Dispatch(ctx, dest)
 	if err != nil {
 		return err

+ 2 - 1
proxy/socks/client.go

@@ -83,7 +83,8 @@ func (c *Client) Process(ctx context.Context, ray ray.OutboundRay, dialer proxy.
 		return newError("failed to establish connection to server").AtWarning().Base(err)
 	}
 
-	ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5)
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*5)
 
 	var requestFunc func() error
 	var responseFunc func() error

+ 2 - 1
proxy/socks/server.go

@@ -107,7 +107,8 @@ func (v *Server) transport(ctx context.Context, reader io.Reader, writer io.Writ
 	if timeout == 0 {
 		timeout = time.Minute * 5
 	}
-	ctx, timer := signal.CancelAfterInactivity(ctx, timeout)
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, timeout)
 
 	ray, err := dispatcher.Dispatch(ctx, dest)
 	if err != nil {

+ 2 - 1
proxy/vmess/inbound/inbound.go

@@ -204,7 +204,8 @@ func (v *Handler) Process(ctx context.Context, network net.Network, connection i
 
 	ctx = protocol.ContextWithUser(ctx, request.User)
 
-	ctx, timer := signal.CancelAfterInactivity(ctx, userSettings.PayloadTimeout)
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, userSettings.PayloadTimeout)
 	ray, err := dispatcher.Dispatch(ctx, request.Destination())
 	if err != nil {
 		return newError("failed to dispatch request to ", request.Destination()).Base(err)

+ 2 - 1
proxy/vmess/outbound/outbound.go

@@ -103,7 +103,8 @@ func (v *Handler) Process(ctx context.Context, outboundRay ray.OutboundRay, dial
 
 	session := encoding.NewClientSession(protocol.DefaultIDHash)
 
-	ctx, timer := signal.CancelAfterInactivity(ctx, time.Minute*5)
+	ctx, cancel := context.WithCancel(ctx)
+	timer := signal.CancelAfterInactivity(ctx, cancel, time.Minute*5)
 
 	requestDone := signal.ExecuteAsync(func() error {
 		writer := buf.NewBufferedWriter(buf.NewWriter(conn))

+ 48 - 19
transport/internet/udp/dispatcher.go

@@ -3,25 +3,33 @@ package udp
 import (
 	"context"
 	"sync"
+	"time"
 
 	"v2ray.com/core/app/dispatcher"
 	"v2ray.com/core/app/log"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/signal"
 	"v2ray.com/core/transport/ray"
 )
 
 type ResponseCallback func(payload *buf.Buffer)
 
+type connEntry struct {
+	inbound ray.InboundRay
+	timer   signal.ActivityUpdater
+	cancel  context.CancelFunc
+}
+
 type Dispatcher struct {
 	sync.RWMutex
-	conns      map[net.Destination]ray.InboundRay
+	conns      map[net.Destination]*connEntry
 	dispatcher dispatcher.Interface
 }
 
 func NewDispatcher(dispatcher dispatcher.Interface) *Dispatcher {
 	return &Dispatcher{
-		conns:      make(map[net.Destination]ray.InboundRay),
+		conns:      make(map[net.Destination]*connEntry),
 		dispatcher: dispatcher,
 	}
 }
@@ -30,51 +38,72 @@ func (v *Dispatcher) RemoveRay(dest net.Destination) {
 	v.Lock()
 	defer v.Unlock()
 	if conn, found := v.conns[dest]; found {
-		conn.InboundInput().Close()
-		conn.InboundOutput().Close()
+		conn.inbound.InboundInput().Close()
+		conn.inbound.InboundOutput().Close()
 		delete(v.conns, dest)
 	}
 }
 
-func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (ray.InboundRay, bool) {
+func (v *Dispatcher) getInboundRay(dest net.Destination, callback ResponseCallback) *connEntry {
 	v.Lock()
 	defer v.Unlock()
 
 	if entry, found := v.conns[dest]; found {
-		return entry, true
+		return entry
 	}
 
 	log.Trace(newError("establishing new connection for ", dest))
+
+	ctx, cancel := context.WithCancel(context.Background())
+	removeRay := func() {
+		cancel()
+		v.RemoveRay(dest)
+	}
+	timer := signal.CancelAfterInactivity(ctx, removeRay, time.Second*4)
 	inboundRay, _ := v.dispatcher.Dispatch(ctx, dest)
-	v.conns[dest] = inboundRay
-	return inboundRay, false
+	entry := &connEntry{
+		inbound: inboundRay,
+		timer:   timer,
+		cancel:  removeRay,
+	}
+	v.conns[dest] = entry
+	go handleInput(ctx, entry, callback)
+	return entry
 }
 
 func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer, callback ResponseCallback) {
 	// TODO: Add user to destString
 	log.Trace(newError("dispatch request to: ", destination).AtDebug())
 
-	inboundRay, existing := v.getInboundRay(ctx, destination)
-	outputStream := inboundRay.InboundInput()
+	conn := v.getInboundRay(destination, callback)
+	outputStream := conn.inbound.InboundInput()
 	if outputStream != nil {
 		if err := outputStream.WriteMultiBuffer(buf.NewMultiBufferValue(payload)); err != nil {
-			v.RemoveRay(destination)
+			log.Trace(newError("failed to write first UDP payload").Base(err))
+			conn.cancel()
+			return
 		}
 	}
-	if !existing {
-		go func() {
-			handleInput(inboundRay.InboundOutput(), callback)
-			v.RemoveRay(destination)
-		}()
-	}
 }
 
-func handleInput(input ray.InputStream, callback ResponseCallback) {
+func handleInput(ctx context.Context, conn *connEntry, callback ResponseCallback) {
+	input := conn.inbound.InboundOutput()
+	timer := conn.timer
+
 	for {
+		select {
+		case <-ctx.Done():
+			return
+		default:
+		}
+
 		mb, err := input.ReadMultiBuffer()
 		if err != nil {
-			break
+			log.Trace(newError("failed to handl UDP input").Base(err))
+			conn.cancel()
+			return
 		}
+		timer.Update()
 		for _, b := range mb {
 			callback(b)
 		}