Browse Source

use session.Outbound.ResolvedIPs

Darien Raymond 7 years ago
parent
commit
82d562d1f0
3 changed files with 122 additions and 81 deletions
  1. 25 13
      app/router/condition.go
  2. 28 68
      app/router/router.go
  3. 69 0
      app/router/router_test.go

+ 25 - 13
app/router/condition.go

@@ -111,9 +111,18 @@ func targetFromContent(ctx context.Context) net.Destination {
 	return outbound.Target
 }
 
+func resolvedIPFromContext(ctx context.Context) []net.IP {
+	outbound := session.OutboundFromContext(ctx)
+	if outbound == nil {
+		return nil
+	}
+	return outbound.ResolvedIPs
+}
+
 type MultiGeoIPMatcher struct {
-	matchers []*GeoIPMatcher
-	destFunc func(context.Context) net.Destination
+	matchers       []*GeoIPMatcher
+	destFunc       func(context.Context) net.Destination
+	resolvedIPFunc func(context.Context) []net.IP
 }
 
 func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
@@ -126,17 +135,18 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e
 		matchers = append(matchers, matcher)
 	}
 
-	var destFunc func(context.Context) net.Destination
+	matcher := &MultiGeoIPMatcher{
+		matchers: matchers,
+	}
+
 	if onSource {
-		destFunc = sourceFromContext
+		matcher.destFunc = sourceFromContext
 	} else {
-		destFunc = targetFromContent
+		matcher.destFunc = targetFromContent
+		matcher.resolvedIPFunc = resolvedIPFromContext
 	}
 
-	return &MultiGeoIPMatcher{
-		matchers: matchers,
-		destFunc: destFunc,
-	}, nil
+	return matcher, nil
 }
 
 func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
@@ -146,10 +156,12 @@ func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
 
 	if dest.IsValid() && dest.Address.Family().IsIP() {
 		ips = append(ips, dest.Address.IP())
-	} else if resolver, ok := ResolvedIPsFromContext(ctx); ok {
-		resolvedIPs := resolver.Resolve()
-		for _, rip := range resolvedIPs {
-			ips = append(ips, rip.IP())
+	}
+
+	if m.resolvedIPFunc != nil {
+		rips := m.resolvedIPFunc(ctx)
+		if len(rips) > 0 {
+			ips = append(ips, rips...)
 		}
 	}
 

+ 28 - 68
app/router/router.go

@@ -7,32 +7,12 @@ import (
 
 	"v2ray.com/core"
 	"v2ray.com/core/common"
-	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/session"
 	"v2ray.com/core/features/dns"
 	"v2ray.com/core/features/outbound"
 	"v2ray.com/core/features/routing"
 )
 
-type key uint32
-
-const (
-	resolvedIPsKey key = iota
-)
-
-type IPResolver interface {
-	Resolve() []net.Address
-}
-
-func ContextWithResolveIPs(ctx context.Context, f IPResolver) context.Context {
-	return context.WithValue(ctx, resolvedIPsKey, f)
-}
-
-func ResolvedIPsFromContext(ctx context.Context) (IPResolver, bool) {
-	ips, ok := ctx.Value(resolvedIPsKey).(IPResolver)
-	return ips, ok
-}
-
 func init() {
 	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
 		r := new(Router)
@@ -91,34 +71,6 @@ func (r *Router) Init(config *Config, d dns.Client, ohm outbound.Manager) error
 	return nil
 }
 
-type ipResolver struct {
-	dns      dns.Client
-	ip       []net.Address
-	domain   string
-	resolved bool
-}
-
-func (r *ipResolver) Resolve() []net.Address {
-	if r.resolved {
-		return r.ip
-	}
-
-	newError("looking for IP for domain: ", r.domain).WriteToLog()
-	r.resolved = true
-	ips, err := r.dns.LookupIP(r.domain)
-	if err != nil {
-		newError("failed to get IP address").Base(err).WriteToLog()
-	}
-	if len(ips) == 0 {
-		return nil
-	}
-	r.ip = make([]net.Address, len(ips))
-	for i, ip := range ips {
-		r.ip[i] = net.IPAddress(ip)
-	}
-	return r.ip
-}
-
 func (r *Router) PickRoute(ctx context.Context) (string, error) {
 	rule, err := r.pickRouteInternal(ctx)
 	if err != nil {
@@ -127,17 +79,27 @@ func (r *Router) PickRoute(ctx context.Context) (string, error) {
 	return rule.GetTag()
 }
 
-// PickRoute implements routing.Router.
-func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
-	resolver := &ipResolver{
-		dns: r.dns,
+func isDomainOutbound(outbound *session.Outbound) bool {
+	return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain()
+}
+
+func (r *Router) resolveIP(outbound *session.Outbound) error {
+	domain := outbound.Target.Address.Domain()
+	ips, err := r.dns.LookupIP(domain)
+	if err != nil {
+		return err
 	}
 
+	outbound.ResolvedIPs = ips
+	return nil
+}
+
+// PickRoute implements routing.Router.
+func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
 	outbound := session.OutboundFromContext(ctx)
-	if r.domainStrategy == Config_IpOnDemand {
-		if outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain() {
-			resolver.domain = outbound.Target.Address.Domain()
-			ctx = ContextWithResolveIPs(ctx, resolver)
+	if r.domainStrategy == Config_IpOnDemand && isDomainOutbound(outbound) {
+		if err := r.resolveIP(outbound); err != nil {
+			newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
 		}
 	}
 
@@ -147,21 +109,19 @@ func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
 		}
 	}
 
-	if outbound == nil || !outbound.Target.IsValid() {
+	if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(outbound) {
 		return nil, common.ErrNoClue
 	}
 
-	dest := outbound.Target
-	if r.domainStrategy == Config_IpIfNonMatch && dest.Address.Family().IsDomain() {
-		resolver.domain = dest.Address.Domain()
-		ips := resolver.Resolve()
-		if len(ips) > 0 {
-			ctx = ContextWithResolveIPs(ctx, resolver)
-			for _, rule := range r.rules {
-				if rule.Apply(ctx) {
-					return rule, nil
-				}
-			}
+	if err := r.resolveIP(outbound); err != nil {
+		newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
+		return nil, common.ErrNoClue
+	}
+
+	// Try applying rules again if we have IPs.
+	for _, rule := range r.rules {
+		if rule.Apply(ctx) {
+			return rule, nil
 		}
 	}
 

+ 69 - 0
app/router/router_test.go

@@ -125,3 +125,72 @@ func TestIPOnDemand(t *testing.T) {
 		t.Error("expect tag 'test', bug actually ", tag)
 	}
 }
+
+func TestIPIfNonMatchDomain(t *testing.T) {
+	config := &Config{
+		DomainStrategy: Config_IpIfNonMatch,
+		Rule: []*RoutingRule{
+			{
+				TargetTag: &RoutingRule_Tag{
+					Tag: "test",
+				},
+				Cidr: []*CIDR{
+					{
+						Ip:     []byte{192, 168, 0, 0},
+						Prefix: 16,
+					},
+				},
+			},
+		},
+	}
+
+	mockCtl := gomock.NewController(t)
+	defer mockCtl.Finish()
+
+	mockDns := mocks.NewDNSClient(mockCtl)
+	mockDns.EXPECT().LookupIP(gomock.Eq("v2ray.com")).Return([]net.IP{{192, 168, 0, 1}}, nil).AnyTimes()
+
+	r := new(Router)
+	common.Must(r.Init(config, mockDns, nil))
+
+	ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
+	tag, err := r.PickRoute(ctx)
+	common.Must(err)
+	if tag != "test" {
+		t.Error("expect tag 'test', bug actually ", tag)
+	}
+}
+
+func TestIPIfNonMatchIP(t *testing.T) {
+	config := &Config{
+		DomainStrategy: Config_IpIfNonMatch,
+		Rule: []*RoutingRule{
+			{
+				TargetTag: &RoutingRule_Tag{
+					Tag: "test",
+				},
+				Cidr: []*CIDR{
+					{
+						Ip:     []byte{127, 0, 0, 0},
+						Prefix: 8,
+					},
+				},
+			},
+		},
+	}
+
+	mockCtl := gomock.NewController(t)
+	defer mockCtl.Finish()
+
+	mockDns := mocks.NewDNSClient(mockCtl)
+
+	r := new(Router)
+	common.Must(r.Init(config, mockDns, nil))
+
+	ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
+	tag, err := r.PickRoute(ctx)
+	common.Must(err)
+	if tag != "test" {
+		t.Error("expect tag 'test', bug actually ", tag)
+	}
+}