server_test.go 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519
  1. package quic
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "errors"
  6. "net"
  7. "reflect"
  8. "time"
  9. "github.com/golang/mock/gomock"
  10. "github.com/lucas-clemente/quic-go/internal/handshake"
  11. "github.com/lucas-clemente/quic-go/internal/protocol"
  12. "github.com/lucas-clemente/quic-go/internal/testdata"
  13. "github.com/lucas-clemente/quic-go/internal/utils"
  14. "github.com/lucas-clemente/quic-go/internal/wire"
  15. . "github.com/onsi/ginkgo"
  16. . "github.com/onsi/gomega"
  17. )
  18. type mockSession struct {
  19. *MockQuicSession
  20. connID protocol.ConnectionID
  21. runner sessionRunner
  22. }
  23. func (s *mockSession) GetPerspective() protocol.Perspective { panic("not implemented") }
  24. var _ = Describe("Server", func() {
  25. var (
  26. conn *mockPacketConn
  27. config *Config
  28. udpAddr = &net.UDPAddr{IP: net.IPv4(192, 168, 100, 200), Port: 1337}
  29. )
  30. BeforeEach(func() {
  31. conn = newMockPacketConn()
  32. conn.addr = &net.UDPAddr{}
  33. config = &Config{Versions: protocol.SupportedVersions}
  34. })
  35. Context("quic.Config", func() {
  36. It("setups with the right values", func() {
  37. config := &Config{
  38. HandshakeTimeout: 1337 * time.Minute,
  39. IdleTimeout: 42 * time.Hour,
  40. RequestConnectionIDOmission: true,
  41. MaxIncomingStreams: 1234,
  42. MaxIncomingUniStreams: 4321,
  43. ConnectionIDLength: 12,
  44. Versions: []protocol.VersionNumber{VersionGQUIC43},
  45. }
  46. c := populateServerConfig(config)
  47. Expect(c.HandshakeTimeout).To(Equal(1337 * time.Minute))
  48. Expect(c.IdleTimeout).To(Equal(42 * time.Hour))
  49. Expect(c.RequestConnectionIDOmission).To(BeFalse())
  50. Expect(c.MaxIncomingStreams).To(Equal(1234))
  51. Expect(c.MaxIncomingUniStreams).To(Equal(4321))
  52. Expect(c.ConnectionIDLength).To(Equal(12))
  53. Expect(c.Versions).To(Equal([]protocol.VersionNumber{VersionGQUIC43}))
  54. })
  55. It("uses 8 byte connection IDs if gQUIC 44 is supported", func() {
  56. config := &Config{
  57. Versions: []protocol.VersionNumber{protocol.Version43, protocol.Version44},
  58. ConnectionIDLength: 13,
  59. }
  60. c := populateServerConfig(config)
  61. Expect(c.Versions).To(Equal([]protocol.VersionNumber{protocol.Version43, protocol.Version44}))
  62. Expect(c.ConnectionIDLength).To(Equal(8))
  63. })
  64. It("uses 4 byte connection IDs by default, if gQUIC 44 is not supported", func() {
  65. config := &Config{
  66. Versions: []protocol.VersionNumber{protocol.Version39},
  67. }
  68. c := populateServerConfig(config)
  69. Expect(c.ConnectionIDLength).To(Equal(protocol.DefaultConnectionIDLength))
  70. })
  71. It("disables bidirectional streams", func() {
  72. config := &Config{
  73. MaxIncomingStreams: -1,
  74. MaxIncomingUniStreams: 4321,
  75. }
  76. c := populateServerConfig(config)
  77. Expect(c.MaxIncomingStreams).To(BeZero())
  78. Expect(c.MaxIncomingUniStreams).To(Equal(4321))
  79. })
  80. It("disables unidirectional streams", func() {
  81. config := &Config{
  82. MaxIncomingStreams: 1234,
  83. MaxIncomingUniStreams: -1,
  84. }
  85. c := populateServerConfig(config)
  86. Expect(c.MaxIncomingStreams).To(Equal(1234))
  87. Expect(c.MaxIncomingUniStreams).To(BeZero())
  88. })
  89. })
  90. Context("with mock session", func() {
  91. var (
  92. serv *server
  93. firstPacket *receivedPacket
  94. connID = protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6}
  95. sessions = make([]*MockQuicSession, 0)
  96. sessionHandler *MockPacketHandlerManager
  97. )
  98. BeforeEach(func() {
  99. sessionHandler = NewMockPacketHandlerManager(mockCtrl)
  100. newMockSession := func(
  101. _ connection,
  102. runner sessionRunner,
  103. _ protocol.VersionNumber,
  104. connID protocol.ConnectionID,
  105. _ protocol.ConnectionID,
  106. _ *handshake.ServerConfig,
  107. _ *tls.Config,
  108. _ *Config,
  109. _ utils.Logger,
  110. ) (quicSession, error) {
  111. ExpectWithOffset(0, sessions).ToNot(BeEmpty())
  112. s := &mockSession{MockQuicSession: sessions[0]}
  113. s.connID = connID
  114. s.runner = runner
  115. sessions = sessions[1:]
  116. return s, nil
  117. }
  118. serv = &server{
  119. sessionHandler: sessionHandler,
  120. newSession: newMockSession,
  121. conn: conn,
  122. config: config,
  123. sessionQueue: make(chan Session, 5),
  124. errorChan: make(chan struct{}),
  125. logger: utils.DefaultLogger,
  126. }
  127. serv.setup()
  128. b := &bytes.Buffer{}
  129. utils.BigEndian.WriteUint32(b, uint32(protocol.SupportedVersions[0]))
  130. firstPacket = &receivedPacket{
  131. header: &wire.Header{
  132. VersionFlag: true,
  133. Version: serv.config.Versions[0],
  134. DestConnectionID: protocol.ConnectionID{0x4c, 0xfa, 0x9f, 0x9b, 0x66, 0x86, 0x19, 0xf6},
  135. PacketNumber: 1,
  136. },
  137. data: bytes.Repeat([]byte{0}, protocol.MinClientHelloSize),
  138. rcvTime: time.Now(),
  139. }
  140. })
  141. AfterEach(func() {
  142. Expect(sessions).To(BeEmpty())
  143. })
  144. It("returns the address", func() {
  145. conn.addr = &net.UDPAddr{
  146. IP: net.IPv4(192, 168, 13, 37),
  147. Port: 1234,
  148. }
  149. Expect(serv.Addr().String()).To(Equal("192.168.13.37:1234"))
  150. })
  151. It("creates new sessions", func() {
  152. s := NewMockQuicSession(mockCtrl)
  153. s.EXPECT().handlePacket(gomock.Any())
  154. run := make(chan struct{})
  155. s.EXPECT().run().Do(func() { close(run) })
  156. sessions = append(sessions, s)
  157. sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(cid protocol.ConnectionID, _ packetHandler) {
  158. Expect(cid).To(Equal(connID))
  159. })
  160. Expect(serv.handlePacketImpl(firstPacket)).To(Succeed())
  161. Eventually(run).Should(BeClosed())
  162. })
  163. It("accepts new TLS sessions", func() {
  164. connID := protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
  165. sess := NewMockQuicSession(mockCtrl)
  166. err := serv.setupTLS()
  167. Expect(err).ToNot(HaveOccurred())
  168. added := make(chan struct{})
  169. sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, ph packetHandler) {
  170. Expect(ph.GetPerspective()).To(Equal(protocol.PerspectiveServer))
  171. close(added)
  172. })
  173. serv.serverTLS.sessionChan <- tlsSession{
  174. connID: connID,
  175. sess: sess,
  176. }
  177. Eventually(added).Should(BeClosed())
  178. })
  179. It("accepts a session once the connection it is forward secure", func() {
  180. s := NewMockQuicSession(mockCtrl)
  181. s.EXPECT().handlePacket(gomock.Any())
  182. run := make(chan struct{})
  183. s.EXPECT().run().Do(func() { close(run) })
  184. sessions = append(sessions, s)
  185. done := make(chan struct{})
  186. go func() {
  187. defer GinkgoRecover()
  188. _, err := serv.Accept()
  189. Expect(err).ToNot(HaveOccurred())
  190. close(done)
  191. }()
  192. sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(_ protocol.ConnectionID, sess packetHandler) {
  193. Consistently(done).ShouldNot(BeClosed())
  194. sess.(*serverSession).quicSession.(*mockSession).runner.onHandshakeComplete(sess.(Session))
  195. })
  196. err := serv.handlePacketImpl(firstPacket)
  197. Expect(err).ToNot(HaveOccurred())
  198. Eventually(done).Should(BeClosed())
  199. Eventually(run).Should(BeClosed())
  200. })
  201. It("doesn't accept sessions that error during the handshake", func() {
  202. run := make(chan error, 1)
  203. sess := NewMockQuicSession(mockCtrl)
  204. sess.EXPECT().handlePacket(gomock.Any())
  205. sess.EXPECT().run().DoAndReturn(func() error { return <-run })
  206. sessions = append(sessions, sess)
  207. done := make(chan struct{})
  208. go func() {
  209. defer GinkgoRecover()
  210. serv.Accept()
  211. close(done)
  212. }()
  213. sessionHandler.EXPECT().Add(connID, gomock.Any()).Do(func(protocol.ConnectionID, packetHandler) {
  214. run <- errors.New("handshake error")
  215. })
  216. Expect(serv.handlePacketImpl(firstPacket)).To(Succeed())
  217. Consistently(done).ShouldNot(BeClosed())
  218. // make the go routine return
  219. close(serv.errorChan)
  220. Eventually(done).Should(BeClosed())
  221. })
  222. It("closes the sessionHandler when Close is called", func() {
  223. sessionHandler.EXPECT().CloseServer()
  224. Expect(serv.Close()).To(Succeed())
  225. })
  226. It("closes twice", func() {
  227. sessionHandler.EXPECT().CloseServer()
  228. Expect(serv.Close()).To(Succeed())
  229. Expect(serv.Close()).To(Succeed())
  230. })
  231. It("works if no quic.Config is given", func(done Done) {
  232. ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), nil)
  233. Expect(err).ToNot(HaveOccurred())
  234. Expect(ln.Close()).To(Succeed())
  235. close(done)
  236. }, 1)
  237. It("closes properly", func() {
  238. ln, err := ListenAddr("127.0.0.1:0", testdata.GetTLSConfig(), config)
  239. Expect(err).ToNot(HaveOccurred())
  240. done := make(chan struct{})
  241. go func() {
  242. defer GinkgoRecover()
  243. ln.Accept()
  244. close(done)
  245. }()
  246. ln.Close()
  247. Eventually(done).Should(BeClosed())
  248. })
  249. It("closes the connection when it was created with ListenAddr", func() {
  250. addr, err := net.ResolveUDPAddr("udp", "localhost:12345")
  251. Expect(err).ToNot(HaveOccurred())
  252. serv, err := ListenAddr("localhost:0", nil, nil)
  253. Expect(err).ToNot(HaveOccurred())
  254. // test that we can write on the packet conn
  255. _, err = serv.(*server).conn.WriteTo([]byte("foobar"), addr)
  256. Expect(err).ToNot(HaveOccurred())
  257. Expect(serv.Close()).To(Succeed())
  258. // test that we can't write any more on the packet conn
  259. _, err = serv.(*server).conn.WriteTo([]byte("foobar"), addr)
  260. Expect(err.Error()).To(ContainSubstring("use of closed network connection"))
  261. })
  262. It("returns Accept when it is closed", func() {
  263. done := make(chan struct{})
  264. go func() {
  265. defer GinkgoRecover()
  266. _, err := serv.Accept()
  267. Expect(err).To(MatchError("server closed"))
  268. close(done)
  269. }()
  270. sessionHandler.EXPECT().CloseServer()
  271. Expect(serv.Close()).To(Succeed())
  272. Eventually(done).Should(BeClosed())
  273. })
  274. It("returns Accept with the right error when closeWithError is called", func() {
  275. testErr := errors.New("connection error")
  276. done := make(chan struct{})
  277. go func() {
  278. defer GinkgoRecover()
  279. _, err := serv.Accept()
  280. Expect(err).To(MatchError(testErr))
  281. close(done)
  282. }()
  283. sessionHandler.EXPECT().CloseServer()
  284. serv.closeWithError(testErr)
  285. Eventually(done).Should(BeClosed())
  286. })
  287. It("doesn't try to process a packet after sending a gQUIC Version Negotiation Packet", func() {
  288. config.Versions = []protocol.VersionNumber{99}
  289. p := &receivedPacket{
  290. header: &wire.Header{
  291. VersionFlag: true,
  292. DestConnectionID: connID,
  293. PacketNumber: 1,
  294. PacketNumberLen: protocol.PacketNumberLen2,
  295. },
  296. data: make([]byte, protocol.MinClientHelloSize),
  297. }
  298. Expect(serv.handlePacketImpl(p)).To(Succeed())
  299. Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
  300. })
  301. It("sends a PUBLIC_RESET for new connections that don't have the VersionFlag set", func() {
  302. err := serv.handlePacketImpl(&receivedPacket{
  303. remoteAddr: udpAddr,
  304. header: &wire.Header{
  305. IsPublicHeader: true,
  306. Version: versionGQUICFrames,
  307. },
  308. })
  309. Expect(err).ToNot(HaveOccurred())
  310. Expect(conn.dataWritten.Len()).ToNot(BeZero())
  311. Expect(conn.dataWrittenTo).To(Equal(udpAddr))
  312. Expect(conn.dataWritten.Bytes()[0] & 0x02).ToNot(BeZero()) // check that the ResetFlag is set
  313. })
  314. It("sends a gQUIC Version Negotaion Packet, if the client sent a gQUIC Public Header", func() {
  315. connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
  316. err := serv.handlePacketImpl(&receivedPacket{
  317. remoteAddr: udpAddr,
  318. header: &wire.Header{
  319. IsPublicHeader: true,
  320. VersionFlag: true,
  321. DestConnectionID: connID,
  322. PacketNumber: 1,
  323. PacketNumberLen: protocol.PacketNumberLen2,
  324. Version: protocol.Version39 - 1,
  325. },
  326. })
  327. Expect(err).ToNot(HaveOccurred())
  328. Expect(conn.dataWritten.Len()).ToNot(BeZero())
  329. Expect(conn.dataWrittenTo).To(Equal(udpAddr))
  330. r := bytes.NewReader(conn.dataWritten.Bytes())
  331. iHdr, err := wire.ParseInvariantHeader(r, 0)
  332. Expect(err).ToNot(HaveOccurred())
  333. Expect(iHdr.IsLongHeader).To(BeFalse())
  334. replyHdr, err := iHdr.Parse(r, protocol.PerspectiveServer, versionIETFFrames)
  335. Expect(err).ToNot(HaveOccurred())
  336. Expect(replyHdr.IsVersionNegotiation).To(BeTrue())
  337. Expect(replyHdr.DestConnectionID).To(Equal(connID))
  338. Expect(r.Len()).To(BeZero())
  339. })
  340. It("sends an IETF draft style Version Negotaion Packet, if the client sent a IETF draft style header", func() {
  341. connID := protocol.ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
  342. err := serv.handlePacketImpl(&receivedPacket{
  343. remoteAddr: udpAddr,
  344. header: &wire.Header{
  345. Type: protocol.PacketTypeInitial,
  346. IsLongHeader: true,
  347. DestConnectionID: connID,
  348. SrcConnectionID: connID,
  349. PacketNumber: 0x55,
  350. PacketNumberLen: protocol.PacketNumberLen1,
  351. Version: 0x1234,
  352. PayloadLen: protocol.MinInitialPacketSize,
  353. },
  354. })
  355. Expect(err).ToNot(HaveOccurred())
  356. Expect(conn.dataWritten.Len()).ToNot(BeZero())
  357. Expect(conn.dataWrittenTo).To(Equal(udpAddr))
  358. r := bytes.NewReader(conn.dataWritten.Bytes())
  359. iHdr, err := wire.ParseInvariantHeader(r, 0)
  360. Expect(err).ToNot(HaveOccurred())
  361. replyHdr, err := iHdr.Parse(r, protocol.PerspectiveServer, versionIETFFrames)
  362. Expect(err).ToNot(HaveOccurred())
  363. Expect(replyHdr.IsVersionNegotiation).To(BeTrue())
  364. Expect(replyHdr.DestConnectionID).To(Equal(connID))
  365. Expect(replyHdr.SrcConnectionID).To(Equal(connID))
  366. Expect(r.Len()).To(BeZero())
  367. })
  368. })
  369. It("setups with the right values", func() {
  370. supportedVersions := []protocol.VersionNumber{protocol.VersionTLS, protocol.Version39}
  371. acceptCookie := func(_ net.Addr, _ *Cookie) bool { return true }
  372. config := Config{
  373. Versions: supportedVersions,
  374. AcceptCookie: acceptCookie,
  375. HandshakeTimeout: 1337 * time.Hour,
  376. IdleTimeout: 42 * time.Minute,
  377. KeepAlive: true,
  378. }
  379. ln, err := Listen(conn, &tls.Config{}, &config)
  380. Expect(err).ToNot(HaveOccurred())
  381. server := ln.(*server)
  382. Expect(server.sessionHandler).ToNot(BeNil())
  383. Expect(server.scfg).ToNot(BeNil())
  384. Expect(server.config.Versions).To(Equal(supportedVersions))
  385. Expect(server.config.HandshakeTimeout).To(Equal(1337 * time.Hour))
  386. Expect(server.config.IdleTimeout).To(Equal(42 * time.Minute))
  387. Expect(reflect.ValueOf(server.config.AcceptCookie)).To(Equal(reflect.ValueOf(acceptCookie)))
  388. Expect(server.config.KeepAlive).To(BeTrue())
  389. })
  390. It("errors when the Config contains an invalid version", func() {
  391. version := protocol.VersionNumber(0x1234)
  392. _, err := Listen(conn, &tls.Config{}, &Config{Versions: []protocol.VersionNumber{version}})
  393. Expect(err).To(MatchError("0x1234 is not a valid QUIC version"))
  394. })
  395. It("fills in default values if options are not set in the Config", func() {
  396. ln, err := Listen(conn, &tls.Config{}, &Config{})
  397. Expect(err).ToNot(HaveOccurred())
  398. server := ln.(*server)
  399. Expect(server.config.Versions).To(Equal(protocol.SupportedVersions))
  400. Expect(server.config.HandshakeTimeout).To(Equal(protocol.DefaultHandshakeTimeout))
  401. Expect(server.config.IdleTimeout).To(Equal(protocol.DefaultIdleTimeout))
  402. Expect(reflect.ValueOf(server.config.AcceptCookie)).To(Equal(reflect.ValueOf(defaultAcceptCookie)))
  403. Expect(server.config.KeepAlive).To(BeFalse())
  404. })
  405. It("listens on a given address", func() {
  406. addr := "127.0.0.1:13579"
  407. ln, err := ListenAddr(addr, nil, config)
  408. Expect(err).ToNot(HaveOccurred())
  409. serv := ln.(*server)
  410. Expect(serv.Addr().String()).To(Equal(addr))
  411. })
  412. It("errors if given an invalid address", func() {
  413. addr := "127.0.0.1"
  414. _, err := ListenAddr(addr, nil, config)
  415. Expect(err).To(BeAssignableToTypeOf(&net.AddrError{}))
  416. })
  417. It("errors if given an invalid address", func() {
  418. addr := "1.1.1.1:1111"
  419. _, err := ListenAddr(addr, nil, config)
  420. Expect(err).To(BeAssignableToTypeOf(&net.OpError{}))
  421. })
  422. })
  423. var _ = Describe("default source address verification", func() {
  424. It("accepts a token", func() {
  425. remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
  426. cookie := &Cookie{
  427. RemoteAddr: "192.168.0.1",
  428. SentTime: time.Now().Add(-protocol.CookieExpiryTime).Add(time.Second), // will expire in 1 second
  429. }
  430. Expect(defaultAcceptCookie(remoteAddr, cookie)).To(BeTrue())
  431. })
  432. It("requests verification if no token is provided", func() {
  433. remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
  434. Expect(defaultAcceptCookie(remoteAddr, nil)).To(BeFalse())
  435. })
  436. It("rejects a token if the address doesn't match", func() {
  437. remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
  438. cookie := &Cookie{
  439. RemoteAddr: "127.0.0.1",
  440. SentTime: time.Now(),
  441. }
  442. Expect(defaultAcceptCookie(remoteAddr, cookie)).To(BeFalse())
  443. })
  444. It("accepts a token for a remote address is not a UDP address", func() {
  445. remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
  446. cookie := &Cookie{
  447. RemoteAddr: "192.168.0.1:1337",
  448. SentTime: time.Now(),
  449. }
  450. Expect(defaultAcceptCookie(remoteAddr, cookie)).To(BeTrue())
  451. })
  452. It("rejects an invalid token for a remote address is not a UDP address", func() {
  453. remoteAddr := &net.TCPAddr{IP: net.IPv4(192, 168, 0, 1), Port: 1337}
  454. cookie := &Cookie{
  455. RemoteAddr: "192.168.0.1:7331", // mismatching port
  456. SentTime: time.Now(),
  457. }
  458. Expect(defaultAcceptCookie(remoteAddr, cookie)).To(BeFalse())
  459. })
  460. It("rejects an expired token", func() {
  461. remoteAddr := &net.UDPAddr{IP: net.IPv4(192, 168, 0, 1)}
  462. cookie := &Cookie{
  463. RemoteAddr: "192.168.0.1",
  464. SentTime: time.Now().Add(-protocol.CookieExpiryTime).Add(-time.Second), // expired 1 second ago
  465. }
  466. Expect(defaultAcceptCookie(remoteAddr, cookie)).To(BeFalse())
  467. })
  468. })