Jelajahi Sumber

simplify stringlist

v2ray 9 tahun lalu
induk
melakukan
2d233295e6

+ 11 - 2
app/dns/config.go

@@ -1,5 +1,14 @@
 package dns
 
-type CacheConfig interface {
-	IsTrustedSource(tag string) bool
+import (
+	"github.com/v2ray/v2ray-core/common/serial"
+)
+
+type CacheConfig struct {
+	TrustedTags map[serial.StringLiteral]bool
+}
+
+func (this *CacheConfig) IsTrustedSource(tag serial.StringLiteral) bool {
+	_, found := this.TrustedTags[tag]
+	return found
 }

+ 23 - 0
app/dns/config_json.go

@@ -0,0 +1,23 @@
+// +build json
+
+package dns
+
+import (
+	"encoding/json"
+
+	"github.com/v2ray/v2ray-core/common/serial"
+)
+
+func (this *CacheConfig) UnmarshalJSON(data []byte) error {
+	var strlist serial.StringLiteralList
+	if err := json.Unmarshal(data, strlist); err != nil {
+		return err
+	}
+	config := &CacheConfig{
+		TrustedTags: make(map[serial.StringLiteral]bool, strlist.Len()),
+	}
+	for _, str := range strlist {
+		config.TrustedTags[str.TrimSpace()] = true
+	}
+	return nil
+}

+ 3 - 3
app/dns/dns.go

@@ -34,10 +34,10 @@ func (this *entry) Extend() {
 
 type DnsCache struct {
 	cache  *collect.ValidityMap
-	config CacheConfig
+	config *CacheConfig
 }
 
-func NewCache(config CacheConfig) *DnsCache {
+func NewCache(config *CacheConfig) *DnsCache {
 	cache := &DnsCache{
 		cache:  collect.NewValidityMap(3600),
 		config: config,
@@ -47,7 +47,7 @@ func NewCache(config CacheConfig) *DnsCache {
 
 func (this *DnsCache) Add(context app.Context, domain string, ip net.IP) {
 	callerTag := context.CallerTag()
-	if !this.config.IsTrustedSource(callerTag) {
+	if !this.config.IsTrustedSource(serial.StringLiteral(callerTag)) {
 		return
 	}
 

+ 4 - 4
app/dns/dns_test.go

@@ -5,9 +5,9 @@ import (
 	"testing"
 
 	"github.com/v2ray/v2ray-core/app/dns"
-	dnstesting "github.com/v2ray/v2ray-core/app/dns/testing"
 	apptesting "github.com/v2ray/v2ray-core/app/testing"
 	netassert "github.com/v2ray/v2ray-core/common/net/testing/assert"
+	"github.com/v2ray/v2ray-core/common/serial"
 	v2testing "github.com/v2ray/v2ray-core/testing"
 )
 
@@ -15,9 +15,9 @@ func TestDnsAdd(t *testing.T) {
 	v2testing.Current(t)
 
 	domain := "v2ray.com"
-	cache := dns.NewCache(&dnstesting.CacheConfig{
-		TrustedTags: map[string]bool{
-			"testtag": true,
+	cache := dns.NewCache(&dns.CacheConfig{
+		TrustedTags: map[serial.StringLiteral]bool{
+			serial.StringLiteral("testtag"): true,
 		},
 	})
 	ip := cache.Get(&apptesting.Context{}, domain)

+ 0 - 35
app/dns/json/config.go

@@ -1,35 +0,0 @@
-package json
-
-import (
-	"strings"
-
-	serialjson "github.com/v2ray/v2ray-core/common/serial/json"
-)
-
-type TagList map[string]bool
-
-func NewTagList(tags []string) TagList {
-	list := TagList(make(map[string]bool))
-	for _, tag := range tags {
-		list[strings.TrimSpace(tag)] = true
-	}
-	return list
-}
-
-func (this *TagList) UnmarshalJSON(data []byte) error {
-	tags, err := serialjson.UnmarshalStringList(data)
-	if err != nil {
-		return err
-	}
-	*this = NewTagList(tags)
-	return nil
-}
-
-type CacheConfig struct {
-	TrustedTags TagList `json:"trustedTags"`
-}
-
-func (this *CacheConfig) IsTrustedSource(tag string) bool {
-	_, found := this.TrustedTags[tag]
-	return found
-}

+ 0 - 10
app/dns/testing/config.go

@@ -1,10 +0,0 @@
-package testing
-
-type CacheConfig struct {
-	TrustedTags map[string]bool
-}
-
-func (this *CacheConfig) IsTrustedSource(tag string) bool {
-	_, found := this.TrustedTags[tag]
-	return found
-}

+ 9 - 39
app/router/rules/json/fieldrule.go

@@ -8,39 +8,9 @@ import (
 	"strings"
 
 	v2net "github.com/v2ray/v2ray-core/common/net"
+	"github.com/v2ray/v2ray-core/common/serial"
 )
 
-type StringList []string
-
-func NewStringList(str ...string) *StringList {
-	list := StringList(str)
-	return &list
-}
-
-func (this *StringList) UnmarshalJSON(data []byte) error {
-	var strList []string
-	err := json.Unmarshal(data, &strList)
-	if err == nil {
-		*this = make([]string, len(strList))
-		copy(*this, strList)
-		return nil
-	}
-
-	var str string
-	err = json.Unmarshal(data, &str)
-	if err == nil {
-		*this = make([]string, 0, 1)
-		*this = append(*this, str)
-		return nil
-	}
-
-	return errors.New("Failed to unmarshal string list: " + string(data))
-}
-
-func (this *StringList) Len() int {
-	return len([]string(*this))
-}
-
 type DomainMatcher interface {
 	Match(domain string) bool
 }
@@ -138,10 +108,10 @@ func (this *FieldRule) Apply(dest v2net.Destination) bool {
 func (this *FieldRule) UnmarshalJSON(data []byte) error {
 	type RawFieldRule struct {
 		Rule
-		Domain  *StringList        `json:"domain"`
-		IP      *StringList        `json:"ip"`
-		Port    *v2net.PortRange   `json:"port"`
-		Network *v2net.NetworkList `json:"network"`
+		Domain  *serial.StringLiteralList `json:"domain"`
+		IP      *serial.StringLiteralList `json:"ip"`
+		Port    *v2net.PortRange          `json:"port"`
+		Network *v2net.NetworkList        `json:"network"`
 	}
 	rawFieldRule := RawFieldRule{}
 	err := json.Unmarshal(data, &rawFieldRule)
@@ -156,14 +126,14 @@ func (this *FieldRule) UnmarshalJSON(data []byte) error {
 		this.Domain = make([]DomainMatcher, rawFieldRule.Domain.Len())
 		for idx, rawDomain := range *(rawFieldRule.Domain) {
 			var matcher DomainMatcher
-			if strings.HasPrefix(rawDomain, "regexp:") {
-				rawMatcher, err := NewRegexpDomainMatcher(rawDomain[7:])
+			if strings.HasPrefix(rawDomain.String(), "regexp:") {
+				rawMatcher, err := NewRegexpDomainMatcher(rawDomain.String()[7:])
 				if err != nil {
 					return err
 				}
 				matcher = rawMatcher
 			} else {
-				matcher = NewPlainDomainMatcher(rawDomain)
+				matcher = NewPlainDomainMatcher(rawDomain.String())
 			}
 			this.Domain[idx] = matcher
 		}
@@ -173,7 +143,7 @@ func (this *FieldRule) UnmarshalJSON(data []byte) error {
 	if rawFieldRule.IP != nil && rawFieldRule.IP.Len() > 0 {
 		this.IP = make([]*net.IPNet, 0, rawFieldRule.IP.Len())
 		for _, ipStr := range *(rawFieldRule.IP) {
-			_, ipNet, err := net.ParseCIDR(ipStr)
+			_, ipNet, err := net.ParseCIDR(ipStr.String())
 			if err != nil {
 				return errors.New("Invalid IP range in router rule: " + err.Error())
 			}

+ 0 - 21
app/router/rules/json/fieldrule_test.go

@@ -1,7 +1,6 @@
 package json
 
 import (
-	"encoding/json"
 	"testing"
 
 	v2net "github.com/v2ray/v2ray-core/common/net"
@@ -9,26 +8,6 @@ import (
 	"github.com/v2ray/v2ray-core/testing/assert"
 )
 
-func TestStringListParsingList(t *testing.T) {
-	v2testing.Current(t)
-
-	rawJson := `["a", "b", "c", "d"]`
-	var strList StringList
-	err := json.Unmarshal([]byte(rawJson), &strList)
-	assert.Error(err).IsNil()
-	assert.Int(strList.Len()).Equals(4)
-}
-
-func TestStringListParsingString(t *testing.T) {
-	v2testing.Current(t)
-
-	rawJson := `"abcd"`
-	var strList StringList
-	err := json.Unmarshal([]byte(rawJson), &strList)
-	assert.Error(err).IsNil()
-	assert.Int(strList.Len()).Equals(1)
-}
-
 func TestDomainMatching(t *testing.T) {
 	v2testing.Current(t)
 

+ 5 - 5
common/net/network.go

@@ -1,7 +1,7 @@
 package net
 
 import (
-	"strings"
+	"github.com/v2ray/v2ray-core/common/serial"
 )
 
 const (
@@ -9,14 +9,14 @@ const (
 	UDPNetwork = Network("udp")
 )
 
-type Network string
+type Network serial.StringLiteral
 
 type NetworkList []Network
 
-func NewNetworkList(networks []string) NetworkList {
-	list := NetworkList(make([]Network, len(networks)))
+func NewNetworkList(networks serial.StringLiteralList) NetworkList {
+	list := NetworkList(make([]Network, networks.Len()))
 	for idx, network := range networks {
-		list[idx] = Network(strings.ToLower(strings.TrimSpace(network)))
+		list[idx] = Network(network.TrimSpace().ToLower())
 	}
 	return list
 }

+ 5 - 3
common/net/network_json.go

@@ -3,12 +3,14 @@
 package net
 
 import (
-	serialjson "github.com/v2ray/v2ray-core/common/serial/json"
+	"encoding/json"
+
+	"github.com/v2ray/v2ray-core/common/serial"
 )
 
 func (this *NetworkList) UnmarshalJSON(data []byte) error {
-	strlist, err := serialjson.UnmarshalStringList(data)
-	if err != nil {
+	var strlist serial.StringLiteralList
+	if err := json.Unmarshal(data, &strlist); err != nil {
 		return err
 	}
 	*this = NewNetworkList(strlist)

+ 0 - 21
common/serial/json/string_list.go

@@ -1,21 +0,0 @@
-package json
-
-import (
-	"encoding/json"
-	"errors"
-	"strings"
-)
-
-func UnmarshalStringList(data []byte) ([]string, error) {
-	var strarray []string
-	if err := json.Unmarshal(data, &strarray); err == nil {
-		return strarray, nil
-	}
-
-	var rawstr string
-	if err := json.Unmarshal(data, &rawstr); err == nil {
-		strlist := strings.Split(rawstr, ",")
-		return strlist, nil
-	}
-	return nil, errors.New("Unknown format of a string list: " + string(data))
-}

+ 20 - 0
common/serial/string.go

@@ -1,5 +1,9 @@
 package serial
 
+import (
+	"strings"
+)
+
 // An interface for any objects that has string presentation.
 type String interface {
 	String() string
@@ -7,6 +11,22 @@ type String interface {
 
 type StringLiteral string
 
+func NewStringLiteral(str String) StringLiteral {
+	return StringLiteral(str.String())
+}
+
 func (this StringLiteral) String() string {
 	return string(this)
 }
+
+func (this StringLiteral) ToLower() StringLiteral {
+	return StringLiteral(strings.ToLower(string(this)))
+}
+
+func (this StringLiteral) ToUpper() StringLiteral {
+	return StringLiteral(strings.ToUpper(string(this)))
+}
+
+func (this StringLiteral) TrimSpace() StringLiteral {
+	return StringLiteral(strings.TrimSpace(string(this)))
+}

+ 15 - 0
common/serial/string_list.go

@@ -0,0 +1,15 @@
+package serial
+
+type StringLiteralList []StringLiteral
+
+func NewStringLiteralList(raw []string) *StringLiteralList {
+	list := StringLiteralList(make([]StringLiteral, len(raw)))
+	for idx, str := range raw {
+		list[idx] = StringLiteral(str)
+	}
+	return &list
+}
+
+func (this *StringLiteralList) Len() int {
+	return len(*this)
+}

+ 25 - 0
common/serial/string_list_json.go

@@ -0,0 +1,25 @@
+// +build json
+
+package serial
+
+import (
+	"encoding/json"
+	"errors"
+	"strings"
+)
+
+func (this *StringLiteralList) UnmarshalJSON(data []byte) error {
+	var strarray []string
+	if err := json.Unmarshal(data, &strarray); err == nil {
+		*this = *NewStringLiteralList(strarray)
+		return nil
+	}
+
+	var rawstr string
+	if err := json.Unmarshal(data, &rawstr); err == nil {
+		strlist := strings.Split(rawstr, ",")
+		*this = *NewStringLiteralList(strlist)
+		return nil
+	}
+	return errors.New("Unknown format of a string list: " + string(data))
+}