Browse Source

fix: handler logic

AkinoKaede 2 years ago
parent
commit
54b605ba4c
5 changed files with 118 additions and 28 deletions
  1. 27 3
      app/tun/handler.go
  2. 41 13
      app/tun/handler_tcp.go
  3. 30 10
      app/tun/handler_udp.go
  4. 2 2
      app/tun/stack.go
  5. 18 0
      app/tun/tun.go

+ 27 - 3
app/tun/handler.go

@@ -1,7 +1,31 @@
 package tun
 
-import "github.com/v2fly/v2ray-core/v5/common/net"
+import (
+	"github.com/v2fly/v2ray-core/v5/common/net"
+	"gvisor.dev/gvisor/pkg/tcpip/stack"
+)
 
-type Handler interface {
-	Handle(conn net.Conn) error
+var (
+	tcpQueue = make(chan TCPConn)
+	udpQueue = make(chan UDPConn)
+)
+
+type TCPConn interface {
+	net.Conn
+
+	ID() *stack.TransportEndpointID
+}
+
+type UDPConn interface {
+	net.Conn
+
+	ID() *stack.TransportEndpointID
+}
+
+func handleTCP(conn TCPConn) {
+	tcpQueue <- conn
+}
+
+func handleUDP(conn UDPConn) {
+	udpQueue <- conn
 }

+ 41 - 13
app/tun/handler_tcp.go

@@ -23,6 +23,15 @@ const (
 	maxInFlight = 2 << 10
 )
 
+type tcpConn struct {
+	*gonet.TCPConn
+	id stack.TransportEndpointID
+}
+
+func (c *tcpConn) ID() *stack.TransportEndpointID {
+	return &c.id
+}
+
 type TCPHandler struct {
 	ctx           context.Context
 	dispatcher    routing.Dispatcher
@@ -32,7 +41,7 @@ type TCPHandler struct {
 	stack *stack.Stack
 }
 
-func SetTCPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyManager policy.Manager, config *Config) StackOption {
+func HandleTCP(handle func(TCPConn)) StackOption {
 	return func(s *stack.Stack) error {
 		tcpForwarder := tcp.NewForwarder(s, rcvWnd, maxInFlight, func(r *tcp.ForwarderRequest) {
 			wg := new(waiter.Queue)
@@ -45,19 +54,25 @@ func SetTCPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyMan
 
 			// TODO: set sockopt
 
-			tcpHandler := &TCPHandler{
-				ctx:           ctx,
-				dispatcher:    dispatcher,
-				policyManager: policyManager,
-				config:        config,
-				stack:         s,
-			}
-
-			if err := tcpHandler.Handle(gonet.NewTCPConn(wg, linkedEndpoint)); err != nil {
-				// TODO: log
-				// return newError("failed to handle tcp connection").Base(err)
+			// tcpHandler := &TCPHandler{
+			// 	ctx:           ctx,
+			// 	dispatcher:    dispatcher,
+			// 	policyManager: policyManager,
+			// 	config:        config,
+			// 	stack:         s,
+			// }
+
+			// if err := tcpHandler.Handle(gonet.NewTCPConn(wg, linkedEndpoint)); err != nil {
+			// 	// TODO: log
+			// 	// return newError("failed to handle tcp connection").Base(err)
+			// }
+
+			tcpConn := &tcpConn{
+				TCPConn: gonet.NewTCPConn(wg, linkedEndpoint),
+				id:      r.ID(),
 			}
 
+			tcpQueue <- tcpConn
 		})
 		s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
 
@@ -65,7 +80,20 @@ func SetTCPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyMan
 	}
 }
 
-func (h *TCPHandler) Handle(conn net.Conn) error {
+func (h *TCPHandler) HandleQueue(ch chan TCPConn) {
+	for {
+		select {
+		case conn := <-ch:
+			if err := h.Handle(conn); err != nil {
+				newError(err).AtError().WriteToLog(session.ExportIDToError(h.ctx))
+			}
+		case <-h.ctx.Done():
+			return
+		}
+	}
+}
+
+func (h *TCPHandler) Handle(conn TCPConn) error {
 	ctx := session.ContextWithInbound(h.ctx, &session.Inbound{Tag: h.config.Tag})
 	sessionPolicy := h.policyManager.ForLevel(h.config.UserLevel)
 

+ 30 - 10
app/tun/handler_udp.go

@@ -26,7 +26,16 @@ type UDPHandler struct {
 	stack *stack.Stack
 }
 
-func SetUDPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyManager policy.Manager, config *Config) StackOption {
+type udpConn struct {
+	*gonet.UDPConn
+	id stack.TransportEndpointID
+}
+
+func (c *udpConn) ID() *stack.TransportEndpointID {
+	return &c.id
+}
+
+func HandleUDP(handle func(UDPConn)) StackOption {
 	return func(s *stack.Stack) error {
 		udpForwarder := gvisor_udp.NewForwarder(s, func(r *gvisor_udp.ForwarderRequest) {
 			wg := new(waiter.Queue)
@@ -36,21 +45,32 @@ func SetUDPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyMan
 				return
 			}
 
-			udpConn := gonet.NewUDPConn(s, wg, linkedEndpoint)
-			udpHandler := &UDPHandler{
-				ctx:           ctx,
-				dispatcher:    dispatcher,
-				policyManager: policyManager,
-				config:        config,
-				stack:         s,
+			udpConn := &udpConn{
+				UDPConn: gonet.NewUDPConn(s, wg, linkedEndpoint),
+				id:      r.ID(),
 			}
-			udpHandler.Handle(udpConn)
+
+			handle(udpConn)
 		})
 		s.SetTransportProtocolHandler(gvisor_udp.ProtocolNumber, udpForwarder.HandlePacket)
 		return nil
 	}
 }
-func (h *UDPHandler) Handle(conn net.Conn) error {
+
+func (h *UDPHandler) HandleQueue(ch chan UDPConn) {
+	for {
+		select {
+		case <-h.ctx.Done():
+			return
+		case conn := <-ch:
+			if err := h.Handle(conn); err != nil {
+				newError(err).AtError().WriteToLog(session.ExportIDToError(h.ctx))
+			}
+		}
+	}
+}
+
+func (h *UDPHandler) Handle(conn UDPConn) error {
 	ctx := session.ContextWithInbound(h.ctx, &session.Inbound{Tag: h.config.Tag})
 	packetConn := conn.(net.PacketConn)
 

+ 2 - 2
app/tun/stack.go

@@ -29,8 +29,8 @@ func (t *TUN) CreateStack(linkedEndpoint stack.LinkEndpoint) (*stack.Stack, erro
 	nicID := tcpip.NICID(s.UniqueID())
 
 	opts := []StackOption{
-		SetTCPHandler(t.ctx, t.dispatcher, t.policyManager, t.config),
-		SetUDPHandler(t.ctx, t.dispatcher, t.policyManager, t.config),
+		HandleTCP(handleTCP),
+		HandleUDP(handleUDP),
 
 		CreateNIC(nicID, linkedEndpoint),
 		AddProtocolAddress(nicID, t.config.Ips),

+ 18 - 0
app/tun/tun.go

@@ -46,6 +46,24 @@ func (t *TUN) Start() error {
 	}
 	t.stack = stack
 
+	tcpHandler := &TCPHandler{
+		ctx:           t.ctx,
+		dispatcher:    t.dispatcher,
+		policyManager: t.policyManager,
+		config:        t.config,
+		stack:         stack,
+	}
+	go tcpHandler.Handle(<-tcpQueue)
+
+	udpHander := &UDPHandler{
+		ctx:           t.ctx,
+		dispatcher:    t.dispatcher,
+		policyManager: t.policyManager,
+		config:        t.config,
+		stack:         stack,
+	}
+	go udpHander.Handle(<-udpQueue)
+
 	return nil
 }