|  | @@ -2,10 +2,11 @@ package dns
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  import (
 | 
	
		
			
				|  |  |  	"net"
 | 
	
		
			
				|  |  | -	"sync"
 | 
	
		
			
				|  |  |  	"time"
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  	"github.com/v2ray/v2ray-core/app"
 | 
	
		
			
				|  |  | +	"github.com/v2ray/v2ray-core/common/collect"
 | 
	
		
			
				|  |  | +	"github.com/v2ray/v2ray-core/common/serial"
 | 
	
		
			
				|  |  |  )
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  type entry struct {
 | 
	
	
		
			
				|  | @@ -32,66 +33,30 @@ func (this *entry) Extend() {
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  type DnsCache struct {
 | 
	
		
			
				|  |  | -	sync.RWMutex
 | 
	
		
			
				|  |  | -	cache  map[string]*entry
 | 
	
		
			
				|  |  | +	cache  *collect.ValidityMap
 | 
	
		
			
				|  |  |  	config CacheConfig
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  func NewCache(config CacheConfig) *DnsCache {
 | 
	
		
			
				|  |  |  	cache := &DnsCache{
 | 
	
		
			
				|  |  | -		cache:  make(map[string]*entry),
 | 
	
		
			
				|  |  | +		cache:  collect.NewValidityMap(3600),
 | 
	
		
			
				|  |  |  		config: config,
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  | -	go cache.cleanup()
 | 
	
		
			
				|  |  |  	return cache
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -func (this *DnsCache) cleanup() {
 | 
	
		
			
				|  |  | -	for range time.Tick(60 * time.Second) {
 | 
	
		
			
				|  |  | -		entry2Remove := make([]*entry, 0, 128)
 | 
	
		
			
				|  |  | -		this.RLock()
 | 
	
		
			
				|  |  | -		for _, entry := range this.cache {
 | 
	
		
			
				|  |  | -			if !entry.IsValid() {
 | 
	
		
			
				|  |  | -				entry2Remove = append(entry2Remove, entry)
 | 
	
		
			
				|  |  | -			}
 | 
	
		
			
				|  |  | -		}
 | 
	
		
			
				|  |  | -		this.RUnlock()
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  | -		for _, entry := range entry2Remove {
 | 
	
		
			
				|  |  | -			if !entry.IsValid() {
 | 
	
		
			
				|  |  | -				this.Lock()
 | 
	
		
			
				|  |  | -				delete(this.cache, entry.domain)
 | 
	
		
			
				|  |  | -				this.Unlock()
 | 
	
		
			
				|  |  | -			}
 | 
	
		
			
				|  |  | -		}
 | 
	
		
			
				|  |  | -	}
 | 
	
		
			
				|  |  | -}
 | 
	
		
			
				|  |  | -
 | 
	
		
			
				|  |  |  func (this *DnsCache) Add(context app.Context, domain string, ip net.IP) {
 | 
	
		
			
				|  |  |  	callerTag := context.CallerTag()
 | 
	
		
			
				|  |  |  	if !this.config.IsTrustedSource(callerTag) {
 | 
	
		
			
				|  |  |  		return
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  | -	this.RLock()
 | 
	
		
			
				|  |  | -	entry, found := this.cache[domain]
 | 
	
		
			
				|  |  | -	this.RUnlock()
 | 
	
		
			
				|  |  | -	if found {
 | 
	
		
			
				|  |  | -		entry.ip = ip
 | 
	
		
			
				|  |  | -		entry.Extend()
 | 
	
		
			
				|  |  | -	} else {
 | 
	
		
			
				|  |  | -		this.Lock()
 | 
	
		
			
				|  |  | -		this.cache[domain] = newEntry(domain, ip)
 | 
	
		
			
				|  |  | -		this.Unlock()
 | 
	
		
			
				|  |  | -	}
 | 
	
		
			
				|  |  | +	this.cache.Set(serial.StringLiteral(domain), newEntry(domain, ip))
 | 
	
		
			
				|  |  |  }
 | 
	
		
			
				|  |  |  
 | 
	
		
			
				|  |  |  func (this *DnsCache) Get(context app.Context, domain string) net.IP {
 | 
	
		
			
				|  |  | -	this.RLock()
 | 
	
		
			
				|  |  | -	entry, found := this.cache[domain]
 | 
	
		
			
				|  |  | -	this.RUnlock()
 | 
	
		
			
				|  |  | -	if found {
 | 
	
		
			
				|  |  | -		return entry.ip
 | 
	
		
			
				|  |  | +	if value := this.cache.Get(serial.StringLiteral(domain)); value != nil {
 | 
	
		
			
				|  |  | +		return value.(*entry).ip
 | 
	
		
			
				|  |  |  	}
 | 
	
		
			
				|  |  |  	return nil
 | 
	
		
			
				|  |  |  }
 |