Selaa lähdekoodia

global config creator

Darien Raymond 8 vuotta sitten
vanhempi
commit
db1c9131f0

+ 31 - 0
common/type.go

@@ -0,0 +1,31 @@
+package common
+
+import (
+	"context"
+	"errors"
+	"reflect"
+)
+
+type creator func(ctx context.Context, config interface{}) (interface{}, error)
+
+var (
+	typeCreatorRegistry = make(map[reflect.Type]creator)
+)
+
+func RegisterConfig(config interface{}, configCreator creator) error {
+	configType := reflect.TypeOf(config)
+	if _, found := typeCreatorRegistry[configType]; found {
+		return errors.New("Common: " + configType.Name() + " is already registered.")
+	}
+	typeCreatorRegistry[configType] = configCreator
+	return nil
+}
+
+func CreateObject(ctx context.Context, config interface{}) (interface{}, error) {
+	configType := reflect.TypeOf(config)
+	creator, found := typeCreatorRegistry[configType]
+	if !found {
+		return nil, errors.New("Common: " + configType.Name() + " is not registered.")
+	}
+	return creator(ctx, config)
+}

+ 12 - 20
transport/internet/conn_authenticator.go

@@ -1,8 +1,11 @@
 package internet
 
 import (
+	"errors"
 	"net"
 
+	"context"
+
 	"v2ray.com/core/common"
 )
 
@@ -11,26 +14,15 @@ type ConnectionAuthenticator interface {
 	Server(net.Conn) net.Conn
 }
 
-type ConnectionAuthenticatorFactory interface {
-	Create(interface{}) ConnectionAuthenticator
-}
-
-var (
-	connectionAuthenticatorCache = make(map[string]ConnectionAuthenticatorFactory)
-)
-
-func RegisterConnectionAuthenticator(name string, factory ConnectionAuthenticatorFactory) error {
-	if _, found := connectionAuthenticatorCache[name]; found {
-		return common.ErrDuplicatedName
+func CreateConnectionAuthenticator(config interface{}) (ConnectionAuthenticator, error) {
+	auth, err := common.CreateObject(context.Background(), config)
+	if err != nil {
+		return nil, err
 	}
-	connectionAuthenticatorCache[name] = factory
-	return nil
-}
-
-func CreateConnectionAuthenticator(name string, config interface{}) (ConnectionAuthenticator, error) {
-	factory, found := connectionAuthenticatorCache[name]
-	if !found {
-		return nil, common.ErrObjectNotFound
+	switch a := auth.(type) {
+	case ConnectionAuthenticator:
+		return a, nil
+	default:
+		return nil, errors.New("Internet: Not a ConnectionAuthenticator.")
 	}
-	return factory.Create(config), nil
 }

+ 15 - 21
transport/internet/header.go

@@ -1,32 +1,26 @@
 package internet
 
-import "v2ray.com/core/common"
+import (
+	"context"
+	"errors"
+
+	"v2ray.com/core/common"
+)
 
 type PacketHeader interface {
 	Size() int
 	Write([]byte) (int, error)
 }
 
-type PacketHeaderFactory interface {
-	Create(interface{}) PacketHeader
-}
-
-var (
-	headerCache = make(map[string]PacketHeaderFactory)
-)
-
-func RegisterPacketHeader(name string, factory PacketHeaderFactory) error {
-	if _, found := headerCache[name]; found {
-		return common.ErrDuplicatedName
+func CreatePacketHeader(config interface{}) (PacketHeader, error) {
+	header, err := common.CreateObject(context.Background(), config)
+	if err != nil {
+		return nil, err
 	}
-	headerCache[name] = factory
-	return nil
-}
-
-func CreatePacketHeader(name string, config interface{}) (PacketHeader, error) {
-	factory, found := headerCache[name]
-	if !found {
-		return nil, common.ErrObjectNotFound
+	switch h := header.(type) {
+	case PacketHeader:
+		return h, nil
+	default:
+		return nil, errors.New("Internet: Not a packet header.")
 	}
-	return factory.Create(config), nil
 }

+ 3 - 4
transport/internet/header_test.go

@@ -3,7 +3,6 @@ package internet_test
 import (
 	"testing"
 
-	"v2ray.com/core/common/serial"
 	"v2ray.com/core/testing/assert"
 	. "v2ray.com/core/transport/internet"
 	"v2ray.com/core/transport/internet/headers/noop"
@@ -14,15 +13,15 @@ import (
 func TestAllHeadersLoadable(t *testing.T) {
 	assert := assert.On(t)
 
-	noopAuth, err := CreatePacketHeader(serial.GetMessageType(new(noop.Config)), nil)
+	noopAuth, err := CreatePacketHeader((*noop.Config)(nil))
 	assert.Error(err).IsNil()
 	assert.Int(noopAuth.Size()).Equals(0)
 
-	srtp, err := CreatePacketHeader(serial.GetMessageType(new(srtp.Config)), nil)
+	srtp, err := CreatePacketHeader((*srtp.Config)(nil))
 	assert.Error(err).IsNil()
 	assert.Int(srtp.Size()).Equals(4)
 
-	utp, err := CreatePacketHeader(serial.GetMessageType(new(utp.Config)), nil)
+	utp, err := CreatePacketHeader((*utp.Config)(nil))
 	assert.Error(err).IsNil()
 	assert.Int(utp.Size()).Equals(4)
 }

+ 9 - 1
transport/internet/headers/http/http.go

@@ -2,6 +2,7 @@ package http
 
 import (
 	"bytes"
+	"context"
 	"errors"
 	"io"
 	"net"
@@ -9,6 +10,7 @@ import (
 	"strings"
 	"time"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/buf"
 	"v2ray.com/core/common/serial"
 	"v2ray.com/core/transport/internet"
@@ -265,6 +267,12 @@ func (HttpAuthenticatorFactory) Create(config interface{}) internet.ConnectionAu
 	}
 }
 
+func NewHttpAuthenticator(ctx context.Context, config interface{}) (interface{}, error) {
+	return HttpAuthenticator{
+		config: config.(*Config),
+	}, nil
+}
+
 func init() {
-	internet.RegisterConnectionAuthenticator(serial.GetMessageType(new(Config)), HttpAuthenticatorFactory{})
+	common.Must(common.RegisterConfig((*Config)(nil), NewHttpAuthenticator))
 }

+ 21 - 11
transport/internet/headers/noop/config.pb.go

@@ -23,8 +23,17 @@ func (m *Config) String() string            { return proto.CompactTextString(m)
 func (*Config) ProtoMessage()               {}
 func (*Config) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
 
+type ConnectionConfig struct {
+}
+
+func (m *ConnectionConfig) Reset()                    { *m = ConnectionConfig{} }
+func (m *ConnectionConfig) String() string            { return proto.CompactTextString(m) }
+func (*ConnectionConfig) ProtoMessage()               {}
+func (*ConnectionConfig) Descriptor() ([]byte, []int) { return fileDescriptor0, []int{1} }
+
 func init() {
 	proto.RegisterType((*Config)(nil), "v2ray.core.transport.internet.headers.noop.Config")
+	proto.RegisterType((*ConnectionConfig)(nil), "v2ray.core.transport.internet.headers.noop.ConnectionConfig")
 }
 
 func init() {
@@ -32,15 +41,16 @@ func init() {
 }
 
 var fileDescriptor0 = []byte{
-	// 160 bytes of a gzipped FileDescriptorProto
-	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0xe2, 0xb2, 0x2e, 0x33, 0x2a, 0x4a,
-	0xac, 0xd4, 0x4b, 0xce, 0xcf, 0xd5, 0x4f, 0xce, 0x2f, 0x4a, 0xd5, 0x2f, 0x29, 0x4a, 0xcc, 0x2b,
-	0x2e, 0xc8, 0x2f, 0x2a, 0xd1, 0xcf, 0xcc, 0x2b, 0x49, 0x2d, 0xca, 0x4b, 0x2d, 0xd1, 0xcf, 0x48,
-	0x4d, 0x4c, 0x49, 0x2d, 0x2a, 0xd6, 0xcf, 0xcb, 0xcf, 0x2f, 0xd0, 0x4f, 0xce, 0xcf, 0x4b, 0xcb,
-	0x4c, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0xd2, 0x82, 0x69, 0x2e, 0x4a, 0xd5, 0x83, 0x6b,
-	0xd4, 0x83, 0x69, 0xd4, 0x83, 0x6a, 0xd4, 0x03, 0x69, 0x54, 0xe2, 0xe0, 0x62, 0x73, 0x06, 0xeb,
-	0x75, 0x2a, 0xe0, 0x02, 0x59, 0xa7, 0x47, 0xbc, 0x5e, 0x27, 0x6e, 0x88, 0xce, 0x00, 0x90, 0xa5,
-	0x51, 0x2c, 0x20, 0xa1, 0x55, 0x4c, 0x5a, 0x61, 0x46, 0x41, 0x89, 0x95, 0x7a, 0xce, 0x20, 0xfd,
-	0x21, 0x70, 0xfd, 0x9e, 0x30, 0xfd, 0x1e, 0x50, 0xfd, 0x7e, 0xf9, 0xf9, 0x05, 0x49, 0x6c, 0x60,
-	0xe7, 0x1a, 0x03, 0x02, 0x00, 0x00, 0xff, 0xff, 0x89, 0x52, 0x33, 0x81, 0xed, 0x00, 0x00, 0x00,
+	// 174 bytes of a gzipped FileDescriptorProto
+	0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0x8c, 0xcf, 0xb1, 0xea, 0xc2, 0x40,
+	0x0c, 0xc7, 0x71, 0xf8, 0xf3, 0xa7, 0xc8, 0xb9, 0x48, 0x1f, 0xa1, 0x63, 0x87, 0x1c, 0xd4, 0xd1,
+	0xad, 0x5d, 0x74, 0x11, 0x11, 0x71, 0x70, 0x3b, 0xcf, 0xa8, 0x1d, 0x4c, 0x8e, 0x34, 0x08, 0x7d,
+	0x25, 0x9f, 0x52, 0xae, 0xed, 0x75, 0x76, 0xfd, 0xc1, 0x27, 0x7c, 0x63, 0x36, 0xef, 0x4a, 0x5c,
+	0x0f, 0x9e, 0x5f, 0xd6, 0xb3, 0xa0, 0x55, 0x71, 0xd4, 0x05, 0x16, 0xb5, 0x2d, 0x29, 0x0a, 0xa1,
+	0xda, 0x27, 0xba, 0x1b, 0x4a, 0x67, 0x89, 0x39, 0x58, 0xcf, 0x74, 0x6f, 0x1f, 0x10, 0x84, 0x95,
+	0xf3, 0x32, 0x61, 0x41, 0x98, 0x21, 0x24, 0x08, 0x13, 0x84, 0x08, 0x8b, 0x85, 0xc9, 0x9a, 0xc1,
+	0x16, 0xb9, 0x59, 0x35, 0x4c, 0x84, 0x5e, 0x5b, 0xa6, 0x71, 0xab, 0x83, 0x89, 0x09, 0xf0, 0xfb,
+	0xbd, 0x7a, 0x39, 0xca, 0x43, 0x0c, 0xb9, 0xfc, 0xc7, 0xe9, 0xf3, 0x57, 0x9e, 0xab, 0xa3, 0xeb,
+	0xa1, 0x89, 0xfe, 0x34, 0xfb, 0x5d, 0xf2, 0xdb, 0xc9, 0xef, 0x99, 0xc3, 0x35, 0x1b, 0x5e, 0x58,
+	0x7f, 0x03, 0x00, 0x00, 0xff, 0xff, 0x5e, 0x55, 0x14, 0x69, 0x01, 0x01, 0x00, 0x00,
 }

+ 3 - 1
transport/internet/headers/noop/config.proto

@@ -6,4 +6,6 @@ option go_package = "noop";
 option java_package = "com.v2ray.core.transport.internet.headers.noop";
 option java_outer_classname = "ConfigProto";
 
-message Config {}
+message Config {}
+
+message ConnectionConfig{}

+ 8 - 12
transport/internet/headers/noop/noop.go

@@ -1,10 +1,10 @@
 package noop
 
 import (
+	"context"
 	"net"
 
-	"v2ray.com/core/common/serial"
-	"v2ray.com/core/transport/internet"
+	"v2ray.com/core/common"
 )
 
 type NoOpHeader struct{}
@@ -16,10 +16,8 @@ func (v NoOpHeader) Write([]byte) (int, error) {
 	return 0, nil
 }
 
-type NoOpHeaderFactory struct{}
-
-func (v NoOpHeaderFactory) Create(config interface{}) internet.PacketHeader {
-	return NoOpHeader{}
+func NewNoOpHeader(context.Context, interface{}) (interface{}, error) {
+	return NoOpHeader{}, nil
 }
 
 type NoOpConnectionHeader struct{}
@@ -32,13 +30,11 @@ func (NoOpConnectionHeader) Server(conn net.Conn) net.Conn {
 	return conn
 }
 
-type NoOpConnectionHeaderFactory struct{}
-
-func (NoOpConnectionHeaderFactory) Create(config interface{}) internet.ConnectionAuthenticator {
-	return NoOpConnectionHeader{}
+func NewNoOpConnectionHeader(context.Context, interface{}) (interface{}, error) {
+	return NoOpConnectionHeader{}, nil
 }
 
 func init() {
-	internet.RegisterPacketHeader(serial.GetMessageType(new(Config)), NoOpHeaderFactory{})
-	internet.RegisterConnectionAuthenticator(serial.GetMessageType(new(Config)), NoOpConnectionHeaderFactory{})
+	common.Must(common.RegisterConfig((*Config)(nil), NewNoOpHeader))
+	common.Must(common.RegisterConfig((*ConnectionConfig)(nil), NewNoOpConnectionHeader))
 }

+ 5 - 7
transport/internet/headers/srtp/srtp.go

@@ -1,10 +1,11 @@
 package srtp
 
 import (
+	"context"
 	"math/rand"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/serial"
-	"v2ray.com/core/transport/internet"
 )
 
 type SRTP struct {
@@ -23,16 +24,13 @@ func (v *SRTP) Write(b []byte) (int, error) {
 	return 4, nil
 }
 
-type SRTPFactory struct {
-}
-
-func (v SRTPFactory) Create(rawSettings interface{}) internet.PacketHeader {
+func NewSRTP(ctx context.Context, config interface{}) (interface{}, error) {
 	return &SRTP{
 		header: 0xB5E8,
 		number: uint16(rand.Intn(65536)),
-	}
+	}, nil
 }
 
 func init() {
-	internet.RegisterPacketHeader(serial.GetMessageType(new(Config)), SRTPFactory{})
+	common.Must(common.RegisterConfig((*Config)(nil), NewSRTP))
 }

+ 5 - 6
transport/internet/headers/utp/utp.go

@@ -1,10 +1,11 @@
 package utp
 
 import (
+	"context"
 	"math/rand"
 
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/serial"
-	"v2ray.com/core/transport/internet"
 )
 
 type UTP struct {
@@ -24,16 +25,14 @@ func (v *UTP) Write(b []byte) (int, error) {
 	return 4, nil
 }
 
-type UTPFactory struct{}
-
-func (v UTPFactory) Create(rawSettings interface{}) internet.PacketHeader {
+func NewUTP(ctx context.Context, config interface{}) (interface{}, error) {
 	return &UTP{
 		header:       1,
 		extension:    0,
 		connectionId: uint16(rand.Intn(65536)),
-	}
+	}, nil
 }
 
 func init() {
-	internet.RegisterPacketHeader(serial.GetMessageType(new(Config)), UTPFactory{})
+	common.Must(common.RegisterConfig((*Config)(nil), NewUTP))
 }

+ 5 - 6
transport/internet/headers/wechat/wechat.go

@@ -1,10 +1,11 @@
 package wechat
 
 import (
+	"context"
+
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/dice"
 	"v2ray.com/core/common/serial"
-	"v2ray.com/core/transport/internet"
 )
 
 type VideoChat struct {
@@ -23,14 +24,12 @@ func (vc *VideoChat) Write(b []byte) (int, error) {
 	return 13, nil
 }
 
-type VideoChatFactory struct{}
-
-func (VideoChatFactory) Create(rawSettings interface{}) internet.PacketHeader {
+func NewVideoChat(ctx context.Context, config interface{}) (interface{}, error) {
 	return &VideoChat{
 		sn: dice.Roll(65535),
-	}
+	}, nil
 }
 
 func init() {
-	common.Must(internet.RegisterPacketHeader(serial.GetMessageType(new(VideoConfig)), VideoChatFactory{}))
+	common.Must(common.RegisterConfig((*VideoConfig)(nil), NewVideoChat))
 }

+ 1 - 1
transport/internet/kcp/config.go

@@ -67,7 +67,7 @@ func (v *Config) GetPackerHeader() (internet.PacketHeader, error) {
 			return nil, err
 		}
 
-		return internet.CreatePacketHeader(v.HeaderConfig.Type, rawConfig)
+		return internet.CreatePacketHeader(rawConfig)
 	}
 	return nil, nil
 }

+ 1 - 1
transport/internet/tcp/dialer.go

@@ -59,7 +59,7 @@ func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOpti
 			if err != nil {
 				return nil, errors.Base(err).Message("Interent|TCP: Failed to get header settings.")
 			}
-			auth, err := internet.CreateConnectionAuthenticator(tcpSettings.HeaderSettings.Type, headerConfig)
+			auth, err := internet.CreateConnectionAuthenticator(headerConfig)
 			if err != nil {
 				return nil, errors.Base(err).Message("Internet|TCP: Failed to create header authenticator.")
 			}

+ 1 - 1
transport/internet/tcp/hub.go

@@ -70,7 +70,7 @@ func ListenTCP(address v2net.Address, port v2net.Port, options internet.ListenOp
 		if err != nil {
 			return nil, errors.Base(err).Message("Internet|TCP: Invalid header settings.")
 		}
-		auth, err := internet.CreateConnectionAuthenticator(tcpSettings.HeaderSettings.Type, headerConfig)
+		auth, err := internet.CreateConnectionAuthenticator(headerConfig)
 		if err != nil {
 			return nil, errors.Base(err).Message("Internet|TCP: Invalid header settings.")
 		}