Browse Source

refactor mux client worker

Darien Raymond 7 years ago
parent
commit
53870f1ea7
5 changed files with 141 additions and 35 deletions
  1. 5 3
      app/proxyman/outbound/handler.go
  2. 59 29
      common/mux/client.go
  3. 29 3
      common/mux/client_test.go
  4. 1 0
      mocks.go
  5. 47 0
      testing/mocks/mux.go

+ 5 - 3
app/proxyman/outbound/handler.go

@@ -74,11 +74,13 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou
 		}
 		h.mux = &mux.ClientManager{
 			Picker: &mux.IncrementalWorkerPicker{
-				New: func() (*mux.ClientWorker, error) {
-					return mux.NewClientWorker(proxyHandler, h, mux.ClientStrategy{
+				Factory: &mux.DialingWorkerFactory{
+					Proxy:  proxyHandler,
+					Dialer: h,
+					Strategy: mux.ClientStrategy{
 						MaxConcurrency: config.Concurrency,
 						MaxConnection:  128,
-					})
+					},
 				},
 			},
 		}

+ 59 - 29
common/mux/client.go

@@ -41,7 +41,7 @@ type WorkerPicker interface {
 }
 
 type IncrementalWorkerPicker struct {
-	New func() (*ClientWorker, error)
+	Factory ClientWorkerFactory
 
 	access      sync.Mutex
 	workers     []*ClientWorker
@@ -82,7 +82,7 @@ func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, error, bool) {
 
 	p.cleanup()
 
-	worker, err := p.New()
+	worker, err := p.Factory.Create()
 	if err != nil {
 		return nil, err, false
 	}
@@ -107,6 +107,46 @@ func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) {
 	return worker, err
 }
 
+type ClientWorkerFactory interface {
+	Create() (*ClientWorker, error)
+}
+
+type DialingWorkerFactory struct {
+	Proxy    proxy.Outbound
+	Dialer   internet.Dialer
+	Strategy ClientStrategy
+}
+
+func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
+	opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)}
+	uplinkReader, upLinkWriter := pipe.New(opts...)
+	downlinkReader, downlinkWriter := pipe.New(opts...)
+
+	c, err := NewClientWorker(vio.Link{
+		Reader: downlinkReader,
+		Writer: upLinkWriter,
+	}, f.Strategy)
+
+	if err != nil {
+		return nil, err
+	}
+
+	go func(p proxy.Outbound, d internet.Dialer, c common.Closable) {
+		ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{
+			Target: net.TCPDestination(muxCoolAddress, muxCoolPort),
+		})
+		ctx, cancel := context.WithCancel(ctx)
+
+		if err := p.Process(ctx, &vio.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
+			errors.New("failed to handler mux client connection").Base(err).WriteToLog()
+		}
+		common.Must(c.Close())
+		cancel()
+	}(f.Proxy, f.Dialer, c.done)
+
+	return c, nil
+}
+
 type ClientStrategy struct {
 	MaxConcurrency uint32
 	MaxConnection  uint32
@@ -123,36 +163,17 @@ var muxCoolAddress = net.DomainAddress("v1.mux.cool")
 var muxCoolPort = net.Port(9527)
 
 // NewClientWorker creates a new mux.Client.
-func NewClientWorker(p proxy.Outbound, dialer internet.Dialer, s ClientStrategy) (*ClientWorker, error) {
-	ctx := session.ContextWithOutbound(context.Background(), &session.Outbound{
-		Target: net.TCPDestination(muxCoolAddress, muxCoolPort),
-	})
-	ctx, cancel := context.WithCancel(ctx)
-
-	opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)}
-	uplinkReader, upLinkWriter := pipe.New(opts...)
-	downlinkReader, downlinkWriter := pipe.New(opts...)
-
+func NewClientWorker(stream vio.Link, s ClientStrategy) (*ClientWorker, error) {
 	c := &ClientWorker{
 		sessionManager: NewSessionManager(),
-		link: vio.Link{
-			Reader: downlinkReader,
-			Writer: upLinkWriter,
-		},
-		done:     done.New(),
-		strategy: s,
+		link:           stream,
+		done:           done.New(),
+		strategy:       s,
 	}
 
-	go func() {
-		if err := p.Process(ctx, &vio.Link{Reader: uplinkReader, Writer: downlinkWriter}, dialer); err != nil {
-			errors.New("failed to handler mux client connection").Base(err).WriteToLog()
-		}
-		common.Must(c.done.Close())
-		cancel()
-	}()
-
 	go c.fetchOutput()
 	go c.monitor()
+
 	return c, nil
 }
 
@@ -221,12 +242,21 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
 	}
 }
 
-func (m *ClientWorker) IsFull() bool {
+func (m *ClientWorker) IsClosing() bool {
 	sm := m.sessionManager
-	if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) {
+	if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) {
 		return true
 	}
-	if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) {
+	return false
+}
+
+func (m *ClientWorker) IsFull() bool {
+	if m.IsClosing() {
+		return true
+	}
+
+	sm := m.sessionManager
+	if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) {
 		return true
 	}
 	return false

+ 29 - 3
common/mux/client_test.go

@@ -1,17 +1,28 @@
 package mux_test
 
 import (
+	"context"
 	"testing"
+	"time"
 
+	"github.com/golang/mock/gomock"
+	"v2ray.com/core/common"
 	"v2ray.com/core/common/errors"
 	"v2ray.com/core/common/mux"
+	"v2ray.com/core/common/vio"
+	"v2ray.com/core/testing/mocks"
+	"v2ray.com/core/transport/pipe"
 )
 
 func TestIncrementalPickerFailure(t *testing.T) {
+	mockCtl := gomock.NewController(t)
+	defer mockCtl.Finish()
+
+	mockWorkerFactory := mocks.NewMuxClientWorkerFactory(mockCtl)
+	mockWorkerFactory.EXPECT().Create().Return(nil, errors.New("test"))
+
 	picker := mux.IncrementalWorkerPicker{
-		New: func() (*mux.ClientWorker, error) {
-			return nil, errors.New("test")
-		},
+		Factory: mockWorkerFactory,
 	}
 
 	_, err := picker.PickAvailable()
@@ -19,3 +30,18 @@ func TestIncrementalPickerFailure(t *testing.T) {
 		t.Error("expected error, but nil")
 	}
 }
+
+func TestClientWorkerEOF(t *testing.T) {
+	reader, writer := pipe.New(pipe.WithoutSizeLimit())
+	common.Must(writer.Close())
+
+	worker, err := mux.NewClientWorker(vio.Link{Reader: reader, Writer: writer}, mux.ClientStrategy{})
+	common.Must(err)
+
+	time.Sleep(time.Millisecond * 500)
+
+	f := worker.Dispatch(context.Background(), nil)
+	if f {
+		t.Error("expected failed dispatching, but actually not")
+	}
+}

+ 1 - 0
mocks.go

@@ -4,5 +4,6 @@ package core
 //go:generate go install github.com/golang/mock/mockgen
 
 //go:generate mockgen -package mocks -destination testing/mocks/io.go -mock_names Reader=Reader,Writer=Writer io Reader,Writer
+//go:generate mockgen -package mocks -destination testing/mocks/mux.go -mock_names ClientWorkerFactory=MuxClientWorkerFactory v2ray.com/core/common/mux ClientWorkerFactory
 //go:generate mockgen -package mocks -destination testing/mocks/dns.go -mock_names Client=DNSClient v2ray.com/core/features/dns Client
 //go:generate mockgen -package mocks -destination testing/mocks/proxy.go -mock_names Inbound=ProxyInbound,Outbound=ProxyOutbound v2ray.com/core/proxy Inbound,Outbound

+ 47 - 0
testing/mocks/mux.go

@@ -0,0 +1,47 @@
+// Code generated by MockGen. DO NOT EDIT.
+// Source: v2ray.com/core/common/mux (interfaces: ClientWorkerFactory)
+
+// Package mocks is a generated GoMock package.
+package mocks
+
+import (
+	gomock "github.com/golang/mock/gomock"
+	reflect "reflect"
+	mux "v2ray.com/core/common/mux"
+)
+
+// MuxClientWorkerFactory is a mock of ClientWorkerFactory interface
+type MuxClientWorkerFactory struct {
+	ctrl     *gomock.Controller
+	recorder *MuxClientWorkerFactoryMockRecorder
+}
+
+// MuxClientWorkerFactoryMockRecorder is the mock recorder for MuxClientWorkerFactory
+type MuxClientWorkerFactoryMockRecorder struct {
+	mock *MuxClientWorkerFactory
+}
+
+// NewMuxClientWorkerFactory creates a new mock instance
+func NewMuxClientWorkerFactory(ctrl *gomock.Controller) *MuxClientWorkerFactory {
+	mock := &MuxClientWorkerFactory{ctrl: ctrl}
+	mock.recorder = &MuxClientWorkerFactoryMockRecorder{mock}
+	return mock
+}
+
+// EXPECT returns an object that allows the caller to indicate expected use
+func (m *MuxClientWorkerFactory) EXPECT() *MuxClientWorkerFactoryMockRecorder {
+	return m.recorder
+}
+
+// Create mocks base method
+func (m *MuxClientWorkerFactory) Create() (*mux.ClientWorker, error) {
+	ret := m.ctrl.Call(m, "Create")
+	ret0, _ := ret[0].(*mux.ClientWorker)
+	ret1, _ := ret[1].(error)
+	return ret0, ret1
+}
+
+// Create indicates an expected call of Create
+func (mr *MuxClientWorkerFactoryMockRecorder) Create() *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MuxClientWorkerFactory)(nil).Create))
+}