Browse Source

move sniffing result to session

Darien Raymond 6 years ago
parent
commit
7e5e080488

+ 6 - 1
app/dispatcher/default.go

@@ -207,7 +207,12 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
 			outbound.Reader = cReader
 			outbound.Reader = cReader
 			result, err := sniffer(ctx, cReader)
 			result, err := sniffer(ctx, cReader)
 			if err == nil {
 			if err == nil {
-				ctx = ContextWithSniffingResult(ctx, result)
+				content := session.ContentFromContext(ctx)
+				if content == nil {
+					content = new(session.Content)
+				}
+				content.Protocol = result.Protocol()
+				ctx = session.ContextWithContent(ctx, content)
 			}
 			}
 			if err == nil && shouldOverride(result, sniffingConfig.DestinationOverride) {
 			if err == nil && shouldOverride(result, sniffingConfig.DestinationOverride) {
 				domain := result.Domain()
 				domain := result.Domain()

+ 0 - 19
app/dispatcher/dispatcher.go

@@ -2,23 +2,4 @@
 
 
 package dispatcher
 package dispatcher
 
 
-import "context"
-
 //go:generate errorgen
 //go:generate errorgen
-
-type key int
-
-const (
-	sniffing key = iota
-)
-
-func ContextWithSniffingResult(ctx context.Context, r SniffResult) context.Context {
-	return context.WithValue(ctx, sniffing, r)
-}
-
-func SniffingResultFromContext(ctx context.Context) SniffResult {
-	if c, ok := ctx.Value(sniffing).(SniffResult); ok {
-		return c
-	}
-	return nil
-}

+ 3 - 0
app/dns/udpns.go

@@ -373,6 +373,9 @@ func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option
 		if inbound := session.InboundFromContext(ctx); inbound != nil {
 		if inbound := session.InboundFromContext(ctx); inbound != nil {
 			udpCtx = session.ContextWithInbound(udpCtx, inbound)
 			udpCtx = session.ContextWithInbound(udpCtx, inbound)
 		}
 		}
+		udpCtx = session.ContextWithContent(udpCtx, &session.Content{
+			Protocol: "dns",
+		})
 		s.udpServer.Dispatch(udpCtx, s.address, b)
 		s.udpServer.Dispatch(udpCtx, s.address, b)
 	}
 	}
 }
 }

+ 3 - 4
app/router/condition.go

@@ -6,7 +6,6 @@ import (
 	"context"
 	"context"
 	"strings"
 	"strings"
 
 
-	"v2ray.com/core/app/dispatcher"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/net"
 	"v2ray.com/core/common/session"
 	"v2ray.com/core/common/session"
 	"v2ray.com/core/common/strmatcher"
 	"v2ray.com/core/common/strmatcher"
@@ -298,13 +297,13 @@ func NewProtocolMatcher(protocols []string) *ProtocolMatcher {
 }
 }
 
 
 func (m *ProtocolMatcher) Apply(ctx context.Context) bool {
 func (m *ProtocolMatcher) Apply(ctx context.Context) bool {
-	result := dispatcher.SniffingResultFromContext(ctx)
+	content := session.ContentFromContext(ctx)
 
 
-	if result == nil {
+	if content == nil {
 		return false
 		return false
 	}
 	}
 
 
-	protocol := result.Protocol()
+	protocol := content.Protocol
 	for _, p := range m.protocols {
 	for _, p := range m.protocols {
 		if strings.HasPrefix(protocol, p) {
 		if strings.HasPrefix(protocol, p) {
 			return true
 			return true

+ 1 - 2
app/router/condition_test.go

@@ -9,7 +9,6 @@ import (
 
 
 	proto "github.com/golang/protobuf/proto"
 	proto "github.com/golang/protobuf/proto"
 
 
-	"v2ray.com/core/app/dispatcher"
 	. "v2ray.com/core/app/router"
 	. "v2ray.com/core/app/router"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/errors"
@@ -218,7 +217,7 @@ func TestRoutingRule(t *testing.T) {
 			},
 			},
 			test: []ruleTest{
 			test: []ruleTest{
 				{
 				{
-					input:  dispatcher.ContextWithSniffingResult(context.Background(), &http.SniffHeader{}),
+					input:  session.ContextWithContent(context.Background(), &session.Content{Protocol: (&http.SniffHeader{}).Protocol()}),
 					output: true,
 					output: true,
 				},
 				},
 			},
 			},

+ 12 - 0
common/session/context.go

@@ -8,6 +8,7 @@ const (
 	idSessionKey sessionKey = iota
 	idSessionKey sessionKey = iota
 	inboundSessionKey
 	inboundSessionKey
 	outboundSessionKey
 	outboundSessionKey
+	contentSessionKey
 )
 )
 
 
 // ContextWithID returns a new context with the given ID.
 // ContextWithID returns a new context with the given ID.
@@ -44,3 +45,14 @@ func OutboundFromContext(ctx context.Context) *Outbound {
 	}
 	}
 	return nil
 	return nil
 }
 }
+
+func ContextWithContent(ctx context.Context, content *Content) context.Context {
+	return context.WithValue(ctx, contentSessionKey, content)
+}
+
+func ContentFromContext(ctx context.Context) *Content {
+	if content, ok := ctx.Value(contentSessionKey).(*Content); ok {
+		return content
+	}
+	return nil
+}

+ 6 - 0
common/session/session.go

@@ -54,3 +54,9 @@ type Outbound struct {
 	// ResolvedIPs is the resolved IP addresses, if the Targe is a domain address.
 	// ResolvedIPs is the resolved IP addresses, if the Targe is a domain address.
 	ResolvedIPs []net.IP
 	ResolvedIPs []net.IP
 }
 }
+
+// Content is the metadata of the connection content.
+type Content struct {
+	// Protocol of current content.
+	Protocol string
+}