Browse Source

move sniffing request to session

Darien Raymond 6 years ago
parent
commit
3828a463ea
4 changed files with 29 additions and 22 deletions
  1. 8 9
      app/dispatcher/default.go
  2. 4 1
      app/proxyman/inbound/worker.go
  3. 10 12
      app/proxyman/proxyman.go
  4. 7 0
      common/session/session.go

+ 8 - 9
app/dispatcher/default.go

@@ -11,7 +11,6 @@ import (
 	"time"
 
 	"v2ray.com/core"
-	"v2ray.com/core/app/proxyman"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/net"
@@ -196,8 +195,13 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
 	ctx = session.ContextWithOutbound(ctx, ob)
 
 	inbound, outbound := d.getLink(ctx)
-	sniffingConfig := proxyman.SniffingConfigFromContext(ctx)
-	if destination.Network != net.Network_TCP || sniffingConfig == nil || !sniffingConfig.Enabled {
+	content := session.ContentFromContext(ctx)
+	if content == nil {
+		content = new(session.Content)
+		ctx = session.ContextWithContent(ctx, content)
+	}
+	sniffingRequest := content.SniffingRequest
+	if destination.Network != net.Network_TCP || !sniffingRequest.Enabled {
 		go d.routedDispatch(ctx, outbound, destination)
 	} else {
 		go func() {
@@ -207,14 +211,9 @@ func (d *DefaultDispatcher) Dispatch(ctx context.Context, destination net.Destin
 			outbound.Reader = cReader
 			result, err := sniffer(ctx, cReader)
 			if err == nil {
-				content := session.ContentFromContext(ctx)
-				if content == nil {
-					content = new(session.Content)
-				}
 				content.Protocol = result.Protocol()
-				ctx = session.ContextWithContent(ctx, content)
 			}
-			if err == nil && shouldOverride(result, sniffingConfig.DestinationOverride) {
+			if err == nil && shouldOverride(result, sniffingRequest.OverrideDestinationForProtocol) {
 				domain := result.Domain()
 				newError("sniffed domain: ", domain).WriteToLog(session.ExportIDToError(ctx))
 				destination.Address = net.ParseAddress(domain)

+ 4 - 1
app/proxyman/inbound/worker.go

@@ -81,9 +81,12 @@ func (w *tcpWorker) callback(conn internet.Connection) {
 		Gateway: net.TCPDestination(w.address, w.port),
 		Tag:     w.tag,
 	})
+	content := new(session.Content)
 	if w.sniffingConfig != nil {
-		ctx = proxyman.ContextWithSniffingConfig(ctx, w.sniffingConfig)
+		content.SniffingRequest.Enabled = w.sniffingConfig.Enabled
+		content.SniffingRequest.OverrideDestinationForProtocol = w.sniffingConfig.DestinationOverride
 	}
+	ctx = session.ContextWithContent(ctx, content)
 	if w.uplinkCounter != nil || w.downlinkCounter != nil {
 		conn = &internet.StatCouterConnection{
 			Connection: conn,

+ 10 - 12
app/proxyman/proxyman.go

@@ -3,21 +3,19 @@ package proxyman
 
 import (
 	"context"
-)
-
-type key int
 
-const (
-	sniffing key = iota
+	"v2ray.com/core/common/session"
 )
 
+// ContextWithSniffingConfig is a wrapper of session.ContextWithContent.
+// Deprecated. Use session.ContextWithContent directly.
 func ContextWithSniffingConfig(ctx context.Context, c *SniffingConfig) context.Context {
-	return context.WithValue(ctx, sniffing, c)
-}
-
-func SniffingConfigFromContext(ctx context.Context) *SniffingConfig {
-	if c, ok := ctx.Value(sniffing).(*SniffingConfig); ok {
-		return c
+	content := session.ContentFromContext(ctx)
+	if content == nil {
+		content = new(session.Content)
+		ctx = session.ContextWithContent(ctx, content)
 	}
-	return nil
+	content.SniffingRequest.Enabled = c.Enabled
+	content.SniffingRequest.OverrideDestinationForProtocol = c.DestinationOverride
+	return ctx
 }

+ 7 - 0
common/session/session.go

@@ -55,8 +55,15 @@ type Outbound struct {
 	ResolvedIPs []net.IP
 }
 
+type SniffingRequest struct {
+	OverrideDestinationForProtocol []string
+	Enabled                        bool
+}
+
 // Content is the metadata of the connection content.
 type Content struct {
 	// Protocol of current content.
 	Protocol string
+
+	SniffingRequest SniffingRequest
 }