client_session.go 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. package shadowsocks2022
  2. import (
  3. "context"
  4. "crypto/rand"
  5. "io"
  6. gonet "net"
  7. "sync"
  8. "time"
  9. "github.com/v2fly/v2ray-core/v5/common/buf"
  10. "github.com/v2fly/v2ray-core/v5/common/net"
  11. "github.com/v2fly/v2ray-core/v5/transport/internet"
  12. "github.com/pion/transport/v2/replaydetector"
  13. )
  14. func NewClientUDPSession(ctx context.Context, conn io.ReadWriteCloser, packetProcessor UDPClientPacketProcessor) *ClientUDPSession {
  15. session := &ClientUDPSession{
  16. locker: &sync.RWMutex{},
  17. conn: conn,
  18. packetProcessor: packetProcessor,
  19. sessionMap: make(map[string]*ClientUDPSessionConn),
  20. sessionMapAlias: make(map[string]string),
  21. }
  22. session.ctx, session.finish = context.WithCancel(ctx)
  23. go session.KeepReading()
  24. return session
  25. }
  26. type ClientUDPSession struct {
  27. locker *sync.RWMutex
  28. conn io.ReadWriteCloser
  29. packetProcessor UDPClientPacketProcessor
  30. sessionMap map[string]*ClientUDPSessionConn
  31. sessionMapAlias map[string]string
  32. ctx context.Context
  33. finish func()
  34. }
  35. func (c *ClientUDPSession) GetCachedState(sessionID string) UDPClientPacketProcessorCachedState {
  36. c.locker.RLock()
  37. defer c.locker.RUnlock()
  38. state, ok := c.sessionMap[sessionID]
  39. if !ok {
  40. return nil
  41. }
  42. return state.cachedProcessorState
  43. }
  44. func (c *ClientUDPSession) GetCachedServerState(serverSessionID string) UDPClientPacketProcessorCachedState {
  45. c.locker.RLock()
  46. defer c.locker.RUnlock()
  47. clientSessionID := c.getCachedStateAlias(serverSessionID)
  48. if clientSessionID == "" {
  49. return nil
  50. }
  51. state, ok := c.sessionMap[clientSessionID]
  52. if !ok {
  53. return nil
  54. }
  55. if serverState, ok := state.trackedServerSessionID[serverSessionID]; !ok {
  56. return nil
  57. } else {
  58. return serverState.cachedRecvProcessorState
  59. }
  60. }
  61. func (c *ClientUDPSession) getCachedStateAlias(serverSessionID string) string {
  62. state, ok := c.sessionMapAlias[serverSessionID]
  63. if !ok {
  64. return ""
  65. }
  66. return state
  67. }
  68. func (c *ClientUDPSession) PutCachedState(sessionID string, cache UDPClientPacketProcessorCachedState) {
  69. c.locker.RLock()
  70. defer c.locker.RUnlock()
  71. state, ok := c.sessionMap[sessionID]
  72. if !ok {
  73. return
  74. }
  75. state.cachedProcessorState = cache
  76. }
  77. func (c *ClientUDPSession) PutCachedServerState(serverSessionID string, cache UDPClientPacketProcessorCachedState) {
  78. c.locker.RLock()
  79. defer c.locker.RUnlock()
  80. clientSessionID := c.getCachedStateAlias(serverSessionID)
  81. if clientSessionID == "" {
  82. return
  83. }
  84. state, ok := c.sessionMap[clientSessionID]
  85. if !ok {
  86. return
  87. }
  88. if serverState, ok := state.trackedServerSessionID[serverSessionID]; ok {
  89. serverState.cachedRecvProcessorState = cache
  90. return
  91. }
  92. }
  93. func (c *ClientUDPSession) Close() error {
  94. c.finish()
  95. return c.conn.Close()
  96. }
  97. func (c *ClientUDPSession) WriteUDPRequest(request *UDPRequest) error {
  98. buffer := buf.New()
  99. defer buffer.Release()
  100. err := c.packetProcessor.EncodeUDPRequest(request, buffer, c)
  101. if request.Payload != nil {
  102. request.Payload.Release()
  103. }
  104. if err != nil {
  105. return newError("unable to encode udp request").Base(err)
  106. }
  107. _, err = c.conn.Write(buffer.Bytes())
  108. if err != nil {
  109. return newError("unable to write to conn").Base(err)
  110. }
  111. return nil
  112. }
  113. func (c *ClientUDPSession) KeepReading() {
  114. for c.ctx.Err() == nil {
  115. udpResp := &UDPResponse{}
  116. buffer := make([]byte, 1600)
  117. n, err := c.conn.Read(buffer)
  118. if err != nil {
  119. newError("unable to read from conn").Base(err).WriteToLog()
  120. return
  121. }
  122. if n != 0 {
  123. err := c.packetProcessor.DecodeUDPResp(buffer[:n], udpResp, c)
  124. if err != nil {
  125. newError("unable to decode udp response").Base(err).WriteToLog()
  126. continue
  127. }
  128. {
  129. timeDifference := int64(udpResp.TimeStamp) - time.Now().Unix()
  130. if timeDifference < -30 || timeDifference > 30 {
  131. newError("udp packet timestamp difference too large, packet discarded, time diff = ", timeDifference).WriteToLog()
  132. continue
  133. }
  134. }
  135. c.locker.RLock()
  136. session, ok := c.sessionMap[string(udpResp.ClientSessionID[:])]
  137. c.locker.RUnlock()
  138. if ok {
  139. select {
  140. case session.readChan <- udpResp:
  141. default:
  142. }
  143. } else {
  144. newError("misbehaving server: unknown client session ID").Base(err).WriteToLog()
  145. }
  146. }
  147. }
  148. }
  149. func (c *ClientUDPSession) NewSessionConn() (internet.AbstractPacketConn, error) {
  150. sessionID := make([]byte, 8)
  151. _, err := rand.Read(sessionID)
  152. if err != nil {
  153. return nil, newError("unable to generate session id").Base(err)
  154. }
  155. connctx, connfinish := context.WithCancel(c.ctx)
  156. sessionConn := &ClientUDPSessionConn{
  157. sessionID: string(sessionID),
  158. readChan: make(chan *UDPResponse, 128),
  159. parent: c,
  160. ctx: connctx,
  161. finish: connfinish,
  162. nextWritePacketID: 0,
  163. trackedServerSessionID: make(map[string]*ClientUDPSessionServerTracker),
  164. }
  165. c.locker.Lock()
  166. c.sessionMap[sessionConn.sessionID] = sessionConn
  167. c.locker.Unlock()
  168. return sessionConn, nil
  169. }
  170. type ClientUDPSessionServerTracker struct {
  171. cachedRecvProcessorState UDPClientPacketProcessorCachedState
  172. rxReplayDetector replaydetector.ReplayDetector
  173. lastSeen time.Time
  174. }
  175. type ClientUDPSessionConn struct {
  176. sessionID string
  177. readChan chan *UDPResponse
  178. parent *ClientUDPSession
  179. nextWritePacketID uint64
  180. trackedServerSessionID map[string]*ClientUDPSessionServerTracker
  181. cachedProcessorState UDPClientPacketProcessorCachedState
  182. ctx context.Context
  183. finish func()
  184. }
  185. func (c *ClientUDPSessionConn) Close() error {
  186. c.parent.locker.Lock()
  187. delete(c.parent.sessionMap, c.sessionID)
  188. for k := range c.trackedServerSessionID {
  189. delete(c.parent.sessionMapAlias, k)
  190. }
  191. c.parent.locker.Unlock()
  192. c.finish()
  193. return nil
  194. }
  195. func (c *ClientUDPSessionConn) WriteTo(p []byte, addr gonet.Addr) (n int, err error) {
  196. thisPacketID := c.nextWritePacketID
  197. c.nextWritePacketID += 1
  198. req := &UDPRequest{
  199. SessionID: [8]byte{},
  200. PacketID: thisPacketID,
  201. TimeStamp: uint64(time.Now().Unix()),
  202. Address: net.IPAddress(addr.(*gonet.UDPAddr).IP),
  203. Port: addr.(*net.UDPAddr).Port,
  204. Payload: nil,
  205. }
  206. copy(req.SessionID[:], c.sessionID)
  207. req.Payload = buf.New()
  208. req.Payload.Write(p)
  209. err = c.parent.WriteUDPRequest(req)
  210. if err != nil {
  211. return 0, newError("unable to write to parent session").Base(err)
  212. }
  213. return len(p), nil
  214. }
  215. func (c *ClientUDPSessionConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
  216. for {
  217. select {
  218. case <-c.ctx.Done():
  219. return 0, nil, io.EOF
  220. case resp := <-c.readChan:
  221. n = copy(p, resp.Payload.Bytes())
  222. resp.Payload.Release()
  223. var trackedState *ClientUDPSessionServerTracker
  224. if trackedStateReceived, ok := c.trackedServerSessionID[string(resp.SessionID[:])]; !ok {
  225. for key, value := range c.trackedServerSessionID {
  226. if time.Since(value.lastSeen) > 65*time.Second {
  227. delete(c.trackedServerSessionID, key)
  228. }
  229. }
  230. state := &ClientUDPSessionServerTracker{
  231. rxReplayDetector: replaydetector.New(1024, ^uint64(0)),
  232. }
  233. c.trackedServerSessionID[string(resp.SessionID[:])] = state
  234. c.parent.locker.RLock()
  235. c.parent.sessionMapAlias[string(resp.SessionID[:])] = string(resp.ClientSessionID[:])
  236. c.parent.locker.RUnlock()
  237. trackedState = state
  238. } else {
  239. trackedState = trackedStateReceived
  240. }
  241. if accept, ok := trackedState.rxReplayDetector.Check(resp.PacketID); ok {
  242. accept()
  243. } else {
  244. newError("misbehaving server: replayed packet").Base(err).WriteToLog()
  245. continue
  246. }
  247. trackedState.lastSeen = time.Now()
  248. addr = &net.UDPAddr{IP: resp.Address.IP(), Port: resp.Port}
  249. }
  250. return n, addr, nil
  251. }
  252. }