streams_map_test.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. package quic
  2. import (
  3. "errors"
  4. "fmt"
  5. "math"
  6. "github.com/golang/mock/gomock"
  7. "github.com/lucas-clemente/quic-go/internal/flowcontrol"
  8. "github.com/lucas-clemente/quic-go/internal/handshake"
  9. "github.com/lucas-clemente/quic-go/internal/mocks"
  10. "github.com/lucas-clemente/quic-go/internal/protocol"
  11. "github.com/lucas-clemente/quic-go/internal/wire"
  12. "github.com/lucas-clemente/quic-go/qerr"
  13. . "github.com/onsi/ginkgo"
  14. . "github.com/onsi/gomega"
  15. )
  16. type streamMapping struct {
  17. firstIncomingBidiStream protocol.StreamID
  18. firstIncomingUniStream protocol.StreamID
  19. firstOutgoingBidiStream protocol.StreamID
  20. firstOutgoingUniStream protocol.StreamID
  21. }
  22. var _ = Describe("Streams Map (for IETF QUIC)", func() {
  23. newFlowController := func(protocol.StreamID) flowcontrol.StreamFlowController {
  24. return mocks.NewMockStreamFlowController(mockCtrl)
  25. }
  26. serverStreamMapping := streamMapping{
  27. firstIncomingBidiStream: 4,
  28. firstOutgoingBidiStream: 1,
  29. firstIncomingUniStream: 2,
  30. firstOutgoingUniStream: 3,
  31. }
  32. clientStreamMapping := streamMapping{
  33. firstIncomingBidiStream: 1,
  34. firstOutgoingBidiStream: 4,
  35. firstIncomingUniStream: 3,
  36. firstOutgoingUniStream: 2,
  37. }
  38. for _, p := range []protocol.Perspective{protocol.PerspectiveServer, protocol.PerspectiveClient} {
  39. perspective := p
  40. var ids streamMapping
  41. if perspective == protocol.PerspectiveClient {
  42. ids = clientStreamMapping
  43. } else {
  44. ids = serverStreamMapping
  45. }
  46. Context(perspective.String(), func() {
  47. var (
  48. m *streamsMap
  49. mockSender *MockStreamSender
  50. )
  51. const (
  52. maxBidiStreams = 111
  53. maxUniStreams = 222
  54. )
  55. allowUnlimitedStreams := func() {
  56. m.UpdateLimits(&handshake.TransportParameters{
  57. MaxBidiStreams: math.MaxUint16,
  58. MaxUniStreams: math.MaxUint16,
  59. })
  60. }
  61. BeforeEach(func() {
  62. mockSender = NewMockStreamSender(mockCtrl)
  63. m = newStreamsMap(mockSender, newFlowController, maxBidiStreams, maxUniStreams, perspective, versionIETFFrames).(*streamsMap)
  64. })
  65. Context("opening", func() {
  66. It("opens bidirectional streams", func() {
  67. allowUnlimitedStreams()
  68. str, err := m.OpenStream()
  69. Expect(err).ToNot(HaveOccurred())
  70. Expect(str).To(BeAssignableToTypeOf(&stream{}))
  71. Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
  72. str, err = m.OpenStream()
  73. Expect(err).ToNot(HaveOccurred())
  74. Expect(str).To(BeAssignableToTypeOf(&stream{}))
  75. Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream + 4))
  76. })
  77. It("opens unidirectional streams", func() {
  78. allowUnlimitedStreams()
  79. str, err := m.OpenUniStream()
  80. Expect(err).ToNot(HaveOccurred())
  81. Expect(str).To(BeAssignableToTypeOf(&sendStream{}))
  82. Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
  83. str, err = m.OpenUniStream()
  84. Expect(err).ToNot(HaveOccurred())
  85. Expect(str).To(BeAssignableToTypeOf(&sendStream{}))
  86. Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream + 4))
  87. })
  88. })
  89. Context("accepting", func() {
  90. It("accepts bidirectional streams", func() {
  91. _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream)
  92. Expect(err).ToNot(HaveOccurred())
  93. str, err := m.AcceptStream()
  94. Expect(err).ToNot(HaveOccurred())
  95. Expect(str).To(BeAssignableToTypeOf(&stream{}))
  96. Expect(str.StreamID()).To(Equal(ids.firstIncomingBidiStream))
  97. })
  98. It("accepts unidirectional streams", func() {
  99. _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream)
  100. Expect(err).ToNot(HaveOccurred())
  101. str, err := m.AcceptUniStream()
  102. Expect(err).ToNot(HaveOccurred())
  103. Expect(str).To(BeAssignableToTypeOf(&receiveStream{}))
  104. Expect(str.StreamID()).To(Equal(ids.firstIncomingUniStream))
  105. })
  106. })
  107. Context("deleting", func() {
  108. BeforeEach(func() {
  109. mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
  110. allowUnlimitedStreams()
  111. })
  112. It("deletes outgoing bidirectional streams", func() {
  113. id := ids.firstOutgoingBidiStream
  114. str, err := m.OpenStream()
  115. Expect(err).ToNot(HaveOccurred())
  116. Expect(str.StreamID()).To(Equal(id))
  117. Expect(m.DeleteStream(id)).To(Succeed())
  118. dstr, err := m.GetOrOpenSendStream(id)
  119. Expect(err).ToNot(HaveOccurred())
  120. Expect(dstr).To(BeNil())
  121. })
  122. It("deletes incoming bidirectional streams", func() {
  123. id := ids.firstIncomingBidiStream
  124. str, err := m.GetOrOpenReceiveStream(id)
  125. Expect(err).ToNot(HaveOccurred())
  126. Expect(str.StreamID()).To(Equal(id))
  127. Expect(m.DeleteStream(id)).To(Succeed())
  128. dstr, err := m.GetOrOpenReceiveStream(id)
  129. Expect(err).ToNot(HaveOccurred())
  130. Expect(dstr).To(BeNil())
  131. })
  132. It("deletes outgoing unidirectional streams", func() {
  133. id := ids.firstOutgoingUniStream
  134. str, err := m.OpenUniStream()
  135. Expect(err).ToNot(HaveOccurred())
  136. Expect(str.StreamID()).To(Equal(id))
  137. Expect(m.DeleteStream(id)).To(Succeed())
  138. dstr, err := m.GetOrOpenSendStream(id)
  139. Expect(err).ToNot(HaveOccurred())
  140. Expect(dstr).To(BeNil())
  141. })
  142. It("deletes incoming unidirectional streams", func() {
  143. id := ids.firstIncomingUniStream
  144. str, err := m.GetOrOpenReceiveStream(id)
  145. Expect(err).ToNot(HaveOccurred())
  146. Expect(str.StreamID()).To(Equal(id))
  147. Expect(m.DeleteStream(id)).To(Succeed())
  148. dstr, err := m.GetOrOpenReceiveStream(id)
  149. Expect(err).ToNot(HaveOccurred())
  150. Expect(dstr).To(BeNil())
  151. })
  152. })
  153. Context("getting streams", func() {
  154. BeforeEach(func() {
  155. allowUnlimitedStreams()
  156. })
  157. Context("send streams", func() {
  158. It("gets an outgoing bidirectional stream", func() {
  159. // need to open the stream ourselves first
  160. // the peer is not allowed to create a stream initiated by us
  161. _, err := m.OpenStream()
  162. Expect(err).ToNot(HaveOccurred())
  163. str, err := m.GetOrOpenSendStream(ids.firstOutgoingBidiStream)
  164. Expect(err).ToNot(HaveOccurred())
  165. Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
  166. })
  167. It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
  168. id := ids.firstOutgoingBidiStream + 5*4
  169. _, err := m.GetOrOpenSendStream(id)
  170. Expect(err).To(MatchError(qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))))
  171. })
  172. It("gets an outgoing unidirectional stream", func() {
  173. // need to open the stream ourselves first
  174. // the peer is not allowed to create a stream initiated by us
  175. _, err := m.OpenUniStream()
  176. Expect(err).ToNot(HaveOccurred())
  177. str, err := m.GetOrOpenSendStream(ids.firstOutgoingUniStream)
  178. Expect(err).ToNot(HaveOccurred())
  179. Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
  180. })
  181. It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
  182. id := ids.firstOutgoingUniStream + 5*4
  183. _, err := m.GetOrOpenSendStream(id)
  184. Expect(err).To(MatchError(qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))))
  185. })
  186. It("gets an incoming bidirectional stream", func() {
  187. id := ids.firstIncomingBidiStream + 4*7
  188. str, err := m.GetOrOpenSendStream(id)
  189. Expect(err).ToNot(HaveOccurred())
  190. Expect(str.StreamID()).To(Equal(id))
  191. })
  192. It("errors when trying to get an incoming unidirectional stream", func() {
  193. id := ids.firstIncomingUniStream
  194. _, err := m.GetOrOpenSendStream(id)
  195. Expect(err).To(MatchError(fmt.Errorf("peer attempted to open send stream %d", id)))
  196. })
  197. })
  198. Context("receive streams", func() {
  199. It("gets an outgoing bidirectional stream", func() {
  200. // need to open the stream ourselves first
  201. // the peer is not allowed to create a stream initiated by us
  202. _, err := m.OpenStream()
  203. Expect(err).ToNot(HaveOccurred())
  204. str, err := m.GetOrOpenReceiveStream(ids.firstOutgoingBidiStream)
  205. Expect(err).ToNot(HaveOccurred())
  206. Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
  207. })
  208. It("errors when the peer tries to open a higher outgoing bidirectional stream", func() {
  209. id := ids.firstOutgoingBidiStream + 5*4
  210. _, err := m.GetOrOpenReceiveStream(id)
  211. Expect(err).To(MatchError(qerr.Error(qerr.InvalidStreamID, fmt.Sprintf("peer attempted to open stream %d", id))))
  212. })
  213. It("gets an incoming bidirectional stream", func() {
  214. id := ids.firstIncomingBidiStream + 4*7
  215. str, err := m.GetOrOpenReceiveStream(id)
  216. Expect(err).ToNot(HaveOccurred())
  217. Expect(str.StreamID()).To(Equal(id))
  218. })
  219. It("gets an incoming unidirectional stream", func() {
  220. id := ids.firstIncomingUniStream + 4*10
  221. str, err := m.GetOrOpenReceiveStream(id)
  222. Expect(err).ToNot(HaveOccurred())
  223. Expect(str.StreamID()).To(Equal(id))
  224. })
  225. It("errors when trying to get an outgoing unidirectional stream", func() {
  226. id := ids.firstOutgoingUniStream
  227. _, err := m.GetOrOpenReceiveStream(id)
  228. Expect(err).To(MatchError(fmt.Errorf("peer attempted to open receive stream %d", id)))
  229. })
  230. })
  231. })
  232. Context("updating stream ID limits", func() {
  233. BeforeEach(func() {
  234. mockSender.EXPECT().queueControlFrame(gomock.Any())
  235. })
  236. It("processes the parameter for outgoing streams, as a server", func() {
  237. m.perspective = protocol.PerspectiveServer
  238. _, err := m.OpenStream()
  239. Expect(err).To(MatchError(qerr.TooManyOpenStreams))
  240. m.UpdateLimits(&handshake.TransportParameters{
  241. MaxBidiStreams: 5,
  242. MaxUniStreams: 5,
  243. })
  244. Expect(m.outgoingBidiStreams.maxStream).To(Equal(protocol.StreamID(17)))
  245. Expect(m.outgoingUniStreams.maxStream).To(Equal(protocol.StreamID(19)))
  246. })
  247. It("processes the parameter for outgoing streams, as a client", func() {
  248. m.perspective = protocol.PerspectiveClient
  249. _, err := m.OpenUniStream()
  250. Expect(err).To(MatchError(qerr.TooManyOpenStreams))
  251. m.UpdateLimits(&handshake.TransportParameters{
  252. MaxBidiStreams: 5,
  253. MaxUniStreams: 5,
  254. })
  255. Expect(m.outgoingBidiStreams.maxStream).To(Equal(protocol.StreamID(20)))
  256. Expect(m.outgoingUniStreams.maxStream).To(Equal(protocol.StreamID(18)))
  257. })
  258. })
  259. Context("handling MAX_STREAM_ID frames", func() {
  260. BeforeEach(func() {
  261. mockSender.EXPECT().queueControlFrame(gomock.Any()).AnyTimes()
  262. })
  263. It("processes IDs for outgoing bidirectional streams", func() {
  264. _, err := m.OpenStream()
  265. Expect(err).To(MatchError(qerr.TooManyOpenStreams))
  266. err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingBidiStream})
  267. Expect(err).ToNot(HaveOccurred())
  268. str, err := m.OpenStream()
  269. Expect(err).ToNot(HaveOccurred())
  270. Expect(str.StreamID()).To(Equal(ids.firstOutgoingBidiStream))
  271. })
  272. It("processes IDs for outgoing bidirectional streams", func() {
  273. _, err := m.OpenUniStream()
  274. Expect(err).To(MatchError(qerr.TooManyOpenStreams))
  275. err = m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstOutgoingUniStream})
  276. Expect(err).ToNot(HaveOccurred())
  277. str, err := m.OpenUniStream()
  278. Expect(err).ToNot(HaveOccurred())
  279. Expect(str.StreamID()).To(Equal(ids.firstOutgoingUniStream))
  280. })
  281. It("rejects IDs for incoming bidirectional streams", func() {
  282. err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingBidiStream})
  283. Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingBidiStream)))
  284. })
  285. It("rejects IDs for incoming unidirectional streams", func() {
  286. err := m.HandleMaxStreamIDFrame(&wire.MaxStreamIDFrame{StreamID: ids.firstIncomingUniStream})
  287. Expect(err).To(MatchError(fmt.Sprintf("received MAX_STREAM_DATA frame for incoming stream %d", ids.firstIncomingUniStream)))
  288. })
  289. })
  290. Context("sending MAX_STREAM_ID frames", func() {
  291. It("sends MAX_STREAM_ID frames for bidirectional streams", func() {
  292. _, err := m.GetOrOpenReceiveStream(ids.firstIncomingBidiStream + 4*10)
  293. Expect(err).ToNot(HaveOccurred())
  294. mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{
  295. StreamID: protocol.MaxBidiStreamID(maxBidiStreams, perspective) + 4,
  296. })
  297. Expect(m.DeleteStream(ids.firstIncomingBidiStream)).To(Succeed())
  298. })
  299. It("sends MAX_STREAM_ID frames for unidirectional streams", func() {
  300. _, err := m.GetOrOpenReceiveStream(ids.firstIncomingUniStream + 4*10)
  301. Expect(err).ToNot(HaveOccurred())
  302. mockSender.EXPECT().queueControlFrame(&wire.MaxStreamIDFrame{
  303. StreamID: protocol.MaxUniStreamID(maxUniStreams, perspective) + 4,
  304. })
  305. Expect(m.DeleteStream(ids.firstIncomingUniStream)).To(Succeed())
  306. })
  307. })
  308. It("closes", func() {
  309. testErr := errors.New("test error")
  310. m.CloseWithError(testErr)
  311. _, err := m.OpenStream()
  312. Expect(err).To(MatchError(testErr))
  313. _, err = m.OpenUniStream()
  314. Expect(err).To(MatchError(testErr))
  315. _, err = m.AcceptStream()
  316. Expect(err).To(MatchError(testErr))
  317. _, err = m.AcceptUniStream()
  318. Expect(err).To(MatchError(testErr))
  319. })
  320. })
  321. }
  322. })