Browse Source

add context arg to LoadImplementationByAlias

Shelikhoo 4 years ago
parent
commit
908dd96f1f
2 changed files with 23 additions and 4 deletions
  1. 11 4
      common/registry/registry.go
  2. 12 0
      common/session/context.go

+ 11 - 4
common/registry/registry.go

@@ -2,9 +2,11 @@ package registry
 
 
 import (
 import (
 	"bytes"
 	"bytes"
+	"context"
 	"github.com/golang/protobuf/jsonpb"
 	"github.com/golang/protobuf/jsonpb"
 	"github.com/golang/protobuf/proto"
 	"github.com/golang/protobuf/proto"
 	"github.com/v2fly/v2ray-core/v4/common/protoext"
 	"github.com/v2fly/v2ray-core/v4/common/protoext"
+	"github.com/v2fly/v2ray-core/v4/common/protofilter"
 	"github.com/v2fly/v2ray-core/v4/common/serial"
 	"github.com/v2fly/v2ray-core/v4/common/serial"
 	protov2 "google.golang.org/protobuf/proto"
 	protov2 "google.golang.org/protobuf/proto"
 	"reflect"
 	"reflect"
@@ -34,7 +36,7 @@ func (i *implementationRegistry) findImplementationByAlias(interfaceType, alias
 	return implSet.findImplementationByAlias(alias)
 	return implSet.findImplementationByAlias(alias)
 }
 }
 
 
-func (i *implementationRegistry) LoadImplementationByAlias(interfaceType, alias string, data []byte) (proto.Message, error) {
+func (i *implementationRegistry) LoadImplementationByAlias(ctx context.Context, interfaceType, alias string, data []byte) (proto.Message, error) {
 	var implementationFullName string
 	var implementationFullName string
 
 
 	if strings.HasPrefix(alias, "#") {
 	if strings.HasPrefix(alias, "#") {
@@ -61,6 +63,11 @@ func (i *implementationRegistry) LoadImplementationByAlias(interfaceType, alias
 		return nil, newError("unable to parse json content").Base(err)
 		return nil, newError("unable to parse json content").Base(err)
 	}
 	}
 
 
+	implementationConfigInstancev2 := proto.MessageV2(implementationConfigInstance)
+	if err := protofilter.FilterProtoConfig(ctx, implementationConfigInstancev2); err != nil {
+		return nil, err
+	}
+
 	return implementationConfigInstance.(proto.Message), nil
 	return implementationConfigInstance.(proto.Message), nil
 
 
 }
 }
@@ -108,14 +115,14 @@ func registerImplementation(proto interface{}, loader CustomLoader) error {
 }
 }
 
 
 type LoadByAlias interface {
 type LoadByAlias interface {
-	LoadImplementationByAlias(interfaceType, alias string, data []byte) (proto.Message, error)
+	LoadImplementationByAlias(ctx context.Context, interfaceType, alias string, data []byte) (proto.Message, error)
 }
 }
 
 
-func LoadImplementationByAlias(interfaceType, alias string, data []byte) (proto.Message, error) {
+func LoadImplementationByAlias(ctx context.Context, interfaceType, alias string, data []byte) (proto.Message, error) {
 	initialized.Do(func() {
 	initialized.Do(func() {
 		for _, v := range registerRequests {
 		for _, v := range registerRequests {
 			registerImplementation(v.proto, v.loader)
 			registerImplementation(v.proto, v.loader)
 		}
 		}
 	})
 	})
-	return globalImplementationRegistry.LoadImplementationByAlias(interfaceType, alias, data)
+	return globalImplementationRegistry.LoadImplementationByAlias(ctx, interfaceType, alias, data)
 }
 }

+ 12 - 0
common/session/context.go

@@ -15,6 +15,7 @@ const (
 	sockoptSessionKey
 	sockoptSessionKey
 	trackedConnectionErrorKey
 	trackedConnectionErrorKey
 	handlerSessionKey
 	handlerSessionKey
+	environmentKey
 )
 )
 
 
 // ContextWithID returns a new context with the given ID.
 // ContextWithID returns a new context with the given ID.
@@ -133,3 +134,14 @@ func SubmitOutboundErrorToOriginator(ctx context.Context, err error) {
 func TrackedConnectionError(ctx context.Context, tracker TrackedRequestErrorFeedback) context.Context {
 func TrackedConnectionError(ctx context.Context, tracker TrackedRequestErrorFeedback) context.Context {
 	return context.WithValue(ctx, trackedConnectionErrorKey, tracker)
 	return context.WithValue(ctx, trackedConnectionErrorKey, tracker)
 }
 }
+
+func ContextWithEnvironment(ctx context.Context, environment interface{}) context.Context {
+	return context.WithValue(ctx, environmentKey, environment)
+}
+
+func EnvironmentFromContext(ctx context.Context) interface{} {
+	if environment, ok := ctx.Value(environmentKey).(interface{}); ok {
+		return environment
+	}
+	return nil
+}