ソースを参照

refactor UDPNameServer clean up task

Darien Raymond 7 年 前
コミット
0a3b3d0b6d
1 ファイル変更14 行追加19 行削除
  1. 14 19
      app/dns/nameserver.go

+ 14 - 19
app/dns/nameserver.go

@@ -10,14 +10,10 @@ import (
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/dice"
 	"v2ray.com/core/common/net"
+	"v2ray.com/core/common/signal"
 	"v2ray.com/core/transport/internet/udp"
 )
 
-const (
-	CleanupInterval  = time.Second * 120
-	CleanupThreshold = 512
-)
-
 var (
 	multiQuestionDNS = map[net.Address]bool{
 		net.IPAddress([]byte{8, 8, 8, 8}): true,
@@ -42,10 +38,10 @@ type PendingRequest struct {
 
 type UDPNameServer struct {
 	sync.Mutex
-	address     net.Destination
-	requests    map[uint16]*PendingRequest
-	udpServer   *udp.Dispatcher
-	nextCleanup time.Time
+	address   net.Destination
+	requests  map[uint16]*PendingRequest
+	udpServer *udp.Dispatcher
+	cleanup   *signal.PeriodicTask
 }
 
 func NewUDPNameServer(address net.Destination, dispatcher core.Dispatcher) *UDPNameServer {
@@ -54,36 +50,35 @@ func NewUDPNameServer(address net.Destination, dispatcher core.Dispatcher) *UDPN
 		requests:  make(map[uint16]*PendingRequest),
 		udpServer: udp.NewDispatcher(dispatcher),
 	}
+	s.cleanup = &signal.PeriodicTask{
+		Interval: time.Minute,
+		Execute:  s.Cleanup,
+	}
+	s.cleanup.Start()
 	return s
 }
 
-func (s *UDPNameServer) Cleanup() {
-	expiredRequests := make([]uint16, 0, 16)
+func (s *UDPNameServer) Cleanup() error {
 	now := time.Now()
 	s.Lock()
 	for id, r := range s.requests {
 		if r.expire.Before(now) {
-			expiredRequests = append(expiredRequests, id)
 			close(r.response)
+			delete(s.requests, id)
 		}
 	}
-	for _, id := range expiredRequests {
-		delete(s.requests, id)
-	}
 	s.Unlock()
+	return nil
 }
 
 func (s *UDPNameServer) AssignUnusedID(response chan<- *ARecord) uint16 {
 	var id uint16
 	s.Lock()
-	if len(s.requests) > CleanupThreshold && s.nextCleanup.Before(time.Now()) {
-		s.nextCleanup = time.Now().Add(CleanupInterval)
-		go s.Cleanup()
-	}
 
 	for {
 		id = dice.RollUint16()
 		if _, found := s.requests[id]; found {
+			time.Sleep(time.Millisecond * 500)
 			continue
 		}
 		newError("add pending request id ", id).AtDebug().WriteToLog()