server.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. package quic
  2. import (
  3. "crypto/tls"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "net"
  8. "sync"
  9. "time"
  10. "github.com/lucas-clemente/quic-go/internal/crypto"
  11. "github.com/lucas-clemente/quic-go/internal/handshake"
  12. "github.com/lucas-clemente/quic-go/internal/protocol"
  13. "github.com/lucas-clemente/quic-go/internal/utils"
  14. "github.com/lucas-clemente/quic-go/internal/wire"
  15. )
  16. // packetHandler handles packets
  17. type packetHandler interface {
  18. handlePacket(*receivedPacket)
  19. io.Closer
  20. destroy(error)
  21. GetVersion() protocol.VersionNumber
  22. GetPerspective() protocol.Perspective
  23. }
  24. type unknownPacketHandler interface {
  25. handlePacket(*receivedPacket)
  26. closeWithError(error) error
  27. }
  28. type packetHandlerManager interface {
  29. Add(protocol.ConnectionID, packetHandler)
  30. SetServer(unknownPacketHandler)
  31. Remove(protocol.ConnectionID)
  32. CloseServer()
  33. }
  34. type quicSession interface {
  35. Session
  36. handlePacket(*receivedPacket)
  37. GetVersion() protocol.VersionNumber
  38. run() error
  39. destroy(error)
  40. closeRemote(error)
  41. }
  42. type sessionRunner interface {
  43. onHandshakeComplete(Session)
  44. removeConnectionID(protocol.ConnectionID)
  45. }
  46. type runner struct {
  47. onHandshakeCompleteImpl func(Session)
  48. removeConnectionIDImpl func(protocol.ConnectionID)
  49. }
  50. func (r *runner) onHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) }
  51. func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) }
  52. var _ sessionRunner = &runner{}
  53. // A Listener of QUIC
  54. type server struct {
  55. mutex sync.Mutex
  56. tlsConf *tls.Config
  57. config *Config
  58. conn net.PacketConn
  59. // If the server is started with ListenAddr, we create a packet conn.
  60. // If it is started with Listen, we take a packet conn as a parameter.
  61. createdPacketConn bool
  62. supportsTLS bool
  63. serverTLS *serverTLS
  64. certChain crypto.CertChain
  65. scfg *handshake.ServerConfig
  66. sessionHandler packetHandlerManager
  67. serverError error
  68. errorChan chan struct{}
  69. closed bool
  70. sessionQueue chan Session
  71. sessionRunner sessionRunner
  72. // set as a member, so they can be set in the tests
  73. newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (quicSession, error)
  74. logger utils.Logger
  75. }
  76. var _ Listener = &server{}
  77. var _ unknownPacketHandler = &server{}
  78. // ListenAddr creates a QUIC server listening on a given address.
  79. // The tls.Config must not be nil, the quic.Config may be nil.
  80. func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
  81. udpAddr, err := net.ResolveUDPAddr("udp", addr)
  82. if err != nil {
  83. return nil, err
  84. }
  85. conn, err := net.ListenUDP("udp", udpAddr)
  86. if err != nil {
  87. return nil, err
  88. }
  89. serv, err := listen(conn, tlsConf, config)
  90. if err != nil {
  91. return nil, err
  92. }
  93. serv.createdPacketConn = true
  94. return serv, nil
  95. }
  96. // Listen listens for QUIC connections on a given net.PacketConn.
  97. // The tls.Config must not be nil, the quic.Config may be nil.
  98. func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
  99. return listen(conn, tlsConf, config)
  100. }
  101. func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) {
  102. certChain := crypto.NewCertChain(tlsConf)
  103. kex, err := crypto.NewCurve25519KEX()
  104. if err != nil {
  105. return nil, err
  106. }
  107. scfg, err := handshake.NewServerConfig(kex, certChain)
  108. if err != nil {
  109. return nil, err
  110. }
  111. config = populateServerConfig(config)
  112. var supportsTLS bool
  113. for _, v := range config.Versions {
  114. if !protocol.IsValidVersion(v) {
  115. return nil, fmt.Errorf("%s is not a valid QUIC version", v)
  116. }
  117. // check if any of the supported versions supports TLS
  118. if v.UsesTLS() {
  119. supportsTLS = true
  120. break
  121. }
  122. }
  123. sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength)
  124. if err != nil {
  125. return nil, err
  126. }
  127. s := &server{
  128. conn: conn,
  129. tlsConf: tlsConf,
  130. config: config,
  131. certChain: certChain,
  132. scfg: scfg,
  133. newSession: newSession,
  134. sessionHandler: sessionHandler,
  135. sessionQueue: make(chan Session, 5),
  136. errorChan: make(chan struct{}),
  137. supportsTLS: supportsTLS,
  138. logger: utils.DefaultLogger.WithPrefix("server"),
  139. }
  140. s.setup()
  141. if supportsTLS {
  142. if err := s.setupTLS(); err != nil {
  143. return nil, err
  144. }
  145. }
  146. sessionHandler.SetServer(s)
  147. s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
  148. return s, nil
  149. }
  150. func (s *server) setup() {
  151. s.sessionRunner = &runner{
  152. onHandshakeCompleteImpl: func(sess Session) { s.sessionQueue <- sess },
  153. removeConnectionIDImpl: s.sessionHandler.Remove,
  154. }
  155. }
  156. func (s *server) setupTLS() error {
  157. serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, s.tlsConf, s.logger)
  158. if err != nil {
  159. return err
  160. }
  161. s.serverTLS = serverTLS
  162. // handle TLS connection establishment statelessly
  163. go func() {
  164. for {
  165. select {
  166. case <-s.errorChan:
  167. return
  168. case tlsSession := <-sessionChan:
  169. // The connection ID is a randomly chosen value.
  170. // It is safe to assume that it doesn't collide with other randomly chosen values.
  171. serverSession := newServerSession(tlsSession.sess, s.config, s.logger)
  172. s.sessionHandler.Add(tlsSession.connID, serverSession)
  173. }
  174. }
  175. }()
  176. return nil
  177. }
  178. var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
  179. if cookie == nil {
  180. return false
  181. }
  182. if time.Now().After(cookie.SentTime.Add(protocol.CookieExpiryTime)) {
  183. return false
  184. }
  185. var sourceAddr string
  186. if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
  187. sourceAddr = udpAddr.IP.String()
  188. } else {
  189. sourceAddr = clientAddr.String()
  190. }
  191. return sourceAddr == cookie.RemoteAddr
  192. }
  193. // populateServerConfig populates fields in the quic.Config with their default values, if none are set
  194. // it may be called with nil
  195. func populateServerConfig(config *Config) *Config {
  196. if config == nil {
  197. config = &Config{}
  198. }
  199. versions := config.Versions
  200. if len(versions) == 0 {
  201. versions = protocol.SupportedVersions
  202. }
  203. vsa := defaultAcceptCookie
  204. if config.AcceptCookie != nil {
  205. vsa = config.AcceptCookie
  206. }
  207. handshakeTimeout := protocol.DefaultHandshakeTimeout
  208. if config.HandshakeTimeout != 0 {
  209. handshakeTimeout = config.HandshakeTimeout
  210. }
  211. idleTimeout := protocol.DefaultIdleTimeout
  212. if config.IdleTimeout != 0 {
  213. idleTimeout = config.IdleTimeout
  214. }
  215. maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
  216. if maxReceiveStreamFlowControlWindow == 0 {
  217. maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowServer
  218. }
  219. maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
  220. if maxReceiveConnectionFlowControlWindow == 0 {
  221. maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowServer
  222. }
  223. maxIncomingStreams := config.MaxIncomingStreams
  224. if maxIncomingStreams == 0 {
  225. maxIncomingStreams = protocol.DefaultMaxIncomingStreams
  226. } else if maxIncomingStreams < 0 {
  227. maxIncomingStreams = 0
  228. }
  229. maxIncomingUniStreams := config.MaxIncomingUniStreams
  230. if maxIncomingUniStreams == 0 {
  231. maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
  232. } else if maxIncomingUniStreams < 0 {
  233. maxIncomingUniStreams = 0
  234. }
  235. connIDLen := config.ConnectionIDLength
  236. if connIDLen == 0 {
  237. connIDLen = protocol.DefaultConnectionIDLength
  238. }
  239. for _, v := range versions {
  240. if v == protocol.Version44 {
  241. connIDLen = protocol.ConnectionIDLenGQUIC
  242. }
  243. }
  244. return &Config{
  245. Versions: versions,
  246. HandshakeTimeout: handshakeTimeout,
  247. IdleTimeout: idleTimeout,
  248. AcceptCookie: vsa,
  249. KeepAlive: config.KeepAlive,
  250. MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
  251. MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
  252. MaxIncomingStreams: maxIncomingStreams,
  253. MaxIncomingUniStreams: maxIncomingUniStreams,
  254. ConnectionIDLength: connIDLen,
  255. }
  256. }
  257. // Accept returns newly openend sessions
  258. func (s *server) Accept() (Session, error) {
  259. var sess Session
  260. select {
  261. case sess = <-s.sessionQueue:
  262. return sess, nil
  263. case <-s.errorChan:
  264. return nil, s.serverError
  265. }
  266. }
  267. // Close the server
  268. func (s *server) Close() error {
  269. s.mutex.Lock()
  270. defer s.mutex.Unlock()
  271. if s.closed {
  272. return nil
  273. }
  274. return s.closeWithMutex()
  275. }
  276. func (s *server) closeWithMutex() error {
  277. s.sessionHandler.CloseServer()
  278. if s.serverError == nil {
  279. s.serverError = errors.New("server closed")
  280. }
  281. var err error
  282. // If the server was started with ListenAddr, we created the packet conn.
  283. // We need to close it in order to make the go routine reading from that conn return.
  284. if s.createdPacketConn {
  285. err = s.conn.Close()
  286. }
  287. s.closed = true
  288. close(s.errorChan)
  289. return err
  290. }
  291. func (s *server) closeWithError(e error) error {
  292. s.mutex.Lock()
  293. defer s.mutex.Unlock()
  294. if s.closed {
  295. return nil
  296. }
  297. s.serverError = e
  298. return s.closeWithMutex()
  299. }
  300. // Addr returns the server's network address
  301. func (s *server) Addr() net.Addr {
  302. return s.conn.LocalAddr()
  303. }
  304. func (s *server) handlePacket(p *receivedPacket) {
  305. if err := s.handlePacketImpl(p); err != nil {
  306. s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
  307. }
  308. }
  309. func (s *server) handlePacketImpl(p *receivedPacket) error {
  310. hdr := p.header
  311. if hdr.VersionFlag || hdr.IsLongHeader {
  312. // send a Version Negotiation Packet if the client is speaking a different protocol version
  313. if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
  314. return s.sendVersionNegotiationPacket(p)
  315. }
  316. }
  317. if hdr.Type == protocol.PacketTypeInitial && hdr.Version.UsesTLS() {
  318. go s.serverTLS.HandleInitial(p)
  319. return nil
  320. }
  321. // TODO(#943): send Stateless Reset, if this an IETF QUIC packet
  322. if !hdr.VersionFlag && !hdr.Version.UsesIETFHeaderFormat() {
  323. _, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), p.remoteAddr)
  324. return err
  325. }
  326. // This is (potentially) a Client Hello.
  327. // Make sure it has the minimum required size before spending any more ressources on it.
  328. if len(p.data) < protocol.MinClientHelloSize {
  329. return errors.New("dropping small packet for unknown connection")
  330. }
  331. var destConnID, srcConnID protocol.ConnectionID
  332. if hdr.Version.UsesIETFHeaderFormat() {
  333. srcConnID = hdr.DestConnectionID
  334. } else {
  335. destConnID = hdr.DestConnectionID
  336. srcConnID = hdr.DestConnectionID
  337. }
  338. s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, hdr.Version, p.remoteAddr)
  339. sess, err := s.newSession(
  340. &conn{pconn: s.conn, currentAddr: p.remoteAddr},
  341. s.sessionRunner,
  342. hdr.Version,
  343. destConnID,
  344. srcConnID,
  345. s.scfg,
  346. s.tlsConf,
  347. s.config,
  348. s.logger,
  349. )
  350. if err != nil {
  351. return err
  352. }
  353. s.sessionHandler.Add(hdr.DestConnectionID, newServerSession(sess, s.config, s.logger))
  354. go sess.run()
  355. sess.handlePacket(p)
  356. return nil
  357. }
  358. func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
  359. hdr := p.header
  360. s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
  361. var data []byte
  362. if hdr.IsPublicHeader {
  363. data = wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions)
  364. } else {
  365. var err error
  366. data, err = wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
  367. if err != nil {
  368. return err
  369. }
  370. }
  371. _, err := s.conn.WriteTo(data, p.remoteAddr)
  372. return err
  373. }