server.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. package quic
  2. import (
  3. "bytes"
  4. "crypto/tls"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "net"
  9. "sync"
  10. "time"
  11. "github.com/lucas-clemente/quic-go/internal/handshake"
  12. "github.com/lucas-clemente/quic-go/internal/protocol"
  13. "github.com/lucas-clemente/quic-go/internal/utils"
  14. "github.com/lucas-clemente/quic-go/internal/wire"
  15. )
  16. // packetHandler handles packets
  17. type packetHandler interface {
  18. handlePacket(*receivedPacket)
  19. io.Closer
  20. destroy(error)
  21. GetVersion() protocol.VersionNumber
  22. GetPerspective() protocol.Perspective
  23. }
  24. type unknownPacketHandler interface {
  25. handlePacket(*receivedPacket)
  26. closeWithError(error) error
  27. }
  28. type packetHandlerManager interface {
  29. Add(protocol.ConnectionID, packetHandler)
  30. Retire(protocol.ConnectionID)
  31. Remove(protocol.ConnectionID)
  32. SetServer(unknownPacketHandler)
  33. CloseServer()
  34. }
  35. type quicSession interface {
  36. Session
  37. handlePacket(*receivedPacket)
  38. GetVersion() protocol.VersionNumber
  39. run() error
  40. destroy(error)
  41. closeRemote(error)
  42. }
  43. type sessionRunner interface {
  44. onHandshakeComplete(Session)
  45. retireConnectionID(protocol.ConnectionID)
  46. removeConnectionID(protocol.ConnectionID)
  47. }
  48. type runner struct {
  49. onHandshakeCompleteImpl func(Session)
  50. retireConnectionIDImpl func(protocol.ConnectionID)
  51. removeConnectionIDImpl func(protocol.ConnectionID)
  52. }
  53. func (r *runner) onHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) }
  54. func (r *runner) retireConnectionID(c protocol.ConnectionID) { r.retireConnectionIDImpl(c) }
  55. func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) }
  56. var _ sessionRunner = &runner{}
  57. // A Listener of QUIC
  58. type server struct {
  59. mutex sync.Mutex
  60. tlsConf *tls.Config
  61. config *Config
  62. conn net.PacketConn
  63. // If the server is started with ListenAddr, we create a packet conn.
  64. // If it is started with Listen, we take a packet conn as a parameter.
  65. createdPacketConn bool
  66. cookieGenerator *handshake.CookieGenerator
  67. sessionHandler packetHandlerManager
  68. // set as a member, so they can be set in the tests
  69. newSession func(connection, sessionRunner, protocol.ConnectionID /* original connection ID */, protocol.ConnectionID /* destination connection ID */, protocol.ConnectionID /* source connection ID */, *Config, *tls.Config, *handshake.TransportParameters, utils.Logger, protocol.VersionNumber) (quicSession, error)
  70. serverError error
  71. errorChan chan struct{}
  72. closed bool
  73. sessionQueue chan Session
  74. sessionRunner sessionRunner
  75. logger utils.Logger
  76. }
  77. var _ Listener = &server{}
  78. var _ unknownPacketHandler = &server{}
  79. // ListenAddr creates a QUIC server listening on a given address.
  80. // The tls.Config must not be nil, the quic.Config may be nil.
  81. func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
  82. udpAddr, err := net.ResolveUDPAddr("udp", addr)
  83. if err != nil {
  84. return nil, err
  85. }
  86. conn, err := net.ListenUDP("udp", udpAddr)
  87. if err != nil {
  88. return nil, err
  89. }
  90. serv, err := listen(conn, tlsConf, config)
  91. if err != nil {
  92. return nil, err
  93. }
  94. serv.createdPacketConn = true
  95. return serv, nil
  96. }
  97. // Listen listens for QUIC connections on a given net.PacketConn.
  98. // The tls.Config must not be nil, the quic.Config may be nil.
  99. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
  100. return listen(conn, tlsConf, config)
  101. }
  102. func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) {
  103. config = populateServerConfig(config)
  104. for _, v := range config.Versions {
  105. if !protocol.IsValidVersion(v) {
  106. return nil, fmt.Errorf("%s is not a valid QUIC version", v)
  107. }
  108. }
  109. sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength)
  110. if err != nil {
  111. return nil, err
  112. }
  113. s := &server{
  114. conn: conn,
  115. tlsConf: tlsConf,
  116. config: config,
  117. sessionHandler: sessionHandler,
  118. sessionQueue: make(chan Session, 5),
  119. errorChan: make(chan struct{}),
  120. newSession: newSession,
  121. logger: utils.DefaultLogger.WithPrefix("server"),
  122. }
  123. if err := s.setup(); err != nil {
  124. return nil, err
  125. }
  126. sessionHandler.SetServer(s)
  127. s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
  128. return s, nil
  129. }
  130. func (s *server) setup() error {
  131. s.sessionRunner = &runner{
  132. onHandshakeCompleteImpl: func(sess Session) { s.sessionQueue <- sess },
  133. retireConnectionIDImpl: s.sessionHandler.Retire,
  134. removeConnectionIDImpl: s.sessionHandler.Remove,
  135. }
  136. cookieGenerator, err := handshake.NewCookieGenerator()
  137. if err != nil {
  138. return err
  139. }
  140. s.cookieGenerator = cookieGenerator
  141. return nil
  142. }
  143. var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
  144. if cookie == nil {
  145. return false
  146. }
  147. if time.Now().After(cookie.SentTime.Add(protocol.CookieExpiryTime)) {
  148. return false
  149. }
  150. var sourceAddr string
  151. if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
  152. sourceAddr = udpAddr.IP.String()
  153. } else {
  154. sourceAddr = clientAddr.String()
  155. }
  156. return sourceAddr == cookie.RemoteAddr
  157. }
  158. // populateServerConfig populates fields in the quic.Config with their default values, if none are set
  159. // it may be called with nil
  160. func populateServerConfig(config *Config) *Config {
  161. if config == nil {
  162. config = &Config{}
  163. }
  164. versions := config.Versions
  165. if len(versions) == 0 {
  166. versions = protocol.SupportedVersions
  167. }
  168. vsa := defaultAcceptCookie
  169. if config.AcceptCookie != nil {
  170. vsa = config.AcceptCookie
  171. }
  172. handshakeTimeout := protocol.DefaultHandshakeTimeout
  173. if config.HandshakeTimeout != 0 {
  174. handshakeTimeout = config.HandshakeTimeout
  175. }
  176. idleTimeout := protocol.DefaultIdleTimeout
  177. if config.IdleTimeout != 0 {
  178. idleTimeout = config.IdleTimeout
  179. }
  180. maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
  181. if maxReceiveStreamFlowControlWindow == 0 {
  182. maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindow
  183. }
  184. maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
  185. if maxReceiveConnectionFlowControlWindow == 0 {
  186. maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindow
  187. }
  188. maxIncomingStreams := config.MaxIncomingStreams
  189. if maxIncomingStreams == 0 {
  190. maxIncomingStreams = protocol.DefaultMaxIncomingStreams
  191. } else if maxIncomingStreams < 0 {
  192. maxIncomingStreams = 0
  193. }
  194. maxIncomingUniStreams := config.MaxIncomingUniStreams
  195. if maxIncomingUniStreams == 0 {
  196. maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
  197. } else if maxIncomingUniStreams < 0 {
  198. maxIncomingUniStreams = 0
  199. }
  200. connIDLen := config.ConnectionIDLength
  201. if connIDLen == 0 {
  202. connIDLen = protocol.DefaultConnectionIDLength
  203. }
  204. return &Config{
  205. Versions: versions,
  206. HandshakeTimeout: handshakeTimeout,
  207. IdleTimeout: idleTimeout,
  208. AcceptCookie: vsa,
  209. KeepAlive: config.KeepAlive,
  210. MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
  211. MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
  212. MaxIncomingStreams: maxIncomingStreams,
  213. MaxIncomingUniStreams: maxIncomingUniStreams,
  214. ConnectionIDLength: connIDLen,
  215. }
  216. }
  217. // Accept returns newly openend sessions
  218. func (s *server) Accept() (Session, error) {
  219. var sess Session
  220. select {
  221. case sess = <-s.sessionQueue:
  222. return sess, nil
  223. case <-s.errorChan:
  224. return nil, s.serverError
  225. }
  226. }
  227. // Close the server
  228. func (s *server) Close() error {
  229. s.mutex.Lock()
  230. defer s.mutex.Unlock()
  231. if s.closed {
  232. return nil
  233. }
  234. return s.closeWithMutex()
  235. }
  236. func (s *server) closeWithMutex() error {
  237. s.sessionHandler.CloseServer()
  238. if s.serverError == nil {
  239. s.serverError = errors.New("server closed")
  240. }
  241. var err error
  242. // If the server was started with ListenAddr, we created the packet conn.
  243. // We need to close it in order to make the go routine reading from that conn return.
  244. if s.createdPacketConn {
  245. err = s.conn.Close()
  246. }
  247. s.closed = true
  248. close(s.errorChan)
  249. return err
  250. }
  251. func (s *server) closeWithError(e error) error {
  252. s.mutex.Lock()
  253. defer s.mutex.Unlock()
  254. if s.closed {
  255. return nil
  256. }
  257. s.serverError = e
  258. return s.closeWithMutex()
  259. }
  260. // Addr returns the server's network address
  261. func (s *server) Addr() net.Addr {
  262. return s.conn.LocalAddr()
  263. }
  264. func (s *server) handlePacket(p *receivedPacket) {
  265. if err := s.handlePacketImpl(p); err != nil {
  266. s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
  267. }
  268. }
  269. func (s *server) handlePacketImpl(p *receivedPacket) error {
  270. hdr := p.header
  271. // send a Version Negotiation Packet if the client is speaking a different protocol version
  272. if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
  273. return s.sendVersionNegotiationPacket(p)
  274. }
  275. if hdr.Type == protocol.PacketTypeInitial {
  276. go s.handleInitial(p)
  277. }
  278. // TODO(#943): send Stateless Reset
  279. return nil
  280. }
  281. func (s *server) handleInitial(p *receivedPacket) {
  282. // TODO: add a check that DestConnID == SrcConnID
  283. s.logger.Debugf("<- Received Initial packet.")
  284. sess, connID, err := s.handleInitialImpl(p)
  285. if err != nil {
  286. s.logger.Errorf("Error occurred handling initial packet: %s", err)
  287. return
  288. }
  289. if sess == nil { // a retry was done
  290. return
  291. }
  292. serverSession := newServerSession(sess, s.config, s.logger)
  293. s.sessionHandler.Add(connID, serverSession)
  294. }
  295. func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) {
  296. hdr := p.header
  297. if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
  298. return nil, nil, errors.New("dropping Initial packet with too short connection ID")
  299. }
  300. if len(hdr.Raw)+len(p.data) < protocol.MinInitialPacketSize {
  301. return nil, nil, errors.New("dropping too small Initial packet")
  302. }
  303. var cookie *Cookie
  304. var origDestConnectionID protocol.ConnectionID
  305. if len(hdr.Token) > 0 {
  306. c, err := s.cookieGenerator.DecodeToken(hdr.Token)
  307. if err == nil {
  308. cookie = &Cookie{
  309. RemoteAddr: c.RemoteAddr,
  310. SentTime: c.SentTime,
  311. }
  312. origDestConnectionID = c.OriginalDestConnectionID
  313. }
  314. }
  315. if !s.config.AcceptCookie(p.remoteAddr, cookie) {
  316. // Log the Initial packet now.
  317. // If no Retry is sent, the packet will be logged by the session.
  318. p.header.Log(s.logger)
  319. return nil, nil, s.sendRetry(p.remoteAddr, hdr)
  320. }
  321. connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
  322. if err != nil {
  323. return nil, nil, err
  324. }
  325. s.logger.Debugf("Changing connection ID to %s.", connID)
  326. sess, err := s.createNewSession(
  327. p.remoteAddr,
  328. origDestConnectionID,
  329. hdr.DestConnectionID,
  330. hdr.SrcConnectionID,
  331. connID,
  332. hdr.Version,
  333. )
  334. if err != nil {
  335. return nil, nil, err
  336. }
  337. sess.handlePacket(p)
  338. return sess, connID, nil
  339. }
  340. func (s *server) createNewSession(
  341. remoteAddr net.Addr,
  342. origDestConnID protocol.ConnectionID,
  343. clientDestConnID protocol.ConnectionID,
  344. destConnID protocol.ConnectionID,
  345. srcConnID protocol.ConnectionID,
  346. version protocol.VersionNumber,
  347. ) (quicSession, error) {
  348. params := &handshake.TransportParameters{
  349. InitialMaxStreamDataBidiLocal: protocol.InitialMaxStreamData,
  350. InitialMaxStreamDataBidiRemote: protocol.InitialMaxStreamData,
  351. InitialMaxStreamDataUni: protocol.InitialMaxStreamData,
  352. InitialMaxData: protocol.InitialMaxData,
  353. IdleTimeout: s.config.IdleTimeout,
  354. MaxBidiStreams: uint64(s.config.MaxIncomingStreams),
  355. MaxUniStreams: uint64(s.config.MaxIncomingUniStreams),
  356. DisableMigration: true,
  357. // TODO(#855): generate a real token
  358. StatelessResetToken: bytes.Repeat([]byte{42}, 16),
  359. OriginalConnectionID: origDestConnID,
  360. }
  361. sess, err := s.newSession(
  362. &conn{pconn: s.conn, currentAddr: remoteAddr},
  363. s.sessionRunner,
  364. clientDestConnID,
  365. destConnID,
  366. srcConnID,
  367. s.config,
  368. s.tlsConf,
  369. params,
  370. s.logger,
  371. version,
  372. )
  373. if err != nil {
  374. return nil, err
  375. }
  376. go sess.run()
  377. return sess, nil
  378. }
  379. func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
  380. token, err := s.cookieGenerator.NewToken(remoteAddr, hdr.DestConnectionID)
  381. if err != nil {
  382. return err
  383. }
  384. connID, err := protocol.GenerateConnectionID(s.config.ConnectionIDLength)
  385. if err != nil {
  386. return err
  387. }
  388. replyHdr := &wire.Header{
  389. IsLongHeader: true,
  390. Type: protocol.PacketTypeRetry,
  391. Version: hdr.Version,
  392. SrcConnectionID: connID,
  393. DestConnectionID: hdr.SrcConnectionID,
  394. OrigDestConnectionID: hdr.DestConnectionID,
  395. Token: token,
  396. }
  397. s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID)
  398. replyHdr.Log(s.logger)
  399. buf := &bytes.Buffer{}
  400. if err := replyHdr.Write(buf, protocol.PerspectiveServer, hdr.Version); err != nil {
  401. return err
  402. }
  403. if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
  404. s.logger.Debugf("Error sending Retry: %s", err)
  405. }
  406. return nil
  407. }
  408. func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
  409. hdr := p.header
  410. s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
  411. data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
  412. if err != nil {
  413. return err
  414. }
  415. _, err = s.conn.WriteTo(data, p.remoteAddr)
  416. return err
  417. }