| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420 |
- package quic
- import (
- "crypto/tls"
- "errors"
- "fmt"
- "io"
- "net"
- "sync"
- "time"
- "github.com/lucas-clemente/quic-go/internal/crypto"
- "github.com/lucas-clemente/quic-go/internal/handshake"
- "github.com/lucas-clemente/quic-go/internal/protocol"
- "github.com/lucas-clemente/quic-go/internal/utils"
- "github.com/lucas-clemente/quic-go/internal/wire"
- )
- // packetHandler handles packets
- type packetHandler interface {
- handlePacket(*receivedPacket)
- io.Closer
- destroy(error)
- GetVersion() protocol.VersionNumber
- GetPerspective() protocol.Perspective
- }
- type unknownPacketHandler interface {
- handlePacket(*receivedPacket)
- closeWithError(error) error
- }
- type packetHandlerManager interface {
- Add(protocol.ConnectionID, packetHandler)
- SetServer(unknownPacketHandler)
- Remove(protocol.ConnectionID)
- CloseServer()
- }
- type quicSession interface {
- Session
- handlePacket(*receivedPacket)
- GetVersion() protocol.VersionNumber
- run() error
- destroy(error)
- closeRemote(error)
- }
- type sessionRunner interface {
- onHandshakeComplete(Session)
- removeConnectionID(protocol.ConnectionID)
- }
- type runner struct {
- onHandshakeCompleteImpl func(Session)
- removeConnectionIDImpl func(protocol.ConnectionID)
- }
- func (r *runner) onHandshakeComplete(s Session) { r.onHandshakeCompleteImpl(s) }
- func (r *runner) removeConnectionID(c protocol.ConnectionID) { r.removeConnectionIDImpl(c) }
- var _ sessionRunner = &runner{}
- // A Listener of QUIC
- type server struct {
- mutex sync.Mutex
- tlsConf *tls.Config
- config *Config
- conn net.PacketConn
- // If the server is started with ListenAddr, we create a packet conn.
- // If it is started with Listen, we take a packet conn as a parameter.
- createdPacketConn bool
- supportsTLS bool
- serverTLS *serverTLS
- certChain crypto.CertChain
- scfg *handshake.ServerConfig
- sessionHandler packetHandlerManager
- serverError error
- errorChan chan struct{}
- closed bool
- sessionQueue chan Session
- sessionRunner sessionRunner
- // set as a member, so they can be set in the tests
- newSession func(connection, sessionRunner, protocol.VersionNumber, protocol.ConnectionID, protocol.ConnectionID, *handshake.ServerConfig, *tls.Config, *Config, utils.Logger) (quicSession, error)
- logger utils.Logger
- }
- var _ Listener = &server{}
- var _ unknownPacketHandler = &server{}
- // ListenAddr creates a QUIC server listening on a given address.
- // The tls.Config must not be nil, the quic.Config may be nil.
- func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
- udpAddr, err := net.ResolveUDPAddr("udp", addr)
- if err != nil {
- return nil, err
- }
- conn, err := net.ListenUDP("udp", udpAddr)
- if err != nil {
- return nil, err
- }
- serv, err := listen(conn, tlsConf, config)
- if err != nil {
- return nil, err
- }
- serv.createdPacketConn = true
- return serv, nil
- }
- // Listen listens for QUIC connections on a given net.PacketConn.
- // The tls.Config must not be nil, the quic.Config may be nil.
- func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
- return listen(conn, tlsConf, config)
- }
- func listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (*server, error) {
- certChain := crypto.NewCertChain(tlsConf)
- kex, err := crypto.NewCurve25519KEX()
- if err != nil {
- return nil, err
- }
- scfg, err := handshake.NewServerConfig(kex, certChain)
- if err != nil {
- return nil, err
- }
- config = populateServerConfig(config)
- var supportsTLS bool
- for _, v := range config.Versions {
- if !protocol.IsValidVersion(v) {
- return nil, fmt.Errorf("%s is not a valid QUIC version", v)
- }
- // check if any of the supported versions supports TLS
- if v.UsesTLS() {
- supportsTLS = true
- break
- }
- }
- sessionHandler, err := getMultiplexer().AddConn(conn, config.ConnectionIDLength)
- if err != nil {
- return nil, err
- }
- s := &server{
- conn: conn,
- tlsConf: tlsConf,
- config: config,
- certChain: certChain,
- scfg: scfg,
- newSession: newSession,
- sessionHandler: sessionHandler,
- sessionQueue: make(chan Session, 5),
- errorChan: make(chan struct{}),
- supportsTLS: supportsTLS,
- logger: utils.DefaultLogger.WithPrefix("server"),
- }
- s.setup()
- if supportsTLS {
- if err := s.setupTLS(); err != nil {
- return nil, err
- }
- }
- sessionHandler.SetServer(s)
- s.logger.Debugf("Listening for %s connections on %s", conn.LocalAddr().Network(), conn.LocalAddr().String())
- return s, nil
- }
- func (s *server) setup() {
- s.sessionRunner = &runner{
- onHandshakeCompleteImpl: func(sess Session) { s.sessionQueue <- sess },
- removeConnectionIDImpl: s.sessionHandler.Remove,
- }
- }
- func (s *server) setupTLS() error {
- serverTLS, sessionChan, err := newServerTLS(s.conn, s.config, s.sessionRunner, s.tlsConf, s.logger)
- if err != nil {
- return err
- }
- s.serverTLS = serverTLS
- // handle TLS connection establishment statelessly
- go func() {
- for {
- select {
- case <-s.errorChan:
- return
- case tlsSession := <-sessionChan:
- // The connection ID is a randomly chosen value.
- // It is safe to assume that it doesn't collide with other randomly chosen values.
- serverSession := newServerSession(tlsSession.sess, s.config, s.logger)
- s.sessionHandler.Add(tlsSession.connID, serverSession)
- }
- }
- }()
- return nil
- }
- var defaultAcceptCookie = func(clientAddr net.Addr, cookie *Cookie) bool {
- if cookie == nil {
- return false
- }
- if time.Now().After(cookie.SentTime.Add(protocol.CookieExpiryTime)) {
- return false
- }
- var sourceAddr string
- if udpAddr, ok := clientAddr.(*net.UDPAddr); ok {
- sourceAddr = udpAddr.IP.String()
- } else {
- sourceAddr = clientAddr.String()
- }
- return sourceAddr == cookie.RemoteAddr
- }
- // populateServerConfig populates fields in the quic.Config with their default values, if none are set
- // it may be called with nil
- func populateServerConfig(config *Config) *Config {
- if config == nil {
- config = &Config{}
- }
- versions := config.Versions
- if len(versions) == 0 {
- versions = protocol.SupportedVersions
- }
- vsa := defaultAcceptCookie
- if config.AcceptCookie != nil {
- vsa = config.AcceptCookie
- }
- handshakeTimeout := protocol.DefaultHandshakeTimeout
- if config.HandshakeTimeout != 0 {
- handshakeTimeout = config.HandshakeTimeout
- }
- idleTimeout := protocol.DefaultIdleTimeout
- if config.IdleTimeout != 0 {
- idleTimeout = config.IdleTimeout
- }
- maxReceiveStreamFlowControlWindow := config.MaxReceiveStreamFlowControlWindow
- if maxReceiveStreamFlowControlWindow == 0 {
- maxReceiveStreamFlowControlWindow = protocol.DefaultMaxReceiveStreamFlowControlWindowServer
- }
- maxReceiveConnectionFlowControlWindow := config.MaxReceiveConnectionFlowControlWindow
- if maxReceiveConnectionFlowControlWindow == 0 {
- maxReceiveConnectionFlowControlWindow = protocol.DefaultMaxReceiveConnectionFlowControlWindowServer
- }
- maxIncomingStreams := config.MaxIncomingStreams
- if maxIncomingStreams == 0 {
- maxIncomingStreams = protocol.DefaultMaxIncomingStreams
- } else if maxIncomingStreams < 0 {
- maxIncomingStreams = 0
- }
- maxIncomingUniStreams := config.MaxIncomingUniStreams
- if maxIncomingUniStreams == 0 {
- maxIncomingUniStreams = protocol.DefaultMaxIncomingUniStreams
- } else if maxIncomingUniStreams < 0 {
- maxIncomingUniStreams = 0
- }
- connIDLen := config.ConnectionIDLength
- if connIDLen == 0 {
- connIDLen = protocol.DefaultConnectionIDLength
- }
- for _, v := range versions {
- if v == protocol.Version44 {
- connIDLen = protocol.ConnectionIDLenGQUIC
- }
- }
- return &Config{
- Versions: versions,
- HandshakeTimeout: handshakeTimeout,
- IdleTimeout: idleTimeout,
- AcceptCookie: vsa,
- KeepAlive: config.KeepAlive,
- MaxReceiveStreamFlowControlWindow: maxReceiveStreamFlowControlWindow,
- MaxReceiveConnectionFlowControlWindow: maxReceiveConnectionFlowControlWindow,
- MaxIncomingStreams: maxIncomingStreams,
- MaxIncomingUniStreams: maxIncomingUniStreams,
- ConnectionIDLength: connIDLen,
- }
- }
- // Accept returns newly openend sessions
- func (s *server) Accept() (Session, error) {
- var sess Session
- select {
- case sess = <-s.sessionQueue:
- return sess, nil
- case <-s.errorChan:
- return nil, s.serverError
- }
- }
- // Close the server
- func (s *server) Close() error {
- s.mutex.Lock()
- defer s.mutex.Unlock()
- if s.closed {
- return nil
- }
- return s.closeWithMutex()
- }
- func (s *server) closeWithMutex() error {
- s.sessionHandler.CloseServer()
- if s.serverError == nil {
- s.serverError = errors.New("server closed")
- }
- var err error
- // If the server was started with ListenAddr, we created the packet conn.
- // We need to close it in order to make the go routine reading from that conn return.
- if s.createdPacketConn {
- err = s.conn.Close()
- }
- s.closed = true
- close(s.errorChan)
- return err
- }
- func (s *server) closeWithError(e error) error {
- s.mutex.Lock()
- defer s.mutex.Unlock()
- if s.closed {
- return nil
- }
- s.serverError = e
- return s.closeWithMutex()
- }
- // Addr returns the server's network address
- func (s *server) Addr() net.Addr {
- return s.conn.LocalAddr()
- }
- func (s *server) handlePacket(p *receivedPacket) {
- if err := s.handlePacketImpl(p); err != nil {
- s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
- }
- }
- func (s *server) handlePacketImpl(p *receivedPacket) error {
- hdr := p.header
- if hdr.VersionFlag || hdr.IsLongHeader {
- // send a Version Negotiation Packet if the client is speaking a different protocol version
- if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
- return s.sendVersionNegotiationPacket(p)
- }
- }
- if hdr.Type == protocol.PacketTypeInitial && hdr.Version.UsesTLS() {
- go s.serverTLS.HandleInitial(p)
- return nil
- }
- // TODO(#943): send Stateless Reset, if this an IETF QUIC packet
- if !hdr.VersionFlag && !hdr.Version.UsesIETFHeaderFormat() {
- _, err := s.conn.WriteTo(wire.WritePublicReset(hdr.DestConnectionID, 0, 0), p.remoteAddr)
- return err
- }
- // This is (potentially) a Client Hello.
- // Make sure it has the minimum required size before spending any more ressources on it.
- if len(p.data) < protocol.MinClientHelloSize {
- return errors.New("dropping small packet for unknown connection")
- }
- var destConnID, srcConnID protocol.ConnectionID
- if hdr.Version.UsesIETFHeaderFormat() {
- srcConnID = hdr.DestConnectionID
- } else {
- destConnID = hdr.DestConnectionID
- srcConnID = hdr.DestConnectionID
- }
- s.logger.Infof("Serving new connection: %s, version %s from %v", hdr.DestConnectionID, hdr.Version, p.remoteAddr)
- sess, err := s.newSession(
- &conn{pconn: s.conn, currentAddr: p.remoteAddr},
- s.sessionRunner,
- hdr.Version,
- destConnID,
- srcConnID,
- s.scfg,
- s.tlsConf,
- s.config,
- s.logger,
- )
- if err != nil {
- return err
- }
- s.sessionHandler.Add(hdr.DestConnectionID, newServerSession(sess, s.config, s.logger))
- go sess.run()
- sess.handlePacket(p)
- return nil
- }
- func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
- hdr := p.header
- s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
- var data []byte
- if hdr.IsPublicHeader {
- data = wire.ComposeGQUICVersionNegotiation(hdr.DestConnectionID, s.config.Versions)
- } else {
- var err error
- data, err = wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
- if err != nil {
- return err
- }
- }
- _, err := s.conn.WriteTo(data, p.remoteAddr)
- return err
- }
|