Darien Raymond 9 роки тому
батько
коміт
22379e5a6b

+ 121 - 0
transport/internet/internal/pool.go

@@ -0,0 +1,121 @@
+package internal
+
+import (
+	"net"
+	"sync"
+	"time"
+	v2net "v2ray.com/core/common/net"
+	"v2ray.com/core/common/signal"
+)
+
+type ConnectionId struct {
+	Local      v2net.Address
+	Remote     v2net.Address
+	RemotePort v2net.Port
+}
+
+func NewConnectionId(source v2net.Address, dest v2net.Destination) ConnectionId {
+	return ConnectionId{
+		Local:      source,
+		Remote:     dest.Address,
+		RemotePort: dest.Port,
+	}
+}
+
+type ExpiringConnection struct {
+	conn   net.Conn
+	expire time.Time
+}
+
+func (o *ExpiringConnection) Expired() bool {
+	return o.expire.Before(time.Now())
+}
+
+type Pool struct {
+	sync.Mutex
+	connsByDest map[ConnectionId][]*ExpiringConnection
+	cleanupOnce signal.Once
+}
+
+func NewConnectionPool() *Pool {
+	return &Pool{
+		connsByDest: make(map[ConnectionId][]*ExpiringConnection),
+	}
+}
+
+func (o *Pool) Get(id ConnectionId) net.Conn {
+	o.Lock()
+	defer o.Unlock()
+
+	list, found := o.connsByDest[id]
+	if !found {
+		return nil
+	}
+	connIdx := -1
+	for idx, conn := range list {
+		if !conn.Expired() {
+			connIdx = idx
+			break
+		}
+	}
+	if connIdx == -1 {
+		return nil
+	}
+	listLen := len(list)
+	conn := list[connIdx]
+	if connIdx != listLen-1 {
+		list[connIdx] = list[listLen-1]
+	}
+	list = list[:listLen-1]
+	o.connsByDest[id] = list
+	return conn.conn
+}
+
+func (o *Pool) Cleanup() {
+	defer o.cleanupOnce.Reset()
+
+	for len(o.connsByDest) > 0 {
+		time.Sleep(time.Second * 5)
+		expiredConns := make([]net.Conn, 0, 16)
+		o.Lock()
+		for dest, list := range o.connsByDest {
+			validConns := make([]*ExpiringConnection, 0, len(list))
+			for _, conn := range list {
+				if conn.Expired() {
+					expiredConns = append(expiredConns, conn.conn)
+				} else {
+					validConns = append(validConns, conn)
+				}
+			}
+			if len(validConns) != len(list) {
+				o.connsByDest[dest] = validConns
+			}
+		}
+		o.Unlock()
+		for _, conn := range expiredConns {
+			conn.Close()
+		}
+	}
+}
+
+func (o *Pool) Put(id ConnectionId, conn net.Conn) {
+	expiringConn := &ExpiringConnection{
+		conn:   conn,
+		expire: time.Now().Add(time.Second * 4),
+	}
+
+	o.Lock()
+	defer o.Unlock()
+
+	list, found := o.connsByDest[id]
+	if !found {
+		list = []*ExpiringConnection{expiringConn}
+	} else {
+		list = append(list, expiringConn)
+	}
+	o.connsByDest[id] = list
+
+	o.cleanupOnce.Do(func() {
+		go o.Cleanup()
+	})
+}

+ 72 - 0
transport/internet/internal/pool_test.go

@@ -0,0 +1,72 @@
+package internal_test
+
+import (
+	"net"
+	"testing"
+	"time"
+	v2net "v2ray.com/core/common/net"
+	"v2ray.com/core/testing/assert"
+	. "v2ray.com/core/transport/internet/internal"
+)
+
+type TestConnection struct {
+	id     string
+	closed bool
+}
+
+func (o *TestConnection) Read([]byte) (int, error) {
+	return 0, nil
+}
+
+func (o *TestConnection) Write([]byte) (int, error) {
+	return 0, nil
+}
+
+func (o *TestConnection) Close() error {
+	o.closed = true
+	return nil
+}
+
+func (o *TestConnection) LocalAddr() net.Addr {
+	return nil
+}
+
+func (o *TestConnection) RemoteAddr() net.Addr {
+	return nil
+}
+
+func (o *TestConnection) SetDeadline(t time.Time) error {
+	return nil
+}
+
+func (o *TestConnection) SetReadDeadline(t time.Time) error {
+	return nil
+}
+
+func (o *TestConnection) SetWriteDeadline(t time.Time) error {
+	return nil
+}
+
+func TestConnectionCache(t *testing.T) {
+	assert := assert.On(t)
+
+	pool := NewConnectionPool()
+	conn := pool.Get(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80))))
+	assert.Pointer(conn).IsNil()
+
+	pool.Put(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80))), &TestConnection{id: "test"})
+	conn = pool.Get(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80))))
+	assert.String(conn.(*TestConnection).id).Equals("test")
+}
+
+func TestConnectionRecycle(t *testing.T) {
+	assert := assert.On(t)
+
+	pool := NewConnectionPool()
+	c := &TestConnection{id: "test"}
+	pool.Put(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80))), c)
+	time.Sleep(6 * time.Second)
+	assert.Bool(c.closed).IsTrue()
+	conn := pool.Get(NewConnectionId(v2net.LocalHostIP, v2net.TCPDestination(v2net.LocalHostIP, v2net.Port(80))))
+	assert.Pointer(conn).IsNil()
+}

+ 5 - 5
transport/internet/tcp/connection.go

@@ -9,7 +9,7 @@ import (
 )
 
 type ConnectionManager interface {
-	Recycle(string, net.Conn)
+	Put(internal.ConnectionId, net.Conn)
 }
 
 type RawConnection struct {
@@ -27,16 +27,16 @@ func (this *RawConnection) SysFd() (int, error) {
 }
 
 type Connection struct {
-	dest     string
+	id       internal.ConnectionId
 	conn     net.Conn
 	listener ConnectionManager
 	reusable bool
 	config   *Config
 }
 
-func NewConnection(dest string, conn net.Conn, manager ConnectionManager, config *Config) *Connection {
+func NewConnection(id internal.ConnectionId, conn net.Conn, manager ConnectionManager, config *Config) *Connection {
 	return &Connection{
-		dest:     dest,
+		id:       id,
 		conn:     conn,
 		listener: manager,
 		reusable: config.ConnectionReuse.IsEnabled(),
@@ -64,7 +64,7 @@ func (this *Connection) Close() error {
 		return io.ErrClosedPipe
 	}
 	if this.Reusable() {
-		this.listener.Recycle(this.dest, this.conn)
+		this.listener.Put(this.id, this.conn)
 		return nil
 	}
 	err := this.conn.Close()

+ 0 - 112
transport/internet/tcp/connection_cache.go

@@ -1,112 +0,0 @@
-package tcp
-
-import (
-	"net"
-	"sync"
-	"time"
-
-	"v2ray.com/core/common/signal"
-)
-
-type AwaitingConnection struct {
-	conn   net.Conn
-	expire time.Time
-}
-
-func (this *AwaitingConnection) Expired() bool {
-	return this.expire.Before(time.Now())
-}
-
-type ConnectionCache struct {
-	sync.Mutex
-	cache       map[string][]*AwaitingConnection
-	cleanupOnce signal.Once
-}
-
-func NewConnectionCache() *ConnectionCache {
-	return &ConnectionCache{
-		cache: make(map[string][]*AwaitingConnection),
-	}
-}
-
-func (this *ConnectionCache) Cleanup() {
-	defer this.cleanupOnce.Reset()
-
-	for len(this.cache) > 0 {
-		time.Sleep(time.Second * 4)
-		this.Lock()
-		for key, value := range this.cache {
-			size := len(value)
-			changed := false
-			for i := 0; i < size; {
-				if value[i].Expired() {
-					value[i].conn.Close()
-					value[i] = value[size-1]
-					size--
-					changed = true
-				} else {
-					i++
-				}
-			}
-			if changed {
-				for i := size; i < len(value); i++ {
-					value[i] = nil
-				}
-				value = value[:size]
-				this.cache[key] = value
-			}
-		}
-		this.Unlock()
-	}
-}
-
-func (this *ConnectionCache) Recycle(dest string, conn net.Conn) {
-	this.Lock()
-	defer this.Unlock()
-
-	aconn := &AwaitingConnection{
-		conn:   conn,
-		expire: time.Now().Add(time.Second * 4),
-	}
-
-	var list []*AwaitingConnection
-	if v, found := this.cache[dest]; found {
-		v = append(v, aconn)
-		list = v
-	} else {
-		list = []*AwaitingConnection{aconn}
-	}
-	this.cache[dest] = list
-
-	go this.cleanupOnce.Do(this.Cleanup)
-}
-
-func FindFirstValid(list []*AwaitingConnection) int {
-	for idx, conn := range list {
-		if !conn.Expired() {
-			return idx
-		}
-		go conn.conn.Close()
-	}
-	return -1
-}
-
-func (this *ConnectionCache) Get(dest string) net.Conn {
-	this.Lock()
-	defer this.Unlock()
-
-	list, found := this.cache[dest]
-	if !found {
-		return nil
-	}
-
-	firstValid := FindFirstValid(list)
-	if firstValid == -1 {
-		delete(this.cache, dest)
-		return nil
-	}
-	res := list[firstValid].conn
-	list = list[firstValid+1:]
-	this.cache[dest] = list
-	return res
-}

+ 3 - 2
transport/internet/tcp/dialer.go

@@ -8,11 +8,12 @@ import (
 	"v2ray.com/core/common/log"
 	v2net "v2ray.com/core/common/net"
 	"v2ray.com/core/transport/internet"
+	"v2ray.com/core/transport/internet/internal"
 	v2tls "v2ray.com/core/transport/internet/tls"
 )
 
 var (
-	globalCache = NewConnectionCache()
+	globalCache = internal.NewConnectionPool()
 )
 
 func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOptions) (internet.Connection, error) {
@@ -26,7 +27,7 @@ func Dial(src v2net.Address, dest v2net.Destination, options internet.DialerOpti
 	}
 	tcpSettings := networkSettings.(*Config)
 
-	id := src.String() + "-" + dest.NetAddr()
+	id := internal.NewConnectionId(src, dest)
 	var conn net.Conn
 	if dest.Network == v2net.Network_TCP && tcpSettings.ConnectionReuse.IsEnabled() {
 		conn = globalCache.Get(id)

+ 3 - 2
transport/internet/tcp/hub.go

@@ -10,6 +10,7 @@ import (
 	"v2ray.com/core/common/log"
 	v2net "v2ray.com/core/common/net"
 	"v2ray.com/core/transport/internet"
+	"v2ray.com/core/transport/internet/internal"
 	v2tls "v2ray.com/core/transport/internet/tls"
 )
 
@@ -89,7 +90,7 @@ func (this *TCPListener) Accept() (internet.Connection, error) {
 				return nil, connErr.err
 			}
 			conn := connErr.conn
-			return NewConnection("", conn, this, this.config), nil
+			return NewConnection(internal.ConnectionId{}, conn, this, this.config), nil
 		case <-time.After(time.Second * 2):
 		}
 	}
@@ -125,7 +126,7 @@ func (this *TCPListener) KeepAccepting() {
 	}
 }
 
-func (this *TCPListener) Recycle(dest string, conn net.Conn) {
+func (this *TCPListener) Put(id internal.ConnectionId, conn net.Conn) {
 	this.Lock()
 	defer this.Unlock()
 	if !this.acccepting {