Browse Source

create session content in the context if do not exist yet

Shelikhoo 4 years ago
parent
commit
867bbb429e

+ 1 - 1
app/dispatcher/default.go

@@ -295,7 +295,7 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *transport.
 	var handler outbound.Handler
 	var handler outbound.Handler
 
 
 	if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" {
 	if forcedOutboundTag := session.GetForcedOutboundTagFromContext(ctx); forcedOutboundTag != "" {
-		session.SetForcedOutboundTagToContext(ctx, "")
+		ctx = session.SetForcedOutboundTagToContext(ctx, "")
 		if h := d.ohm.GetHandler(forcedOutboundTag); h != nil {
 		if h := d.ohm.GetHandler(forcedOutboundTag); h != nil {
 			newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
 			newError("taking platform initialized detour [", forcedOutboundTag, "] for [", destination, "]").WriteToLog(session.ExportIDToError(ctx))
 			handler = h
 			handler = h

+ 13 - 3
common/session/context.go

@@ -1,6 +1,8 @@
 package session
 package session
 
 
-import "context"
+import (
+	"context"
+)
 
 
 type sessionKey int
 type sessionKey int
 
 
@@ -92,8 +94,12 @@ func GetTransportLayerProxyTagFromContext(ctx context.Context) string {
 	return ContentFromContext(ctx).Attribute("transportLayerOutgoingTag")
 	return ContentFromContext(ctx).Attribute("transportLayerOutgoingTag")
 }
 }
 
 
-func SetTransportLayerProxyTagToContext(ctx context.Context, tag string) {
+func SetTransportLayerProxyTagToContext(ctx context.Context, tag string) context.Context {
+	if contentFromContext := ContentFromContext(ctx); contentFromContext == nil {
+		ctx = ContextWithContent(ctx, &Content{})
+	}
 	ContentFromContext(ctx).SetAttribute("transportLayerOutgoingTag", tag)
 	ContentFromContext(ctx).SetAttribute("transportLayerOutgoingTag", tag)
+	return ctx
 }
 }
 
 
 func GetForcedOutboundTagFromContext(ctx context.Context) string {
 func GetForcedOutboundTagFromContext(ctx context.Context) string {
@@ -103,6 +109,10 @@ func GetForcedOutboundTagFromContext(ctx context.Context) string {
 	return ContentFromContext(ctx).Attribute("forcedOutboundTag")
 	return ContentFromContext(ctx).Attribute("forcedOutboundTag")
 }
 }
 
 
-func SetForcedOutboundTagToContext(ctx context.Context, tag string) {
+func SetForcedOutboundTagToContext(ctx context.Context, tag string) context.Context {
+	if contentFromContext := ContentFromContext(ctx); contentFromContext == nil {
+		ctx = ContextWithContent(ctx, &Content{})
+	}
 	ContentFromContext(ctx).SetAttribute("forcedOutboundTag", tag)
 	ContentFromContext(ctx).SetAttribute("forcedOutboundTag", tag)
+	return ctx
 }
 }

+ 1 - 1
transport/internet/tagged/taggedimpl/impl.go

@@ -26,7 +26,7 @@ func DialTaggedOutbound(ctx context.Context, dest net.Destination, tag string) (
 	content.SkipDNSResolve = true
 	content.SkipDNSResolve = true
 
 
 	ctx = session.ContextWithContent(ctx, content)
 	ctx = session.ContextWithContent(ctx, content)
-	session.SetForcedOutboundTagToContext(ctx, tag)
+	ctx = session.SetForcedOutboundTagToContext(ctx, tag)
 
 
 	r, err := dispatcher.Dispatch(ctx, dest)
 	r, err := dispatcher.Dispatch(ctx, dest)
 	if err != nil {
 	if err != nil {