Ver código fonte

refactor mux

Darien Raymond 8 anos atrás
pai
commit
4d682c01e0
1 arquivos alterados com 49 adições e 33 exclusões
  1. 49 33
      app/proxyman/mux/mux.go

+ 49 - 33
app/proxyman/mux/mux.go

@@ -42,7 +42,7 @@ func (s *session) closeUplink() {
 	allDone = s.uplinkClosed && s.downlinkClosed
 	s.Unlock()
 	if allDone {
-		go s.parent.remove(s.id)
+		s.parent.remove(s.id)
 	}
 }
 
@@ -53,7 +53,7 @@ func (s *session) closeDownlink() {
 	allDone = s.uplinkClosed && s.downlinkClosed
 	s.Unlock()
 	if allDone {
-		go s.parent.remove(s.id)
+		s.parent.remove(s.id)
 	}
 }
 
@@ -109,13 +109,14 @@ func (m *ClientManager) onClientFinish() {
 }
 
 type Client struct {
-	access     sync.RWMutex
-	count      uint16
-	sessions   map[uint16]*session
-	inboundRay ray.InboundRay
-	ctx        context.Context
-	cancel     context.CancelFunc
-	manager    *ClientManager
+	access         sync.RWMutex
+	count          uint16
+	sessions       map[uint16]*session
+	inboundRay     ray.InboundRay
+	ctx            context.Context
+	cancel         context.CancelFunc
+	manager        *ClientManager
+	session2Remove chan uint16
 }
 
 var muxCoolDestination = net.TCPDestination(net.DomainAddress("v1.mux.cool"), net.Port(9527))
@@ -126,27 +127,24 @@ func NewClient(p proxy.Outbound, dialer proxy.Dialer, m *ClientManager) (*Client
 	pipe := ray.NewRay(ctx)
 	go p.Process(ctx, pipe, dialer)
 	c := &Client{
-		sessions:   make(map[uint16]*session, 256),
-		inboundRay: pipe,
-		ctx:        ctx,
-		cancel:     cancel,
-		manager:    m,
-		count:      0,
+		sessions:       make(map[uint16]*session, 256),
+		inboundRay:     pipe,
+		ctx:            ctx,
+		cancel:         cancel,
+		manager:        m,
+		count:          0,
+		session2Remove: make(chan uint16, 16),
 	}
 	go c.fetchOutput()
+	go c.monitor()
 	return c, nil
 }
 
 func (m *Client) remove(id uint16) {
-	m.access.Lock()
-	defer m.access.Unlock()
-
-	delete(m.sessions, id)
-
-	if len(m.sessions) == 0 {
-		m.cancel()
-		m.inboundRay.InboundInput().Close()
-		go m.manager.onClientFinish()
+	select {
+	case m.session2Remove <- id:
+	default:
+		// Probably not gonna happen.
 	}
 }
 
@@ -159,6 +157,31 @@ func (m *Client) Closed() bool {
 	}
 }
 
+func (m *Client) monitor() {
+	for {
+		select {
+		case <-m.ctx.Done():
+			m.cleanup()
+			return
+		case id := <-m.session2Remove:
+			m.access.Lock()
+			delete(m.sessions, id)
+			m.access.Unlock()
+		}
+	}
+}
+
+func (m *Client) cleanup() {
+	m.access.Lock()
+	defer m.access.Unlock()
+
+	for _, s := range m.sessions {
+		s.closeUplink()
+		s.closeDownlink()
+		s.output.CloseError()
+	}
+}
+
 func fetchInput(ctx context.Context, s *session, output buf.Writer) {
 	dest, _ := proxy.TargetFromContext(ctx)
 	writer := &Writer{
@@ -242,6 +265,8 @@ func pipe(reader *Reader, writer buf.Writer) error {
 }
 
 func (m *Client) fetchOutput() {
+	defer m.cancel()
+
 	reader := NewReader(m.inboundRay.InboundOutput())
 	for {
 		meta, err := reader.ReadMetadata()
@@ -271,15 +296,6 @@ func (m *Client) fetchOutput() {
 			break
 		}
 	}
-
-	// Close all downlinks
-	m.access.RLock()
-	for _, s := range m.sessions {
-		s.closeUplink()
-		s.closeDownlink()
-		s.output.CloseError()
-	}
-	m.access.RUnlock()
 }
 
 type Server struct {