client.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. package quic
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "sync"
  9. "github.com/lucas-clemente/quic-go/internal/handshake"
  10. "github.com/lucas-clemente/quic-go/internal/protocol"
  11. "github.com/lucas-clemente/quic-go/internal/qerr"
  12. "github.com/lucas-clemente/quic-go/internal/utils"
  13. "github.com/lucas-clemente/quic-go/internal/wire"
  14. )
  15. type client struct {
  16. mutex sync.Mutex
  17. conn connection
  18. // If the client is created with DialAddr, we create a packet conn.
  19. // If it is started with Dial, we take a packet conn as a parameter.
  20. createdPacketConn bool
  21. packetHandlers packetHandlerManager
  22. token []byte
  23. versionNegotiated bool // has the server accepted our version
  24. receivedVersionNegotiationPacket bool
  25. negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
  26. tlsConf *tls.Config
  27. config *Config
  28. srcConnID protocol.ConnectionID
  29. destConnID protocol.ConnectionID
  30. origDestConnID protocol.ConnectionID // the destination conn ID used on the first Initial (before a Retry)
  31. initialVersion protocol.VersionNumber
  32. version protocol.VersionNumber
  33. handshakeChan chan struct{}
  34. session quicSession
  35. logger utils.Logger
  36. }
  37. var _ packetHandler = &client{}
  38. var (
  39. // make it possible to mock connection ID generation in the tests
  40. generateConnectionID = protocol.GenerateConnectionID
  41. generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
  42. errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
  43. errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
  44. )
  45. // DialAddr establishes a new QUIC connection to a server.
  46. // The hostname for SNI is taken from the given address.
  47. func DialAddr(
  48. addr string,
  49. tlsConf *tls.Config,
  50. config *Config,
  51. ) (Session, error) {
  52. return DialAddrContext(context.Background(), addr, tlsConf, config)
  53. }
  54. // DialAddrContext establishes a new QUIC connection to a server using the provided context.
  55. // The hostname for SNI is taken from the given address.
  56. func DialAddrContext(
  57. ctx context.Context,
  58. addr string,
  59. tlsConf *tls.Config,
  60. config *Config,
  61. ) (Session, error) {
  62. udpAddr, err := net.ResolveUDPAddr("udp", addr)
  63. if err != nil {
  64. return nil, err
  65. }
  66. udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
  67. if err != nil {
  68. return nil, err
  69. }
  70. return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
  71. }
  72. // Dial establishes a new QUIC connection to a server using a net.PacketConn.
  73. // The host parameter is used for SNI.
  74. func Dial(
  75. pconn net.PacketConn,
  76. remoteAddr net.Addr,
  77. host string,
  78. tlsConf *tls.Config,
  79. config *Config,
  80. ) (Session, error) {
  81. return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
  82. }
  83. // DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
  84. // The host parameter is used for SNI.
  85. func DialContext(
  86. ctx context.Context,
  87. pconn net.PacketConn,
  88. remoteAddr net.Addr,
  89. host string,
  90. tlsConf *tls.Config,
  91. config *Config,
  92. ) (Session, error) {
  93. return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
  94. }
  95. func dialContext(
  96. ctx context.Context,
  97. pconn net.PacketConn,
  98. remoteAddr net.Addr,
  99. host string,
  100. tlsConf *tls.Config,
  101. config *Config,
  102. createdPacketConn bool,
  103. ) (Session, error) {
  104. config = populateClientConfig(config, createdPacketConn)
  105. packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
  106. if err != nil {
  107. return nil, err
  108. }
  109. c, err := newClient(pconn, remoteAddr, config, tlsConf, host, createdPacketConn)
  110. if err != nil {
  111. return nil, err
  112. }
  113. c.packetHandlers = packetHandlers
  114. if err := c.dial(ctx); err != nil {
  115. return nil, err
  116. }
  117. return c.session, nil
  118. }
  119. func newClient(
  120. pconn net.PacketConn,
  121. remoteAddr net.Addr,
  122. config *Config,
  123. tlsConf *tls.Config,
  124. host string,
  125. createdPacketConn bool,
  126. ) (*client, error) {
  127. if tlsConf == nil {
  128. tlsConf = &tls.Config{}
  129. }
  130. if tlsConf.ServerName == "" {
  131. var err error
  132. tlsConf.ServerName, _, err = net.SplitHostPort(host)
  133. if err != nil {
  134. return nil, err
  135. }
  136. }
  137. // check that all versions are actually supported
  138. if config != nil {
  139. for _, v := range config.Versions {
  140. if !protocol.IsValidVersion(v) {
  141. return nil, fmt.Errorf("%s is not a valid QUIC version", v)
  142. }
  143. }
  144. }
  145. c := &client{
  146. conn: &conn{pconn: pconn, currentAddr: remoteAddr},
  147. createdPacketConn: createdPacketConn,
  148. tlsConf: tlsConf,
  149. config: config,
  150. version: config.Versions[0],
  151. handshakeChan: make(chan struct{}),
  152. logger: utils.DefaultLogger.WithPrefix("client"),
  153. }
  154. return c, c.generateConnectionIDs()
  155. }
  156. // populateClientConfig populates fields in the quic.Config with their default values, if none are set
  157. // it may be called with nil
  158. func populateClientConfig(config *Config, createdPacketConn bool) *Config {
  159. if config == nil {
  160. config = &Config{}
  161. }
  162. versions := config.Versions
  163. if len(versions) == 0 {
  164. versions = protocol.SupportedVersions
  165. }
  166. handshakeTimeout := protocol.DefaultHandshakeTimeout
  167. if config.HandshakeTimeout != 0 {
  168. handshakeTimeout = config.HandshakeTimeout
  169. }
  170. idleTimeout := protocol.DefaultIdleTimeout
  171. if config.IdleTimeout != 0 {
  172. idleTimeout = config.IdleTimeout
  173. }
  174. maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
  175. if maxReceiveStreamFlowControlWindow == 0 {
  176. maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
  177. }
  178. maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
  179. if maxReceiveConnectionFlowControlWindow == 0 {
  180. maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
  181. }
  182. maxIncomingStreams := config.MaxIncomingStreams
  183. if maxIncomingStreams == 0 {
  184. maxIncomingStreams = protocol.DefaultMaxIncomingStreams
  185. } else if maxIncomingStreams < 0 {
  186. maxIncomingStreams = 0
  187. }
  188. maxIncomingUniStreams := config.MaxIncomingUniStreams
  189. if maxIncomingUniStreams == 0 {
  190. maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
  191. } else if maxIncomingUniStreams < 0 {
  192. maxIncomingUniStreams = 0
  193. }
  194. connIDLen := config.ConnectionIDLength
  195. if connIDLen == 0 && !createdPacketConn {
  196. connIDLen = protocol.DefaultConnectionIDLength
  197. }
  198. return &Config{
  199. Versions: versions,
  200. HandshakeTimeout: handshakeTimeout,
  201. IdleTimeout: idleTimeout,
  202. ConnectionIDLength: connIDLen,
  203. MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
  204. MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
  205. MaxIncomingStreams: maxIncomingStreams,
  206. MaxIncomingUniStreams: maxIncomingUniStreams,
  207. KeepAlive: config.KeepAlive,
  208. }
  209. }
  210. func (c *client) generateConnectionIDs() error {
  211. srcConnID, err := generateConnectionID(c.config.ConnectionIDLength)
  212. if err != nil {
  213. return err
  214. }
  215. destConnID, err := generateConnectionIDForInitial()
  216. if err != nil {
  217. return err
  218. }
  219. c.srcConnID = srcConnID
  220. c.destConnID = destConnID
  221. return nil
  222. }
  223. func (c *client) dial(ctx context.Context) error {
  224. c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
  225. if err := c.createNewTLSSession(c.version); err != nil {
  226. return err
  227. }
  228. err := c.establishSecureConnection(ctx)
  229. if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
  230. return c.dial(ctx)
  231. }
  232. return err
  233. }
  234. // establishSecureConnection runs the session, and tries to establish a secure connection
  235. // It returns:
  236. // - errCloseSessionForNewVersion when the server sends a version negotiation packet
  237. // - handshake.ErrCloseSessionForRetry when the server performs a stateless retry
  238. // - any other error that might occur
  239. // - when the connection is forward-secure
  240. func (c *client) establishSecureConnection(ctx context.Context) error {
  241. errorChan := make(chan error, 1)
  242. go func() {
  243. err := c.session.run() // returns as soon as the session is closed
  244. if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
  245. c.conn.Close()
  246. }
  247. errorChan <- err
  248. }()
  249. select {
  250. case <-ctx.Done():
  251. // The session will send a PeerGoingAway error to the server.
  252. c.session.Close()
  253. return ctx.Err()
  254. case err := <-errorChan:
  255. return err
  256. case <-c.handshakeChan:
  257. // handshake successfully completed
  258. return nil
  259. }
  260. }
  261. func (c *client) handlePacket(p *receivedPacket) {
  262. if err := c.handlePacketImpl(p); err != nil {
  263. c.logger.Errorf("error handling packet: %s", err)
  264. }
  265. }
  266. func (c *client) handlePacketImpl(p *receivedPacket) error {
  267. c.mutex.Lock()
  268. defer c.mutex.Unlock()
  269. // handle Version Negotiation Packets
  270. if p.header.IsVersionNegotiation {
  271. err := c.handleVersionNegotiationPacket(p.header)
  272. if err != nil {
  273. c.session.destroy(err)
  274. }
  275. // version negotiation packets have no payload
  276. return err
  277. }
  278. // reject packets with the wrong connection ID
  279. if !p.header.DestConnectionID.Equal(c.srcConnID) {
  280. return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
  281. }
  282. if p.header.Type == protocol.PacketTypeRetry {
  283. c.handleRetryPacket(p.header)
  284. return nil
  285. }
  286. // this is the first packet we are receiving
  287. // since it is not a Version Negotiation Packet, this means the server supports the suggested version
  288. if !c.versionNegotiated {
  289. c.versionNegotiated = true
  290. }
  291. c.session.handlePacket(p)
  292. return nil
  293. }
  294. func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
  295. // ignore delayed / duplicated version negotiation packets
  296. if c.receivedVersionNegotiationPacket || c.versionNegotiated {
  297. c.logger.Debugf("Received a delayed Version Negotiation Packet.")
  298. return nil
  299. }
  300. for _, v := range hdr.SupportedVersions {
  301. if v == c.version {
  302. // the version negotiation packet contains the version that we offered
  303. // this might be a packet sent by an attacker (or by a terribly broken server implementation)
  304. // ignore it
  305. return nil
  306. }
  307. }
  308. c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
  309. newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
  310. if !ok {
  311. return qerr.InvalidVersion
  312. }
  313. c.receivedVersionNegotiationPacket = true
  314. c.negotiatedVersions = hdr.SupportedVersions
  315. // switch to negotiated version
  316. c.initialVersion = c.version
  317. c.version = newVersion
  318. if err := c.generateConnectionIDs(); err != nil {
  319. return err
  320. }
  321. c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
  322. c.session.destroy(errCloseSessionForNewVersion)
  323. return nil
  324. }
  325. func (c *client) handleRetryPacket(hdr *wire.Header) {
  326. c.logger.Debugf("<- Received Retry")
  327. hdr.Log(c.logger)
  328. if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
  329. c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
  330. return
  331. }
  332. if hdr.SrcConnectionID.Equal(c.destConnID) {
  333. c.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
  334. return
  335. }
  336. // If a token is already set, this means that we already received a Retry from the server.
  337. // Ignore this Retry packet.
  338. if len(c.token) > 0 {
  339. c.logger.Debugf("Ignoring Retry, since a Retry was already received.")
  340. return
  341. }
  342. c.origDestConnID = c.destConnID
  343. c.destConnID = hdr.SrcConnectionID
  344. c.token = hdr.Token
  345. c.session.destroy(errCloseSessionForRetry)
  346. }
  347. func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
  348. params := &handshake.TransportParameters{
  349. InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
  350. InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
  351. InitialMaxStreamDataUni: protocol.InitialMaxStreamData,
  352. InitialMaxData: protocol.InitialMaxData,
  353. IdleTimeout: c.config.IdleTimeout,
  354. MaxBidiStreams: uint64(c.config.MaxIncomingStreams),
  355. MaxUniStreams: uint64(c.config.MaxIncomingUniStreams),
  356. DisableMigration: true,
  357. }
  358. c.mutex.Lock()
  359. defer c.mutex.Unlock()
  360. runner := &runner{
  361. onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
  362. retireConnectionIDImpl: c.packetHandlers.Retire,
  363. removeConnectionIDImpl: c.packetHandlers.Remove,
  364. }
  365. sess, err := newClientSession(
  366. c.conn,
  367. runner,
  368. c.token,
  369. c.origDestConnID,
  370. c.destConnID,
  371. c.srcConnID,
  372. c.config,
  373. c.tlsConf,
  374. params,
  375. c.initialVersion,
  376. c.logger,
  377. c.version,
  378. )
  379. if err != nil {
  380. return err
  381. }
  382. c.session = sess
  383. c.packetHandlers.Add(c.srcConnID, c)
  384. return nil
  385. }
  386. func (c *client) Close() error {
  387. c.mutex.Lock()
  388. defer c.mutex.Unlock()
  389. if c.session == nil {
  390. return nil
  391. }
  392. return c.session.Close()
  393. }
  394. func (c *client) destroy(e error) {
  395. c.mutex.Lock()
  396. defer c.mutex.Unlock()
  397. if c.session == nil {
  398. return
  399. }
  400. c.session.destroy(e)
  401. }
  402. func (c *client) GetVersion() protocol.VersionNumber {
  403. c.mutex.Lock()
  404. v := c.version
  405. c.mutex.Unlock()
  406. return v
  407. }
  408. func (c *client) GetPerspective() protocol.Perspective {
  409. return protocol.PerspectiveClient
  410. }