req2packet.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. package packetconn
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/rand"
  6. "io"
  7. "time"
  8. "github.com/golang-collections/go-datastructures/queue"
  9. "github.com/v2fly/v2ray-core/v5/transport/internet/request"
  10. )
  11. func newRequestToPacketConnClient(ctx context.Context, config *ClientConfig) (*requestToPacketConnClient, error) { //nolint: unparam
  12. return &requestToPacketConnClient{ctx: ctx, config: config}, nil
  13. }
  14. type requestToPacketConnClient struct {
  15. assembly request.TransportClientAssembly
  16. ctx context.Context
  17. config *ClientConfig
  18. }
  19. func (r *requestToPacketConnClient) OnTransportClientAssemblyReady(assembly request.TransportClientAssembly) {
  20. r.assembly = assembly
  21. }
  22. func (r *requestToPacketConnClient) Dial() (io.ReadWriteCloser, error) {
  23. sessionID := make([]byte, 16)
  24. _, err := rand.Read(sessionID)
  25. if err != nil {
  26. return nil, err
  27. }
  28. ctxWithCancel, cancel := context.WithCancel(r.ctx)
  29. clientSess := &requestToPacketConnClientSession{
  30. sessionID: sessionID,
  31. currentPollingInterval: int(r.config.PollingIntervalInitial),
  32. maxRequestSize: int(r.config.MaxRequestSize),
  33. maxWriteDelay: int(r.config.MaxWriteDelay),
  34. assembly: r.assembly,
  35. writerChan: make(chan []byte, 256),
  36. readerChan: make(chan []byte, 256),
  37. ctx: ctxWithCancel,
  38. finish: cancel,
  39. }
  40. go clientSess.keepRunning()
  41. return clientSess, nil
  42. }
  43. type requestToPacketConnClientSession struct {
  44. sessionID []byte
  45. currentPollingInterval int
  46. maxRequestSize int
  47. maxWriteDelay int
  48. assembly request.TransportClientAssembly
  49. writerChan chan []byte
  50. readerChan chan []byte
  51. ctx context.Context
  52. finish func()
  53. nextWrite []byte
  54. }
  55. func (r *requestToPacketConnClientSession) keepRunning() {
  56. for r.ctx.Err() == nil {
  57. r.runOnce()
  58. }
  59. }
  60. func (r *requestToPacketConnClientSession) runOnce() {
  61. requestBody := bytes.NewBuffer(nil)
  62. waitTimer := time.NewTimer(time.Duration(r.currentPollingInterval) * time.Millisecond)
  63. var seenPacket bool
  64. packetBundler := NewPacketBundle()
  65. copyFromChan:
  66. for {
  67. select {
  68. case <-r.ctx.Done():
  69. return
  70. case <-waitTimer.C:
  71. break copyFromChan
  72. case packet := <-r.writerChan:
  73. if !seenPacket {
  74. seenPacket = true
  75. waitTimer.Stop()
  76. waitTimer.Reset(time.Duration(r.maxWriteDelay) * time.Millisecond)
  77. }
  78. sizeOffset := packetBundler.Overhead() + len(packet)
  79. if requestBody.Len()+sizeOffset > r.maxRequestSize {
  80. r.nextWrite = packet
  81. break copyFromChan
  82. }
  83. err := packetBundler.WriteToBundle(packet, requestBody)
  84. if err != nil {
  85. newError("failed to write to bundle").Base(err).WriteToLog()
  86. }
  87. }
  88. }
  89. waitTimer.Stop()
  90. go func() {
  91. reader, writer := io.Pipe()
  92. defer writer.Close()
  93. streamingRespOpt := &pipedStreamingRespOption{writer}
  94. go func() {
  95. for {
  96. if packet, err := packetBundler.ReadFromBundle(reader); err == nil {
  97. r.readerChan <- packet
  98. } else {
  99. return
  100. }
  101. }
  102. }()
  103. resp, err := r.assembly.Tripper().RoundTrip(r.ctx, request.Request{Data: requestBody.Bytes(), ConnectionTag: r.sessionID},
  104. streamingRespOpt)
  105. if err != nil {
  106. newError("failed to roundtrip").Base(err).WriteToLog()
  107. if r.ctx.Err() != nil {
  108. return
  109. }
  110. }
  111. if resp.Data != nil && len(resp.Data) != 0 {
  112. respReader := bytes.NewReader(resp.Data)
  113. for respReader.Len() != 0 {
  114. packet, err := packetBundler.ReadFromBundle(respReader)
  115. if err != nil {
  116. newError("failed to read from bundle").Base(err).WriteToLog()
  117. if r.ctx.Err() != nil {
  118. return
  119. }
  120. }
  121. r.readerChan <- packet
  122. }
  123. }
  124. }()
  125. }
  126. type pipedStreamingRespOption struct {
  127. writer *io.PipeWriter
  128. }
  129. func (p *pipedStreamingRespOption) RoundTripperOption() {
  130. }
  131. func (p *pipedStreamingRespOption) GetResponseWriter() io.Writer {
  132. return p.writer
  133. }
  134. func (r *requestToPacketConnClientSession) Write(p []byte) (n int, err error) {
  135. buf := make([]byte, len(p))
  136. copy(buf, p)
  137. select {
  138. case <-r.ctx.Done():
  139. return 0, r.ctx.Err()
  140. case r.writerChan <- buf:
  141. return len(p), nil
  142. }
  143. }
  144. func (r *requestToPacketConnClientSession) Read(p []byte) (n int, err error) {
  145. select {
  146. case <-r.ctx.Done():
  147. return 0, r.ctx.Err()
  148. case buf := <-r.readerChan:
  149. copy(p, buf)
  150. return len(buf), nil
  151. }
  152. }
  153. func (r *requestToPacketConnClientSession) Close() error {
  154. r.finish()
  155. return nil
  156. }
  157. func newRequestToPacketConnServer(ctx context.Context, config *ServerConfig) *requestToPacketConnServer {
  158. return &requestToPacketConnServer{
  159. sessionMap: make(map[string]*requestToPacketConnServerSession),
  160. ctx: ctx,
  161. config: config,
  162. }
  163. }
  164. type requestToPacketConnServer struct {
  165. packetSessionReceiver request.SessionReceiver
  166. sessionMap map[string]*requestToPacketConnServerSession
  167. ctx context.Context
  168. config *ServerConfig
  169. }
  170. func (r *requestToPacketConnServer) onSessionReceiverReady(sessrecv request.SessionReceiver) {
  171. r.packetSessionReceiver = sessrecv
  172. }
  173. func (r *requestToPacketConnServer) OnRoundTrip(ctx context.Context, req request.Request,
  174. opts ...request.RoundTripperOption,
  175. ) (resp request.Response, err error) {
  176. SessionID := req.ConnectionTag
  177. if SessionID == nil {
  178. return request.Response{}, newError("nil session id")
  179. }
  180. sessionID := string(SessionID)
  181. session, found := r.sessionMap[sessionID]
  182. if !found {
  183. ctxWithFinish, finish := context.WithCancel(ctx)
  184. session = &requestToPacketConnServerSession{
  185. SessionID: SessionID,
  186. writingConnectionQueue: queue.New(64),
  187. writerChan: make(chan []byte, int(r.config.PacketWritingBuffer)),
  188. readerChan: make(chan []byte, 256),
  189. ctx: ctxWithFinish,
  190. finish: finish,
  191. server: r,
  192. maxWriteSize: int(r.config.MaxWriteSize),
  193. maxWriteDuration: int(r.config.MaxWriteDurationMs),
  194. maxSimultaneousWriteConnection: int(r.config.MaxSimultaneousWriteConnection),
  195. }
  196. r.sessionMap[sessionID] = session
  197. err = r.packetSessionReceiver.OnNewSession(ctx, session)
  198. }
  199. if err != nil {
  200. return request.Response{}, err
  201. }
  202. return session.OnRoundTrip(ctx, req, opts...)
  203. }
  204. func (r *requestToPacketConnServer) removeSessionID(sessionID []byte) {
  205. delete(r.sessionMap, string(sessionID))
  206. }
  207. type requestToPacketConnServerSession struct {
  208. SessionID []byte
  209. writingConnectionQueue *queue.Queue
  210. writerChan chan []byte
  211. readerChan chan []byte
  212. ctx context.Context
  213. finish func()
  214. server *requestToPacketConnServer
  215. maxWriteSize int
  216. maxWriteDuration int
  217. maxSimultaneousWriteConnection int
  218. }
  219. func (r *requestToPacketConnServerSession) Read(p []byte) (n int, err error) {
  220. select {
  221. case <-r.ctx.Done():
  222. return 0, r.ctx.Err()
  223. case buf := <-r.readerChan:
  224. copy(p, buf)
  225. return len(buf), nil
  226. }
  227. }
  228. var debugStats struct {
  229. packetWritten int
  230. packetDropped int
  231. }
  232. /*
  233. var _ = func() bool {
  234. go func() {
  235. for {
  236. time.Sleep(time.Second)
  237. newError("packet written: ", debugStats.packetWritten, " packet dropped: ", debugStats.packetDropped).WriteToLog()
  238. }
  239. }()
  240. return true
  241. }()*/
  242. func (r *requestToPacketConnServerSession) Write(p []byte) (n int, err error) {
  243. buf := make([]byte, len(p))
  244. copy(buf, p)
  245. select {
  246. case <-r.ctx.Done():
  247. return 0, r.ctx.Err()
  248. case r.writerChan <- buf:
  249. debugStats.packetWritten++
  250. return len(p), nil
  251. default: // This write will be called from global listener's routine, it must not block
  252. debugStats.packetDropped++
  253. return len(p), nil
  254. }
  255. }
  256. func (r *requestToPacketConnServerSession) Close() error {
  257. r.server.removeSessionID(r.SessionID)
  258. r.finish()
  259. return nil
  260. }
  261. type writingConnection struct {
  262. focus func()
  263. finish func()
  264. finishCtx context.Context
  265. }
  266. func (r *requestToPacketConnServerSession) OnRoundTrip(ctx context.Context, req request.Request,
  267. opts ...request.RoundTripperOption,
  268. ) (resp request.Response, err error) {
  269. // TODO: fix connection graceful close
  270. var streamingRespWriter io.Writer
  271. var streamingRespWriterFlusher request.OptionSupportsStreamingResponseExtensionFlusher
  272. for _, opt := range opts {
  273. if streamingRespOpt, ok := opt.(request.OptionSupportsStreamingResponse); ok {
  274. streamingRespWriter = streamingRespOpt.GetResponseWriter()
  275. if streamingRespWriterFlusherOpt, ok := opt.(request.OptionSupportsStreamingResponseExtensionFlusher); ok {
  276. streamingRespWriterFlusher = streamingRespWriterFlusherOpt
  277. }
  278. }
  279. }
  280. packetBundler := NewPacketBundle()
  281. reqReader := bytes.NewReader(req.Data)
  282. for reqReader.Len() != 0 {
  283. packet, err := packetBundler.ReadFromBundle(reqReader)
  284. if err != nil {
  285. err = newError("failed to read from bundle").Base(err)
  286. return request.Response{}, err
  287. }
  288. r.readerChan <- packet
  289. }
  290. onFocusCtx, focus := context.WithCancel(ctx)
  291. onFinishCtx, finish := context.WithCancel(ctx)
  292. r.writingConnectionQueue.Put(&writingConnection{
  293. focus: focus,
  294. finish: finish,
  295. finishCtx: onFinishCtx,
  296. })
  297. amountToEnd := r.writingConnectionQueue.Len() - int64(r.maxSimultaneousWriteConnection)
  298. for amountToEnd > 0 {
  299. {
  300. _, _ = r.writingConnectionQueue.TakeUntil(func(i interface{}) bool {
  301. i.(*writingConnection).finish()
  302. amountToEnd--
  303. return amountToEnd > 0
  304. })
  305. }
  306. }
  307. {
  308. _, _ = r.writingConnectionQueue.TakeUntil(func(i interface{}) bool {
  309. i.(*writingConnection).focus()
  310. return false
  311. })
  312. }
  313. bufferedRespWriter := bytes.NewBuffer(nil)
  314. finishWrite := func() {
  315. resp.Data = bufferedRespWriter.Bytes()
  316. {
  317. _, _ = r.writingConnectionQueue.TakeUntil(func(i interface{}) bool {
  318. i.(*writingConnection).focus()
  319. if i.(*writingConnection).finishCtx.Err() != nil { //nolint: gosimple
  320. return true
  321. }
  322. return false
  323. })
  324. }
  325. }
  326. progressiveSend := streamingRespWriter != nil
  327. var respWriter io.Writer
  328. if progressiveSend {
  329. respWriter = streamingRespWriter
  330. } else {
  331. respWriter = bufferedRespWriter
  332. }
  333. var bytesSent int
  334. onReceivePacket := func(packet []byte) bool {
  335. bytesSent += len(packet) + packetBundler.Overhead()
  336. err := packetBundler.WriteToBundle(packet, respWriter)
  337. if err != nil {
  338. newError("failed to write to bundle").Base(err).WriteToLog()
  339. }
  340. if streamingRespWriterFlusher != nil {
  341. streamingRespWriterFlusher.Flush()
  342. }
  343. if bytesSent >= r.maxWriteSize {
  344. return false
  345. }
  346. return true
  347. }
  348. finishWriteTimer := time.NewTimer(time.Millisecond * time.Duration(r.maxWriteDuration))
  349. if !progressiveSend {
  350. select {
  351. case <-onFocusCtx.Done():
  352. case <-onFinishCtx.Done():
  353. finishWrite()
  354. return resp, nil
  355. }
  356. } else {
  357. select {
  358. case <-onFinishCtx.Done():
  359. finishWrite()
  360. return resp, nil
  361. default:
  362. }
  363. }
  364. firstRead := true
  365. for {
  366. select {
  367. case <-onFinishCtx.Done():
  368. finishWrite()
  369. finishWriteTimer.Stop()
  370. return resp, nil
  371. case packet := <-r.writerChan:
  372. keepSending := onReceivePacket(packet)
  373. if firstRead {
  374. firstRead = false
  375. }
  376. if !keepSending {
  377. finishWrite()
  378. finishWriteTimer.Stop()
  379. return resp, nil
  380. }
  381. case <-finishWriteTimer.C:
  382. finishWrite()
  383. return resp, nil
  384. }
  385. }
  386. }