Browse Source

Feat: add special handling for /dev/fd address

to support socket activation
Misaki Kasumi 1 year ago
parent
commit
90edd2e9db
2 changed files with 22 additions and 4 deletions
  1. 1 0
      common/net/system.go
  2. 21 4
      transport/internet/system_listener.go

+ 1 - 0
common/net/system.go

@@ -14,6 +14,7 @@ var (
 	DialUDP         = net.DialUDP
 	DialUDP         = net.DialUDP
 	DialUnix        = net.DialUnix
 	DialUnix        = net.DialUnix
 	FileConn        = net.FileConn
 	FileConn        = net.FileConn
+	FileListener    = net.FileListener
 	Listen          = net.Listen
 	Listen          = net.Listen
 	ListenTCP       = net.ListenTCP
 	ListenTCP       = net.ListenTCP
 	ListenUDP       = net.ListenUDP
 	ListenUDP       = net.ListenUDP

+ 21 - 4
transport/internet/system_listener.go

@@ -59,6 +59,8 @@ func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []co
 func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.Listener, error) {
 func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.Listener, error) {
 	var lc net.ListenConfig
 	var lc net.ListenConfig
 	var network, address string
 	var network, address string
+	var l net.Listener
+	var err error
 	// callback is called after the Listen function returns
 	// callback is called after the Listen function returns
 	// this is used to wrap the listener and do some post processing
 	// this is used to wrap the listener and do some post processing
 	callback := func(l net.Listener, err error) (net.Listener, error) {
 	callback := func(l net.Listener, err error) (net.Listener, error) {
@@ -93,6 +95,16 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
 				copy(fullAddr, address[1:])
 				copy(fullAddr, address[1:])
 				address = string(fullAddr)
 				address = string(fullAddr)
 			}
 			}
+		} else if strings.HasPrefix(address, "/dev/fd/") {
+			fd, err := strconv.Atoi(address[8:])
+			if err != nil {
+				return nil, err
+			}
+			_ = syscall.SetNonblock(fd, true)
+			l, err = net.FileListener(os.NewFile(uintptr(fd), address))
+			if err != nil {
+				return nil, err
+			}
 		} else {
 		} else {
 			// normal unix domain socket
 			// normal unix domain socket
 			var fileMode *os.FileMode
 			var fileMode *os.FileMode
@@ -133,13 +145,18 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
 		}
 		}
 	}
 	}
 
 
-	l, err := lc.Listen(ctx, network, address)
-	l, err = callback(l, err)
-	if err == nil && sockopt != nil && sockopt.AcceptProxyProtocol {
+	if l == nil {
+		l, err = lc.Listen(ctx, network, address)
+		l, err = callback(l, err)
+		if err != nil {
+			return nil, err
+		}
+	}
+	if sockopt != nil && sockopt.AcceptProxyProtocol {
 		policyFunc := func(upstream net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil }
 		policyFunc := func(upstream net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil }
 		l = &proxyproto.Listener{Listener: l, Policy: policyFunc}
 		l = &proxyproto.Listener{Listener: l, Policy: policyFunc}
 	}
 	}
-	return l, err
+	return l, nil
 }
 }
 
 
 func (dl *DefaultListener) ListenPacket(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.PacketConn, error) {
 func (dl *DefaultListener) ListenPacket(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.PacketConn, error) {