Browse Source

Fix: context of reverse (#1502)

秋のかえで 3 years ago
parent
commit
36cfbed180
3 changed files with 19 additions and 17 deletions
  1. 8 7
      app/reverse/bridge.go
  2. 7 6
      app/reverse/portal.go
  3. 4 4
      app/reverse/reverse.go

+ 8 - 7
app/reverse/bridge.go

@@ -17,6 +17,7 @@ import (
 
 
 // Bridge is a component in reverse proxy, that relays connections from Portal to local address.
 // Bridge is a component in reverse proxy, that relays connections from Portal to local address.
 type Bridge struct {
 type Bridge struct {
+	ctx         context.Context
 	dispatcher  routing.Dispatcher
 	dispatcher  routing.Dispatcher
 	tag         string
 	tag         string
 	domain      string
 	domain      string
@@ -25,7 +26,7 @@ type Bridge struct {
 }
 }
 
 
 // NewBridge creates a new Bridge instance.
 // NewBridge creates a new Bridge instance.
-func NewBridge(config *BridgeConfig, dispatcher routing.Dispatcher) (*Bridge, error) {
+func NewBridge(ctx context.Context, config *BridgeConfig, dispatcher routing.Dispatcher) (*Bridge, error) {
 	if config.Tag == "" {
 	if config.Tag == "" {
 		return nil, newError("bridge tag is empty")
 		return nil, newError("bridge tag is empty")
 	}
 	}
@@ -34,6 +35,7 @@ func NewBridge(config *BridgeConfig, dispatcher routing.Dispatcher) (*Bridge, er
 	}
 	}
 
 
 	b := &Bridge{
 	b := &Bridge{
+		ctx:        ctx,
 		dispatcher: dispatcher,
 		dispatcher: dispatcher,
 		tag:        config.Tag,
 		tag:        config.Tag,
 		domain:     config.Domain,
 		domain:     config.Domain,
@@ -73,7 +75,7 @@ func (b *Bridge) monitor() error {
 	}
 	}
 
 
 	if numWorker == 0 || numConnections/numWorker > 16 {
 	if numWorker == 0 || numConnections/numWorker > 16 {
-		worker, err := NewBridgeWorker(b.domain, b.tag, b.dispatcher)
+		worker, err := NewBridgeWorker(b.ctx, b.domain, b.tag, b.dispatcher)
 		if err != nil {
 		if err != nil {
 			newError("failed to create bridge worker").Base(err).AtWarning().WriteToLog()
 			newError("failed to create bridge worker").Base(err).AtWarning().WriteToLog()
 			return nil
 			return nil
@@ -99,12 +101,11 @@ type BridgeWorker struct {
 	state      Control_State
 	state      Control_State
 }
 }
 
 
-func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWorker, error) {
-	ctx := context.Background()
-	ctx = session.ContextWithInbound(ctx, &session.Inbound{
+func NewBridgeWorker(ctx context.Context, domain string, tag string, d routing.Dispatcher) (*BridgeWorker, error) {
+	bridgeCtx := session.ContextWithInbound(ctx, &session.Inbound{
 		Tag: tag,
 		Tag: tag,
 	})
 	})
-	link, err := d.Dispatch(ctx, net.Destination{
+	link, err := d.Dispatch(bridgeCtx, net.Destination{
 		Network: net.Network_TCP,
 		Network: net.Network_TCP,
 		Address: net.DomainAddress(domain),
 		Address: net.DomainAddress(domain),
 		Port:    0,
 		Port:    0,
@@ -118,7 +119,7 @@ func NewBridgeWorker(domain string, tag string, d routing.Dispatcher) (*BridgeWo
 		tag:        tag,
 		tag:        tag,
 	}
 	}
 
 
-	worker, err := mux.NewServerWorker(context.Background(), w, link)
+	worker, err := mux.NewServerWorker(ctx, w, link)
 	if err != nil {
 	if err != nil {
 		return nil, err
 		return nil, err
 	}
 	}

+ 7 - 6
app/reverse/portal.go

@@ -19,6 +19,7 @@ import (
 )
 )
 
 
 type Portal struct {
 type Portal struct {
+	ctx    context.Context
 	ohm    outbound.Manager
 	ohm    outbound.Manager
 	tag    string
 	tag    string
 	domain string
 	domain string
@@ -26,7 +27,7 @@ type Portal struct {
 	client *mux.ClientManager
 	client *mux.ClientManager
 }
 }
 
 
-func NewPortal(config *PortalConfig, ohm outbound.Manager) (*Portal, error) {
+func NewPortal(ctx context.Context, config *PortalConfig, ohm outbound.Manager) (*Portal, error) {
 	if config.Tag == "" {
 	if config.Tag == "" {
 		return nil, newError("portal tag is empty")
 		return nil, newError("portal tag is empty")
 	}
 	}
@@ -41,6 +42,7 @@ func NewPortal(config *PortalConfig, ohm outbound.Manager) (*Portal, error) {
 	}
 	}
 
 
 	return &Portal{
 	return &Portal{
+		ctx:    ctx,
 		ohm:    ohm,
 		ohm:    ohm,
 		tag:    config.Tag,
 		tag:    config.Tag,
 		domain: config.Domain,
 		domain: config.Domain,
@@ -52,14 +54,14 @@ func NewPortal(config *PortalConfig, ohm outbound.Manager) (*Portal, error) {
 }
 }
 
 
 func (p *Portal) Start() error {
 func (p *Portal) Start() error {
-	return p.ohm.AddHandler(context.Background(), &Outbound{
+	return p.ohm.AddHandler(p.ctx, &Outbound{
 		portal: p,
 		portal: p,
 		tag:    p.tag,
 		tag:    p.tag,
 	})
 	})
 }
 }
 
 
 func (p *Portal) Close() error {
 func (p *Portal) Close() error {
-	return p.ohm.RemoveHandler(context.Background(), p.tag)
+	return p.ohm.RemoveHandler(p.ctx, p.tag)
 }
 }
 
 
 func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) error {
 func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) error {
@@ -74,7 +76,7 @@ func (p *Portal) HandleConnection(ctx context.Context, link *transport.Link) err
 			return newError("failed to create mux client worker").Base(err).AtWarning()
 			return newError("failed to create mux client worker").Base(err).AtWarning()
 		}
 		}
 
 
-		worker, err := NewPortalWorker(muxClient)
+		worker, err := NewPortalWorker(ctx, muxClient)
 		if err != nil {
 		if err != nil {
 			return newError("failed to create portal worker").Base(err)
 			return newError("failed to create portal worker").Base(err)
 		}
 		}
@@ -198,12 +200,11 @@ type PortalWorker struct {
 	draining bool
 	draining bool
 }
 }
 
 
-func NewPortalWorker(client *mux.ClientWorker) (*PortalWorker, error) {
+func NewPortalWorker(ctx context.Context, client *mux.ClientWorker) (*PortalWorker, error) {
 	opt := []pipe.Option{pipe.WithSizeLimit(16 * 1024)}
 	opt := []pipe.Option{pipe.WithSizeLimit(16 * 1024)}
 	uplinkReader, uplinkWriter := pipe.New(opt...)
 	uplinkReader, uplinkWriter := pipe.New(opt...)
 	downlinkReader, downlinkWriter := pipe.New(opt...)
 	downlinkReader, downlinkWriter := pipe.New(opt...)
 
 
-	ctx := context.Background()
 	ctx = session.ContextWithOutbound(ctx, &session.Outbound{
 	ctx = session.ContextWithOutbound(ctx, &session.Outbound{
 		Target: net.UDPDestination(net.DomainAddress(internalDomain), 0),
 		Target: net.UDPDestination(net.DomainAddress(internalDomain), 0),
 	})
 	})

+ 4 - 4
app/reverse/reverse.go

@@ -29,7 +29,7 @@ func init() {
 	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
 	common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) {
 		r := new(Reverse)
 		r := new(Reverse)
 		if err := core.RequireFeatures(ctx, func(d routing.Dispatcher, om outbound.Manager) error {
 		if err := core.RequireFeatures(ctx, func(d routing.Dispatcher, om outbound.Manager) error {
-			return r.Init(config.(*Config), d, om)
+			return r.Init(ctx, config.(*Config), d, om)
 		}); err != nil {
 		}); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}
@@ -42,9 +42,9 @@ type Reverse struct {
 	portals []*Portal
 	portals []*Portal
 }
 }
 
 
-func (r *Reverse) Init(config *Config, d routing.Dispatcher, ohm outbound.Manager) error {
+func (r *Reverse) Init(ctx context.Context, config *Config, d routing.Dispatcher, ohm outbound.Manager) error {
 	for _, bConfig := range config.BridgeConfig {
 	for _, bConfig := range config.BridgeConfig {
-		b, err := NewBridge(bConfig, d)
+		b, err := NewBridge(ctx, bConfig, d)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -52,7 +52,7 @@ func (r *Reverse) Init(config *Config, d routing.Dispatcher, ohm outbound.Manage
 	}
 	}
 
 
 	for _, pConfig := range config.PortalConfig {
 	for _, pConfig := range config.PortalConfig {
-		p, err := NewPortal(pConfig, ohm)
+		p, err := NewPortal(ctx, pConfig, ohm)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}