Browse Source

update quic vendor

Darien Raymond 7 years ago
parent
commit
135bf169c0

+ 44 - 61
vendor/github.com/lucas-clemente/quic-go/client.go

@@ -27,7 +27,7 @@ type client struct {
 
 	token []byte
 
-	versionNegotiated                bool // has the server accepted our version
+	versionNegotiated                utils.AtomicBool // has the server accepted our version
 	receivedVersionNegotiationPacket bool
 	negotiatedVersions               []protocol.VersionNumber // the list of versions from the version negotiation packet
 
@@ -59,6 +59,7 @@ var (
 )
 
 // DialAddr establishes a new QUIC connection to a server.
+// It uses a new UDP connection and closes this connection when the QUIC session is closed.
 // The hostname for SNI is taken from the given address.
 func DialAddr(
 	addr string,
@@ -69,7 +70,7 @@ func DialAddr(
 }
 
 // DialAddrContext establishes a new QUIC connection to a server using the provided context.
-// The hostname for SNI is taken from the given address.
+// See DialAddr for details.
 func DialAddrContext(
 	ctx context.Context,
 	addr string,
@@ -88,6 +89,8 @@ func DialAddrContext(
 }
 
 // Dial establishes a new QUIC connection to a server using a net.PacketConn.
+// The same PacketConn can be used for multiple calls to Dial and Listen,
+// QUIC connection IDs are used for demultiplexing the different connections.
 // The host parameter is used for SNI.
 func Dial(
 	pconn net.PacketConn,
@@ -100,7 +103,7 @@ func Dial(
 }
 
 // DialContext establishes a new QUIC connection to a server using a net.PacketConn using the provided context.
-// The host parameter is used for SNI.
+// See Dial for details.
 func DialContext(
 	ctx context.Context,
 	pconn net.PacketConn,
@@ -164,7 +167,18 @@ func newClient(
 			}
 		}
 	}
+
+	srcConnID, err := generateConnectionID(config.ConnectionIDLength)
+	if err != nil {
+		return nil, err
+	}
+	destConnID, err := generateConnectionIDForInitial()
+	if err != nil {
+		return nil, err
+	}
 	c := &client{
+		srcConnID:         srcConnID,
+		destConnID:        destConnID,
 		conn:              &conn{pconn: pconn, currentAddr: remoteAddr},
 		createdPacketConn: createdPacketConn,
 		tlsConf:           tlsConf,
@@ -173,7 +187,7 @@ func newClient(
 		handshakeChan:     make(chan struct{}),
 		logger:            utils.DefaultLogger.WithPrefix("client"),
 	}
-	return c, c.generateConnectionIDs()
+	return c, nil
 }
 
 // populateClientConfig populates fields in the quic.Config with their default values, if none are set
@@ -234,20 +248,6 @@ func populateClientConfig(config *Config, createdPacketConn bool) *Config {
 	}
 }
 
-func (c *client) generateConnectionIDs() error {
-	srcConnID, err := generateConnectionID(c.config.ConnectionIDLength)
-	if err != nil {
-		return err
-	}
-	destConnID, err := generateConnectionIDForInitial()
-	if err != nil {
-		return err
-	}
-	c.srcConnID = srcConnID
-	c.destConnID = destConnID
-	return nil
-}
-
 func (c *client) dial(ctx context.Context) error {
 	c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.tlsConf.ServerName, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)
 
@@ -292,65 +292,49 @@ func (c *client) establishSecureConnection(ctx context.Context) error {
 }
 
 func (c *client) handlePacket(p *receivedPacket) {
-	if err := c.handlePacketImpl(p); err != nil {
-		c.logger.Errorf("error handling packet: %s", err)
-	}
-}
-
-func (c *client) handlePacketImpl(p *receivedPacket) error {
-	c.mutex.Lock()
-	defer c.mutex.Unlock()
-
-	// handle Version Negotiation Packets
-	if p.header.IsVersionNegotiation {
-		err := c.handleVersionNegotiationPacket(p.header)
-		if err != nil {
-			c.session.destroy(err)
-		}
-		// version negotiation packets have no payload
-		return err
-	}
-
-	// reject packets with the wrong connection ID
-	if !p.header.DestConnectionID.Equal(c.srcConnID) {
-		return fmt.Errorf("received a packet with an unexpected connection ID (%s, expected %s)", p.header.DestConnectionID, c.srcConnID)
+	if p.hdr.IsVersionNegotiation() {
+		go c.handleVersionNegotiationPacket(p.hdr)
+		return
 	}
 
-	if p.header.Type == protocol.PacketTypeRetry {
-		c.handleRetryPacket(p.header)
-		return nil
+	if p.hdr.Type == protocol.PacketTypeRetry {
+		go c.handleRetryPacket(p.hdr)
+		return
 	}
 
 	// this is the first packet we are receiving
 	// since it is not a Version Negotiation Packet, this means the server supports the suggested version
-	if !c.versionNegotiated {
-		c.versionNegotiated = true
+	if !c.versionNegotiated.Get() {
+		c.versionNegotiated.Set(true)
 	}
 
 	c.session.handlePacket(p)
-	return nil
 }
 
-func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
+func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) {
+	c.mutex.Lock()
+	defer c.mutex.Unlock()
+
 	// ignore delayed / duplicated version negotiation packets
-	if c.receivedVersionNegotiationPacket || c.versionNegotiated {
-		c.logger.Debugf("Received a delayed Version Negotiation Packet.")
-		return nil
+	if c.receivedVersionNegotiationPacket || c.versionNegotiated.Get() {
+		c.logger.Debugf("Received a delayed Version Negotiation packet.")
+		return
 	}
 
 	for _, v := range hdr.SupportedVersions {
 		if v == c.version {
-			// the version negotiation packet contains the version that we offered
-			// this might be a packet sent by an attacker (or by a terribly broken server implementation)
-			// ignore it
-			return nil
+			// The Version Negotiation packet contains the version that we offered.
+			// This might be a packet sent by an attacker (or by a terribly broken server implementation).
+			return
 		}
 	}
 
-	c.logger.Infof("Received a Version Negotiation Packet. Supported Versions: %s", hdr.SupportedVersions)
+	c.logger.Infof("Received a Version Negotiation packet. Supported Versions: %s", hdr.SupportedVersions)
 	newVersion, ok := protocol.ChooseSupportedVersion(c.config.Versions, hdr.SupportedVersions)
 	if !ok {
-		return qerr.InvalidVersion
+		c.session.destroy(qerr.InvalidVersion)
+		c.logger.Debugf("No compatible version found.")
+		return
 	}
 	c.receivedVersionNegotiationPacket = true
 	c.negotiatedVersions = hdr.SupportedVersions
@@ -358,18 +342,17 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
 	// switch to negotiated version
 	c.initialVersion = c.version
 	c.version = newVersion
-	if err := c.generateConnectionIDs(); err != nil {
-		return err
-	}
 
 	c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
 	c.session.destroy(errCloseSessionForNewVersion)
-	return nil
 }
 
 func (c *client) handleRetryPacket(hdr *wire.Header) {
+	c.mutex.Lock()
+	defer c.mutex.Unlock()
+
 	c.logger.Debugf("<- Received Retry")
-	hdr.Log(c.logger)
+	(&wire.ExtendedHeader{Header: *hdr}).Log(c.logger)
 	if !hdr.OrigDestConnectionID.Equal(c.destConnID) {
 		c.logger.Debugf("Ignoring spoofed Retry. Original Destination Connection ID: %s, expected: %s", hdr.OrigDestConnectionID, c.destConnID)
 		return

+ 2 - 5
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/sent_packet_handler.go

@@ -75,12 +75,10 @@ type sentPacketHandler struct {
 	alarm time.Time
 
 	logger utils.Logger
-
-	version protocol.VersionNumber
 }
 
 // NewSentPacketHandler creates a new sentPacketHandler
-func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, version protocol.VersionNumber) SentPacketHandler {
+func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger) SentPacketHandler {
 	congestion := congestion.NewCubicSender(
 		congestion.DefaultClock{},
 		rttStats,
@@ -95,7 +93,6 @@ func NewSentPacketHandler(rttStats *congestion.RTTStats, logger utils.Logger, ve
 		rttStats:              rttStats,
 		congestion:            congestion,
 		logger:                logger,
-		version:               version,
 	}
 }
 
@@ -516,7 +513,7 @@ func (h *sentPacketHandler) DequeueProbePacket() (*Packet, error) {
 
 func (h *sentPacketHandler) PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen) {
 	pn := h.packetNumberGenerator.Peek()
-	return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked(), h.version)
+	return pn, protocol.GetPacketNumberLengthForHeader(pn, h.lowestUnacked())
 }
 
 func (h *sentPacketHandler) PopPacketNumber() protocol.PacketNumber {

+ 8 - 6
vendor/github.com/lucas-clemente/quic-go/internal/handshake/qtls.go

@@ -11,11 +11,13 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
 		c = &tls.Config{}
 	}
 	// QUIC requires TLS 1.3 or newer
-	if c.MinVersion < qtls.VersionTLS13 {
-		c.MinVersion = qtls.VersionTLS13
+	minVersion := c.MinVersion
+	if minVersion < qtls.VersionTLS13 {
+		minVersion = qtls.VersionTLS13
 	}
-	if c.MaxVersion < qtls.VersionTLS13 {
-		c.MaxVersion = qtls.VersionTLS13
+	maxVersion := c.MaxVersion
+	if maxVersion < qtls.VersionTLS13 {
+		maxVersion = qtls.VersionTLS13
 	}
 	return &qtls.Config{
 		Rand:              c.Rand,
@@ -38,8 +40,8 @@ func tlsConfigToQtlsConfig(c *tls.Config) *qtls.Config {
 		PreferServerCipherSuites:    c.PreferServerCipherSuites,
 		SessionTicketsDisabled:      c.SessionTicketsDisabled,
 		SessionTicketKey:            c.SessionTicketKey,
-		MinVersion:                  c.MinVersion,
-		MaxVersion:                  c.MaxVersion,
+		MinVersion:                  minVersion,
+		MaxVersion:                  maxVersion,
 		CurvePreferences:            c.CurvePreferences,
 		DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
 		Renegotiation:               c.Renegotiation,

+ 29 - 6
vendor/github.com/lucas-clemente/quic-go/internal/protocol/packet_number.go

@@ -1,20 +1,37 @@
 package protocol
 
+// PacketNumberLen is the length of the packet number in bytes
+type PacketNumberLen uint8
+
+const (
+	// PacketNumberLenInvalid is the default value and not a valid length for a packet number
+	PacketNumberLenInvalid PacketNumberLen = 0
+	// PacketNumberLen1 is a packet number length of 1 byte
+	PacketNumberLen1 PacketNumberLen = 1
+	// PacketNumberLen2 is a packet number length of 2 bytes
+	PacketNumberLen2 PacketNumberLen = 2
+	// PacketNumberLen3 is a packet number length of 3 bytes
+	PacketNumberLen3 PacketNumberLen = 3
+	// PacketNumberLen4 is a packet number length of 4 bytes
+	PacketNumberLen4 PacketNumberLen = 4
+)
+
 // InferPacketNumber calculates the packet number based on the received packet number, its length and the last seen packet number
 func InferPacketNumber(
 	packetNumberLength PacketNumberLen,
 	lastPacketNumber PacketNumber,
 	wirePacketNumber PacketNumber,
-	version VersionNumber,
 ) PacketNumber {
 	var epochDelta PacketNumber
 	switch packetNumberLength {
 	case PacketNumberLen1:
-		epochDelta = PacketNumber(1) << 7
+		epochDelta = PacketNumber(1) << 8
 	case PacketNumberLen2:
-		epochDelta = PacketNumber(1) << 14
+		epochDelta = PacketNumber(1) << 16
+	case PacketNumberLen3:
+		epochDelta = PacketNumber(1) << 24
 	case PacketNumberLen4:
-		epochDelta = PacketNumber(1) << 30
+		epochDelta = PacketNumber(1) << 32
 	}
 	epoch := lastPacketNumber & ^(epochDelta - 1)
 	prevEpochBegin := epoch - epochDelta
@@ -42,11 +59,14 @@ func delta(a, b PacketNumber) PacketNumber {
 
 // GetPacketNumberLengthForHeader gets the length of the packet number for the public header
 // it never chooses a PacketNumberLen of 1 byte, since this is too short under certain circumstances
-func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber, version VersionNumber) PacketNumberLen {
+func GetPacketNumberLengthForHeader(packetNumber, leastUnacked PacketNumber) PacketNumberLen {
 	diff := uint64(packetNumber - leastUnacked)
-	if diff < (1 << (14 - 1)) {
+	if diff < (1 << (16 - 1)) {
 		return PacketNumberLen2
 	}
+	if diff < (1 << (24 - 1)) {
+		return PacketNumberLen3
+	}
 	return PacketNumberLen4
 }
 
@@ -58,5 +78,8 @@ func GetPacketNumberLength(packetNumber PacketNumber) PacketNumberLen {
 	if packetNumber < (1 << (uint8(PacketNumberLen2) * 8)) {
 		return PacketNumberLen2
 	}
+	if packetNumber < (1 << (uint8(PacketNumberLen3) * 8)) {
+		return PacketNumberLen3
+	}
 	return PacketNumberLen4
 }

+ 4 - 24
vendor/github.com/lucas-clemente/quic-go/internal/protocol/protocol.go

@@ -7,32 +7,18 @@ import (
 // A PacketNumber in QUIC
 type PacketNumber uint64
 
-// PacketNumberLen is the length of the packet number in bytes
-type PacketNumberLen uint8
-
-const (
-	// PacketNumberLenInvalid is the default value and not a valid length for a packet number
-	PacketNumberLenInvalid PacketNumberLen = 0
-	// PacketNumberLen1 is a packet number length of 1 byte
-	PacketNumberLen1 PacketNumberLen = 1
-	// PacketNumberLen2 is a packet number length of 2 bytes
-	PacketNumberLen2 PacketNumberLen = 2
-	// PacketNumberLen4 is a packet number length of 4 bytes
-	PacketNumberLen4 PacketNumberLen = 4
-)
-
 // The PacketType is the Long Header Type
 type PacketType uint8
 
 const (
 	// PacketTypeInitial is the packet type of an Initial packet
-	PacketTypeInitial PacketType = 0x7f
+	PacketTypeInitial PacketType = 1 + iota
 	// PacketTypeRetry is the packet type of a Retry packet
-	PacketTypeRetry PacketType = 0x7e
+	PacketTypeRetry
 	// PacketTypeHandshake is the packet type of a Handshake packet
-	PacketTypeHandshake PacketType = 0x7d
+	PacketTypeHandshake
 	// PacketType0RTT is the packet type of a 0-RTT packet
-	PacketType0RTT PacketType = 0x7c
+	PacketType0RTT
 )
 
 func (t PacketType) String() string {
@@ -72,11 +58,5 @@ const DefaultTCPMSS ByteCount = 1460
 // MinInitialPacketSize is the minimum size an Initial packet is required to have.
 const MinInitialPacketSize = 1200
 
-// MaxClientHellos is the maximum number of times we'll send a client hello
-// The value 3 accounts for:
-// * one failure due to an incorrect or missing source-address token
-// * one failure due the server's certificate chain being unavailable and the server being unwilling to send it without a valid source-address token
-const MaxClientHellos = 3
-
 // MinConnectionIDLenInitial is the minimum length of the destination connection ID on an Initial packet.
 const MinConnectionIDLenInitial = 8

+ 1 - 2
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder.go

@@ -8,11 +8,10 @@ import (
 // A ByteOrder specifies how to convert byte sequences into 16-, 32-, or 64-bit unsigned integers.
 type ByteOrder interface {
 	ReadUintN(b io.ByteReader, length uint8) (uint64, error)
-	ReadUint64(io.ByteReader) (uint64, error)
 	ReadUint32(io.ByteReader) (uint32, error)
 	ReadUint16(io.ByteReader) (uint16, error)
 
-	WriteUint64(*bytes.Buffer, uint64)
+	WriteUintN(b *bytes.Buffer, length uint8, value uint64)
 	WriteUint32(*bytes.Buffer, uint32)
 	WriteUint16(*bytes.Buffer, uint16)
 }

+ 4 - 37
vendor/github.com/lucas-clemente/quic-go/internal/utils/byteorder_big_endian.go

@@ -25,37 +25,6 @@ func (bigEndian) ReadUintN(b io.ByteReader, length uint8) (uint64, error) {
 	return res, nil
 }
 
-// ReadUint64 reads a uint64
-func (bigEndian) ReadUint64(b io.ByteReader) (uint64, error) {
-	var b1, b2, b3, b4, b5, b6, b7, b8 uint8
-	var err error
-	if b8, err = b.ReadByte(); err != nil {
-		return 0, err
-	}
-	if b7, err = b.ReadByte(); err != nil {
-		return 0, err
-	}
-	if b6, err = b.ReadByte(); err != nil {
-		return 0, err
-	}
-	if b5, err = b.ReadByte(); err != nil {
-		return 0, err
-	}
-	if b4, err = b.ReadByte(); err != nil {
-		return 0, err
-	}
-	if b3, err = b.ReadByte(); err != nil {
-		return 0, err
-	}
-	if b2, err = b.ReadByte(); err != nil {
-		return 0, err
-	}
-	if b1, err = b.ReadByte(); err != nil {
-		return 0, err
-	}
-	return uint64(b1) + uint64(b2)<<8 + uint64(b3)<<16 + uint64(b4)<<24 + uint64(b5)<<32 + uint64(b6)<<40 + uint64(b7)<<48 + uint64(b8)<<56, nil
-}
-
 // ReadUint32 reads a uint32
 func (bigEndian) ReadUint32(b io.ByteReader) (uint32, error) {
 	var b1, b2, b3, b4 uint8
@@ -88,12 +57,10 @@ func (bigEndian) ReadUint16(b io.ByteReader) (uint16, error) {
 	return uint16(b1) + uint16(b2)<<8, nil
 }
 
-// WriteUint64 writes a uint64
-func (bigEndian) WriteUint64(b *bytes.Buffer, i uint64) {
-	b.Write([]byte{
-		uint8(i >> 56), uint8(i >> 48), uint8(i >> 40), uint8(i >> 32),
-		uint8(i >> 24), uint8(i >> 16), uint8(i >> 8), uint8(i),
-	})
+func (bigEndian) WriteUintN(b *bytes.Buffer, length uint8, i uint64) {
+	for j := length; j > 0; j-- {
+		b.WriteByte(uint8(i >> (8 * (j - 1))))
+	}
 }
 
 // WriteUint32 writes a uint32

+ 205 - 0
vendor/github.com/lucas-clemente/quic-go/internal/wire/extended_header.go

@@ -0,0 +1,205 @@
+package wire
+
+import (
+	"bytes"
+	"errors"
+	"fmt"
+	"io"
+
+	"github.com/lucas-clemente/quic-go/internal/protocol"
+	"github.com/lucas-clemente/quic-go/internal/utils"
+)
+
+// ExtendedHeader is the header of a QUIC packet.
+type ExtendedHeader struct {
+	Header
+
+	typeByte byte
+	Raw      []byte
+
+	PacketNumberLen protocol.PacketNumberLen
+	PacketNumber    protocol.PacketNumber
+
+	KeyPhase int
+}
+
+func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
+	// read the (now unencrypted) first byte
+	var err error
+	h.typeByte, err = b.ReadByte()
+	if err != nil {
+		return nil, err
+	}
+	if _, err := b.Seek(int64(h.len)-1, io.SeekCurrent); err != nil {
+		return nil, err
+	}
+	if h.IsLongHeader {
+		return h.parseLongHeader(b, v)
+	}
+	return h.parseShortHeader(b, v)
+}
+
+func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
+	if h.typeByte&0xc != 0 {
+		return nil, errors.New("5th and 6th bit must be 0")
+	}
+	if err := h.readPacketNumber(b); err != nil {
+		return nil, err
+	}
+	return h, nil
+}
+
+func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, v protocol.VersionNumber) (*ExtendedHeader, error) {
+	if h.typeByte&0x18 != 0 {
+		return nil, errors.New("4th and 5th bit must be 0")
+	}
+
+	h.KeyPhase = int(h.typeByte&0x4) >> 2
+
+	if err := h.readPacketNumber(b); err != nil {
+		return nil, err
+	}
+	return h, nil
+}
+
+func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
+	h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
+	pn, err := utils.BigEndian.ReadUintN(b, uint8(h.PacketNumberLen))
+	if err != nil {
+		return err
+	}
+	h.PacketNumber = protocol.PacketNumber(pn)
+	return nil
+}
+
+// Write writes the Header.
+func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
+	if h.IsLongHeader {
+		return h.writeLongHeader(b, ver)
+	}
+	return h.writeShortHeader(b, ver)
+}
+
+func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
+	var packetType uint8
+	switch h.Type {
+	case protocol.PacketTypeInitial:
+		packetType = 0x0
+	case protocol.PacketType0RTT:
+		packetType = 0x1
+	case protocol.PacketTypeHandshake:
+		packetType = 0x2
+	case protocol.PacketTypeRetry:
+		packetType = 0x3
+	}
+	firstByte := 0xc0 | packetType<<4
+	if h.Type == protocol.PacketTypeRetry {
+		odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
+		if err != nil {
+			return err
+		}
+		firstByte |= odcil
+	} else { // Retry packets don't have a packet number
+		firstByte |= uint8(h.PacketNumberLen - 1)
+	}
+
+	b.WriteByte(firstByte)
+	utils.BigEndian.WriteUint32(b, uint32(h.Version))
+	connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
+	if err != nil {
+		return err
+	}
+	b.WriteByte(connIDLen)
+	b.Write(h.DestConnectionID.Bytes())
+	b.Write(h.SrcConnectionID.Bytes())
+
+	switch h.Type {
+	case protocol.PacketTypeRetry:
+		b.Write(h.OrigDestConnectionID.Bytes())
+		b.Write(h.Token)
+		return nil
+	case protocol.PacketTypeInitial:
+		utils.WriteVarInt(b, uint64(len(h.Token)))
+		b.Write(h.Token)
+	}
+
+	utils.WriteVarInt(b, uint64(h.Length))
+	return h.writePacketNumber(b)
+}
+
+// TODO: add support for the key phase
+func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
+	typeByte := 0x40 | uint8(h.PacketNumberLen-1)
+	typeByte |= byte(h.KeyPhase << 2)
+
+	b.WriteByte(typeByte)
+	b.Write(h.DestConnectionID.Bytes())
+	return h.writePacketNumber(b)
+}
+
+func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
+	if h.PacketNumberLen == protocol.PacketNumberLenInvalid || h.PacketNumberLen > protocol.PacketNumberLen4 {
+		return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
+	}
+	utils.BigEndian.WriteUintN(b, uint8(h.PacketNumberLen), uint64(h.PacketNumber))
+	return nil
+}
+
+// GetLength determines the length of the Header.
+func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount {
+	if h.IsLongHeader {
+		length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length))
+		if h.Type == protocol.PacketTypeInitial {
+			length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
+		}
+		return length
+	}
+
+	length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
+	length += protocol.ByteCount(h.PacketNumberLen)
+	return length
+}
+
+// Log logs the Header
+func (h *ExtendedHeader) Log(logger utils.Logger) {
+	if h.IsLongHeader {
+		var token string
+		if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
+			if len(h.Token) == 0 {
+				token = "Token: (empty), "
+			} else {
+				token = fmt.Sprintf("Token: %#x, ", h.Token)
+			}
+			if h.Type == protocol.PacketTypeRetry {
+				logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
+				return
+			}
+		}
+		logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
+	} else {
+		logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
+	}
+}
+
+func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
+	dcil, err := encodeSingleConnIDLen(dest)
+	if err != nil {
+		return 0, err
+	}
+	scil, err := encodeSingleConnIDLen(src)
+	if err != nil {
+		return 0, err
+	}
+	return scil | dcil<<4, nil
+}
+
+func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
+	len := id.Len()
+	if len == 0 {
+		return 0, nil
+	}
+	if len < 4 || len > 18 {
+		return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
+	}
+	return byte(len - 3), nil
+}

+ 138 - 105
vendor/github.com/lucas-clemente/quic-go/internal/wire/header.go

@@ -2,150 +2,183 @@ package wire
 
 import (
 	"bytes"
-	"crypto/rand"
-	"fmt"
+	"errors"
+	"io"
 
 	"github.com/lucas-clemente/quic-go/internal/protocol"
+	"github.com/lucas-clemente/quic-go/internal/qerr"
 	"github.com/lucas-clemente/quic-go/internal/utils"
 )
 
-// Header is the header of a QUIC packet.
+// The Header is the version independent part of the header
 type Header struct {
-	Raw []byte
+	Version          protocol.VersionNumber
+	SrcConnectionID  protocol.ConnectionID
+	DestConnectionID protocol.ConnectionID
 
-	Version protocol.VersionNumber
-
-	DestConnectionID     protocol.ConnectionID
-	SrcConnectionID      protocol.ConnectionID
-	OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet
+	IsLongHeader bool
+	Type         protocol.PacketType
+	Length       protocol.ByteCount
 
-	PacketNumberLen protocol.PacketNumberLen
-	PacketNumber    protocol.PacketNumber
+	Token                []byte
+	SupportedVersions    []protocol.VersionNumber // sent in a Version Negotiation Packet
+	OrigDestConnectionID protocol.ConnectionID    // sent in the Retry packet
 
-	IsVersionNegotiation bool
-	SupportedVersions    []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
+	typeByte byte
+	len      int // how many bytes were read while parsing this header
+}
 
-	Type         protocol.PacketType
-	IsLongHeader bool
-	KeyPhase     int
-	Length       protocol.ByteCount
-	Token        []byte
+// ParseHeader parses the header.
+// For short header packets: up to the packet number.
+// For long header packets:
+// * if we understand the version: up to the packet number
+// * if not, only the invariant part of the header
+func ParseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
+	startLen := b.Len()
+	h, err := parseHeaderImpl(b, shortHeaderConnIDLen)
+	if err != nil {
+		return nil, err
+	}
+	h.len = startLen - b.Len()
+	return h, nil
 }
 
-// Write writes the Header.
-func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
-	if h.IsLongHeader {
-		return h.writeLongHeader(b, ver)
+func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
+	typeByte, err := b.ReadByte()
+	if err != nil {
+		return nil, err
+	}
+
+	h := &Header{
+		typeByte:     typeByte,
+		IsLongHeader: typeByte&0x80 > 0,
+	}
+
+	if !h.IsLongHeader {
+		if h.typeByte&0x40 == 0 {
+			return nil, errors.New("not a QUIC packet")
+		}
+		if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil {
+			return nil, err
+		}
+		return h, nil
 	}
-	return h.writeShortHeader(b, ver)
+	if err := h.parseLongHeader(b); err != nil {
+		return nil, err
+	}
+	return h, nil
+}
+
+func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error {
+	var err error
+	h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
+	return err
 }
 
-// TODO: add support for the key phase
-func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
-	b.WriteByte(byte(0x80 | h.Type))
-	utils.BigEndian.WriteUint32(b, uint32(h.Version))
-	connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
+func (h *Header) parseLongHeader(b *bytes.Reader) error {
+	v, err := utils.BigEndian.ReadUint32(b)
+	if err != nil {
+		return err
+	}
+	h.Version = protocol.VersionNumber(v)
+	if !h.IsVersionNegotiation() && h.typeByte&0x40 == 0 {
+		return errors.New("not a QUIC packet")
+	}
+	connIDLenByte, err := b.ReadByte()
+	if err != nil {
+		return err
+	}
+	dcil, scil := decodeConnIDLen(connIDLenByte)
+	h.DestConnectionID, err = protocol.ReadConnectionID(b, dcil)
+	if err != nil {
+		return err
+	}
+	h.SrcConnectionID, err = protocol.ReadConnectionID(b, scil)
 	if err != nil {
 		return err
 	}
-	b.WriteByte(connIDLen)
-	b.Write(h.DestConnectionID.Bytes())
-	b.Write(h.SrcConnectionID.Bytes())
+	if h.Version == 0 {
+		return h.parseVersionNegotiationPacket(b)
+	}
+	// If we don't understand the version, we have no idea how to interpret the rest of the bytes
+	if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
+		return nil
+	}
 
-	if h.Type == protocol.PacketTypeInitial {
-		utils.WriteVarInt(b, uint64(len(h.Token)))
-		b.Write(h.Token)
+	switch (h.typeByte & 0x30) >> 4 {
+	case 0x0:
+		h.Type = protocol.PacketTypeInitial
+	case 0x1:
+		h.Type = protocol.PacketType0RTT
+	case 0x2:
+		h.Type = protocol.PacketTypeHandshake
+	case 0x3:
+		h.Type = protocol.PacketTypeRetry
 	}
 
 	if h.Type == protocol.PacketTypeRetry {
-		odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
+		odcil := decodeSingleConnIDLen(h.typeByte & 0xf)
+		h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
 		if err != nil {
 			return err
 		}
-		// randomize the first 4 bits
-		odcilByte := make([]byte, 1)
-		_, _ = rand.Read(odcilByte) // it's safe to ignore the error here
-		odcilByte[0] = (odcilByte[0] & 0xf0) | odcil
-		b.Write(odcilByte)
-		b.Write(h.OrigDestConnectionID.Bytes())
-		b.Write(h.Token)
+		h.Token = make([]byte, b.Len())
+		if _, err := io.ReadFull(b, h.Token); err != nil {
+			return err
+		}
 		return nil
 	}
 
-	utils.WriteVarInt(b, uint64(h.Length))
-	return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
-}
-
-func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
-	typeByte := byte(0x30)
-	typeByte |= byte(h.KeyPhase << 6)
-
-	b.WriteByte(typeByte)
-	b.Write(h.DestConnectionID.Bytes())
-	return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
-}
-
-// GetLength determines the length of the Header.
-func (h *Header) GetLength(v protocol.VersionNumber) protocol.ByteCount {
-	if h.IsLongHeader {
-		length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length))
-		if h.Type == protocol.PacketTypeInitial {
-			length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
+	if h.Type == protocol.PacketTypeInitial {
+		tokenLen, err := utils.ReadVarInt(b)
+		if err != nil {
+			return err
+		}
+		if tokenLen > uint64(b.Len()) {
+			return io.EOF
+		}
+		h.Token = make([]byte, tokenLen)
+		if _, err := io.ReadFull(b, h.Token); err != nil {
+			return err
 		}
-		return length
 	}
 
-	length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
-	length += protocol.ByteCount(h.PacketNumberLen)
-	return length
+	pl, err := utils.ReadVarInt(b)
+	if err != nil {
+		return err
+	}
+	h.Length = protocol.ByteCount(pl)
+	return nil
 }
 
-// Log logs the Header
-func (h *Header) Log(logger utils.Logger) {
-	if h.IsLongHeader {
-		if h.Version == 0 {
-			logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
-		} else {
-			var token string
-			if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
-				if len(h.Token) == 0 {
-					token = "Token: (empty), "
-				} else {
-					token = fmt.Sprintf("Token: %#x, ", h.Token)
-				}
-			}
-			if h.Type == protocol.PacketTypeRetry {
-				logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
-				return
-			}
-			logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
+func (h *Header) parseVersionNegotiationPacket(b *bytes.Reader) error {
+	if b.Len() == 0 {
+		return qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
+	}
+	h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4)
+	for i := 0; b.Len() > 0; i++ {
+		v, err := utils.BigEndian.ReadUint32(b)
+		if err != nil {
+			return qerr.InvalidVersionNegotiationPacket
 		}
-	} else {
-		logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
+		h.SupportedVersions[i] = protocol.VersionNumber(v)
 	}
+	return nil
 }
 
-func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
-	dcil, err := encodeSingleConnIDLen(dest)
-	if err != nil {
-		return 0, err
-	}
-	scil, err := encodeSingleConnIDLen(src)
-	if err != nil {
-		return 0, err
-	}
-	return scil | dcil<<4, nil
+// IsVersionNegotiation says if this a version negotiation packet
+func (h *Header) IsVersionNegotiation() bool {
+	return h.IsLongHeader && h.Version == 0
 }
 
-func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
-	len := id.Len()
-	if len == 0 {
-		return 0, nil
-	}
-	if len < 4 || len > 18 {
-		return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
-	}
-	return byte(len - 3), nil
+// ParseExtended parses the version dependent part of the header.
+// The Reader has to be set such that it points to the first byte of the header.
+func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
+	return h.toExtendedHeader().parse(b, ver)
+}
+
+func (h *Header) toExtendedHeader() *ExtendedHeader {
+	return &ExtendedHeader{Header: *h}
 }
 
 func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {

+ 29 - 58
vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go

@@ -162,75 +162,46 @@ func (h *packetHandlerMap) listen() {
 }
 
 func (h *packetHandlerMap) handlePacket(addr net.Addr, data []byte) error {
-	rcvTime := time.Now()
-
 	r := bytes.NewReader(data)
-	iHdr, err := wire.ParseInvariantHeader(r, h.connIDLen)
+	hdr, err := wire.ParseHeader(r, h.connIDLen)
 	// drop the packet if we can't parse the header
 	if err != nil {
-		return fmt.Errorf("error parsing invariant header: %s", err)
+		return fmt.Errorf("error parsing header: %s", err)
+	}
+
+	p := &receivedPacket{
+		remoteAddr: addr,
+		hdr:        hdr,
+		data:       data,
+		rcvTime:    time.Now(),
 	}
 
 	h.mutex.RLock()
-	handlerEntry, handlerFound := h.handlers[string(iHdr.DestConnectionID)]
-	server := h.server
+	defer h.mutex.RUnlock()
+
+	handlerEntry, handlerFound := h.handlers[string(hdr.DestConnectionID)]
 
-	var sentBy protocol.Perspective
-	var version protocol.VersionNumber
-	var handlePacket func(*receivedPacket)
 	if handlerFound { // existing session
-		handler := handlerEntry.handler
-		sentBy = handler.GetPerspective().Opposite()
-		version = handler.GetVersion()
-		handlePacket = handler.handlePacket
-	} else { // no session found
-		// this might be a stateless reset
-		if !iHdr.IsLongHeader {
-			if len(data) >= protocol.MinStatelessResetSize {
-				var token [16]byte
-				copy(token[:], data[len(data)-16:])
-				if sess, ok := h.resetTokens[token]; ok {
-					h.mutex.RUnlock()
-					sess.destroy(errors.New("received a stateless reset"))
-					return nil
-				}
+		handlerEntry.handler.handlePacket(p)
+		return nil
+	}
+	// No session found.
+	// This might be a stateless reset.
+	if !hdr.IsLongHeader {
+		if len(data) >= protocol.MinStatelessResetSize {
+			var token [16]byte
+			copy(token[:], data[len(data)-16:])
+			if sess, ok := h.resetTokens[token]; ok {
+				sess.destroy(errors.New("received a stateless reset"))
+				return nil
 			}
-			// TODO(#943): send a stateless reset
-			return fmt.Errorf("received a short header packet with an unexpected connection ID %s", iHdr.DestConnectionID)
-		}
-		if server == nil { // no server set
-			h.mutex.RUnlock()
-			return fmt.Errorf("received a packet with an unexpected connection ID %s", iHdr.DestConnectionID)
 		}
-		handlePacket = server.handlePacket
-		sentBy = protocol.PerspectiveClient
-		version = iHdr.Version
+		// TODO(#943): send a stateless reset
+		return fmt.Errorf("received a short header packet with an unexpected connection ID %s", hdr.DestConnectionID)
 	}
-	h.mutex.RUnlock()
-
-	hdr, err := iHdr.Parse(r, sentBy, version)
-	if err != nil {
-		return fmt.Errorf("error parsing header: %s", err)
-	}
-	hdr.Raw = data[:len(data)-r.Len()]
-	packetData := data[len(data)-r.Len():]
-
-	if hdr.IsLongHeader {
-		if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) {
-			return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen)
-		}
-		if protocol.ByteCount(len(packetData))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length {
-			return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(packetData)+int(hdr.PacketNumberLen), hdr.Length)
-		}
-		packetData = packetData[:int(hdr.Length)-int(hdr.PacketNumberLen)]
-		// TODO(#1312): implement parsing of compound packets
+	if h.server == nil { // no server set
+		return fmt.Errorf("received a packet with an unexpected connection ID %s", hdr.DestConnectionID)
 	}
-
-	handlePacket(&receivedPacket{
-		remoteAddr: addr,
-		header:     hdr,
-		data:       packetData,
-		rcvTime:    rcvTime,
-	})
+	h.server.handlePacket(p)
 	return nil
 }

+ 9 - 11
vendor/github.com/lucas-clemente/quic-go/packet_packer.go

@@ -25,7 +25,7 @@ type packer interface {
 }
 
 type packedPacket struct {
-	header          *wire.Header
+	header          *wire.ExtendedHeader
 	raw             []byte
 	frames          []wire.Frame
 	encryptionLevel protocol.EncryptionLevel
@@ -397,14 +397,13 @@ func (p *packetPacker) composeNextPacket(
 	return frames, nil
 }
 
-func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header {
+func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
 	pn, pnLen := p.pnManager.PeekPacketNumber()
-	header := &wire.Header{
-		PacketNumber:     pn,
-		PacketNumberLen:  pnLen,
-		Version:          p.version,
-		DestConnectionID: p.destConnID,
-	}
+	header := &wire.ExtendedHeader{}
+	header.PacketNumber = pn
+	header.PacketNumberLen = pnLen
+	header.Version = p.version
+	header.DestConnectionID = p.destConnID
 
 	if encLevel != protocol.Encryption1RTT {
 		header.IsLongHeader = true
@@ -424,8 +423,7 @@ func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.Header
 }
 
 func (p *packetPacker) writeAndSealPacket(
-	header *wire.Header,
-	frames []wire.Frame,
+	header *wire.ExtendedHeader, frames []wire.Frame,
 	sealer handshake.Sealer,
 ) ([]byte, error) {
 	raw := *getPacketBuffer()
@@ -450,7 +448,7 @@ func (p *packetPacker) writeAndSealPacket(
 		}
 	}
 
-	if err := header.Write(buffer, p.perspective, p.version); err != nil {
+	if err := header.Write(buffer, p.version); err != nil {
 		return nil, err
 	}
 	payloadStartIndex := buffer.Len()

+ 1 - 1
vendor/github.com/lucas-clemente/quic-go/packet_unpacker.go

@@ -35,7 +35,7 @@ func newPacketUnpacker(aead quicAEAD, version protocol.VersionNumber) unpacker {
 	}
 }
 
-func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error) {
+func (u *packetUnpacker) Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error) {
 	buf := *getPacketBuffer()
 	buf = buf[:0]
 	defer putPacketBuffer(&buf)

+ 30 - 32
vendor/github.com/lucas-clemente/quic-go/server.go

@@ -21,7 +21,6 @@ type packetHandler interface {
 	handlePacket(*receivedPacket)
 	io.Closer
 	destroy(error)
-	GetVersion() protocol.VersionNumber
 	GetPerspective() protocol.Perspective
 }
 
@@ -99,7 +98,8 @@ var _ Listener = &server{}
 var _ unknownPacketHandler = &server{}
 
 // ListenAddr creates a QUIC server listening on a given address.
-// The tls.Config must not be nil, the quic.Config may be nil.
+// The tls.Config must not be nil and must contain a certificate configuration.
+// The quic.Config may be nil, in that case the default values will be used.
 func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, error) {
 	udpAddr, err := net.ResolveUDPAddr("udp", addr)
 	if err != nil {
@@ -118,7 +118,11 @@ func ListenAddr(addr string, tlsConf *tls.Config, config *Config) (Listener, err
 }
 
 // Listen listens for QUIC connections on a given net.PacketConn.
-// The tls.Config must not be nil, the quic.Config may be nil.
+// A single PacketConn only be used for a single call to Listen.
+// The PacketConn can be used for simultaneous calls to Dial.
+// QUIC connection IDs are used for demultiplexing the different connections.
+// The tls.Config must not be nil and must contain a certificate configuration.
+// The quic.Config may be nil, in that case the default values will be used.
 func Listen(conn net.PacketConn, tlsConf *tls.Config, config *Config) (Listener, error) {
 	return listen(conn, tlsConf, config)
 }
@@ -300,23 +304,17 @@ func (s *server) Addr() net.Addr {
 }
 
 func (s *server) handlePacket(p *receivedPacket) {
-	if err := s.handlePacketImpl(p); err != nil {
-		s.logger.Debugf("error handling packet from %s: %s", p.remoteAddr, err)
-	}
-}
-
-func (s *server) handlePacketImpl(p *receivedPacket) error {
-	hdr := p.header
+	hdr := p.hdr
 
 	// send a Version Negotiation Packet if the client is speaking a different protocol version
 	if !protocol.IsSupportedVersion(s.config.Versions, hdr.Version) {
-		return s.sendVersionNegotiationPacket(p)
+		go s.sendVersionNegotiationPacket(p)
+		return
 	}
 	if hdr.Type == protocol.PacketTypeInitial {
 		go s.handleInitial(p)
 	}
 	// TODO(#943): send Stateless Reset
-	return nil
 }
 
 func (s *server) handleInitial(p *receivedPacket) {
@@ -335,11 +333,11 @@ func (s *server) handleInitial(p *receivedPacket) {
 }
 
 func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.ConnectionID, error) {
-	hdr := p.header
+	hdr := p.hdr
 	if len(hdr.Token) == 0 && hdr.DestConnectionID.Len() < protocol.MinConnectionIDLenInitial {
 		return nil, nil, errors.New("dropping Initial packet with too short connection ID")
 	}
-	if len(hdr.Raw)+len(p.data) < protocol.MinInitialPacketSize {
+	if len(p.data) < protocol.MinInitialPacketSize {
 		return nil, nil, errors.New("dropping too small Initial packet")
 	}
 
@@ -358,7 +356,7 @@ func (s *server) handleInitialImpl(p *receivedPacket) (quicSession, protocol.Con
 	if !s.config.AcceptCookie(p.remoteAddr, cookie) {
 		// Log the Initial packet now.
 		// If no Retry is sent, the packet will be logged by the session.
-		p.header.Log(s.logger)
+		(&wire.ExtendedHeader{Header: *p.hdr}).Log(s.logger)
 		return nil, nil, s.sendRetry(p.remoteAddr, hdr)
 	}
 
@@ -431,19 +429,18 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
 	if err != nil {
 		return err
 	}
-	replyHdr := &wire.Header{
-		IsLongHeader:         true,
-		Type:                 protocol.PacketTypeRetry,
-		Version:              hdr.Version,
-		SrcConnectionID:      connID,
-		DestConnectionID:     hdr.SrcConnectionID,
-		OrigDestConnectionID: hdr.DestConnectionID,
-		Token:                token,
-	}
+	replyHdr := &wire.ExtendedHeader{}
+	replyHdr.IsLongHeader = true
+	replyHdr.Type = protocol.PacketTypeRetry
+	replyHdr.Version = hdr.Version
+	replyHdr.SrcConnectionID = connID
+	replyHdr.DestConnectionID = hdr.SrcConnectionID
+	replyHdr.OrigDestConnectionID = hdr.DestConnectionID
+	replyHdr.Token = token
 	s.logger.Debugf("Changing connection ID to %s.\n-> Sending Retry", connID)
 	replyHdr.Log(s.logger)
 	buf := &bytes.Buffer{}
-	if err := replyHdr.Write(buf, protocol.PerspectiveServer, hdr.Version); err != nil {
+	if err := replyHdr.Write(buf, hdr.Version); err != nil {
 		return err
 	}
 	if _, err := s.conn.WriteTo(buf.Bytes(), remoteAddr); err != nil {
@@ -452,14 +449,15 @@ func (s *server) sendRetry(remoteAddr net.Addr, hdr *wire.Header) error {
 	return nil
 }
 
-func (s *server) sendVersionNegotiationPacket(p *receivedPacket) error {
-	hdr := p.header
-	s.logger.Debugf("Client offered version %s, sending VersionNegotiationPacket", hdr.Version)
-
+func (s *server) sendVersionNegotiationPacket(p *receivedPacket) {
+	hdr := p.hdr
+	s.logger.Debugf("Client offered version %s, sending Version Negotiation", hdr.Version)
 	data, err := wire.ComposeVersionNegotiation(hdr.SrcConnectionID, hdr.DestConnectionID, s.config.Versions)
 	if err != nil {
-		return err
+		s.logger.Debugf("Error composing Version Negotiation: %s", err)
+		return
+	}
+	if _, err := s.conn.WriteTo(data, p.remoteAddr); err != nil {
+		s.logger.Debugf("Error sending Version Negotiation: %s", err)
 	}
-	_, err = s.conn.WriteTo(data, p.remoteAddr)
-	return err
 }

+ 1 - 1
vendor/github.com/lucas-clemente/quic-go/server_session.go

@@ -32,7 +32,7 @@ func (s *serverSession) handlePacket(p *receivedPacket) {
 }
 
 func (s *serverSession) handlePacketImpl(p *receivedPacket) error {
-	hdr := p.header
+	hdr := p.hdr
 
 	// Probably an old packet that was sent by the client before the version was negotiated.
 	// It is safe to drop it.

+ 36 - 20
vendor/github.com/lucas-clemente/quic-go/session.go

@@ -1,6 +1,7 @@
 package quic
 
 import (
+	"bytes"
 	"context"
 	"crypto/tls"
 	"errors"
@@ -21,7 +22,7 @@ import (
 )
 
 type unpacker interface {
-	Unpack(headerBinary []byte, hdr *wire.Header, data []byte) (*unpackedPacket, error)
+	Unpack(headerBinary []byte, hdr *wire.ExtendedHeader, data []byte) (*unpackedPacket, error)
 }
 
 type streamGetter interface {
@@ -52,7 +53,7 @@ type cryptoStreamHandler interface {
 
 type receivedPacket struct {
 	remoteAddr net.Addr
-	header     *wire.Header
+	hdr        *wire.Header
 	data       []byte
 	rcvTime    time.Time
 }
@@ -113,7 +114,6 @@ type session struct {
 
 	receivedFirstPacket              bool // since packet numbers start at 0, we can't use largestRcvdPacketNumber != 0 for this
 	receivedFirstForwardSecurePacket bool
-	lastRcvdPacketNumber             protocol.PacketNumber
 	// Used to calculate the next packet number from the truncated wire
 	// representation, and sent back in public reset packets
 	largestRcvdPacketNumber protocol.PacketNumber
@@ -289,7 +289,7 @@ var newClientSession = func(
 
 func (s *session) preSetup() {
 	s.rttStats = &congestion.RTTStats{}
-	s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger, s.version)
+	s.sentPacketHandler = ackhandler.NewSentPacketHandler(s.rttStats, s.logger)
 	s.receivedPacketHandler = ackhandler.NewReceivedPacketHandler(s.rttStats, s.logger, s.version)
 	s.connFlowController = flowcontrol.NewConnectionFlowController(
 		protocol.InitialMaxData,
@@ -374,7 +374,7 @@ runLoop:
 			}
 			// This is a bit unclean, but works properly, since the packet always
 			// begins with the public header and we never copy it.
-			putPacketBuffer(&p.header.Raw)
+			// TODO: putPacketBuffer(&p.extHdr.Raw)
 		case <-s.handshakeCompleteChan:
 			s.handleHandshakeComplete()
 		}
@@ -479,24 +479,41 @@ func (s *session) handleHandshakeComplete() {
 }
 
 func (s *session) handlePacketImpl(p *receivedPacket) error {
-	hdr := p.header
 	// The server can change the source connection ID with the first Handshake packet.
 	// After this, all packets with a different source connection have to be ignored.
-	if s.receivedFirstPacket && hdr.IsLongHeader && !hdr.SrcConnectionID.Equal(s.destConnID) {
-		s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.header.SrcConnectionID, s.destConnID)
+	if s.receivedFirstPacket && p.hdr.IsLongHeader && !p.hdr.SrcConnectionID.Equal(s.destConnID) {
+		s.logger.Debugf("Dropping packet with unexpected source connection ID: %s (expected %s)", p.hdr.SrcConnectionID, s.destConnID)
 		return nil
 	}
 
-	p.rcvTime = time.Now()
+	data := p.data
+	r := bytes.NewReader(data)
+	hdr, err := p.hdr.ParseExtended(r, s.version)
+	if err != nil {
+		return fmt.Errorf("error parsing extended header: %s", err)
+	}
+	hdr.Raw = data[:len(data)-r.Len()]
+	data = data[len(data)-r.Len():]
+
+	if hdr.IsLongHeader {
+		if hdr.Length < protocol.ByteCount(hdr.PacketNumberLen) {
+			return fmt.Errorf("packet length (%d bytes) shorter than packet number (%d bytes)", hdr.Length, hdr.PacketNumberLen)
+		}
+		if protocol.ByteCount(len(data))+protocol.ByteCount(hdr.PacketNumberLen) < hdr.Length {
+			return fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)+int(hdr.PacketNumberLen), hdr.Length)
+		}
+		data = data[:int(hdr.Length)-int(hdr.PacketNumberLen)]
+		// TODO(#1312): implement parsing of compound packets
+	}
+
 	// Calculate packet number
 	hdr.PacketNumber = protocol.InferPacketNumber(
 		hdr.PacketNumberLen,
 		s.largestRcvdPacketNumber,
 		hdr.PacketNumber,
-		s.version,
 	)
 
-	packet, err := s.unpacker.Unpack(hdr.Raw, hdr, p.data)
+	packet, err := s.unpacker.Unpack(hdr.Raw, hdr, data)
 	if s.logger.Debug() {
 		if err != nil {
 			s.logger.Debugf("<- Reading packet 0x%x (%d bytes) for connection %s", hdr.PacketNumber, len(p.data)+len(hdr.Raw), hdr.DestConnectionID)
@@ -530,7 +547,6 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
 		}
 	}
 
-	s.lastRcvdPacketNumber = hdr.PacketNumber
 	// Only do this after decrypting, so we are sure the packet is not attacker-controlled
 	s.largestRcvdPacketNumber = utils.MaxPacketNumber(s.largestRcvdPacketNumber, hdr.PacketNumber)
 
@@ -543,10 +559,10 @@ func (s *session) handlePacketImpl(p *receivedPacket) error {
 		}
 	}
 
-	return s.handleFrames(packet.frames, packet.encryptionLevel)
+	return s.handleFrames(packet.frames, hdr.PacketNumber, packet.encryptionLevel)
 }
 
-func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLevel) error {
+func (s *session) handleFrames(fs []wire.Frame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {
 	for _, ff := range fs {
 		var err error
 		wire.LogFrame(s.logger, ff, false)
@@ -556,7 +572,7 @@ func (s *session) handleFrames(fs []wire.Frame, encLevel protocol.EncryptionLeve
 		case *wire.StreamFrame:
 			err = s.handleStreamFrame(frame, encLevel)
 		case *wire.AckFrame:
-			err = s.handleAckFrame(frame, encLevel)
+			err = s.handleAckFrame(frame, pn, encLevel)
 		case *wire.ConnectionCloseFrame:
 			s.closeRemote(qerr.Error(frame.ErrorCode, frame.ReasonPhrase))
 		case *wire.ResetStreamFrame:
@@ -702,8 +718,8 @@ func (s *session) handlePathChallengeFrame(frame *wire.PathChallengeFrame) {
 	s.queueControlFrame(&wire.PathResponseFrame{Data: frame.Data})
 }
 
-func (s *session) handleAckFrame(frame *wire.AckFrame, encLevel protocol.EncryptionLevel) error {
-	if err := s.sentPacketHandler.ReceivedAck(frame, s.lastRcvdPacketNumber, encLevel, s.lastNetworkActivityTime); err != nil {
+func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber, encLevel protocol.EncryptionLevel) error {
+	if err := s.sentPacketHandler.ReceivedAck(frame, pn, encLevel, s.lastNetworkActivityTime); err != nil {
 		return err
 	}
 	s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
@@ -1065,14 +1081,14 @@ func (s *session) scheduleSending() {
 
 func (s *session) tryQueueingUndecryptablePacket(p *receivedPacket) {
 	if s.handshakeComplete {
-		s.logger.Debugf("Received undecryptable packet from %s after the handshake: %#v, %d bytes data", p.remoteAddr.String(), p.header, len(p.data))
+		s.logger.Debugf("Received undecryptable packet from %s after the handshake (%d bytes)", p.remoteAddr.String(), len(p.data))
 		return
 	}
 	if len(s.undecryptablePackets)+1 > protocol.MaxUndecryptablePackets {
-		s.logger.Infof("Dropping undecrytable packet 0x%x (undecryptable packet queue full)", p.header.PacketNumber)
+		s.logger.Infof("Dropping undecrytable packet (%d bytes). Undecryptable packet queue full.", len(p.data))
 		return
 	}
-	s.logger.Infof("Queueing packet 0x%x for later decryption", p.header.PacketNumber)
+	s.logger.Infof("Queueing packet (%d bytes) for later decryption", len(p.data))
 	s.undecryptablePackets = append(s.undecryptablePackets, p)
 }