Browse Source

extract all session context before checking conditions

Darien Raymond 6 years ago
parent
commit
0d31a68694
5 changed files with 101 additions and 111 deletions
  1. 36 68
      app/router/condition.go
  2. 10 11
      app/router/condition_test.go
  3. 2 4
      app/router/config.go
  4. 47 23
      app/router/router.go
  5. 6 5
      app/router/router_test.go

+ 36 - 68
app/router/condition.go

@@ -3,16 +3,14 @@
 package router
 
 import (
-	"context"
 	"strings"
 
 	"v2ray.com/core/common/net"
-	"v2ray.com/core/common/session"
 	"v2ray.com/core/common/strmatcher"
 )
 
 type Condition interface {
-	Apply(ctx context.Context) bool
+	Apply(ctx *Context) bool
 }
 
 type ConditionChan []Condition
@@ -27,7 +25,7 @@ func (v *ConditionChan) Add(cond Condition) *ConditionChan {
 	return v
 }
 
-func (v *ConditionChan) Apply(ctx context.Context) bool {
+func (v *ConditionChan) Apply(ctx *Context) bool {
 	for _, cond := range *v {
 		if !cond.Apply(ctx) {
 			return false
@@ -84,46 +82,36 @@ func (m *DomainMatcher) ApplyDomain(domain string) bool {
 	return m.matchers.Match(domain) > 0
 }
 
-func (m *DomainMatcher) Apply(ctx context.Context) bool {
-	outbound := session.OutboundFromContext(ctx)
-	if outbound == nil || !outbound.Target.IsValid() {
+func (m *DomainMatcher) Apply(ctx *Context) bool {
+	if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
 		return false
 	}
-	dest := outbound.Target
+	dest := ctx.Outbound.Target
 	if !dest.Address.Family().IsDomain() {
 		return false
 	}
 	return m.ApplyDomain(dest.Address.Domain())
 }
 
-func sourceFromContext(ctx context.Context) net.Destination {
-	inbound := session.InboundFromContext(ctx)
-	if inbound == nil {
-		return net.Destination{}
+func getIPsFromSource(ctx *Context) []net.IP {
+	if ctx.Inbound == nil || !ctx.Inbound.Source.IsValid() {
+		return nil
 	}
-	return inbound.Source
-}
-
-func targetFromContent(ctx context.Context) net.Destination {
-	outbound := session.OutboundFromContext(ctx)
-	if outbound == nil {
-		return net.Destination{}
+	dest := ctx.Inbound.Source
+	if dest.Address.Family().IsDomain() {
+		return nil
 	}
-	return outbound.Target
+
+	return []net.IP{dest.Address.IP()}
 }
 
-func resolvedIPFromContext(ctx context.Context) []net.IP {
-	outbound := session.OutboundFromContext(ctx)
-	if outbound == nil {
-		return nil
-	}
-	return outbound.ResolvedIPs
+func getIPsFromTarget(ctx *Context) []net.IP {
+	return ctx.GetTargetIPs()
 }
 
 type MultiGeoIPMatcher struct {
-	matchers       []*GeoIPMatcher
-	destFunc       func(context.Context) net.Destination
-	resolvedIPFunc func(context.Context) []net.IP
+	matchers []*GeoIPMatcher
+	ipFunc   func(*Context) []net.IP
 }
 
 func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, error) {
@@ -141,30 +129,16 @@ func NewMultiGeoIPMatcher(geoips []*GeoIP, onSource bool) (*MultiGeoIPMatcher, e
 	}
 
 	if onSource {
-		matcher.destFunc = sourceFromContext
+		matcher.ipFunc = getIPsFromSource
 	} else {
-		matcher.destFunc = targetFromContent
-		matcher.resolvedIPFunc = resolvedIPFromContext
+		matcher.ipFunc = getIPsFromTarget
 	}
 
 	return matcher, nil
 }
 
-func (m *MultiGeoIPMatcher) Apply(ctx context.Context) bool {
-	ips := make([]net.IP, 0, 4)
-
-	dest := m.destFunc(ctx)
-
-	if dest.IsValid() && dest.Address.Family().IsIP() {
-		ips = append(ips, dest.Address.IP())
-	}
-
-	if m.resolvedIPFunc != nil {
-		rips := m.resolvedIPFunc(ctx)
-		if len(rips) > 0 {
-			ips = append(ips, rips...)
-		}
-	}
+func (m *MultiGeoIPMatcher) Apply(ctx *Context) bool {
+	ips := m.ipFunc(ctx)
 
 	for _, ip := range ips {
 		for _, matcher := range m.matchers {
@@ -186,12 +160,11 @@ func NewPortMatcher(list *net.PortList) *PortMatcher {
 	}
 }
 
-func (v *PortMatcher) Apply(ctx context.Context) bool {
-	outbound := session.OutboundFromContext(ctx)
-	if outbound == nil || !outbound.Target.IsValid() {
+func (v *PortMatcher) Apply(ctx *Context) bool {
+	if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
 		return false
 	}
-	return v.port.Contains(outbound.Target.Port)
+	return v.port.Contains(ctx.Outbound.Target.Port)
 }
 
 type NetworkMatcher struct {
@@ -206,12 +179,11 @@ func NewNetworkMatcher(network []net.Network) NetworkMatcher {
 	return matcher
 }
 
-func (v NetworkMatcher) Apply(ctx context.Context) bool {
-	outbound := session.OutboundFromContext(ctx)
-	if outbound == nil || !outbound.Target.IsValid() {
+func (v NetworkMatcher) Apply(ctx *Context) bool {
+	if ctx.Outbound == nil || !ctx.Outbound.Target.IsValid() {
 		return false
 	}
-	return v.list[int(outbound.Target.Network)]
+	return v.list[int(ctx.Outbound.Target.Network)]
 }
 
 type UserMatcher struct {
@@ -230,13 +202,12 @@ func NewUserMatcher(users []string) *UserMatcher {
 	}
 }
 
-func (v *UserMatcher) Apply(ctx context.Context) bool {
-	inbound := session.InboundFromContext(ctx)
-	if inbound == nil {
+func (v *UserMatcher) Apply(ctx *Context) bool {
+	if ctx.Inbound == nil {
 		return false
 	}
 
-	user := inbound.User
+	user := ctx.Inbound.User
 	if user == nil {
 		return false
 	}
@@ -264,12 +235,11 @@ func NewInboundTagMatcher(tags []string) *InboundTagMatcher {
 	}
 }
 
-func (v *InboundTagMatcher) Apply(ctx context.Context) bool {
-	inbound := session.InboundFromContext(ctx)
-	if inbound == nil || len(inbound.Tag) == 0 {
+func (v *InboundTagMatcher) Apply(ctx *Context) bool {
+	if ctx.Inbound == nil || len(ctx.Inbound.Tag) == 0 {
 		return false
 	}
-	tag := inbound.Tag
+	tag := ctx.Inbound.Tag
 	for _, t := range v.tags {
 		if t == tag {
 			return true
@@ -296,14 +266,12 @@ func NewProtocolMatcher(protocols []string) *ProtocolMatcher {
 	}
 }
 
-func (m *ProtocolMatcher) Apply(ctx context.Context) bool {
-	content := session.ContentFromContext(ctx)
-
-	if content == nil {
+func (m *ProtocolMatcher) Apply(ctx *Context) bool {
+	if ctx.Content == nil {
 		return false
 	}
 
-	protocol := content.Protocol
+	protocol := ctx.Content.Protocol
 	for _, p := range m.protocols {
 		if strings.HasPrefix(protocol, p) {
 			return true

+ 10 - 11
app/router/condition_test.go

@@ -1,7 +1,6 @@
 package router_test
 
 import (
-	"context"
 	"os"
 	"path/filepath"
 	"strconv"
@@ -28,17 +27,17 @@ func init() {
 	common.Must(filesystem.CopyFile(platform.GetAssetLocation("geosite.dat"), filepath.Join(wd, "..", "..", "release", "config", "geosite.dat")))
 }
 
-func withOutbound(outbound *session.Outbound) context.Context {
-	return session.ContextWithOutbound(context.Background(), outbound)
+func withOutbound(outbound *session.Outbound) *Context {
+	return &Context{Outbound: outbound}
 }
 
-func withInbound(inbound *session.Inbound) context.Context {
-	return session.ContextWithInbound(context.Background(), inbound)
+func withInbound(inbound *session.Inbound) *Context {
+	return &Context{Inbound: inbound}
 }
 
 func TestRoutingRule(t *testing.T) {
 	type ruleTest struct {
-		input  context.Context
+		input  *Context
 		output bool
 	}
 
@@ -89,7 +88,7 @@ func TestRoutingRule(t *testing.T) {
 					output: false,
 				},
 				{
-					input:  context.Background(),
+					input:  &Context{},
 					output: false,
 				},
 			},
@@ -125,7 +124,7 @@ func TestRoutingRule(t *testing.T) {
 					output: true,
 				},
 				{
-					input:  context.Background(),
+					input:  &Context{},
 					output: false,
 				},
 			},
@@ -165,7 +164,7 @@ func TestRoutingRule(t *testing.T) {
 					output: true,
 				},
 				{
-					input:  context.Background(),
+					input:  &Context{},
 					output: false,
 				},
 			},
@@ -206,7 +205,7 @@ func TestRoutingRule(t *testing.T) {
 					output: false,
 				},
 				{
-					input:  context.Background(),
+					input:  &Context{},
 					output: false,
 				},
 			},
@@ -217,7 +216,7 @@ func TestRoutingRule(t *testing.T) {
 			},
 			test: []ruleTest{
 				{
-					input:  session.ContextWithContent(context.Background(), &session.Content{Protocol: (&http.SniffHeader{}).Protocol()}),
+					input:  &Context{Content: &session.Content{Protocol: (&http.SniffHeader{}).Protocol()}},
 					output: true,
 				},
 			},

+ 2 - 4
app/router/config.go

@@ -3,9 +3,7 @@
 package router
 
 import (
-	"context"
-
-	net "v2ray.com/core/common/net"
+	"v2ray.com/core/common/net"
 	"v2ray.com/core/features/outbound"
 )
 
@@ -61,7 +59,7 @@ func (r *Rule) GetTag() (string, error) {
 	return r.Tag, nil
 }
 
-func (r *Rule) Apply(ctx context.Context) bool {
+func (r *Rule) Apply(ctx *Context) bool {
 	return r.Condition.Apply(ctx)
 }
 

+ 47 - 23
app/router/router.go

@@ -9,6 +9,7 @@ import (
 
 	"v2ray.com/core"
 	"v2ray.com/core/common"
+	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/session"
 	"v2ray.com/core/features/dns"
 	"v2ray.com/core/features/outbound"
@@ -85,44 +86,33 @@ func isDomainOutbound(outbound *session.Outbound) bool {
 	return outbound != nil && outbound.Target.IsValid() && outbound.Target.Address.Family().IsDomain()
 }
 
-func (r *Router) resolveIP(outbound *session.Outbound) error {
-	domain := outbound.Target.Address.Domain()
-	ips, err := r.dns.LookupIP(domain)
-	if err != nil {
-		return err
-	}
-
-	outbound.ResolvedIPs = ips
-	return nil
-}
-
 // PickRoute implements routing.Router.
 func (r *Router) pickRouteInternal(ctx context.Context) (*Rule, error) {
-	outbound := session.OutboundFromContext(ctx)
-	if r.domainStrategy == Config_IpOnDemand && isDomainOutbound(outbound) {
-		if err := r.resolveIP(outbound); err != nil {
-			newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
-		}
+	sessionContext := &Context{
+		Inbound:  session.InboundFromContext(ctx),
+		Outbound: session.OutboundFromContext(ctx),
+		Content:  session.ContentFromContext(ctx),
+	}
+
+	if r.domainStrategy == Config_IpOnDemand {
+		sessionContext.dnsClient = r.dns
 	}
 
 	for _, rule := range r.rules {
-		if rule.Apply(ctx) {
+		if rule.Apply(sessionContext) {
 			return rule, nil
 		}
 	}
 
-	if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(outbound) {
+	if r.domainStrategy != Config_IpIfNonMatch || !isDomainOutbound(sessionContext.Outbound) {
 		return nil, common.ErrNoClue
 	}
 
-	if err := r.resolveIP(outbound); err != nil {
-		newError("failed to resolve IP for domain").Base(err).WriteToLog(session.ExportIDToError(ctx))
-		return nil, common.ErrNoClue
-	}
+	sessionContext.dnsClient = r.dns
 
 	// Try applying rules again if we have IPs.
 	for _, rule := range r.rules {
-		if rule.Apply(ctx) {
+		if rule.Apply(sessionContext) {
 			return rule, nil
 		}
 	}
@@ -144,3 +134,37 @@ func (*Router) Close() error {
 func (*Router) Type() interface{} {
 	return routing.RouterType()
 }
+
+type Context struct {
+	Inbound  *session.Inbound
+	Outbound *session.Outbound
+	Content  *session.Content
+
+	dnsClient dns.Client
+}
+
+func (c *Context) GetTargetIPs() []net.IP {
+	if c.Outbound == nil || !c.Outbound.Target.IsValid() {
+		return nil
+	}
+
+	if c.Outbound.Target.Address.Family().IsIP() {
+		return []net.IP{c.Outbound.Target.Address.IP()}
+	}
+
+	if len(c.Outbound.ResolvedIPs) > 0 {
+		return c.Outbound.ResolvedIPs
+	}
+
+	if c.dnsClient != nil {
+		domain := c.Outbound.Target.Address.Domain()
+		ips, err := c.dnsClient.LookupIP(domain)
+		if err == nil {
+			c.Outbound.ResolvedIPs = ips
+			return ips
+		}
+		newError("resolve ip for ", domain).Base(err).WriteToLog()
+	}
+
+	return nil
+}

+ 6 - 5
app/router/router_test.go

@@ -1,6 +1,7 @@
 package router_test
 
 import (
+	"context"
 	"testing"
 
 	"github.com/golang/mock/gomock"
@@ -42,7 +43,7 @@ func TestSimpleRouter(t *testing.T) {
 		HandlerSelector: mockHs,
 	}))
 
-	ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
+	ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
 	tag, err := r.PickRoute(ctx)
 	common.Must(err)
 	if tag != "test" {
@@ -83,7 +84,7 @@ func TestSimpleBalancer(t *testing.T) {
 		HandlerSelector: mockHs,
 	}))
 
-	ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
+	ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
 	tag, err := r.PickRoute(ctx)
 	common.Must(err)
 	if tag != "test" {
@@ -118,7 +119,7 @@ func TestIPOnDemand(t *testing.T) {
 	r := new(Router)
 	common.Must(r.Init(config, mockDns, nil))
 
-	ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
+	ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
 	tag, err := r.PickRoute(ctx)
 	common.Must(err)
 	if tag != "test" {
@@ -153,7 +154,7 @@ func TestIPIfNonMatchDomain(t *testing.T) {
 	r := new(Router)
 	common.Must(r.Init(config, mockDns, nil))
 
-	ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
+	ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
 	tag, err := r.PickRoute(ctx)
 	common.Must(err)
 	if tag != "test" {
@@ -187,7 +188,7 @@ func TestIPIfNonMatchIP(t *testing.T) {
 	r := new(Router)
 	common.Must(r.Init(config, mockDns, nil))
 
-	ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
+	ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{Target: net.TCPDestination(net.LocalHostIP, 80)})
 	tag, err := r.PickRoute(ctx)
 	common.Must(err)
 	if tag != "test" {