Przeglądaj źródła

testing with mock

Darien Raymond 7 lat temu
rodzic
commit
f78cf6cfc2
8 zmienionych plików z 248 dodań i 119 usunięć
  1. 12 15
      app/dispatcher/default.go
  2. 19 18
      app/router/router.go
  3. 49 24
      app/router/router_test.go
  4. 83 0
      features/mocks/dns.go
  5. 5 0
      mocks.go
  6. 18 16
      proxy/dokodemo/dokodemo.go
  7. 18 17
      proxy/freedom/freedom.go
  8. 44 29
      v2ray.go

+ 12 - 15
app/dispatcher/default.go

@@ -89,22 +89,25 @@ type DefaultDispatcher struct {
 	stats  stats.Manager
 }
 
-// NewDefaultDispatcher create a new DefaultDispatcher.
-func NewDefaultDispatcher(ctx context.Context, config *Config) (*DefaultDispatcher, error) {
-	d := &DefaultDispatcher{}
-
-	core.RequireFeatures(ctx, d.Init)
-
-	return d, nil
+func init() {
+	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
+		d := new(DefaultDispatcher)
+		if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error {
+			return d.Init(config.(*Config), om, router, pm, sm)
+		}); err != nil {
+			return nil, err
+		}
+		return d, nil
+	}))
 }
 
 // Init initializes DefaultDispatcher.
-// This method is visible for testing purpose.
-func (d *DefaultDispatcher) Init(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) {
+func (d *DefaultDispatcher) Init(config *Config, om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager) error {
 	d.ohm = om
 	d.router = router
 	d.policy = pm
 	d.stats = sm
+	return nil
 }
 
 // Type implements common.HasType.
@@ -257,9 +260,3 @@ func (d *DefaultDispatcher) routedDispatch(ctx context.Context, link *vio.Link,
 	}
 	dispatcher.Dispatch(ctx, link)
 }
-
-func init() {
-	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
-		return NewDefaultDispatcher(ctx, config.(*Config))
-	}))
-}

+ 19 - 18
app/router/router.go

@@ -14,6 +14,18 @@ import (
 	"v2ray.com/core/proxy"
 )
 
+func init() {
+	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
+		r := new(Router)
+		if err := core.RequireFeatures(ctx, func(d dns.Client) error {
+			return r.Init(config.(*Config), d)
+		}); err != nil {
+			return nil, err
+		}
+		return r, nil
+	}))
+}
+
 // Router is an implementation of routing.Router.
 type Router struct {
 	domainStrategy Config_DomainStrategy
@@ -21,27 +33,22 @@ type Router struct {
 	dns            dns.Client
 }
 
-// NewRouter creates a new Router based on the given config.
-func NewRouter(ctx context.Context, config *Config) (*Router, error) {
-	r := &Router{
-		domainStrategy: config.DomainStrategy,
-		rules:          make([]Rule, len(config.Rule)),
-	}
+// Init initializes the Router.
+func (r *Router) Init(config *Config, d dns.Client) error {
+	r.domainStrategy = config.DomainStrategy
+	r.rules = make([]Rule, len(config.Rule))
+	r.dns = d
 
 	for idx, rule := range config.Rule {
 		r.rules[idx].Tag = rule.Tag
 		cond, err := rule.BuildCondition()
 		if err != nil {
-			return nil, err
+			return err
 		}
 		r.rules[idx].Condition = cond
 	}
 
-	core.RequireFeatures(ctx, func(d dns.Client) {
-		r.dns = d
-	})
-
-	return r, nil
+	return nil
 }
 
 type ipResolver struct {
@@ -127,9 +134,3 @@ func (*Router) Close() error {
 func (*Router) Type() interface{} {
 	return routing.RouterType()
 }
-
-func init() {
-	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
-		return NewRouter(ctx, config.(*Config))
-	}))
-}

+ 49 - 24
app/router/router_test.go

@@ -3,46 +3,71 @@ package router_test
 import (
 	"testing"
 
-	"v2ray.com/core"
-	"v2ray.com/core/app/dispatcher"
-	"v2ray.com/core/app/proxyman"
-	_ "v2ray.com/core/app/proxyman/outbound"
+	"github.com/golang/mock/gomock"
 	. "v2ray.com/core/app/router"
 	"v2ray.com/core/common"
 	"v2ray.com/core/common/net"
-	"v2ray.com/core/common/serial"
 	"v2ray.com/core/common/session"
-	"v2ray.com/core/features/routing"
-	. "v2ray.com/ext/assert"
+	"v2ray.com/core/features/mocks"
 )
 
 func TestSimpleRouter(t *testing.T) {
-	assert := With(t)
+	config := &Config{
+		Rule: []*RoutingRule{
+			{
+				Tag: "test",
+				NetworkList: &net.NetworkList{
+					Network: []net.Network{net.Network_TCP},
+				},
+			},
+		},
+	}
+
+	mockCtl := gomock.NewController(t)
+	defer mockCtl.Finish()
+
+	mockDns := mocks.NewMockDNSClient(mockCtl)
+
+	r := new(Router)
+	common.Must(r.Init(config, mockDns))
+
+	ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
+	tag, err := r.PickRoute(ctx)
+	common.Must(err)
+	if tag != "test" {
+		t.Error("expect tag 'test', bug actually ", tag)
+	}
+}
 
-	config := &core.Config{
-		App: []*serial.TypedMessage{
-			serial.ToTypedMessage(&Config{
-				Rule: []*RoutingRule{
+func TestIPOnDemand(t *testing.T) {
+	config := &Config{
+		DomainStrategy: Config_IpOnDemand,
+		Rule: []*RoutingRule{
+			{
+				Tag: "test",
+				Cidr: []*CIDR{
 					{
-						Tag: "test",
-						NetworkList: &net.NetworkList{
-							Network: []net.Network{net.Network_TCP},
-						},
+						Ip:     []byte{192, 168, 0, 0},
+						Prefix: 16,
 					},
 				},
-			}),
-			serial.ToTypedMessage(&dispatcher.Config{}),
-			serial.ToTypedMessage(&proxyman.OutboundConfig{}),
+			},
 		},
 	}
 
-	v, err := core.New(config)
-	common.Must(err)
+	mockCtl := gomock.NewController(t)
+	defer mockCtl.Finish()
+
+	mockDns := mocks.NewMockDNSClient(mockCtl)
+	mockDns.EXPECT().LookupIP(gomock.Eq("v2ray.com")).Return([]net.IP{{192, 168, 0, 1}}, nil).AnyTimes()
 
-	r := v.GetFeature(routing.RouterType()).(routing.Router)
+	r := new(Router)
+	common.Must(r.Init(config, mockDns))
 
 	ctx := withOutbound(&session.Outbound{Target: net.TCPDestination(net.DomainAddress("v2ray.com"), 80)})
 	tag, err := r.PickRoute(ctx)
-	assert(err, IsNil)
-	assert(tag, Equals, "test")
+	common.Must(err)
+	if tag != "test" {
+		t.Error("expect tag 'test', bug actually ", tag)
+	}
 }

+ 83 - 0
features/mocks/dns.go

@@ -0,0 +1,83 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: v2ray.com/core/features/dns (interfaces: Client)
+
+// Package mocks is a generated GoMock package.
+package mocks
+
+import (
+	gomock "github.com/golang/mock/gomock"
+	net "net"
+	reflect "reflect"
+)
+
+// MockDNSClient is a mock of Client interface
+type MockDNSClient struct {
+	ctrl     *gomock.Controller
+	recorder *MockDNSClientMockRecorder
+}
+
+// MockDNSClientMockRecorder is the mock recorder for MockDNSClient
+type MockDNSClientMockRecorder struct {
+	mock *MockDNSClient
+}
+
+// NewMockDNSClient creates a new mock instance
+func NewMockDNSClient(ctrl *gomock.Controller) *MockDNSClient {
+	mock := &MockDNSClient{ctrl: ctrl}
+	mock.recorder = &MockDNSClientMockRecorder{mock}
+	return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use
+func (m *MockDNSClient) EXPECT() *MockDNSClientMockRecorder {
+	return m.recorder
+}
+
+// Close mocks base method
+func (m *MockDNSClient) Close() error {
+	ret := m.ctrl.Call(m, "Close")
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// Close indicates an expected call of Close
+func (mr *MockDNSClientMockRecorder) Close() *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockDNSClient)(nil).Close))
+}
+
+// LookupIP mocks base method
+func (m *MockDNSClient) LookupIP(arg0 string) ([]net.IP, error) {
+	ret := m.ctrl.Call(m, "LookupIP", arg0)
+	ret0, _ := ret[0].([]net.IP)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// LookupIP indicates an expected call of LookupIP
+func (mr *MockDNSClientMockRecorder) LookupIP(arg0 interface{}) *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LookupIP", reflect.TypeOf((*MockDNSClient)(nil).LookupIP), arg0)
+}
+
+// Start mocks base method
+func (m *MockDNSClient) Start() error {
+	ret := m.ctrl.Call(m, "Start")
+	ret0, _ := ret[0].(error)
+	return ret0
+}
+
+// Start indicates an expected call of Start
+func (mr *MockDNSClientMockRecorder) Start() *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockDNSClient)(nil).Start))
+}
+
+// Type mocks base method
+func (m *MockDNSClient) Type() interface{} {
+	ret := m.ctrl.Call(m, "Type")
+	ret0, _ := ret[0].(interface{})
+	return ret0
+}
+
+// Type indicates an expected call of Type
+func (mr *MockDNSClientMockRecorder) Type() *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Type", reflect.TypeOf((*MockDNSClient)(nil).Type))
+}

+ 5 - 0
mocks.go

@@ -0,0 +1,5 @@
+package core
+
+//go:generate go get -u github.com/golang/mock/gomock
+//go:generate go install github.com/golang/mock/mockgen
+//go:generate mockgen -package mocks -destination v2ray.com/core/features/mocks/dns.go v2ray.com/core/features/dns Client

+ 18 - 16
proxy/dokodemo/dokodemo.go

@@ -19,6 +19,16 @@ import (
 	"v2ray.com/core/transport/pipe"
 )
 
+func init() {
+	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
+		d := new(DokodemoDoor)
+		err := core.RequireFeatures(ctx, func(pm policy.Manager) error {
+			return d.Init(config.(*Config), pm)
+		})
+		return d, err
+	}))
+}
+
 type DokodemoDoor struct {
 	policyManager policy.Manager
 	config        *Config
@@ -26,19 +36,17 @@ type DokodemoDoor struct {
 	port          net.Port
 }
 
-func New(ctx context.Context, config *Config) (*DokodemoDoor, error) {
+// Init initializes the DokodemoDoor instance with necessary parameters.
+func (d *DokodemoDoor) Init(config *Config, pm policy.Manager) error {
 	if config.NetworkList == nil || config.NetworkList.Size() == 0 {
-		return nil, newError("no network specified")
-	}
-	v := core.MustFromContext(ctx)
-	d := &DokodemoDoor{
-		config:        config,
-		address:       config.GetPredefinedAddress(),
-		port:          net.Port(config.Port),
-		policyManager: v.GetFeature(policy.ManagerType()).(policy.Manager),
+		return newError("no network specified")
 	}
+	d.config = config
+	d.address = config.GetPredefinedAddress()
+	d.port = net.Port(config.Port)
+	d.policyManager = pm
 
-	return d, nil
+	return nil
 }
 
 func (d *DokodemoDoor) Network() net.NetworkList {
@@ -144,9 +152,3 @@ func (d *DokodemoDoor) Process(ctx context.Context, network net.Network, conn in
 
 	return nil
 }
-
-func init() {
-	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
-		return New(ctx, config.(*Config))
-	}))
-}

+ 18 - 17
proxy/freedom/freedom.go

@@ -22,6 +22,18 @@ import (
 	"v2ray.com/core/transport/internet"
 )
 
+func init() {
+	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
+		h := new(Handler)
+		if err := core.RequireFeatures(ctx, func(pm policy.Manager, d dns.Client) error {
+			return h.Init(config.(*Config), pm, d)
+		}); err != nil {
+			return nil, err
+		}
+		return h, nil
+	}))
+}
+
 // Handler handles Freedom connections.
 type Handler struct {
 	policyManager policy.Manager
@@ -29,18 +41,13 @@ type Handler struct {
 	config        Config
 }
 
-// New creates a new Freedom handler.
-func New(ctx context.Context, config *Config) (*Handler, error) {
-	f := &Handler{
-		config: *config,
-	}
+// Init initializes the Handler with necessary parameters.
+func (h *Handler) Init(config *Config, pm policy.Manager, d dns.Client) error {
+	h.config = *config
+	h.policyManager = pm
+	h.dns = d
 
-	core.RequireFeatures(ctx, func(pm policy.Manager, d dns.Client) {
-		f.policyManager = pm
-		f.dns = d
-	})
-
-	return f, nil
+	return nil
 }
 
 func (h *Handler) policy() policy.Session {
@@ -163,9 +170,3 @@ func (h *Handler) Process(ctx context.Context, link *vio.Link, dialer proxy.Dial
 
 	return nil
 }
-
-func init() {
-	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
-		return New(ctx, config.(*Config))
-	}))
-}

+ 44 - 29
v2ray.go

@@ -40,12 +40,12 @@ func getFeature(allFeatures []features.Feature, t reflect.Type) features.Feature
 	return nil
 }
 
-func (r *resolution) resolve(allFeatures []features.Feature) bool {
+func (r *resolution) resolve(allFeatures []features.Feature) (bool, error) {
 	var fs []features.Feature
 	for _, d := range r.deps {
 		f := getFeature(allFeatures, d)
 		if f == nil {
-			return false
+			return false, nil
 		}
 		fs = append(fs, f)
 	}
@@ -67,9 +67,16 @@ func (r *resolution) resolve(allFeatures []features.Feature) bool {
 		panic("Can't get all input parameters")
 	}
 
-	callback.Call(input)
+	var err error
+	ret := callback.Call(input)
+	errInterface := reflect.TypeOf((*error)(nil)).Elem()
+	for i := len(ret) - 1; i >= 0; i-- {
+		if ret[i].Type().Implements(errInterface) {
+			err = ret[i].Interface().(error)
+		}
+	}
 
-	return true
+	return true, err
 }
 
 // Instance combines all functionalities in V2Ray.
@@ -134,9 +141,9 @@ func addOutboundHandlers(server *Instance, configs []*OutboundHandlerConfig) err
 
 // RequireFeatures is a helper function to require features from Instance in context.
 // See Instance.RequireFeatures for more information.
-func RequireFeatures(ctx context.Context, callback interface{}) {
+func RequireFeatures(ctx context.Context, callback interface{}) error {
 	v := MustFromContext(ctx)
-	v.RequireFeatures(callback)
+	return v.RequireFeatures(callback)
 }
 
 // New returns a new V2Ray instance based on given configuration.
@@ -162,29 +169,30 @@ func New(config *Config) (*Instance, error) {
 			return nil, err
 		}
 		if feature, ok := obj.(features.Feature); ok {
-			server.AddFeature(feature)
+			if err := server.AddFeature(feature); err != nil {
+				return nil, err
+			}
 		}
 	}
 
-	if server.GetFeature(dns.ClientType()) == nil {
-		server.AddFeature(dns.LocalClient{})
-	}
-
-	if server.GetFeature(policy.ManagerType()) == nil {
-		server.AddFeature(policy.DefaultManager{})
-	}
-
-	if server.GetFeature(routing.RouterType()) == nil {
-		server.AddFeature(routing.DefaultRouter{})
+	essentialFeatures := []struct {
+		Type     interface{}
+		Instance features.Feature
+	}{
+		{dns.ClientType(), dns.LocalClient{}},
+		{policy.ManagerType(), policy.DefaultManager{}},
+		{routing.RouterType(), routing.DefaultRouter{}},
+		{stats.ManagerType(), stats.NoopManager{}},
 	}
 
-	if server.GetFeature(stats.ManagerType()) == nil {
-		server.AddFeature(stats.NoopManager{})
+	for _, f := range essentialFeatures {
+		if server.GetFeature(f.Type) == nil {
+			if err := server.AddFeature(f.Instance); err != nil {
+				return nil, err
+			}
+		}
 	}
 
-	// Add an empty instance at the end, for optional feature requirement.
-	server.AddFeature(&Instance{})
-
 	if server.featureResolutions != nil {
 		return nil, newError("not all dependency are resolved.")
 	}
@@ -227,7 +235,7 @@ func (s *Instance) Close() error {
 
 // RequireFeatures registers a callback, which will be called when all dependent features are registered.
 // The callback must be a func(). All its parameters must be features.Feature.
-func (s *Instance) RequireFeatures(callback interface{}) {
+func (s *Instance) RequireFeatures(callback interface{}) error {
 	callbackType := reflect.TypeOf(callback)
 	if callbackType.Kind() != reflect.Func {
 		panic("not a function")
@@ -242,30 +250,35 @@ func (s *Instance) RequireFeatures(callback interface{}) {
 		deps:     featureTypes,
 		callback: callback,
 	}
-	if r.resolve(s.features) {
-		return
+	if finished, err := r.resolve(s.features); finished {
+		return err
 	}
 	s.featureResolutions = append(s.featureResolutions, r)
+	return nil
 }
 
 // AddFeature registers a feature into current Instance.
-func (s *Instance) AddFeature(feature features.Feature) {
+func (s *Instance) AddFeature(feature features.Feature) error {
 	s.features = append(s.features, feature)
 
 	if s.running {
 		if err := feature.Start(); err != nil {
 			newError("failed to start feature").Base(err).WriteToLog()
 		}
-		return
+		return nil
 	}
 
 	if s.featureResolutions == nil {
-		return
+		return nil
 	}
 
 	var pendingResolutions []resolution
 	for _, r := range s.featureResolutions {
-		if !r.resolve(s.features) {
+		finished, err := r.resolve(s.features)
+		if finished && err != nil {
+			return err
+		}
+		if !finished {
 			pendingResolutions = append(pendingResolutions, r)
 		}
 	}
@@ -274,6 +287,8 @@ func (s *Instance) AddFeature(feature features.Feature) {
 	} else if len(pendingResolutions) < len(s.featureResolutions) {
 		s.featureResolutions = pendingResolutions
 	}
+
+	return nil
 }
 
 // GetFeature returns a feature of the given type, or nil if such feature is not registered.