Procházet zdrojové kódy

locker protected connection

Darien Raymond před 8 roky
rodič
revize
19e0cb40e9
1 změnil soubory, kde provedl 64 přidání a 17 odebrání
  1. 64 17
      transport/internet/tcp/connection.go

+ 64 - 17
transport/internet/tcp/connection.go

@@ -3,15 +3,12 @@ package tcp
 import (
 	"io"
 	"net"
+	"sync"
 	"time"
 
 	"v2ray.com/core/transport/internet/internal"
 )
 
-type ConnectionManager interface {
-	Put(internal.ConnectionID, net.Conn)
-}
-
 type RawConnection struct {
 	net.TCPConn
 }
@@ -27,14 +24,15 @@ func (v *RawConnection) SysFd() (int, error) {
 }
 
 type Connection struct {
+	sync.RWMutex
 	id       internal.ConnectionID
 	reusable bool
 	conn     net.Conn
-	listener ConnectionManager
+	listener internal.ConnectionRecyler
 	config   *Config
 }
 
-func NewConnection(id internal.ConnectionID, conn net.Conn, manager ConnectionManager, config *Config) *Connection {
+func NewConnection(id internal.ConnectionID, conn net.Conn, manager internal.ConnectionRecyler, config *Config) *Connection {
 	return &Connection{
 		id:       id,
 		conn:     conn,
@@ -45,22 +43,30 @@ func NewConnection(id internal.ConnectionID, conn net.Conn, manager ConnectionMa
 }
 
 func (v *Connection) Read(b []byte) (int, error) {
-	if v == nil || v.conn == nil {
+	conn := v.underlyingConn()
+	if conn == nil {
 		return 0, io.EOF
 	}
 
-	return v.conn.Read(b)
+	return conn.Read(b)
 }
 
 func (v *Connection) Write(b []byte) (int, error) {
-	if v == nil || v.conn == nil {
+	conn := v.underlyingConn()
+	if conn == nil {
 		return 0, io.ErrClosedPipe
 	}
-	return v.conn.Write(b)
+	return conn.Write(b)
 }
 
 func (v *Connection) Close() error {
-	if v == nil || v.conn == nil {
+	if v == nil {
+		return io.ErrClosedPipe
+	}
+
+	v.Lock()
+	defer v.Unlock()
+	if v.conn == nil {
 		return io.ErrClosedPipe
 	}
 	if v.Reusable() {
@@ -73,33 +79,74 @@ func (v *Connection) Close() error {
 }
 
 func (v *Connection) LocalAddr() net.Addr {
-	return v.conn.LocalAddr()
+	conn := v.underlyingConn()
+	if conn == nil {
+		return nil
+	}
+	return conn.LocalAddr()
 }
 
 func (v *Connection) RemoteAddr() net.Addr {
-	return v.conn.RemoteAddr()
+	conn := v.underlyingConn()
+	if conn == nil {
+		return nil
+	}
+	return conn.RemoteAddr()
 }
 
 func (v *Connection) SetDeadline(t time.Time) error {
-	return v.conn.SetDeadline(t)
+	conn := v.underlyingConn()
+	if conn == nil {
+		return nil
+	}
+	return conn.SetDeadline(t)
 }
 
 func (v *Connection) SetReadDeadline(t time.Time) error {
-	return v.conn.SetReadDeadline(t)
+	conn := v.underlyingConn()
+	if conn == nil {
+		return nil
+	}
+	return conn.SetReadDeadline(t)
 }
 
 func (v *Connection) SetWriteDeadline(t time.Time) error {
-	return v.conn.SetWriteDeadline(t)
+	conn := v.underlyingConn()
+	if conn == nil {
+		return nil
+	}
+	return conn.SetWriteDeadline(t)
 }
 
 func (v *Connection) SetReusable(reusable bool) {
+	if v == nil {
+		return
+	}
 	v.reusable = reusable
 }
 
 func (v *Connection) Reusable() bool {
+	if v == nil {
+		return false
+	}
 	return v.config.IsConnectionReuse() && v.reusable
 }
 
 func (v *Connection) SysFd() (int, error) {
-	return internal.GetSysFd(v.conn)
+	conn := v.underlyingConn()
+	if conn == nil {
+		return 0, io.ErrClosedPipe
+	}
+	return internal.GetSysFd(conn)
+}
+
+func (v *Connection) underlyingConn() net.Conn {
+	if v == nil {
+		return nil
+	}
+
+	v.RLock()
+	defer v.RUnlock()
+
+	return v.conn
 }