pool.go 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package internal
  2. import (
  3. "net"
  4. "sync"
  5. "time"
  6. v2net "v2ray.com/core/common/net"
  7. "v2ray.com/core/common/signal"
  8. )
  9. type ConnectionId struct {
  10. Local v2net.Address
  11. Remote v2net.Address
  12. RemotePort v2net.Port
  13. }
  14. func NewConnectionId(source v2net.Address, dest v2net.Destination) ConnectionId {
  15. return ConnectionId{
  16. Local: source,
  17. Remote: dest.Address,
  18. RemotePort: dest.Port,
  19. }
  20. }
  21. type ExpiringConnection struct {
  22. conn net.Conn
  23. expire time.Time
  24. }
  25. func (o *ExpiringConnection) Expired() bool {
  26. return o.expire.Before(time.Now())
  27. }
  28. type Pool struct {
  29. sync.Mutex
  30. connsByDest map[ConnectionId][]*ExpiringConnection
  31. cleanupOnce signal.Once
  32. }
  33. func NewConnectionPool() *Pool {
  34. return &Pool{
  35. connsByDest: make(map[ConnectionId][]*ExpiringConnection),
  36. }
  37. }
  38. func (o *Pool) Get(id ConnectionId) net.Conn {
  39. o.Lock()
  40. defer o.Unlock()
  41. list, found := o.connsByDest[id]
  42. if !found {
  43. return nil
  44. }
  45. connIdx := -1
  46. for idx, conn := range list {
  47. if !conn.Expired() {
  48. connIdx = idx
  49. break
  50. }
  51. }
  52. if connIdx == -1 {
  53. return nil
  54. }
  55. listLen := len(list)
  56. conn := list[connIdx]
  57. if connIdx != listLen-1 {
  58. list[connIdx] = list[listLen-1]
  59. }
  60. list = list[:listLen-1]
  61. o.connsByDest[id] = list
  62. return conn.conn
  63. }
  64. func (o *Pool) Cleanup() {
  65. defer o.cleanupOnce.Reset()
  66. for len(o.connsByDest) > 0 {
  67. time.Sleep(time.Second * 5)
  68. expiredConns := make([]net.Conn, 0, 16)
  69. o.Lock()
  70. for dest, list := range o.connsByDest {
  71. validConns := make([]*ExpiringConnection, 0, len(list))
  72. for _, conn := range list {
  73. if conn.Expired() {
  74. expiredConns = append(expiredConns, conn.conn)
  75. } else {
  76. validConns = append(validConns, conn)
  77. }
  78. }
  79. if len(validConns) != len(list) {
  80. o.connsByDest[dest] = validConns
  81. }
  82. }
  83. o.Unlock()
  84. for _, conn := range expiredConns {
  85. conn.Close()
  86. }
  87. }
  88. }
  89. func (o *Pool) Put(id ConnectionId, conn net.Conn) {
  90. expiringConn := &ExpiringConnection{
  91. conn: conn,
  92. expire: time.Now().Add(time.Second * 4),
  93. }
  94. o.Lock()
  95. defer o.Unlock()
  96. list, found := o.connsByDest[id]
  97. if !found {
  98. list = []*ExpiringConnection{expiringConn}
  99. } else {
  100. list = append(list, expiringConn)
  101. }
  102. o.connsByDest[id] = list
  103. o.cleanupOnce.Do(func() {
  104. go o.Cleanup()
  105. })
  106. }