Browse Source

migrate to the new geoip matcher

Darien Raymond 7 years ago
parent
commit
41956e92a5
4 changed files with 2 additions and 344 deletions
  1. 0 90
      app/router/condition.go
  2. 2 38
      app/router/config.go
  3. 0 83
      common/net/ipnet.go
  4. 0 133
      common/net/ipnet_test.go

+ 0 - 90
app/router/condition.go

@@ -120,22 +120,6 @@ func (m *DomainMatcher) Apply(ctx context.Context) bool {
 	return m.ApplyDomain(dest.Address.Domain())
 }
 
-type CIDRMatcher struct {
-	cidr     *net.IPNet
-	onSource bool
-}
-
-func NewCIDRMatcher(ip []byte, mask uint32, onSource bool) (*CIDRMatcher, error) {
-	cidr := &net.IPNet{
-		IP:   net.IP(ip),
-		Mask: net.CIDRMask(int(mask), len(ip)*8),
-	}
-	return &CIDRMatcher{
-		cidr:     cidr,
-		onSource: onSource,
-	}, nil
-}
-
 func sourceFromContext(ctx context.Context) net.Destination {
 	inbound := session.InboundFromContext(ctx)
 	if inbound == nil {
@@ -152,80 +136,6 @@ func targetFromContent(ctx context.Context) net.Destination {
 	return outbound.Target
 }
 
-func (v *CIDRMatcher) Apply(ctx context.Context) bool {
-	ips := make([]net.IP, 0, 4)
-	if resolver, ok := ResolvedIPsFromContext(ctx); ok {
-		resolvedIPs := resolver.Resolve()
-		for _, rip := range resolvedIPs {
-			if !rip.Family().IsIPv6() {
-				continue
-			}
-			ips = append(ips, rip.IP())
-		}
-	}
-
-	var dest net.Destination
-	if v.onSource {
-		dest = sourceFromContext(ctx)
-	} else {
-		dest = targetFromContent(ctx)
-	}
-
-	if dest.IsValid() && dest.Address.Family().IsIPv6() {
-		ips = append(ips, dest.Address.IP())
-	}
-
-	for _, ip := range ips {
-		if v.cidr.Contains(ip) {
-			return true
-		}
-	}
-	return false
-}
-
-type IPv4Matcher struct {
-	ipv4net  *net.IPNetTable
-	onSource bool
-}
-
-func NewIPv4Matcher(ipnet *net.IPNetTable, onSource bool) *IPv4Matcher {
-	return &IPv4Matcher{
-		ipv4net:  ipnet,
-		onSource: onSource,
-	}
-}
-
-func (v *IPv4Matcher) Apply(ctx context.Context) bool {
-	ips := make([]net.IP, 0, 4)
-	if resolver, ok := ResolvedIPsFromContext(ctx); ok {
-		resolvedIPs := resolver.Resolve()
-		for _, rip := range resolvedIPs {
-			if !rip.Family().IsIPv4() {
-				continue
-			}
-			ips = append(ips, rip.IP())
-		}
-	}
-
-	var dest net.Destination
-	if v.onSource {
-		dest = sourceFromContext(ctx)
-	} else {
-		dest = targetFromContent(ctx)
-	}
-
-	if dest.IsValid() && dest.Address.Family().IsIPv4() {
-		ips = append(ips, dest.Address.IP())
-	}
-
-	for _, ip := range ips {
-		if v.ipv4net.Contains(ip) {
-			return true
-		}
-	}
-	return false
-}
-
 type MultiGeoIPMatcher struct {
 	matchers []*GeoIPMatcher
 	onSource bool

+ 2 - 38
app/router/config.go

@@ -2,8 +2,6 @@ package router
 
 import (
 	"context"
-
-	"v2ray.com/core/common/net"
 )
 
 // CIDRList is an alias of []*CIDR to provide sort.Interface.
@@ -54,40 +52,6 @@ func (r *Rule) Apply(ctx context.Context) bool {
 	return r.Condition.Apply(ctx)
 }
 
-func cidrToCondition(cidr []*CIDR, source bool) (Condition, error) {
-	ipv4Net := net.NewIPNetTable()
-	ipv6Cond := NewAnyCondition()
-	hasIpv6 := false
-
-	for _, ip := range cidr {
-		switch len(ip.Ip) {
-		case net.IPv4len:
-			ipv4Net.AddIP(ip.Ip, byte(ip.Prefix))
-		case net.IPv6len:
-			hasIpv6 = true
-			matcher, err := NewCIDRMatcher(ip.Ip, ip.Prefix, source)
-			if err != nil {
-				return nil, err
-			}
-			ipv6Cond.Add(matcher)
-		default:
-			return nil, newError("invalid IP length").AtWarning()
-		}
-	}
-
-	switch {
-	case !ipv4Net.IsEmpty() && hasIpv6:
-		cond := NewAnyCondition()
-		cond.Add(NewIPv4Matcher(ipv4Net, source))
-		cond.Add(ipv6Cond)
-		return cond, nil
-	case !ipv4Net.IsEmpty():
-		return NewIPv4Matcher(ipv4Net, source), nil
-	default:
-		return ipv6Cond, nil
-	}
-}
-
 func (rr *RoutingRule) BuildCondition() (Condition, error) {
 	conds := NewConditionChan()
 
@@ -122,7 +86,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 		}
 		conds.Add(cond)
 	} else if len(rr.Cidr) > 0 {
-		cond, err := cidrToCondition(rr.Cidr, false)
+		cond, err := NewMultiGeoIPMatcher([]*GeoIP{{Cidr: rr.Cidr}}, false)
 		if err != nil {
 			return nil, err
 		}
@@ -136,7 +100,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
 		}
 		conds.Add(cond)
 	} else if len(rr.SourceCidr) > 0 {
-		cond, err := cidrToCondition(rr.SourceCidr, true)
+		cond, err := NewMultiGeoIPMatcher([]*GeoIP{{Cidr: rr.SourceCidr}}, true)
 		if err != nil {
 			return nil, err
 		}

+ 0 - 83
common/net/ipnet.go

@@ -1,83 +0,0 @@
-package net
-
-import (
-	"math/bits"
-	"net"
-)
-
-type IPNetTable struct {
-	cache map[uint32]byte
-}
-
-func NewIPNetTable() *IPNetTable {
-	return &IPNetTable{
-		cache: make(map[uint32]byte, 1024),
-	}
-}
-
-func ipToUint32(ip IP) uint32 {
-	value := uint32(0)
-	for _, b := range []byte(ip) {
-		value <<= 8
-		value += uint32(b)
-	}
-	return value
-}
-
-func ipMaskToByte(mask net.IPMask) byte {
-	value := byte(0)
-	for _, b := range []byte(mask) {
-		value += byte(bits.OnesCount8(b))
-	}
-	return value
-}
-
-func (n *IPNetTable) Add(ipNet *net.IPNet) {
-	ipv4 := ipNet.IP.To4()
-	if ipv4 == nil {
-		// For now, we don't support IPv6
-		return
-	}
-	mask := ipMaskToByte(ipNet.Mask)
-	n.AddIP(ipv4, mask)
-}
-
-func (n *IPNetTable) AddIP(ip []byte, mask byte) {
-	k := ipToUint32(ip)
-	k = (k >> (32 - mask)) << (32 - mask) // normalize ip
-	existing, found := n.cache[k]
-	if !found || existing > mask {
-		n.cache[k] = mask
-	}
-}
-
-func (n *IPNetTable) Contains(ip net.IP) bool {
-	ipv4 := ip.To4()
-	if ipv4 == nil {
-		return false
-	}
-	originalValue := ipToUint32(ipv4)
-
-	if entry, found := n.cache[originalValue]; found {
-		if entry == 32 {
-			return true
-		}
-	}
-
-	mask := uint32(0)
-	for maskbit := byte(1); maskbit <= 32; maskbit++ {
-		mask += 1 << uint32(32-maskbit)
-
-		maskedValue := originalValue & mask
-		if entry, found := n.cache[maskedValue]; found {
-			if entry == maskbit {
-				return true
-			}
-		}
-	}
-	return false
-}
-
-func (n *IPNetTable) IsEmpty() bool {
-	return len(n.cache) == 0
-}

+ 0 - 133
common/net/ipnet_test.go

@@ -1,133 +0,0 @@
-package net_test
-
-import (
-	"net"
-	"os"
-	"path/filepath"
-	"testing"
-
-	proto "github.com/golang/protobuf/proto"
-	"v2ray.com/core/app/router"
-	"v2ray.com/core/common/platform"
-
-	"v2ray.com/ext/sysio"
-
-	"v2ray.com/core/common"
-	. "v2ray.com/core/common/net"
-	. "v2ray.com/ext/assert"
-)
-
-func parseCIDR(str string) *net.IPNet {
-	_, ipNet, err := net.ParseCIDR(str)
-	common.Must(err)
-	return ipNet
-}
-
-func TestIPNet(t *testing.T) {
-	assert := With(t)
-
-	ipNet := NewIPNetTable()
-	ipNet.Add(parseCIDR(("0.0.0.0/8")))
-	ipNet.Add(parseCIDR(("10.0.0.0/8")))
-	ipNet.Add(parseCIDR(("100.64.0.0/10")))
-	ipNet.Add(parseCIDR(("127.0.0.0/8")))
-	ipNet.Add(parseCIDR(("169.254.0.0/16")))
-	ipNet.Add(parseCIDR(("172.16.0.0/12")))
-	ipNet.Add(parseCIDR(("192.0.0.0/24")))
-	ipNet.Add(parseCIDR(("192.0.2.0/24")))
-	ipNet.Add(parseCIDR(("192.168.0.0/16")))
-	ipNet.Add(parseCIDR(("198.18.0.0/15")))
-	ipNet.Add(parseCIDR(("198.51.100.0/24")))
-	ipNet.Add(parseCIDR(("203.0.113.0/24")))
-	ipNet.Add(parseCIDR(("8.8.8.8/32")))
-	ipNet.AddIP(net.ParseIP("91.108.4.0"), 16)
-	assert(ipNet.Contains(ParseIP("192.168.1.1")), IsTrue)
-	assert(ipNet.Contains(ParseIP("192.0.0.0")), IsTrue)
-	assert(ipNet.Contains(ParseIP("192.0.1.0")), IsFalse)
-	assert(ipNet.Contains(ParseIP("0.1.0.0")), IsTrue)
-	assert(ipNet.Contains(ParseIP("1.0.0.1")), IsFalse)
-	assert(ipNet.Contains(ParseIP("8.8.8.7")), IsFalse)
-	assert(ipNet.Contains(ParseIP("8.8.8.8")), IsTrue)
-	assert(ipNet.Contains(ParseIP("2001:cdba::3257:9652")), IsFalse)
-	assert(ipNet.Contains(ParseIP("91.108.255.254")), IsTrue)
-}
-
-func TestGeoIPCN(t *testing.T) {
-	assert := With(t)
-	common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat")))
-
-	ips, err := loadGeoIP("CN")
-	common.Must(err)
-
-	ipNet := NewIPNetTable()
-	for _, ip := range ips {
-		ipNet.AddIP(ip.Ip, byte(ip.Prefix))
-	}
-
-	assert(ipNet.Contains([]byte{8, 8, 8, 8}), IsFalse)
-}
-
-func loadGeoIP(country string) ([]*router.CIDR, error) {
-	geoipBytes, err := sysio.ReadAsset("geoip.dat")
-	if err != nil {
-		return nil, err
-	}
-	var geoipList router.GeoIPList
-	if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
-		return nil, err
-	}
-
-	for _, geoip := range geoipList.Entry {
-		if geoip.CountryCode == country {
-			return geoip.Cidr, nil
-		}
-	}
-
-	panic("country not found: " + country)
-}
-
-func BenchmarkIPNetQuery(b *testing.B) {
-	common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat")))
-
-	ips, err := loadGeoIP("CN")
-	common.Must(err)
-
-	ipNet := NewIPNetTable()
-	for _, ip := range ips {
-		ipNet.AddIP(ip.Ip, byte(ip.Prefix))
-	}
-
-	b.ResetTimer()
-
-	for i := 0; i < b.N; i++ {
-		ipNet.Contains([]byte{8, 8, 8, 8})
-	}
-}
-
-func BenchmarkCIDRQuery(b *testing.B) {
-	common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "release", "config", "geoip.dat")))
-
-	ips, err := loadGeoIP("CN")
-	common.Must(err)
-
-	ipNet := make([]*net.IPNet, 0, 1024)
-	for _, ip := range ips {
-		if len(ip.Ip) != 4 {
-			continue
-		}
-		ipNet = append(ipNet, &net.IPNet{
-			IP:   net.IP(ip.Ip),
-			Mask: net.CIDRMask(int(ip.Prefix), 32),
-		})
-	}
-
-	b.ResetTimer()
-
-	for i := 0; i < b.N; i++ {
-		for _, n := range ipNet {
-			if n.Contains([]byte{8, 8, 8, 8}) {
-				break
-			}
-		}
-	}
-}