streams_map.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. package quic
  2. import (
  3. "fmt"
  4. "github.com/lucas-clemente/quic-go/internal/flowcontrol"
  5. "github.com/lucas-clemente/quic-go/internal/handshake"
  6. "github.com/lucas-clemente/quic-go/internal/protocol"
  7. "github.com/lucas-clemente/quic-go/internal/wire"
  8. )
  9. type streamType int
  10. const (
  11. streamTypeOutgoingBidi streamType = iota
  12. streamTypeIncomingBidi
  13. streamTypeOutgoingUni
  14. streamTypeIncomingUni
  15. )
  16. type streamsMap struct {
  17. perspective protocol.Perspective
  18. sender streamSender
  19. newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController
  20. outgoingBidiStreams *outgoingBidiStreamsMap
  21. outgoingUniStreams *outgoingUniStreamsMap
  22. incomingBidiStreams *incomingBidiStreamsMap
  23. incomingUniStreams *incomingUniStreamsMap
  24. }
  25. var _ streamManager = &streamsMap{}
  26. func newStreamsMap(
  27. sender streamSender,
  28. newFlowController func(protocol.StreamID) flowcontrol.StreamFlowController,
  29. maxIncomingStreams int,
  30. maxIncomingUniStreams int,
  31. perspective protocol.Perspective,
  32. version protocol.VersionNumber,
  33. ) streamManager {
  34. m := &streamsMap{
  35. perspective: perspective,
  36. newFlowController: newFlowController,
  37. sender: sender,
  38. }
  39. var firstOutgoingBidiStream, firstOutgoingUniStream, firstIncomingBidiStream, firstIncomingUniStream protocol.StreamID
  40. if perspective == protocol.PerspectiveServer {
  41. firstOutgoingBidiStream = 1
  42. firstIncomingBidiStream = 4 // the crypto stream is handled separately
  43. firstOutgoingUniStream = 3
  44. firstIncomingUniStream = 2
  45. } else {
  46. firstOutgoingBidiStream = 4 // the crypto stream is handled separately
  47. firstIncomingBidiStream = 1
  48. firstOutgoingUniStream = 2
  49. firstIncomingUniStream = 3
  50. }
  51. newBidiStream := func(id protocol.StreamID) streamI {
  52. return newStream(id, m.sender, m.newFlowController(id), version)
  53. }
  54. newUniSendStream := func(id protocol.StreamID) sendStreamI {
  55. return newSendStream(id, m.sender, m.newFlowController(id), version)
  56. }
  57. newUniReceiveStream := func(id protocol.StreamID) receiveStreamI {
  58. return newReceiveStream(id, m.sender, m.newFlowController(id), version)
  59. }
  60. m.outgoingBidiStreams = newOutgoingBidiStreamsMap(
  61. firstOutgoingBidiStream,
  62. newBidiStream,
  63. sender.queueControlFrame,
  64. )
  65. m.incomingBidiStreams = newIncomingBidiStreamsMap(
  66. firstIncomingBidiStream,
  67. protocol.MaxBidiStreamID(maxIncomingStreams, perspective),
  68. maxIncomingStreams,
  69. sender.queueControlFrame,
  70. newBidiStream,
  71. )
  72. m.outgoingUniStreams = newOutgoingUniStreamsMap(
  73. firstOutgoingUniStream,
  74. newUniSendStream,
  75. sender.queueControlFrame,
  76. )
  77. m.incomingUniStreams = newIncomingUniStreamsMap(
  78. firstIncomingUniStream,
  79. protocol.MaxUniStreamID(maxIncomingUniStreams, perspective),
  80. maxIncomingUniStreams,
  81. sender.queueControlFrame,
  82. newUniReceiveStream,
  83. )
  84. return m
  85. }
  86. func (m *streamsMap) getStreamType(id protocol.StreamID) streamType {
  87. if m.perspective == protocol.PerspectiveServer {
  88. switch id % 4 {
  89. case 0:
  90. return streamTypeIncomingBidi
  91. case 1:
  92. return streamTypeOutgoingBidi
  93. case 2:
  94. return streamTypeIncomingUni
  95. case 3:
  96. return streamTypeOutgoingUni
  97. }
  98. } else {
  99. switch id % 4 {
  100. case 0:
  101. return streamTypeOutgoingBidi
  102. case 1:
  103. return streamTypeIncomingBidi
  104. case 2:
  105. return streamTypeOutgoingUni
  106. case 3:
  107. return streamTypeIncomingUni
  108. }
  109. }
  110. panic("")
  111. }
  112. func (m *streamsMap) OpenStream() (Stream, error) {
  113. return m.outgoingBidiStreams.OpenStream()
  114. }
  115. func (m *streamsMap) OpenStreamSync() (Stream, error) {
  116. return m.outgoingBidiStreams.OpenStreamSync()
  117. }
  118. func (m *streamsMap) OpenUniStream() (SendStream, error) {
  119. return m.outgoingUniStreams.OpenStream()
  120. }
  121. func (m *streamsMap) OpenUniStreamSync() (SendStream, error) {
  122. return m.outgoingUniStreams.OpenStreamSync()
  123. }
  124. func (m *streamsMap) AcceptStream() (Stream, error) {
  125. return m.incomingBidiStreams.AcceptStream()
  126. }
  127. func (m *streamsMap) AcceptUniStream() (ReceiveStream, error) {
  128. return m.incomingUniStreams.AcceptStream()
  129. }
  130. func (m *streamsMap) DeleteStream(id protocol.StreamID) error {
  131. switch m.getStreamType(id) {
  132. case streamTypeIncomingBidi:
  133. return m.incomingBidiStreams.DeleteStream(id)
  134. case streamTypeOutgoingBidi:
  135. return m.outgoingBidiStreams.DeleteStream(id)
  136. case streamTypeIncomingUni:
  137. return m.incomingUniStreams.DeleteStream(id)
  138. case streamTypeOutgoingUni:
  139. return m.outgoingUniStreams.DeleteStream(id)
  140. default:
  141. panic("invalid stream type")
  142. }
  143. }
  144. func (m *streamsMap) GetOrOpenReceiveStream(id protocol.StreamID) (receiveStreamI, error) {
  145. switch m.getStreamType(id) {
  146. case streamTypeOutgoingBidi:
  147. return m.outgoingBidiStreams.GetStream(id)
  148. case streamTypeIncomingBidi:
  149. return m.incomingBidiStreams.GetOrOpenStream(id)
  150. case streamTypeIncomingUni:
  151. return m.incomingUniStreams.GetOrOpenStream(id)
  152. case streamTypeOutgoingUni:
  153. // an outgoing unidirectional stream is a send stream, not a receive stream
  154. return nil, fmt.Errorf("peer attempted to open receive stream %d", id)
  155. default:
  156. panic("invalid stream type")
  157. }
  158. }
  159. func (m *streamsMap) GetOrOpenSendStream(id protocol.StreamID) (sendStreamI, error) {
  160. switch m.getStreamType(id) {
  161. case streamTypeOutgoingBidi:
  162. return m.outgoingBidiStreams.GetStream(id)
  163. case streamTypeIncomingBidi:
  164. return m.incomingBidiStreams.GetOrOpenStream(id)
  165. case streamTypeOutgoingUni:
  166. return m.outgoingUniStreams.GetStream(id)
  167. case streamTypeIncomingUni:
  168. // an incoming unidirectional stream is a receive stream, not a send stream
  169. return nil, fmt.Errorf("peer attempted to open send stream %d", id)
  170. default:
  171. panic("invalid stream type")
  172. }
  173. }
  174. func (m *streamsMap) HandleMaxStreamIDFrame(f *wire.MaxStreamIDFrame) error {
  175. id := f.StreamID
  176. switch m.getStreamType(id) {
  177. case streamTypeOutgoingBidi:
  178. m.outgoingBidiStreams.SetMaxStream(id)
  179. return nil
  180. case streamTypeOutgoingUni:
  181. m.outgoingUniStreams.SetMaxStream(id)
  182. return nil
  183. default:
  184. return fmt.Errorf("received MAX_STREAM_DATA frame for incoming stream %d", id)
  185. }
  186. }
  187. func (m *streamsMap) UpdateLimits(p *handshake.TransportParameters) {
  188. // Max{Uni,Bidi}StreamID returns the highest stream ID that the peer is allowed to open.
  189. // Invert the perspective to determine the value that we are allowed to open.
  190. peerPers := protocol.PerspectiveServer
  191. if m.perspective == protocol.PerspectiveServer {
  192. peerPers = protocol.PerspectiveClient
  193. }
  194. m.outgoingBidiStreams.SetMaxStream(protocol.MaxBidiStreamID(int(p.MaxBidiStreams), peerPers))
  195. m.outgoingUniStreams.SetMaxStream(protocol.MaxUniStreamID(int(p.MaxUniStreams), peerPers))
  196. }
  197. func (m *streamsMap) CloseWithError(err error) {
  198. m.outgoingBidiStreams.CloseWithError(err)
  199. m.outgoingUniStreams.CloseWithError(err)
  200. m.incomingBidiStreams.CloseWithError(err)
  201. m.incomingUniStreams.CloseWithError(err)
  202. }