Browse Source

style: refine code style

AkinoKaede 2 years ago
parent
commit
225d638338
4 changed files with 49 additions and 28 deletions
  1. 6 0
      app/tun/handler.go
  2. 31 16
      app/tun/handler_tcp.go
  3. 11 1
      app/tun/stack.go
  4. 1 11
      app/tun/tun.go

+ 6 - 0
app/tun/handler.go

@@ -1 +1,7 @@
 package tun
+
+import "github.com/v2fly/v2ray-core/v5/common/net"
+
+type Handler interface {
+	Handle(conn net.Conn) error
+}

+ 31 - 16
app/tun/handler_tcp.go

@@ -31,25 +31,40 @@ type TCPHandler struct {
 	stack *stack.Stack
 }
 
-func (h *TCPHandler) SetHandler() {
-	tcpForwarder := tcp.NewForwarder(h.stack, rcvWnd, maxInFlight, func(r *tcp.ForwarderRequest) {
-		wg := new(waiter.Queue)
-		linkedEndpoint, err := r.CreateEndpoint(wg)
-		if err != nil {
-			r.Complete(true)
-			return
-		}
-		defer r.Complete(false)
-
-		// TODO: set sockopt
-
-		h.handle(gonet.NewTCPConn(wg, linkedEndpoint))
+func SetTCPHandler(ctx context.Context, dispatcher routing.Dispatcher, policyManager policy.Manager, config *Config) func(*stack.Stack) error {
+	return func(s *stack.Stack) error {
+		tcpForwarder := tcp.NewForwarder(s, rcvWnd, maxInFlight, func(r *tcp.ForwarderRequest) {
+			wg := new(waiter.Queue)
+			linkedEndpoint, err := r.CreateEndpoint(wg)
+			if err != nil {
+				r.Complete(true)
+				return
+			}
+			defer r.Complete(false)
+
+			// TODO: set sockopt
+
+			tcpHandler := &TCPHandler{
+				ctx:           ctx,
+				dispatcher:    dispatcher,
+				policyManager: policyManager,
+				config:        config,
+				stack:         s,
+			}
+
+			if err := tcpHandler.Handle(gonet.NewTCPConn(wg, linkedEndpoint)); err != nil {
+				// TODO: log
+				// return newError("failed to handle tcp connection").Base(err)
+			}
+
+		})
+		s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
 
-	})
-	h.stack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
+		return nil
+	}
 }
 
-func (h *TCPHandler) handle(conn *gonet.TCPConn) error {
+func (h *TCPHandler) Handle(conn net.Conn) error {
 	sessionPolicy := h.policyManager.ForLevel(h.config.UserLevel)
 
 	addr := conn.RemoteAddr()

+ 11 - 1
app/tun/stack.go

@@ -11,7 +11,7 @@ import (
 
 type StackOption func(*stack.Stack) error
 
-func CreateStack(_ stack.LinkEndpoint) (*stack.Stack, error) {
+func (t *TUN) CreateStack(_ stack.LinkEndpoint) (*stack.Stack, error) {
 	s := stack.New(stack.Options{
 		NetworkProtocols: []stack.NetworkProtocolFactory{
 			ipv4.NewProtocol,
@@ -25,6 +25,16 @@ func CreateStack(_ stack.LinkEndpoint) (*stack.Stack, error) {
 		},
 	})
 
+	opts := []StackOption{
+		SetTCPHandler(t.ctx, t.dispatcher, t.policyManager, t.config),
+	}
+
+	for _, opt := range opts {
+		if err := opt(s); err != nil {
+			return nil, err
+		}
+	}
+
 	// nicID := tcpip.NICID(s.UniqueID())
 
 	return s, nil

+ 1 - 11
app/tun/tun.go

@@ -40,22 +40,12 @@ func (t *TUN) Start() error {
 		return newError("failed to create device").Base(err).AtError()
 	}
 
-	stack, err := CreateStack(device)
+	stack, err := t.CreateStack(device)
 	if err != nil {
 		return newError("failed to create stack").Base(err).AtError()
 	}
 	t.stack = stack
 
-	tcpHandler := &TCPHandler{
-		ctx:           t.ctx,
-		dispatcher:    t.dispatcher,
-		policyManager: t.policyManager,
-		config:        t.config,
-		stack:         stack,
-	}
-
-	tcpHandler.SetHandler()
-
 	return nil
 }