client.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. package quic
  2. import (
  3. "context"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "net"
  8. "sync"
  9. "github.com/lucas-clemente/quic-go/internal/handshake"
  10. "github.com/lucas-clemente/quic-go/internal/protocol"
  11. "github.com/lucas-clemente/quic-go/internal/qerr"
  12. "github.com/lucas-clemente/quic-go/internal/utils"
  13. "github.com/lucas-clemente/quic-go/internal/wire"
  14. )
  15. type client struct {
  16. mutex sync.Mutex
  17. conn connection
  18. // If the client is created with DialAddr, we create a packet conn.
  19. // If it is started with Dial, we take a packet conn as a parameter.
  20. createdPacketConn bool
  21. packetHandlers packetHandlerManager
  22. token []byte
  23. versionNegotiated utils.AtomicBool // has the server accepted our version
  24. receivedVersionNegotiationPacket bool
  25. negotiatedVersions []protocol.VersionNumber // the list of versions from the version negotiation packet
  26. tlsConf *tls.Config
  27. config *Config
  28. srcConnID protocol.ConnectionID
  29. destConnID protocol.ConnectionID
  30. origDestConnID protocol.ConnectionID // the destination conn ID used on the first Initial (before a Retry)
  31. initialVersion protocol.VersionNumber
  32. version protocol.VersionNumber
  33. handshakeChan chan struct{}
  34. session quicSession
  35. logger utils.Logger
  36. }
  37. var _ packetHandler = &client{}
  38. var (
  39. // make it possible to mock connection ID generation in the tests
  40. generateConnectionID = protocol.GenerateConnectionID
  41. generateConnectionIDForInitial = protocol.GenerateConnectionIDForInitial
  42. errCloseSessionForNewVersion = errors.New("closing session in order to recreate it with a new version")
  43. errCloseSessionForRetry = errors.New("closing session in response to a stateless retry")
  44. )
  45. // DialAddr establishes a new QUIC connection to a server.
  46. // It uses a new UDP connection and closes this connection when the QUIC session is closed.
  47. // The hostname for SNI is taken from the given address.
  48. func DialAddr(
  49. addr string,
  50. tlsConf *tls.Config,
  51. config *Config,
  52. ) (Session, error) {
  53. return DialAddrContext(context.Background(), addr, tlsConf, config)
  54. }
  55. // DialAddrContext establishes a new QUIC connection to a server using the provided context.
  56. // See DialAddr for details.
  57. func DialAddrContext(
  58. ctx context.Context,
  59. addr string,
  60. tlsConf *tls.Config,
  61. config *Config,
  62. ) (Session, error) {
  63. udpAddr, err := net.ResolveUDPAddr("udp", addr)
  64. if err != nil {
  65. return nil, err
  66. }
  67. udpConn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.IPv4zero, Port: 0})
  68. if err != nil {
  69. return nil, err
  70. }
  71. return dialContext(ctx, udpConn, udpAddr, addr, tlsConf, config, true)
  72. }
  73. // Dial establishes a new QUIC connection to a server using a net.PacketConn.
  74. // The same PacketConn can be used for multiple calls to Dial and Listen,
  75. // QUIC connection IDs are used for demultiplexing the different connections.
  76. // The host parameter is used for SNI.
  77. func Dial(
  78. pconn net.PacketConn,
  79. remoteAddr net.Addr,
  80. host string,
  81. tlsConf *tls.Config,
  82. config *Config,
  83. ) (Session, error) {
  84. return DialContext(context.Background(), pconn, remoteAddr, host, tlsConf, config)
  85. }
  86. // DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
  87. // See Dial for details.
  88. func DialContext(
  89. ctx context.Context,
  90. pconn net.PacketConn,
  91. remoteAddr net.Addr,
  92. host string,
  93. tlsConf *tls.Config,
  94. config *Config,
  95. ) (Session, error) {
  96. return dialContext(ctx, pconn, remoteAddr, host, tlsConf, config, false)
  97. }
  98. func dialContext(
  99. ctx context.Context,
  100. pconn net.PacketConn,
  101. remoteAddr net.Addr,
  102. host string,
  103. tlsConf *tls.Config,
  104. config *Config,
  105. createdPacketConn bool,
  106. ) (Session, error) {
  107. config = populateClientConfig(config, createdPacketConn)
  108. packetHandlers, err := getMultiplexer().AddConn(pconn, config.ConnectionIDLength)
  109. if err != nil {
  110. return nil, err
  111. }
  112. c, err := newClient(pconn, remoteAddr, config, tlsConf, host, createdPacketConn)
  113. if err != nil {
  114. return nil, err
  115. }
  116. c.packetHandlers = packetHandlers
  117. if err := c.dial(ctx); err != nil {
  118. return nil, err
  119. }
  120. return c.session, nil
  121. }
  122. func newClient(
  123. pconn net.PacketConn,
  124. remoteAddr net.Addr,
  125. config *Config,
  126. tlsConf *tls.Config,
  127. host string,
  128. createdPacketConn bool,
  129. ) (*client, error) {
  130. if tlsConf == nil {
  131. tlsConf = &tls.Config{}
  132. }
  133. if tlsConf.ServerName == "" {
  134. var err error
  135. tlsConf.ServerName, _, err = net.SplitHostPort(host)
  136. if err != nil {
  137. return nil, err
  138. }
  139. }
  140. // check that all versions are actually supported
  141. if config != nil {
  142. for _, v := range config.Versions {
  143. if !protocol.IsValidVersion(v) {
  144. return nil, fmt.Errorf("%s is not a valid QUIC version", v)
  145. }
  146. }
  147. }
  148. srcConnID, err := generateConnectionID(config.ConnectionIDLength)
  149. if err != nil {
  150. return nil, err
  151. }
  152. destConnID, err := generateConnectionIDForInitial()
  153. if err != nil {
  154. return nil, err
  155. }
  156. c := &client{
  157. srcConnID: srcConnID,
  158. destConnID: destConnID,
  159. conn: &conn{pconn: pconn, currentAddr: remoteAddr},
  160. createdPacketConn: createdPacketConn,
  161. tlsConf: tlsConf,
  162. config: config,
  163. version: config.Versions[0],
  164. handshakeChan: make(chan struct{}),
  165. logger: utils.DefaultLogger.WithPrefix("client"),
  166. }
  167. return c, nil
  168. }
  169. // populateClientConfig populates fields in the quic.Config with their default values, if none are set
  170. // it may be called with nil
  171. func populateClientConfig(config *Config, createdPacketConn bool) *Config {
  172. if config == nil {
  173. config = &Config{}
  174. }
  175. versions := config.Versions
  176. if len(versions) == 0 {
  177. versions = protocol.SupportedVersions
  178. }
  179. handshakeTimeout := protocol.DefaultHandshakeTimeout
  180. if config.HandshakeTimeout != 0 {
  181. handshakeTimeout = config.HandshakeTimeout
  182. }
  183. idleTimeout := protocol.DefaultIdleTimeout
  184. if config.IdleTimeout != 0 {
  185. idleTimeout = config.IdleTimeout
  186. }
  187. maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
  188. if maxReceiveStreamFlowControlWindow == 0 {
  189. maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
  190. }
  191. maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
  192. if maxReceiveConnectionFlowControlWindow == 0 {
  193. maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
  194. }
  195. maxIncomingStreams := config.MaxIncomingStreams
  196. if maxIncomingStreams == 0 {
  197. maxIncomingStreams = protocol.DefaultMaxIncomingStreams
  198. } else if maxIncomingStreams < 0 {
  199. maxIncomingStreams = 0
  200. }
  201. maxIncomingUniStreams := config.MaxIncomingUniStreams
  202. if maxIncomingUniStreams == 0 {
  203. maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
  204. } else if maxIncomingUniStreams < 0 {
  205. maxIncomingUniStreams = 0
  206. }
  207. connIDLen := config.ConnectionIDLength
  208. if connIDLen == 0 && !createdPacketConn {
  209. connIDLen = protocol.DefaultConnectionIDLength
  210. }
  211. return &Config{
  212. Versions: versions,
  213. HandshakeTimeout: handshakeTimeout,
  214. IdleTimeout: idleTimeout,
  215. ConnectionIDLength: connIDLen,
  216. MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
  217. MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
  218. MaxIncomingStreams: maxIncomingStreams,
  219. MaxIncomingUniStreams: maxIncomingUniStreams,
  220. KeepAlive: config.KeepAlive,
  221. }
  222. }
  223. func (c *client) dial(ctx context.Context) error {
  224. c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
  225. if err := c.createNewTLSSession(c.version); err != nil {
  226. return err
  227. }
  228. err := c.establishSecureConnection(ctx)
  229. if err == errCloseSessionForRetry || err == errCloseSessionForNewVersion {
  230. return c.dial(ctx)
  231. }
  232. return err
  233. }
  234. // establishSecureConnection runs the session, and tries to establish a secure connection
  235. // It returns:
  236. // - errCloseSessionForNewVersion when the server sends a version negotiation packet
  237. // - handshake.ErrCloseSessionForRetry when the server performs a stateless retry
  238. // - any other error that might occur
  239. // - when the connection is forward-secure
  240. func (c *client) establishSecureConnection(ctx context.Context) error {
  241. errorChan := make(chan error, 1)
  242. go func() {
  243. err := c.session.run() // returns as soon as the session is closed
  244. if err != errCloseSessionForRetry && err != errCloseSessionForNewVersion && c.createdPacketConn {
  245. c.conn.Close()
  246. }
  247. errorChan <- err
  248. }()
  249. select {
  250. case <-ctx.Done():
  251. // The session will send a PeerGoingAway error to the server.
  252. c.session.Close()
  253. return ctx.Err()
  254. case err := <-errorChan:
  255. return err
  256. case <-c.handshakeChan:
  257. // handshake successfully completed
  258. return nil
  259. }
  260. }
  261. func (c *client) handlePacket(p *receivedPacket) {
  262. if p.hdr.IsVersionNegotiation() {
  263. go c.handleVersionNegotiationPacket(p.hdr)
  264. return
  265. }
  266. if p.hdr.Type == protocol.PacketTypeRetry {
  267. go c.handleRetryPacket(p.hdr)
  268. return
  269. }
  270. // this is the first packet we are receiving
  271. // since it is not a Version Negotiation Packet, this means the server supports the suggested version
  272. if !c.versionNegotiated.Get() {
  273. c.versionNegotiated.Set(true)
  274. }
  275. c.session.handlePacket(p)
  276. }
  277. func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) {
  278. c.mutex.Lock()
  279. defer c.mutex.Unlock()
  280. // ignore delayed / duplicated version negotiation packets
  281. if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() {
  282. c.logger.Debugf("Received a delayed Version Negotiation packet.")
  283. return
  284. }
  285. for _, v := range hdr.SupportedVersions {
  286. if v == c.version {
  287. // The Version Negotiation packet contains the version that we offered.
  288. // This might be a packet sent by an attacker (or by a terribly broken server implementation).
  289. return
  290. }
  291. }
  292. c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
  293. newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
  294. if !ok {
  295. c.session.destroy(qerr.InvalidVersion)
  296. c.logger.Debugf("No compatible version found.")
  297. return
  298. }
  299. c.receivedVersionNegotiationPacket = true
  300. c.negotiatedVersions = hdr.SupportedVersions
  301. // switch to negotiated version
  302. c.initialVersion = c.version
  303. c.version = newVersion
  304. c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
  305. c.session.destroy(errCloseSessionForNewVersion)
  306. }
  307. func (c *client) handleRetryPacket(hdr *wire.Header) {
  308. c.mutex.Lock()
  309. defer c.mutex.Unlock()
  310. c.logger.Debugf("<- Received Retry")
  311. (&wire.ExtendedHeader{Header: *hdr}).Log(c.logger)
  312. if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
  313. c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
  314. return
  315. }
  316. if hdr.SrcConnectionID.Equal(c.destConnID) {
  317. c.logger.Debugf("Ignoring Retry, since the server didn't change the Source Connection ID.")
  318. return
  319. }
  320. // If a token is already set, this means that we already received a Retry from the server.
  321. // Ignore this Retry packet.
  322. if len(c.token) > 0 {
  323. c.logger.Debugf("Ignoring Retry, since a Retry was already received.")
  324. return
  325. }
  326. c.origDestConnID = c.destConnID
  327. c.destConnID = hdr.SrcConnectionID
  328. c.token = hdr.Token
  329. c.session.destroy(errCloseSessionForRetry)
  330. }
  331. func (c *client) createNewTLSSession(version protocol.VersionNumber) error {
  332. params := &handshake.TransportParameters{
  333. InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
  334. InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
  335. InitialMaxStreamDataUni: protocol.InitialMaxStreamData,
  336. InitialMaxData: protocol.InitialMaxData,
  337. IdleTimeout: c.config.IdleTimeout,
  338. MaxBidiStreams: uint64(c.config.MaxIncomingStreams),
  339. MaxUniStreams: uint64(c.config.MaxIncomingUniStreams),
  340. DisableMigration: true,
  341. }
  342. c.mutex.Lock()
  343. defer c.mutex.Unlock()
  344. runner := &runner{
  345. onHandshakeCompleteImpl: func(_ Session) { close(c.handshakeChan) },
  346. retireConnectionIDImpl: c.packetHandlers.Retire,
  347. removeConnectionIDImpl: c.packetHandlers.Remove,
  348. }
  349. sess, err := newClientSession(
  350. c.conn,
  351. runner,
  352. c.token,
  353. c.origDestConnID,
  354. c.destConnID,
  355. c.srcConnID,
  356. c.config,
  357. c.tlsConf,
  358. params,
  359. c.initialVersion,
  360. c.logger,
  361. c.version,
  362. )
  363. if err != nil {
  364. return err
  365. }
  366. c.session = sess
  367. c.packetHandlers.Add(c.srcConnID, c)
  368. return nil
  369. }
  370. func (c *client) Close() error {
  371. c.mutex.Lock()
  372. defer c.mutex.Unlock()
  373. if c.session == nil {
  374. return nil
  375. }
  376. return c.session.Close()
  377. }
  378. func (c *client) destroy(e error) {
  379. c.mutex.Lock()
  380. defer c.mutex.Unlock()
  381. if c.session == nil {
  382. return
  383. }
  384. c.session.destroy(e)
  385. }
  386. func (c *client) GetVersion() protocol.VersionNumber {
  387. c.mutex.Lock()
  388. v := c.version
  389. c.mutex.Unlock()
  390. return v
  391. }
  392. func (c *client) GetPerspective() protocol.Perspective {
  393. return protocol.PerspectiveClient
  394. }