Browse Source

refactor: move unix domain socket filelocker to internal

Allo 2 years ago
parent
commit
6593b88837

+ 0 - 5
transport/internet/grpc/hub.go

@@ -23,7 +23,6 @@ type Listener struct {
 	handler internet.ConnHandler
 	local   net.Addr
 	config  *Config
-	locker  *internet.FileLocker // for unix domain socket
 
 	s *grpc.Server
 }
@@ -96,10 +95,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, settings *i
 				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
 				return
 			}
-			locker := ctx.Value(address.Domain())
-			if locker != nil {
-				listener.locker = locker.(*internet.FileLocker)
-			}
 		} else { // tcp
 			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 				IP:   address.IP(),

+ 0 - 8
transport/internet/http/hub.go

@@ -25,7 +25,6 @@ type Listener struct {
 	handler internet.ConnHandler
 	local   net.Addr
 	config  *Config
-	locker  *internet.FileLocker // for unix domain socket
 }
 
 func (l *Listener) Addr() net.Addr {
@@ -33,9 +32,6 @@ func (l *Listener) Addr() net.Addr {
 }
 
 func (l *Listener) Close() error {
-	if l.locker != nil {
-		l.locker.Release()
-	}
 	return l.server.Close()
 }
 
@@ -171,10 +167,6 @@ func Listen(ctx context.Context, address net.Address, port net.Port, streamSetti
 				newError("failed to listen on ", address).Base(err).AtError().WriteToLog(session.ExportIDToError(ctx))
 				return
 			}
-			locker := ctx.Value(address.Domain())
-			if locker != nil {
-				listener.locker = locker.(*internet.FileLocker)
-			}
 		} else { // tcp
 			streamListener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 				IP:   address.IP(),

+ 27 - 6
transport/internet/system_listener.go

@@ -23,6 +23,19 @@ type DefaultListener struct {
 	controllers []controller
 }
 
+type combinedListener struct {
+	net.Listener
+	locker *FileLocker // for unix domain socket
+}
+
+func (l *combinedListener) Close() error {
+	if l.locker != nil {
+		l.locker.Release()
+		l.locker = nil
+	}
+	return l.Listener.Close()
+}
+
 func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []controller) func(network, address string, c syscall.RawConn) error {
 	return func(network, address string, c syscall.RawConn) error {
 		return c.Control(func(fd uintptr) {
@@ -43,10 +56,8 @@ func getControlFunc(ctx context.Context, sockopt *SocketConfig, controllers []co
 	}
 }
 
-func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (net.Listener, error) {
+func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *SocketConfig) (l net.Listener, err error) {
 	var lc net.ListenConfig
-	var l net.Listener
-	var err error
 	var network, address string
 	switch addr := addr.(type) {
 	case *net.TCPAddr:
@@ -83,6 +94,9 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
 						return
 					}
 					if cerr := os.Chmod(name, mode); cerr != nil {
+						// failed to set file mode, close the listener
+						l.Close()
+						l = nil
 						err = newError("failed to set file mode for file: ", name).Base(cerr)
 					}
 				}(address, os.FileMode(fMode))
@@ -91,11 +105,18 @@ func (dl *DefaultListener) Listen(ctx context.Context, addr net.Addr, sockopt *S
 			locker := &FileLocker{
 				path: address + ".lock",
 			}
-			err := locker.Acquire()
-			if err != nil {
+			if err := locker.Acquire(); err != nil {
 				return nil, err
 			}
-			ctx = context.WithValue(ctx, address, locker) // nolint: revive,staticcheck
+			defer func(locker *FileLocker) {
+				// combine listener and locker
+				if err == nil {
+					l = &combinedListener{Listener: l, locker: locker}
+				} else {
+					// failed to create listener, release the locker
+					locker.Release()
+				}
+			}(locker)
 		}
 	}
 

+ 0 - 8
transport/internet/tcp/hub.go

@@ -21,7 +21,6 @@ type Listener struct {
 	authConfig internet.ConnectionAuthenticator
 	config     *Config
 	addConn    internet.ConnHandler
-	locker     *internet.FileLocker // for unix domain socket
 }
 
 // ListenTCP creates a new Listener based on configurations.
@@ -48,10 +47,6 @@ func ListenTCP(ctx context.Context, address net.Address, port net.Port, streamSe
 			return nil, newError("failed to listen Unix Domain Socket on ", address).Base(err)
 		}
 		newError("listening Unix Domain Socket on ", address).WriteToLog(session.ExportIDToError(ctx))
-		locker := ctx.Value(address.Domain())
-		if locker != nil {
-			l.locker = locker.(*internet.FileLocker)
-		}
 	} else {
 		listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 			IP:   address.IP(),
@@ -122,9 +117,6 @@ func (v *Listener) Addr() net.Addr {
 
 // Close implements internet.Listener.Close.
 func (v *Listener) Close() error {
-	if v.locker != nil {
-		v.locker.Release()
-	}
 	return v.listener.Close()
 }
 

+ 4 - 20
transport/internet/transportcommon/listener.go

@@ -10,22 +10,10 @@ import (
 	"github.com/v2fly/v2ray-core/v5/transport/internet"
 )
 
-type combinedListener struct {
-	net.Listener
-	locker *internet.FileLocker
-}
-
-func (l *combinedListener) Close() error {
-	if l.locker != nil {
-		l.locker.Release()
-	}
-	return l.Listener.Close()
-}
-
 func ListenWithSecuritySettings(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig) (
 	net.Listener, error,
 ) {
-	var l combinedListener
+	var l net.Listener
 
 	transportEnvironment := envctx.EnvironmentFromContext(ctx).(environment.TransportEnvironment)
 	transportListener := transportEnvironment.Listener()
@@ -39,11 +27,7 @@ func ListenWithSecuritySettings(ctx context.Context, address net.Address, port n
 			return nil, newError("failed to listen unix domain socket on ", address).Base(err)
 		}
 		newError("listening unix domain socket on ", address).WriteToLog(session.ExportIDToError(ctx))
-		locker := ctx.Value(address.Domain())
-		if locker != nil {
-			l.locker = locker.(*internet.FileLocker)
-		}
-		l.Listener = listener
+		l = listener
 	} else { // tcp
 		listener, err := transportListener.Listen(ctx, &net.TCPAddr{
 			IP:   address.IP(),
@@ -53,11 +37,11 @@ func ListenWithSecuritySettings(ctx context.Context, address net.Address, port n
 			return nil, newError("failed to listen TCP on ", address, ":", port).Base(err)
 		}
 		newError("listening TCP on ", address, ":", port).WriteToLog(session.ExportIDToError(ctx))
-		l.Listener = listener
+		l = listener
 	}
 
 	if streamSettings.SocketSettings != nil && streamSettings.SocketSettings.AcceptProxyProtocol {
 		newError("accepting PROXY protocol").AtWarning().WriteToLog(session.ExportIDToError(ctx))
 	}
-	return &l, nil
+	return l, nil
 }

+ 0 - 8
transport/internet/websocket/hub.go

@@ -93,7 +93,6 @@ type Listener struct {
 	listener net.Listener
 	config   *Config
 	addConn  internet.ConnHandler
-	locker   *internet.FileLocker // for unix domain socket
 }
 
 func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
@@ -119,10 +118,6 @@ func ListenWS(ctx context.Context, address net.Address, port net.Port, streamSet
 			return nil, newError("failed to listen unix domain socket(for WS) on ", address).Base(err)
 		}
 		newError("listening unix domain socket(for WS) on ", address).WriteToLog(session.ExportIDToError(ctx))
-		locker := ctx.Value(address.Domain())
-		if locker != nil {
-			l.locker = locker.(*internet.FileLocker)
-		}
 	} else { // tcp
 		listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
 			IP:   address.IP(),
@@ -179,9 +174,6 @@ func (ln *Listener) Addr() net.Addr {
 
 // Close implements net.Listener.Close().
 func (ln *Listener) Close() error {
-	if ln.locker != nil {
-		ln.locker.Release()
-	}
 	return ln.listener.Close()
 }