req2packet.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. package packetconn
  2. import (
  3. "bytes"
  4. "context"
  5. "crypto/rand"
  6. "io"
  7. "time"
  8. "github.com/golang-collections/go-datastructures/queue"
  9. "github.com/v2fly/v2ray-core/v5/transport/internet/request"
  10. )
  11. func newRequestToPacketConnClient(ctx context.Context, config *ClientConfig) (*requestToPacketConnClient, error) { //nolint: unparam
  12. return &requestToPacketConnClient{ctx: ctx, config: config}, nil
  13. }
  14. type requestToPacketConnClient struct {
  15. assembly request.TransportClientAssembly
  16. ctx context.Context
  17. config *ClientConfig
  18. }
  19. func (r *requestToPacketConnClient) OnTransportClientAssemblyReady(assembly request.TransportClientAssembly) {
  20. r.assembly = assembly
  21. }
  22. func (r *requestToPacketConnClient) Dial() (io.ReadWriteCloser, error) {
  23. sessionID := make([]byte, 16)
  24. _, err := rand.Read(sessionID)
  25. if err != nil {
  26. return nil, err
  27. }
  28. ctxWithCancel, cancel := context.WithCancel(r.ctx)
  29. clientSess := &requestToPacketConnClientSession{
  30. sessionID: sessionID,
  31. currentPollingInterval: int(r.config.PollingIntervalInitial),
  32. maxRequestSize: int(r.config.MaxRequestSize),
  33. maxWriteDelay: int(r.config.MaxWriteDelay),
  34. assembly: r.assembly,
  35. writerChan: make(chan []byte, 256),
  36. readerChan: make(chan []byte, 256),
  37. ctx: ctxWithCancel,
  38. finish: cancel,
  39. }
  40. go clientSess.keepRunning()
  41. return clientSess, nil
  42. }
  43. type requestToPacketConnClientSession struct {
  44. sessionID []byte
  45. currentPollingInterval int
  46. maxRequestSize int
  47. maxWriteDelay int
  48. assembly request.TransportClientAssembly
  49. writerChan chan []byte
  50. readerChan chan []byte
  51. ctx context.Context
  52. finish func()
  53. nextWrite []byte
  54. }
  55. func (r *requestToPacketConnClientSession) keepRunning() {
  56. for r.ctx.Err() == nil {
  57. r.runOnce()
  58. }
  59. }
  60. func (r *requestToPacketConnClientSession) runOnce() {
  61. requestBody := bytes.NewBuffer(nil)
  62. waitTimer := time.NewTimer(time.Duration(r.currentPollingInterval) * time.Millisecond)
  63. var seenPacket bool
  64. packetBundler := NewPacketBundle()
  65. copyFromChan:
  66. for {
  67. select {
  68. case <-r.ctx.Done():
  69. return
  70. case <-waitTimer.C:
  71. break copyFromChan
  72. case packet := <-r.writerChan:
  73. if !seenPacket {
  74. seenPacket = true
  75. waitTimer.Stop()
  76. waitTimer.Reset(time.Duration(r.maxWriteDelay) * time.Millisecond)
  77. }
  78. sizeOffset := packetBundler.Overhead() + len(packet)
  79. if requestBody.Len()+sizeOffset > r.maxRequestSize {
  80. r.nextWrite = packet
  81. break copyFromChan
  82. }
  83. err := packetBundler.WriteToBundle(packet, requestBody)
  84. if err != nil {
  85. newError("failed to write to bundle").Base(err).WriteToLog()
  86. }
  87. }
  88. }
  89. waitTimer.Stop()
  90. go func() {
  91. reader, writer := io.Pipe()
  92. streamingRespOpt := &pipedStreamingRespOption{writer}
  93. go func() {
  94. for {
  95. if packet, err := packetBundler.ReadFromBundle(reader); err == nil {
  96. r.readerChan <- packet
  97. } else {
  98. return
  99. }
  100. }
  101. }()
  102. resp, err := r.assembly.Tripper().RoundTrip(r.ctx, request.Request{Data: requestBody.Bytes(), ConnectionTag: r.sessionID},
  103. streamingRespOpt)
  104. if err != nil {
  105. newError("failed to roundtrip").Base(err).WriteToLog()
  106. if r.ctx.Err() != nil {
  107. return
  108. }
  109. }
  110. if resp.Data != nil && len(resp.Data) != 0 {
  111. respReader := bytes.NewReader(resp.Data)
  112. for respReader.Len() != 0 {
  113. packet, err := packetBundler.ReadFromBundle(respReader)
  114. if err != nil {
  115. newError("failed to read from bundle").Base(err).WriteToLog()
  116. if r.ctx.Err() != nil {
  117. return
  118. }
  119. }
  120. r.readerChan <- packet
  121. }
  122. }
  123. }()
  124. }
  125. type pipedStreamingRespOption struct {
  126. writer *io.PipeWriter
  127. }
  128. func (p *pipedStreamingRespOption) RoundTripperOption() {
  129. }
  130. func (p *pipedStreamingRespOption) GetResponseWriter() io.Writer {
  131. return p.writer
  132. }
  133. func (r *requestToPacketConnClientSession) Write(p []byte) (n int, err error) {
  134. buf := make([]byte, len(p))
  135. copy(buf, p)
  136. select {
  137. case <-r.ctx.Done():
  138. return 0, r.ctx.Err()
  139. case r.writerChan <- buf:
  140. return len(p), nil
  141. }
  142. }
  143. func (r *requestToPacketConnClientSession) Read(p []byte) (n int, err error) {
  144. select {
  145. case <-r.ctx.Done():
  146. return 0, r.ctx.Err()
  147. case buf := <-r.readerChan:
  148. copy(p, buf)
  149. return len(buf), nil
  150. }
  151. }
  152. func (r *requestToPacketConnClientSession) Close() error {
  153. r.finish()
  154. return nil
  155. }
  156. func newRequestToPacketConnServer(ctx context.Context, config *ServerConfig) *requestToPacketConnServer {
  157. return &requestToPacketConnServer{
  158. sessionMap: make(map[string]*requestToPacketConnServerSession),
  159. ctx: ctx,
  160. config: config,
  161. }
  162. }
  163. type requestToPacketConnServer struct {
  164. packetSessionReceiver request.SessionReceiver
  165. sessionMap map[string]*requestToPacketConnServerSession
  166. ctx context.Context
  167. config *ServerConfig
  168. }
  169. func (r *requestToPacketConnServer) onSessionReceiverReady(sessrecv request.SessionReceiver) {
  170. r.packetSessionReceiver = sessrecv
  171. }
  172. func (r *requestToPacketConnServer) OnRoundTrip(ctx context.Context, req request.Request,
  173. opts ...request.RoundTripperOption,
  174. ) (resp request.Response, err error) {
  175. SessionID := req.ConnectionTag
  176. if SessionID == nil {
  177. return request.Response{}, newError("nil session id")
  178. }
  179. sessionID := string(SessionID)
  180. session, found := r.sessionMap[sessionID]
  181. if !found {
  182. ctxWithFinish, finish := context.WithCancel(ctx)
  183. session = &requestToPacketConnServerSession{
  184. SessionID: SessionID,
  185. writingConnectionQueue: queue.New(64),
  186. writerChan: make(chan []byte, int(r.config.PacketWritingBuffer)),
  187. readerChan: make(chan []byte, 256),
  188. ctx: ctxWithFinish,
  189. finish: finish,
  190. server: r,
  191. maxWriteSize: int(r.config.MaxWriteSize),
  192. maxWriteDuration: int(r.config.MaxWriteDurationMs),
  193. maxSimultaneousWriteConnection: int(r.config.MaxSimultaneousWriteConnection),
  194. }
  195. r.sessionMap[sessionID] = session
  196. err = r.packetSessionReceiver.OnNewSession(ctx, session)
  197. }
  198. if err != nil {
  199. return request.Response{}, err
  200. }
  201. return session.OnRoundTrip(ctx, req, opts...)
  202. }
  203. func (r *requestToPacketConnServer) removeSessionID(sessionID []byte) {
  204. delete(r.sessionMap, string(sessionID))
  205. }
  206. type requestToPacketConnServerSession struct {
  207. SessionID []byte
  208. writingConnectionQueue *queue.Queue
  209. writerChan chan []byte
  210. readerChan chan []byte
  211. ctx context.Context
  212. finish func()
  213. server *requestToPacketConnServer
  214. maxWriteSize int
  215. maxWriteDuration int
  216. maxSimultaneousWriteConnection int
  217. }
  218. func (r *requestToPacketConnServerSession) Read(p []byte) (n int, err error) {
  219. select {
  220. case <-r.ctx.Done():
  221. return 0, r.ctx.Err()
  222. case buf := <-r.readerChan:
  223. copy(p, buf)
  224. return len(buf), nil
  225. }
  226. }
  227. var debugStats struct {
  228. packetWritten int
  229. packetDropped int
  230. }
  231. /*
  232. var _ = func() bool {
  233. go func() {
  234. for {
  235. time.Sleep(time.Second)
  236. newError("packet written: ", debugStats.packetWritten, " packet dropped: ", debugStats.packetDropped).WriteToLog()
  237. }
  238. }()
  239. return true
  240. }()*/
  241. func (r *requestToPacketConnServerSession) Write(p []byte) (n int, err error) {
  242. buf := make([]byte, len(p))
  243. copy(buf, p)
  244. select {
  245. case <-r.ctx.Done():
  246. return 0, r.ctx.Err()
  247. case r.writerChan <- buf:
  248. debugStats.packetWritten++
  249. return len(p), nil
  250. default: // This write will be called from global listener's routine, it must not block
  251. debugStats.packetDropped++
  252. return len(p), nil
  253. }
  254. }
  255. func (r *requestToPacketConnServerSession) Close() error {
  256. r.server.removeSessionID(r.SessionID)
  257. r.finish()
  258. return nil
  259. }
  260. type writingConnection struct {
  261. focus func()
  262. finish func()
  263. finishCtx context.Context
  264. }
  265. func (r *requestToPacketConnServerSession) OnRoundTrip(ctx context.Context, req request.Request,
  266. opts ...request.RoundTripperOption,
  267. ) (resp request.Response, err error) {
  268. // TODO: fix connection graceful close
  269. var streamingRespWriter io.Writer
  270. var streamingRespWriterFlusher request.OptionSupportsStreamingResponseExtensionFlusher
  271. for _, opt := range opts {
  272. if streamingRespOpt, ok := opt.(request.OptionSupportsStreamingResponse); ok {
  273. streamingRespWriter = streamingRespOpt.GetResponseWriter()
  274. if streamingRespWriterFlusherOpt, ok := opt.(request.OptionSupportsStreamingResponseExtensionFlusher); ok {
  275. streamingRespWriterFlusher = streamingRespWriterFlusherOpt
  276. }
  277. }
  278. }
  279. packetBundler := NewPacketBundle()
  280. reqReader := bytes.NewReader(req.Data)
  281. for reqReader.Len() != 0 {
  282. packet, err := packetBundler.ReadFromBundle(reqReader)
  283. if err != nil {
  284. err = newError("failed to read from bundle").Base(err)
  285. return request.Response{}, err
  286. }
  287. r.readerChan <- packet
  288. }
  289. onFocusCtx, focus := context.WithCancel(ctx)
  290. onFinishCtx, finish := context.WithCancel(ctx)
  291. r.writingConnectionQueue.Put(&writingConnection{
  292. focus: focus,
  293. finish: finish,
  294. finishCtx: onFinishCtx,
  295. })
  296. amountToEnd := r.writingConnectionQueue.Len() - int64(r.maxSimultaneousWriteConnection)
  297. for amountToEnd > 0 {
  298. {
  299. _, _ = r.writingConnectionQueue.TakeUntil(func(i interface{}) bool {
  300. i.(*writingConnection).finish()
  301. amountToEnd--
  302. return amountToEnd > 0
  303. })
  304. }
  305. }
  306. {
  307. _, _ = r.writingConnectionQueue.TakeUntil(func(i interface{}) bool {
  308. i.(*writingConnection).focus()
  309. return false
  310. })
  311. }
  312. bufferedRespWriter := bytes.NewBuffer(nil)
  313. finishWrite := func() {
  314. resp.Data = bufferedRespWriter.Bytes()
  315. {
  316. _, _ = r.writingConnectionQueue.TakeUntil(func(i interface{}) bool {
  317. i.(*writingConnection).focus()
  318. if i.(*writingConnection).finishCtx.Err() != nil { //nolint: gosimple
  319. return true
  320. }
  321. return false
  322. })
  323. }
  324. }
  325. progressiveSend := streamingRespWriter != nil
  326. var respWriter io.Writer
  327. if progressiveSend {
  328. respWriter = streamingRespWriter
  329. } else {
  330. respWriter = bufferedRespWriter
  331. }
  332. var bytesSent int
  333. onReceivePacket := func(packet []byte) bool {
  334. bytesSent += len(packet) + packetBundler.Overhead()
  335. err := packetBundler.WriteToBundle(packet, respWriter)
  336. if err != nil {
  337. newError("failed to write to bundle").Base(err).WriteToLog()
  338. }
  339. if streamingRespWriterFlusher != nil {
  340. streamingRespWriterFlusher.Flush()
  341. }
  342. if bytesSent >= r.maxWriteSize {
  343. return false
  344. }
  345. return true
  346. }
  347. finishWriteTimer := time.NewTimer(time.Millisecond * time.Duration(r.maxWriteDuration))
  348. if !progressiveSend {
  349. select {
  350. case <-onFocusCtx.Done():
  351. case <-onFinishCtx.Done():
  352. finishWrite()
  353. return resp, nil
  354. }
  355. } else {
  356. select {
  357. case <-onFinishCtx.Done():
  358. finishWrite()
  359. return resp, nil
  360. default:
  361. }
  362. }
  363. firstRead := true
  364. for {
  365. select {
  366. case <-onFinishCtx.Done():
  367. finishWrite()
  368. finishWriteTimer.Stop()
  369. return resp, nil
  370. case packet := <-r.writerChan:
  371. keepSending := onReceivePacket(packet)
  372. if firstRead {
  373. firstRead = false
  374. }
  375. if !keepSending {
  376. finishWrite()
  377. finishWriteTimer.Stop()
  378. return resp, nil
  379. }
  380. case <-finishWriteTimer.C:
  381. finishWrite()
  382. return resp, nil
  383. }
  384. }
  385. }