Browse Source

validity map

Darien Raymond 10 years ago
parent
commit
9e84519134
2 changed files with 78 additions and 42 deletions
  1. 7 42
      app/dns/dns.go
  2. 71 0
      common/collect/validity_map.go

+ 7 - 42
app/dns/dns.go

@@ -2,10 +2,11 @@ package dns
 
 import (
 	"net"
-	"sync"
 	"time"
 
 	"github.com/v2ray/v2ray-core/app"
+	"github.com/v2ray/v2ray-core/common/collect"
+	"github.com/v2ray/v2ray-core/common/serial"
 )
 
 type entry struct {
@@ -32,66 +33,30 @@ func (this *entry) Extend() {
 }
 
 type DnsCache struct {
-	sync.RWMutex
-	cache  map[string]*entry
+	cache  *collect.ValidityMap
 	config CacheConfig
 }
 
 func NewCache(config CacheConfig) *DnsCache {
 	cache := &DnsCache{
-		cache:  make(map[string]*entry),
+		cache:  collect.NewValidityMap(3600),
 		config: config,
 	}
-	go cache.cleanup()
 	return cache
 }
 
-func (this *DnsCache) cleanup() {
-	for range time.Tick(60 * time.Second) {
-		entry2Remove := make([]*entry, 0, 128)
-		this.RLock()
-		for _, entry := range this.cache {
-			if !entry.IsValid() {
-				entry2Remove = append(entry2Remove, entry)
-			}
-		}
-		this.RUnlock()
-
-		for _, entry := range entry2Remove {
-			if !entry.IsValid() {
-				this.Lock()
-				delete(this.cache, entry.domain)
-				this.Unlock()
-			}
-		}
-	}
-}
-
 func (this *DnsCache) Add(context app.Context, domain string, ip net.IP) {
 	callerTag := context.CallerTag()
 	if !this.config.IsTrustedSource(callerTag) {
 		return
 	}
 
-	this.RLock()
-	entry, found := this.cache[domain]
-	this.RUnlock()
-	if found {
-		entry.ip = ip
-		entry.Extend()
-	} else {
-		this.Lock()
-		this.cache[domain] = newEntry(domain, ip)
-		this.Unlock()
-	}
+	this.cache.Set(serial.StringLiteral(domain), newEntry(domain, ip))
 }
 
 func (this *DnsCache) Get(context app.Context, domain string) net.IP {
-	this.RLock()
-	entry, found := this.cache[domain]
-	this.RUnlock()
-	if found {
-		return entry.ip
+	if value := this.cache.Get(serial.StringLiteral(domain)); value != nil {
+		return value.(*entry).ip
 	}
 	return nil
 }

+ 71 - 0
common/collect/validity_map.go

@@ -0,0 +1,71 @@
+package collect
+
+import (
+	"sync"
+	"time"
+
+	"github.com/v2ray/v2ray-core/common/serial"
+)
+
+type Validity interface {
+	IsValid() bool
+}
+
+type entry struct {
+	key   string
+	value Validity
+}
+
+type ValidityMap struct {
+	sync.RWMutex
+	cache              map[string]Validity
+	cleanupIntervalSec int
+}
+
+func NewValidityMap(cleanupIntervalSec int) *ValidityMap {
+	instance := &ValidityMap{
+		cache:              make(map[string]Validity),
+		cleanupIntervalSec: cleanupIntervalSec,
+	}
+	go instance.cleanup()
+	return instance
+}
+
+func (this *ValidityMap) cleanup() {
+	for range time.Tick(time.Duration(this.cleanupIntervalSec) * time.Second) {
+		entry2Remove := make([]entry, 0, 128)
+		this.RLock()
+		for key, value := range this.cache {
+			if !value.IsValid() {
+				entry2Remove = append(entry2Remove, entry{
+					key:   key,
+					value: value,
+				})
+			}
+		}
+		this.RUnlock()
+
+		for _, entry := range entry2Remove {
+			if !entry.value.IsValid() {
+				this.Lock()
+				delete(this.cache, entry.key)
+				this.Unlock()
+			}
+		}
+	}
+}
+
+func (this *ValidityMap) Set(key serial.String, value Validity) {
+	this.Lock()
+	this.cache[key.String()] = value
+	this.Unlock()
+}
+
+func (this *ValidityMap) Get(key serial.String) Validity {
+	this.RLock()
+	defer this.RUnlock()
+	if value, found := this.cache[key.String()]; found {
+		return value
+	}
+	return nil
+}