req2packet.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. package packetconn
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/rand"
  6. "io"
  7. "sync"
  8. "time"
  9. "github.com/golang-collections/go-datastructures/queue"
  10. "github.com/v2fly/v2ray-core/v5/transport/internet/request"
  11. )
  12. func newRequestToPacketConnClient(ctx context.Context, config *ClientConfig) (*requestToPacketConnClient, error) { //nolint: unparam
  13. return &requestToPacketConnClient{ctx: ctx, config: config}, nil
  14. }
  15. type requestToPacketConnClient struct {
  16. assembly request.TransportClientAssembly
  17. ctx context.Context
  18. config *ClientConfig
  19. }
  20. func (r *requestToPacketConnClient) OnTransportClientAssemblyReady(assembly request.TransportClientAssembly) {
  21. r.assembly = assembly
  22. }
  23. func (r *requestToPacketConnClient) Dial() (io.ReadWriteCloser, error) {
  24. sessionID := make([]byte, 16)
  25. _, err := rand.Read(sessionID)
  26. if err != nil {
  27. return nil, err
  28. }
  29. ctxWithCancel, cancel := context.WithCancel(r.ctx)
  30. clientSess := &requestToPacketConnClientSession{
  31. sessionID: sessionID,
  32. currentPollingInterval: int(r.config.PollingIntervalInitial),
  33. maxRequestSize: int(r.config.MaxRequestSize),
  34. maxWriteDelay: int(r.config.MaxWriteDelay),
  35. assembly: r.assembly,
  36. writerChan: make(chan []byte, 256),
  37. readerChan: make(chan []byte, 256),
  38. ctx: ctxWithCancel,
  39. finish: cancel,
  40. }
  41. go clientSess.keepRunning()
  42. return clientSess, nil
  43. }
  44. type requestToPacketConnClientSession struct {
  45. sessionID []byte
  46. currentPollingInterval int
  47. maxRequestSize int
  48. maxWriteDelay int
  49. assembly request.TransportClientAssembly
  50. writerChan chan []byte
  51. readerChan chan []byte
  52. ctx context.Context
  53. finish func()
  54. nextWrite []byte
  55. }
  56. func (r *requestToPacketConnClientSession) keepRunning() {
  57. for r.ctx.Err() == nil {
  58. r.runOnce()
  59. }
  60. }
  61. func (r *requestToPacketConnClientSession) runOnce() {
  62. requestBody := bytes.NewBuffer(nil)
  63. waitTimer := time.NewTimer(time.Duration(r.currentPollingInterval) * time.Millisecond)
  64. var seenPacket bool
  65. packetBundler := NewPacketBundle()
  66. copyFromChan:
  67. for {
  68. select {
  69. case <-r.ctx.Done():
  70. return
  71. case <-waitTimer.C:
  72. break copyFromChan
  73. case packet := <-r.writerChan:
  74. if !seenPacket {
  75. seenPacket = true
  76. waitTimer.Stop()
  77. waitTimer.Reset(time.Duration(r.maxWriteDelay) * time.Millisecond)
  78. }
  79. sizeOffset := packetBundler.Overhead() + len(packet)
  80. if requestBody.Len()+sizeOffset > r.maxRequestSize {
  81. r.nextWrite = packet
  82. break copyFromChan
  83. }
  84. err := packetBundler.WriteToBundle(packet, requestBody)
  85. if err != nil {
  86. newError("failed to write to bundle").Base(err).WriteToLog()
  87. }
  88. }
  89. }
  90. waitTimer.Stop()
  91. go func() {
  92. reader, writer := io.Pipe()
  93. defer writer.Close()
  94. streamingRespOpt := &pipedStreamingRespOption{writer}
  95. go func() {
  96. for {
  97. if packet, err := packetBundler.ReadFromBundle(reader); err == nil {
  98. r.readerChan <- packet
  99. } else {
  100. return
  101. }
  102. }
  103. }()
  104. resp, err := r.assembly.Tripper().RoundTrip(r.ctx, request.Request{Data: requestBody.Bytes(), ConnectionTag: r.sessionID},
  105. streamingRespOpt)
  106. if err != nil {
  107. newError("failed to roundtrip").Base(err).WriteToLog()
  108. if r.ctx.Err() != nil {
  109. return
  110. }
  111. }
  112. if resp.Data != nil && len(resp.Data) != 0 {
  113. respReader := bytes.NewReader(resp.Data)
  114. for respReader.Len() != 0 {
  115. packet, err := packetBundler.ReadFromBundle(respReader)
  116. if err != nil {
  117. newError("failed to read from bundle").Base(err).WriteToLog()
  118. if r.ctx.Err() != nil {
  119. return
  120. }
  121. }
  122. r.readerChan <- packet
  123. }
  124. }
  125. }()
  126. }
  127. type pipedStreamingRespOption struct {
  128. writer *io.PipeWriter
  129. }
  130. func (p *pipedStreamingRespOption) RoundTripperOption() {
  131. }
  132. func (p *pipedStreamingRespOption) GetResponseWriter() io.Writer {
  133. return p.writer
  134. }
  135. func (r *requestToPacketConnClientSession) Write(p []byte) (n int, err error) {
  136. buf := make([]byte, len(p))
  137. copy(buf, p)
  138. select {
  139. case <-r.ctx.Done():
  140. return 0, r.ctx.Err()
  141. case r.writerChan <- buf:
  142. return len(p), nil
  143. }
  144. }
  145. func (r *requestToPacketConnClientSession) Read(p []byte) (n int, err error) {
  146. select {
  147. case <-r.ctx.Done():
  148. return 0, r.ctx.Err()
  149. case buf := <-r.readerChan:
  150. copy(p, buf)
  151. return len(buf), nil
  152. }
  153. }
  154. func (r *requestToPacketConnClientSession) Close() error {
  155. r.finish()
  156. return nil
  157. }
  158. func newRequestToPacketConnServer(ctx context.Context, config *ServerConfig) *requestToPacketConnServer {
  159. return &requestToPacketConnServer{
  160. sessionMap: sync.Map{},
  161. ctx: ctx,
  162. config: config,
  163. }
  164. }
  165. type requestToPacketConnServer struct {
  166. packetSessionReceiver request.SessionReceiver
  167. sessionMap sync.Map
  168. ctx context.Context
  169. config *ServerConfig
  170. }
  171. func (r *requestToPacketConnServer) onSessionReceiverReady(sessrecv request.SessionReceiver) {
  172. r.packetSessionReceiver = sessrecv
  173. }
  174. func (r *requestToPacketConnServer) OnRoundTrip(ctx context.Context, req request.Request,
  175. opts ...request.RoundTripperOption,
  176. ) (resp request.Response, err error) {
  177. SessionID := req.ConnectionTag
  178. if SessionID == nil {
  179. return request.Response{}, newError("nil session id")
  180. }
  181. sessionID := string(SessionID)
  182. var session *requestToPacketConnServerSession
  183. sessionAny, found := r.sessionMap.Load(sessionID)
  184. if found {
  185. var ok bool
  186. session, ok = sessionAny.(*requestToPacketConnServerSession)
  187. if !ok {
  188. return request.Response{}, newError("failed to cast session")
  189. }
  190. }
  191. if !found {
  192. ctxWithFinish, finish := context.WithCancel(ctx)
  193. session = &requestToPacketConnServerSession{
  194. SessionID: SessionID,
  195. writingConnectionQueue: queue.New(64),
  196. writerChan: make(chan []byte, int(r.config.PacketWritingBuffer)),
  197. readerChan: make(chan []byte, 256),
  198. ctx: ctxWithFinish,
  199. finish: finish,
  200. server: r,
  201. maxWriteSize: int(r.config.MaxWriteSize),
  202. maxWriteDuration: int(r.config.MaxWriteDurationMs),
  203. maxSimultaneousWriteConnection: int(r.config.MaxSimultaneousWriteConnection),
  204. }
  205. _, loaded := r.sessionMap.LoadOrStore(sessionID, session)
  206. if !loaded {
  207. err = r.packetSessionReceiver.OnNewSession(ctx, session)
  208. }
  209. }
  210. if err != nil {
  211. return request.Response{}, err
  212. }
  213. return session.OnRoundTrip(ctx, req, opts...)
  214. }
  215. func (r *requestToPacketConnServer) removeSessionID(sessionID []byte) {
  216. r.sessionMap.Delete(string(sessionID))
  217. }
  218. type requestToPacketConnServerSession struct {
  219. SessionID []byte
  220. writingConnectionQueue *queue.Queue
  221. writerChan chan []byte
  222. readerChan chan []byte
  223. ctx context.Context
  224. finish func()
  225. server *requestToPacketConnServer
  226. maxWriteSize int
  227. maxWriteDuration int
  228. maxSimultaneousWriteConnection int
  229. }
  230. func (r *requestToPacketConnServerSession) Read(p []byte) (n int, err error) {
  231. select {
  232. case <-r.ctx.Done():
  233. return 0, r.ctx.Err()
  234. case buf := <-r.readerChan:
  235. copy(p, buf)
  236. return len(buf), nil
  237. }
  238. }
  239. var debugStats struct {
  240. packetWritten int
  241. packetDropped int
  242. }
  243. /*
  244. var _ = func() bool {
  245. go func() {
  246. for {
  247. time.Sleep(time.Second)
  248. newError("packet written: ", debugStats.packetWritten, " packet dropped: ", debugStats.packetDropped).WriteToLog()
  249. }
  250. }()
  251. return true
  252. }()*/
  253. func (r *requestToPacketConnServerSession) Write(p []byte) (n int, err error) {
  254. buf := make([]byte, len(p))
  255. copy(buf, p)
  256. select {
  257. case <-r.ctx.Done():
  258. return 0, r.ctx.Err()
  259. case r.writerChan <- buf:
  260. debugStats.packetWritten++
  261. return len(p), nil
  262. default: // This write will be called from global listener's routine, it must not block
  263. debugStats.packetDropped++
  264. return len(p), nil
  265. }
  266. }
  267. func (r *requestToPacketConnServerSession) Close() error {
  268. r.server.removeSessionID(r.SessionID)
  269. r.finish()
  270. return nil
  271. }
  272. type writingConnection struct {
  273. focus func()
  274. finish func()
  275. finishCtx context.Context
  276. }
  277. func (r *requestToPacketConnServerSession) OnRoundTrip(ctx context.Context, req request.Request,
  278. opts ...request.RoundTripperOption,
  279. ) (resp request.Response, err error) {
  280. // TODO: fix connection graceful close
  281. var streamingRespWriter io.Writer
  282. var streamingRespWriterFlusher request.OptionSupportsStreamingResponseExtensionFlusher
  283. for _, opt := range opts {
  284. if streamingRespOpt, ok := opt.(request.OptionSupportsStreamingResponse); ok {
  285. streamingRespWriter = streamingRespOpt.GetResponseWriter()
  286. if streamingRespWriterFlusherOpt, ok := opt.(request.OptionSupportsStreamingResponseExtensionFlusher); ok {
  287. streamingRespWriterFlusher = streamingRespWriterFlusherOpt
  288. }
  289. }
  290. }
  291. packetBundler := NewPacketBundle()
  292. reqReader := bytes.NewReader(req.Data)
  293. for reqReader.Len() != 0 {
  294. packet, err := packetBundler.ReadFromBundle(reqReader)
  295. if err != nil {
  296. err = newError("failed to read from bundle").Base(err)
  297. return request.Response{}, err
  298. }
  299. r.readerChan <- packet
  300. }
  301. onFocusCtx, focus := context.WithCancel(ctx)
  302. onFinishCtx, finish := context.WithCancel(ctx)
  303. r.writingConnectionQueue.Put(&writingConnection{
  304. focus: focus,
  305. finish: finish,
  306. finishCtx: onFinishCtx,
  307. })
  308. amountToEnd := r.writingConnectionQueue.Len() - int64(r.maxSimultaneousWriteConnection)
  309. for amountToEnd > 0 {
  310. {
  311. _, _ = r.writingConnectionQueue.TakeUntil(func(i interface{}) bool {
  312. i.(*writingConnection).finish()
  313. amountToEnd--
  314. return amountToEnd > 0
  315. })
  316. }
  317. }
  318. {
  319. _, _ = r.writingConnectionQueue.TakeUntil(func(i interface{}) bool {
  320. i.(*writingConnection).focus()
  321. return false
  322. })
  323. }
  324. bufferedRespWriter := bytes.NewBuffer(nil)
  325. finishWrite := func() {
  326. resp.Data = bufferedRespWriter.Bytes()
  327. {
  328. _, _ = r.writingConnectionQueue.TakeUntil(func(i interface{}) bool {
  329. i.(*writingConnection).focus()
  330. if i.(*writingConnection).finishCtx.Err() != nil { //nolint: gosimple
  331. return true
  332. }
  333. return false
  334. })
  335. }
  336. }
  337. progressiveSend := streamingRespWriter != nil
  338. var respWriter io.Writer
  339. if progressiveSend {
  340. respWriter = streamingRespWriter
  341. } else {
  342. respWriter = bufferedRespWriter
  343. }
  344. var bytesSent int
  345. onReceivePacket := func(packet []byte) bool {
  346. bytesSent += len(packet) + packetBundler.Overhead()
  347. err := packetBundler.WriteToBundle(packet, respWriter)
  348. if err != nil {
  349. newError("failed to write to bundle").Base(err).WriteToLog()
  350. }
  351. if streamingRespWriterFlusher != nil {
  352. streamingRespWriterFlusher.Flush()
  353. }
  354. if bytesSent >= r.maxWriteSize {
  355. return false
  356. }
  357. return true
  358. }
  359. finishWriteTimer := time.NewTimer(time.Millisecond * time.Duration(r.maxWriteDuration))
  360. if !progressiveSend {
  361. select {
  362. case <-onFocusCtx.Done():
  363. case <-onFinishCtx.Done():
  364. finishWrite()
  365. return resp, nil
  366. }
  367. } else {
  368. select {
  369. case <-onFinishCtx.Done():
  370. finishWrite()
  371. return resp, nil
  372. default:
  373. }
  374. }
  375. firstRead := true
  376. for {
  377. select {
  378. case <-onFinishCtx.Done():
  379. finishWrite()
  380. finishWriteTimer.Stop()
  381. return resp, nil
  382. case packet := <-r.writerChan:
  383. keepSending := onReceivePacket(packet)
  384. if firstRead {
  385. firstRead = false
  386. }
  387. if !keepSending {
  388. finishWrite()
  389. finishWriteTimer.Stop()
  390. return resp, nil
  391. }
  392. case <-finishWriteTimer.C:
  393. finishWrite()
  394. return resp, nil
  395. }
  396. }
  397. }