Browse Source

refactor: using callback instead of defer function

allo 2 years ago
parent
commit
edff5c8e2d
1 changed files with 27 additions and 22 deletions
  1. 27 22
      transport/internet/system_listener.go

+ 27 - 22
transport/internet/system_listener.go

@@ -56,9 +56,14 @@ func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []co
 	}
 }
 
-func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (l net.Listener, err error) {
+func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.Listener, error) {
 	var lc net.ListenConfig
 	var network, address string
+	// callback is called after the Listen function returns
+	// this is used to wrap the listener and do some post processing
+	callback := func(l net.Listener, err error) (net.Listener, error) {
+		return l, err
+	}
 	switch addr := addr.(type) {
 	case *net.TCPAddr:
 		network = addr.Network()
@@ -81,6 +86,7 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
 			}
 		} else {
 			// normal unix domain socket
+			var fileMode *os.FileMode
 			// parse file mode from address
 			if s := strings.Split(address, ","); len(s) == 2 {
 				fMode, err := strconv.ParseUint(s[1], 8, 32)
@@ -88,18 +94,8 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
 					return nil, newError("failed to parse file mode").Base(err)
 				}
 				address = s[0]
-				// set file mode for unix domain socket when it is created
-				defer func(name string, mode os.FileMode) {
-					if err != nil {
-						return
-					}
-					if cerr := os.Chmod(name, mode); cerr != nil {
-						// failed to set file mode, close the listener
-						l.Close()
-						l = nil
-						err = newError("failed to set file mode for file: ", name).Base(cerr)
-					}
-				}(address, os.FileMode(fMode))
+				fm := os.FileMode(fMode)
+				fileMode = &fm
 			}
 			// normal unix domain socket needs lock
 			locker := &FileLocker{
@@ -108,20 +104,29 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
 			if err := locker.Acquire(); err != nil {
 				return nil, err
 			}
-			defer func(locker *FileLocker) {
-				// combine listener and locker
-				if err == nil {
-					l = &combinedListener{Listener: l, locker: locker}
-				} else {
-					// failed to create listener, release the locker
+			// set file mode for unix domain socket when it is created
+			callback = func(l net.Listener, err error) (net.Listener, error) {
+				if err != nil {
 					locker.Release()
+					return nil, err
+				}
+				l = &combinedListener{Listener: l, locker: locker}
+				if fileMode == nil {
+					return l, err
 				}
-			}(locker)
+				if cerr := os.Chmod(address, *fileMode); cerr != nil {
+					// failed to set file mode, close the listener
+					l.Close()
+					return nil, newError("failed to set file mode for file: ", address).Base(cerr)
+				}
+				return l, err
+			}
 		}
 	}
 
-	l, err = lc.Listen(ctx, network, address)
-	if sockopt != nil && sockopt.AcceptProxyProtocol {
+	l, err := lc.Listen(ctx, network, address)
+	l, err = callback(l, err)
+	if err == nil && sockopt != nil && sockopt.AcceptProxyProtocol {
 		policyFunc := func(upstream net.Addr) (proxyproto.Policy, error) { return proxyproto.REQUIRE, nil }
 		l = &proxyproto.Listener{Listener: l, Policy: policyFunc}
 	}