Browse Source

add delayed and reflective auto registration
Delay is required for all init to finish, otherwise protoreflect() can return nil.

Shelikhoo 4 years ago
parent
commit
a4c66656b1
1 changed files with 33 additions and 3 deletions
  1. 33 3
      common/registry/registry.go

+ 33 - 3
common/registry/registry.go

@@ -6,8 +6,10 @@ import (
 	"github.com/golang/protobuf/proto"
 	"github.com/v2fly/v2ray-core/v4/common/protoext"
 	"github.com/v2fly/v2ray-core/v4/common/serial"
-	"google.golang.org/protobuf/reflect/protoreflect"
+	protov2 "google.golang.org/protobuf/proto"
+	"reflect"
 	"strings"
+	"sync"
 )
 
 type implementationRegistry struct {
@@ -69,10 +71,33 @@ func newImplementationRegistry() *implementationRegistry {
 
 var globalImplementationRegistry = newImplementationRegistry()
 
+var initialized = &sync.Once{}
+
+type registerRequest struct {
+	proto  interface{}
+	loader CustomLoader
+}
+
+var registerRequests []registerRequest
+
 // RegisterImplementation register an implementation of a type of interface
 // loader(CustomLoader) is a private API, its interface is subject to breaking changes
-func RegisterImplementation(proto protoreflect.MessageDescriptor, loader CustomLoader) error {
-	msgDesc := proto
+func RegisterImplementation(proto interface{}, loader CustomLoader) error {
+	registerRequests = append(registerRequests, registerRequest{
+		proto:  proto,
+		loader: loader,
+	})
+	return nil
+}
+
+func registerImplementation(proto interface{}, loader CustomLoader) error {
+	protoReflect := reflect.New(reflect.TypeOf(proto).Elem())
+	var proto2 protov2.Message
+	assignMessage := func(msg protov2.Message) {
+		proto2 = msg
+	}
+	reflect.ValueOf(assignMessage).Call([]reflect.Value{protoReflect})
+	msgDesc := proto2.ProtoReflect().Descriptor()
 	fullName := string(msgDesc.FullName())
 	msgOpts, err := protoext.GetMessageOptions(msgDesc)
 	if err != nil {
@@ -87,5 +112,10 @@ type LoadByAlias interface {
 }
 
 func LoadImplementationByAlias(interfaceType, alias string, data []byte) (proto.Message, error) {
+	initialized.Do(func() {
+		for _, v := range registerRequests {
+			registerImplementation(v.proto, v.loader)
+		}
+	})
 	return globalImplementationRegistry.LoadImplementationByAlias(interfaceType, alias, data)
 }