Browse Source

refactor: move unix domain socket filelocker to internal

Allo 2 years ago
parent
commit
6593b88837

+ 0 - 5
transport/internet/grpc/hub.go

@@ -23,7 +23,6 @@ type Listener struct {
 	handler internet.ConnHandler
 	handler internet.ConnHandler
 	local   net.Addr
 	local   net.Addr
 	config  *Config
 	config  *Config
-	locker  *internet.FileLocker // for unix domain socket
 
 
 	s *grpc.Server
 	s *grpc.Server
 }
 }
@@ -96,10 +95,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, settings *i
 				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
 				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
 				return
 				return
 			}
 			}
-			locker := ctx.Value(address.Domain())
-			if locker != nil {
-				listener.locker = locker.(*internet.FileLocker)
-			}
 		} else { // tcp
 		} else { // tcp
 			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 				IP:   address.IP(),
 				IP:   address.IP(),

+ 0 - 8
transport/internet/http/hub.go

@@ -25,7 +25,6 @@ type Listener struct {
 	handler internet.ConnHandler
 	handler internet.ConnHandler
 	local   net.Addr
 	local   net.Addr
 	config  *Config
 	config  *Config
-	locker  *internet.FileLocker // for unix domain socket
 }
 }
 
 
 func (l *Listener) Addr() net.Addr {
 func (l *Listener) Addr() net.Addr {
@@ -33,9 +32,6 @@ func (l *Listener) Addr() net.Addr {
 }
 }
 
 
 func (l *Listener) Close() error {
 func (l *Listener) Close() error {
-	if l.locker != nil {
-		l.locker.Release()
-	}
 	return l.server.Close()
 	return l.server.Close()
 }
 }
 
 
@@ -171,10 +167,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
 				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
 				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
 				return
 				return
 			}
 			}
-			locker := ctx.Value(address.Domain())
-			if locker != nil {
-				listener.locker = locker.(*internet.FileLocker)
-			}
 		} else { // tcp
 		} else { // tcp
 			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 				IP:   address.IP(),
 				IP:   address.IP(),

+ 27 - 6
transport/internet/system_listener.go

@@ -23,6 +23,19 @@ type DefaultListener struct {
 	controllers []controller
 	controllers []controller
 }
 }
 
 
+type combinedListener struct {
+	net.Listener
+	locker *FileLocker // for unix domain socket
+}
+
+func (l *combinedListener) Close() error {
+	if l.locker != nil {
+		l.locker.Release()
+		l.locker = nil
+	}
+	return l.Listener.Close()
+}
+
 func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []controller) func(network, address string, c syscall.RawConn) error {
 func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []controller) func(network, address string, c syscall.RawConn) error {
 	return func(network, address string, c syscall.RawConn) error {
 	return func(network, address string, c syscall.RawConn) error {
 		return c.Control(func(fd uintptr) {
 		return c.Control(func(fd uintptr) {
@@ -43,10 +56,8 @@ func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []co
 	}
 	}
 }
 }
 
 
-func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.Listener, error) {
+func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (l net.Listener, err error) {
 	var lc net.ListenConfig
 	var lc net.ListenConfig
-	var l net.Listener
-	var err error
 	var network, address string
 	var network, address string
 	switch addr := addr.(type) {
 	switch addr := addr.(type) {
 	case *net.TCPAddr:
 	case *net.TCPAddr:
@@ -83,6 +94,9 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
 						return
 						return
 					}
 					}
 					if cerr := os.Chmod(name, mode); cerr != nil {
 					if cerr := os.Chmod(name, mode); cerr != nil {
+						// failed to set file mode, close the listener
+						l.Close()
+						l = nil
 						err = newError("failed to set file mode for file: ", name).Base(cerr)
 						err = newError("failed to set file mode for file: ", name).Base(cerr)
 					}
 					}
 				}(address, os.FileMode(fMode))
 				}(address, os.FileMode(fMode))
@@ -91,11 +105,18 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
 			locker := &FileLocker{
 			locker := &FileLocker{
 				path: address + ".lock",
 				path: address + ".lock",
 			}
 			}
-			err := locker.Acquire()
-			if err != nil {
+			if err := locker.Acquire(); err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
-			ctx = context.WithValue(ctx, address, locker) // nolint: revive,staticcheck
+			defer func(locker *FileLocker) {
+				// combine listener and locker
+				if err == nil {
+					l = &combinedListener{Listener: l, locker: locker}
+				} else {
+					// failed to create listener, release the locker
+					locker.Release()
+				}
+			}(locker)
 		}
 		}
 	}
 	}
 
 

+ 0 - 8
transport/internet/tcp/hub.go

@@ -21,7 +21,6 @@ type Listener struct {
 	authConfig internet.ConnectionAuthenticator
 	authConfig internet.ConnectionAuthenticator
 	config     *Config
 	config     *Config
 	addConn    internet.ConnHandler
 	addConn    internet.ConnHandler
-	locker     *internet.FileLocker // for unix domain socket
 }
 }
 
 
 // ListenTCP creates a new Listener based on configurations.
 // ListenTCP creates a new Listener based on configurations.
@@ -48,10 +47,6 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, streamSe
 			return nil, newError("failed to listen Unix Domain Socket on ", address).Base(err)
 			return nil, newError("failed to listen Unix Domain Socket on ", address).Base(err)
 		}
 		}
 		newError("listening Unix Domain Socket on ", address).WriteToLog(session.ExportIDToError(ctx))
 		newError("listening Unix Domain Socket on ", address).WriteToLog(session.ExportIDToError(ctx))
-		locker := ctx.Value(address.Domain())
-		if locker != nil {
-			l.locker = locker.(*internet.FileLocker)
-		}
 	} else {
 	} else {
 		listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 		listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 			IP:   address.IP(),
 			IP:   address.IP(),
@@ -122,9 +117,6 @@ func (v *Listener) Addr() net.Addr {
 
 
 // Close implements internet.Listener.Close.
 // Close implements internet.Listener.Close.
 func (v *Listener) Close() error {
 func (v *Listener) Close() error {
-	if v.locker != nil {
-		v.locker.Release()
-	}
 	return v.listener.Close()
 	return v.listener.Close()
 }
 }
 
 

+ 4 - 20
transport/internet/transportcommon/listener.go

@@ -10,22 +10,10 @@ import (
 	"github.com/v2fly/v2ray-core/v5/transport/internet"
 	"github.com/v2fly/v2ray-core/v5/transport/internet"
 )
 )
 
 
-type combinedListener struct {
-	net.Listener
-	locker *internet.FileLocker
-}
-
-func (l *combinedListener) Close() error {
-	if l.locker != nil {
-		l.locker.Release()
-	}
-	return l.Listener.Close()
-}
-
 func ListenWithSecuritySettings(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig) (
 func ListenWithSecuritySettings(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig) (
 	net.Listener, error,
 	net.Listener, error,
 ) {
 ) {
-	var l combinedListener
+	var l net.Listener
 
 
 	transportEnvironment := envctx.EnvironmentFromContext(ctx).(environment.TransportEnvironment)
 	transportEnvironment := envctx.EnvironmentFromContext(ctx).(environment.TransportEnvironment)
 	transportListener := transportEnvironment.Listener()
 	transportListener := transportEnvironment.Listener()
@@ -39,11 +27,7 @@ func ListenWithSecuritySettings(ctx context.Context, address net.Address, port n
 			return nil, newError("failed to listen unix domain socket on ", address).Base(err)
 			return nil, newError("failed to listen unix domain socket on ", address).Base(err)
 		}
 		}
 		newError("listening unix domain socket on ", address).WriteToLog(session.ExportIDToError(ctx))
 		newError("listening unix domain socket on ", address).WriteToLog(session.ExportIDToError(ctx))
-		locker := ctx.Value(address.Domain())
-		if locker != nil {
-			l.locker = locker.(*internet.FileLocker)
-		}
-		l.Listener = listener
+		l = listener
 	} else { // tcp
 	} else { // tcp
 		listener, err := transportListener.Listen(ctx, &net.TCPAddr{
 		listener, err := transportListener.Listen(ctx, &net.TCPAddr{
 			IP:   address.IP(),
 			IP:   address.IP(),
@@ -53,11 +37,11 @@ func ListenWithSecuritySettings(ctx context.Context, address net.Address, port n
 			return nil, newError("failed to listen TCP on ", address, ":", port).Base(err)
 			return nil, newError("failed to listen TCP on ", address, ":", port).Base(err)
 		}
 		}
 		newError("listening TCP on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx))
 		newError("listening TCP on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx))
-		l.Listener = listener
+		l = listener
 	}
 	}
 
 
 	if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol {
 	if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol {
 		newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx))
 		newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx))
 	}
 	}
-	return &l, nil
+	return l, nil
 }
 }

+ 0 - 8
transport/internet/websocket/hub.go

@@ -93,7 +93,6 @@ type Listener struct {
 	listener net.Listener
 	listener net.Listener
 	config   *Config
 	config   *Config
 	addConn  internet.ConnHandler
 	addConn  internet.ConnHandler
-	locker   *internet.FileLocker // for unix domain socket
 }
 }
 
 
 func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
 func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
@@ -119,10 +118,6 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet
 			return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err)
 			return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err)
 		}
 		}
 		newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx))
 		newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx))
-		locker := ctx.Value(address.Domain())
-		if locker != nil {
-			l.locker = locker.(*internet.FileLocker)
-		}
 	} else { // tcp
 	} else { // tcp
 		listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 		listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 			IP:   address.IP(),
 			IP:   address.IP(),
@@ -179,9 +174,6 @@ func (ln *Listener) Addr() net.Addr {
 
 
 // Close implements net.Listener.Close().
 // Close implements net.Listener.Close().
 func (ln *Listener) Close() error {
 func (ln *Listener) Close() error {
-	if ln.locker != nil {
-		ln.locker.Release()
-	}
 	return ln.listener.Close()
 	return ln.listener.Close()
 }
 }