瀏覽代碼

split session manager out of mux client and server

Darien Raymond 8 年之前
父節點
當前提交
ad083989aa
共有 4 個文件被更改,包括 164 次插入126 次删除
  1. 34 126
      app/proxyman/mux/mux.go
  2. 99 0
      app/proxyman/mux/session.go
  3. 28 0
      app/proxyman/mux/session_test.go
  4. 3 0
      app/proxyman/mux/status.go

+ 34 - 126
app/proxyman/mux/mux.go

@@ -22,42 +22,6 @@ const (
 	maxTotal = 128
 )
 
-type manager interface {
-	remove(id uint16)
-}
-
-type session struct {
-	sync.Mutex
-	input          ray.InputStream
-	output         ray.OutputStream
-	parent         manager
-	id             uint16
-	uplinkClosed   bool
-	downlinkClosed bool
-}
-
-func (s *session) closeUplink() {
-	var allDone bool
-	s.Lock()
-	s.uplinkClosed = true
-	allDone = s.uplinkClosed && s.downlinkClosed
-	s.Unlock()
-	if allDone {
-		s.parent.remove(s.id)
-	}
-}
-
-func (s *session) closeDownlink() {
-	var allDone bool
-	s.Lock()
-	s.downlinkClosed = true
-	allDone = s.uplinkClosed && s.downlinkClosed
-	s.Unlock()
-	if allDone {
-		s.parent.remove(s.id)
-	}
-}
-
 type ClientManager struct {
 	access  sync.Mutex
 	clients []*Client
@@ -112,9 +76,7 @@ func (m *ClientManager) onClientFinish() {
 }
 
 type Client struct {
-	access         sync.RWMutex
-	count          uint16
-	sessions       map[uint16]*session
+	sessionManager *SessionManager
 	inboundRay     ray.InboundRay
 	ctx            context.Context
 	cancel         context.CancelFunc
@@ -131,12 +93,11 @@ func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client
 	pipe := ray.NewRay(ctx)
 	go p.Process(ctx, pipe, dialer)
 	c := &Client{
-		sessions:       make(map[uint16]*session, 256),
+		sessionManager: NewSessionManager(),
 		inboundRay:     pipe,
 		ctx:            ctx,
 		cancel:         cancel,
 		manager:        m,
-		count:          0,
 		session2Remove: make(chan uint16, 16),
 		concurrency:    m.config.Concurrency,
 	}
@@ -145,14 +106,6 @@ func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client
 	return c, nil
 }
 
-func (m *Client) remove(id uint16) {
-	select {
-	case m.session2Remove <- id:
-	default:
-		// Probably not gonna happen.
-	}
-}
-
 func (m *Client) Closed() bool {
 	select {
 	case <-m.ctx.Done():
@@ -168,42 +121,28 @@ func (m *Client) monitor() {
 	for {
 		select {
 		case <-m.ctx.Done():
-			m.cleanup()
+			m.sessionManager.Close()
+			m.inboundRay.InboundInput().Close()
+			m.inboundRay.InboundOutput().CloseError()
 			return
-		case id := <-m.session2Remove:
-			m.access.Lock()
-			delete(m.sessions, id)
-			if len(m.sessions) == 0 {
+		case <-time.After(time.Second * 6):
+			size := m.sessionManager.Size()
+			if size == 0 {
 				m.cancel()
 			}
-			m.access.Unlock()
 		}
 	}
 }
 
-func (m *Client) cleanup() {
-	m.access.Lock()
-	defer m.access.Unlock()
-
-	m.inboundRay.InboundInput().Close()
-	m.inboundRay.InboundOutput().CloseError()
-
-	for _, s := range m.sessions {
-		s.closeUplink()
-		s.closeDownlink()
-		s.output.CloseError()
-	}
-}
-
-func fetchInput(ctx context.Context, s *session, output buf.Writer) {
+func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
 	dest, _ := proxy.TargetFromContext(ctx)
 	writer := &Writer{
 		dest:   dest,
-		id:     s.id,
+		id:     s.ID,
 		writer: output,
 	}
 	defer writer.Close()
-	defer s.closeUplink()
+	defer s.CloseUplink()
 
 	log.Trace(newError("dispatching request to ", dest))
 	data, _ := s.input.ReadTimeout(time.Millisecond * 500)
@@ -218,22 +157,9 @@ func fetchInput(ctx context.Context, s *session, output buf.Writer) {
 	}
 }
 
-func waitForDone(ctx context.Context, s *session) {
-	<-ctx.Done()
-	s.closeUplink()
-	s.closeDownlink()
-	s.output.Close()
-}
-
 func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool {
-	m.access.Lock()
-	defer m.access.Unlock()
-
-	if len(m.sessions) >= int(m.concurrency) {
-		return false
-	}
-
-	if m.count >= maxTotal {
+	numSession := m.sessionManager.Size()
+	if numSession >= int(m.concurrency) || numSession >= maxTotal {
 		return false
 	}
 
@@ -243,17 +169,13 @@ func (m *Client) Dispatch(ctx context.Context, outboundRay ray.OutboundRay) bool
 	default:
 	}
 
-	m.count++
-	id := m.count
-	s := &session{
+	s := &Session{
 		input:  outboundRay.OutboundInput(),
 		output: outboundRay.OutboundOutput(),
-		parent: m,
-		id:     id,
+		parent: m.sessionManager,
 	}
-	m.sessions[id] = s
+	m.sessionManager.Allocate(s)
 	go fetchInput(ctx, s, m.inboundRay.InboundInput())
-	go waitForDone(ctx, s)
 	return true
 }
 
@@ -305,11 +227,9 @@ func (m *Client) fetchOutput() {
 			continue
 		}
 
-		m.access.RLock()
-		s, found := m.sessions[meta.SessionID]
-		m.access.RUnlock()
+		s, found := m.sessionManager.Get(meta.SessionID)
 		if found && meta.SessionStatus == SessionStatusEnd {
-			s.closeDownlink()
+			s.CloseDownlink()
 			s.output.Close()
 		}
 		if !meta.Option.Has(OptionData) {
@@ -354,34 +274,27 @@ func (s *Server) Dispatch(ctx context.Context, dest net.Destination) (ray.Inboun
 
 	ray := ray.NewRay(ctx)
 	worker := &ServerWorker{
-		dispatcher:  s.dispatcher,
-		outboundRay: ray,
-		sessions:    make(map[uint16]*session),
+		dispatcher:     s.dispatcher,
+		outboundRay:    ray,
+		sessionManager: NewSessionManager(),
 	}
 	go worker.run(ctx)
 	return ray, nil
 }
 
 type ServerWorker struct {
-	dispatcher  dispatcher.Interface
-	outboundRay ray.OutboundRay
-	sessions    map[uint16]*session
-	access      sync.RWMutex
-}
-
-func (w *ServerWorker) remove(id uint16) {
-	w.access.Lock()
-	delete(w.sessions, id)
-	w.access.Unlock()
+	dispatcher     dispatcher.Interface
+	outboundRay    ray.OutboundRay
+	sessionManager *SessionManager
 }
 
-func handle(ctx context.Context, s *session, output buf.Writer) {
-	writer := NewResponseWriter(s.id, output)
+func handle(ctx context.Context, s *Session, output buf.Writer) {
+	writer := NewResponseWriter(s.ID, output)
 	if err := buf.PipeUntilEOF(signal.BackgroundTimer(), s.input, writer); err != nil {
-		log.Trace(newError("session ", s.id, " ends: ").Base(err))
+		log.Trace(newError("session ", s.ID, " ends: ").Base(err))
 	}
 	writer.Close()
-	s.closeDownlink()
+	s.CloseDownlink()
 }
 
 func (w *ServerWorker) run(ctx context.Context) {
@@ -410,12 +323,9 @@ func (w *ServerWorker) run(ctx context.Context) {
 			continue
 		}
 
-		w.access.RLock()
-		s, found := w.sessions[meta.SessionID]
-		w.access.RUnlock()
-
+		s, found := w.sessionManager.Get(meta.SessionID)
 		if found && meta.SessionStatus == SessionStatusEnd {
-			s.closeUplink()
+			s.CloseUplink()
 			s.output.Close()
 		}
 
@@ -426,15 +336,13 @@ func (w *ServerWorker) run(ctx context.Context) {
 				log.Trace(newError("failed to dispatch request.").Base(err))
 				continue
 			}
-			s = &session{
+			s = &Session{
 				input:  inboundRay.InboundOutput(),
 				output: inboundRay.InboundInput(),
-				parent: w,
-				id:     meta.SessionID,
+				parent: w.sessionManager,
+				ID:     meta.SessionID,
 			}
-			w.access.Lock()
-			w.sessions[meta.SessionID] = s
-			w.access.Unlock()
+			w.sessionManager.Add(s)
 			go handle(ctx, s, w.outboundRay.OutboundOutput())
 		}
 

+ 99 - 0
app/proxyman/mux/session.go

@@ -0,0 +1,99 @@
+package mux
+
+import (
+	"sync"
+
+	"v2ray.com/core/transport/ray"
+)
+
+type SessionManager struct {
+	sync.RWMutex
+	count    uint16
+	sessions map[uint16]*Session
+}
+
+func NewSessionManager() *SessionManager {
+	return &SessionManager{
+		count:    0,
+		sessions: make(map[uint16]*Session, 32),
+	}
+}
+
+func (m *SessionManager) Size() int {
+	m.RLock()
+	defer m.RUnlock()
+
+	return len(m.sessions)
+}
+
+func (m *SessionManager) Allocate(s *Session) {
+	m.Lock()
+	defer m.Unlock()
+
+	m.count++
+	s.ID = m.count
+	m.sessions[s.ID] = s
+}
+
+func (m *SessionManager) Add(s *Session) {
+	m.Lock()
+	defer m.Unlock()
+
+	m.sessions[s.ID] = s
+}
+
+func (m *SessionManager) Remove(id uint16) {
+	m.Lock()
+	defer m.Unlock()
+
+	delete(m.sessions, id)
+}
+
+func (m *SessionManager) Get(id uint16) (*Session, bool) {
+	m.RLock()
+	defer m.RUnlock()
+
+	s, found := m.sessions[id]
+	return s, found
+}
+
+func (m *SessionManager) Close() {
+	m.RLock()
+	defer m.RUnlock()
+
+	for _, s := range m.sessions {
+		s.output.CloseError()
+	}
+}
+
+type Session struct {
+	sync.Mutex
+	input          ray.InputStream
+	output         ray.OutputStream
+	parent         *SessionManager
+	ID             uint16
+	uplinkClosed   bool
+	downlinkClosed bool
+}
+
+func (s *Session) CloseUplink() {
+	var allDone bool
+	s.Lock()
+	s.uplinkClosed = true
+	allDone = s.uplinkClosed && s.downlinkClosed
+	s.Unlock()
+	if allDone {
+		s.parent.Remove(s.ID)
+	}
+}
+
+func (s *Session) CloseDownlink() {
+	var allDone bool
+	s.Lock()
+	s.downlinkClosed = true
+	allDone = s.uplinkClosed && s.downlinkClosed
+	s.Unlock()
+	if allDone {
+		s.parent.Remove(s.ID)
+	}
+}

+ 28 - 0
app/proxyman/mux/session_test.go

@@ -0,0 +1,28 @@
+package mux_test
+
+import (
+	"testing"
+
+	. "v2ray.com/core/app/proxyman/mux"
+	"v2ray.com/core/testing/assert"
+)
+
+func TestSessionManagerAdd(t *testing.T) {
+	assert := assert.On(t)
+
+	m := NewSessionManager()
+
+	s := &Session{}
+	m.Allocate(s)
+	assert.Uint16(s.ID).Equals(1)
+
+	s = &Session{}
+	m.Allocate(s)
+	assert.Uint16(s.ID).Equals(2)
+
+	s = &Session{
+		ID: 4,
+	}
+	m.Add(s)
+	assert.Uint16(s.ID).Equals(4)
+}

+ 3 - 0
app/proxyman/mux/status.go

@@ -0,0 +1,3 @@
+package mux
+
+type statusHandler func(meta *FrameMetadata) error