|  | @@ -2,8 +2,16 @@ package net_test
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  |  import (
 |  |  import (
 | 
											
												
													
														|  |  	"net"
 |  |  	"net"
 | 
											
												
													
														|  | 
 |  | +	"os"
 | 
											
												
													
														|  | 
 |  | +	"path/filepath"
 | 
											
												
													
														|  |  	"testing"
 |  |  	"testing"
 | 
											
												
													
														|  |  
 |  |  
 | 
											
												
													
														|  | 
 |  | +	proto "github.com/golang/protobuf/proto"
 | 
											
												
													
														|  | 
 |  | +	"v2ray.com/core/app/router"
 | 
											
												
													
														|  | 
 |  | +	"v2ray.com/core/common/platform"
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	"v2ray.com/ext/sysio"
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  |  	"v2ray.com/core/common"
 |  |  	"v2ray.com/core/common"
 | 
											
												
													
														|  |  	. "v2ray.com/core/common/net"
 |  |  	. "v2ray.com/core/common/net"
 | 
											
												
													
														|  |  	. "v2ray.com/ext/assert"
 |  |  	. "v2ray.com/ext/assert"
 | 
											
										
											
												
													
														|  | @@ -43,3 +51,68 @@ func TestIPNet(t *testing.T) {
 | 
											
												
													
														|  |  	assert(ipNet.Contains(ParseIP("2001:cdba::3257:9652")), IsFalse)
 |  |  	assert(ipNet.Contains(ParseIP("2001:cdba::3257:9652")), IsFalse)
 | 
											
												
													
														|  |  	assert(ipNet.Contains(ParseIP("91.108.255.254")), IsTrue)
 |  |  	assert(ipNet.Contains(ParseIP("91.108.255.254")), IsTrue)
 | 
											
												
													
														|  |  }
 |  |  }
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +func loadGeoIP(country string) ([]*router.CIDR, error) {
 | 
											
												
													
														|  | 
 |  | +	geoipBytes, err := sysio.ReadAsset("geoip.dat")
 | 
											
												
													
														|  | 
 |  | +	if err != nil {
 | 
											
												
													
														|  | 
 |  | +		return nil, err
 | 
											
												
													
														|  | 
 |  | +	}
 | 
											
												
													
														|  | 
 |  | +	var geoipList router.GeoIPList
 | 
											
												
													
														|  | 
 |  | +	if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
 | 
											
												
													
														|  | 
 |  | +		return nil, err
 | 
											
												
													
														|  | 
 |  | +	}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	for _, geoip := range geoipList.Entry {
 | 
											
												
													
														|  | 
 |  | +		if geoip.CountryCode == country {
 | 
											
												
													
														|  | 
 |  | +			return geoip.Cidr, nil
 | 
											
												
													
														|  | 
 |  | +		}
 | 
											
												
													
														|  | 
 |  | +	}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	panic("country not found: " + country)
 | 
											
												
													
														|  | 
 |  | +}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +func BenchmarkIPNetQuery(b *testing.B) {
 | 
											
												
													
														|  | 
 |  | +	common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geoip.dat")))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	ips, err := loadGeoIP("CN")
 | 
											
												
													
														|  | 
 |  | +	common.Must(err)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	ipNet := NewIPNetTable()
 | 
											
												
													
														|  | 
 |  | +	for _, ip := range ips {
 | 
											
												
													
														|  | 
 |  | +		ipNet.AddIP(ip.Ip, byte(ip.Prefix))
 | 
											
												
													
														|  | 
 |  | +	}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	b.ResetTimer()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	for i := 0; i < b.N; i++ {
 | 
											
												
													
														|  | 
 |  | +		ipNet.Contains([]byte{8, 8, 8, 8})
 | 
											
												
													
														|  | 
 |  | +	}
 | 
											
												
													
														|  | 
 |  | +}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +func BenchmarkCIDRQuery(b *testing.B) {
 | 
											
												
													
														|  | 
 |  | +	common.Must(sysio.CopyFile(platform.GetAssetLocation("geoip.dat"), filepath.Join(os.Getenv("GOPATH"), "src", "v2ray.com", "core", "tools", "release", "config", "geoip.dat")))
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	ips, err := loadGeoIP("CN")
 | 
											
												
													
														|  | 
 |  | +	common.Must(err)
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	ipNet := make([]*net.IPNet, 0, 1024)
 | 
											
												
													
														|  | 
 |  | +	for _, ip := range ips {
 | 
											
												
													
														|  | 
 |  | +		if len(ip.Ip) != 4 {
 | 
											
												
													
														|  | 
 |  | +			continue
 | 
											
												
													
														|  | 
 |  | +		}
 | 
											
												
													
														|  | 
 |  | +		ipNet = append(ipNet, &net.IPNet{
 | 
											
												
													
														|  | 
 |  | +			IP:   net.IP(ip.Ip),
 | 
											
												
													
														|  | 
 |  | +			Mask: net.CIDRMask(int(ip.Prefix), 32),
 | 
											
												
													
														|  | 
 |  | +		})
 | 
											
												
													
														|  | 
 |  | +	}
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	b.ResetTimer()
 | 
											
												
													
														|  | 
 |  | +
 | 
											
												
													
														|  | 
 |  | +	for i := 0; i < b.N; i++ {
 | 
											
												
													
														|  | 
 |  | +		for _, n := range ipNet {
 | 
											
												
													
														|  | 
 |  | +			if n.Contains([]byte{8, 8, 8, 8}) {
 | 
											
												
													
														|  | 
 |  | +				break
 | 
											
												
													
														|  | 
 |  | +			}
 | 
											
												
													
														|  | 
 |  | +		}
 | 
											
												
													
														|  | 
 |  | +	}
 | 
											
												
													
														|  | 
 |  | +}
 |