| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333 |
- package wire
- import (
- "bytes"
- "crypto/rand"
- "errors"
- "fmt"
- "github.com/lucas-clemente/quic-go/internal/protocol"
- "github.com/lucas-clemente/quic-go/internal/utils"
- )
- // Header is the header of a QUIC packet.
- // It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header.
- type Header struct {
- IsPublicHeader bool
- Raw []byte
- Version protocol.VersionNumber
- DestConnectionID protocol.ConnectionID
- SrcConnectionID protocol.ConnectionID
- OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet
- PacketNumberLen protocol.PacketNumberLen
- PacketNumber protocol.PacketNumber
- IsVersionNegotiation bool
- SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
- // only needed for the gQUIC Public Header
- VersionFlag bool
- ResetFlag bool
- DiversificationNonce []byte
- // only needed for the IETF Header
- Type protocol.PacketType
- IsLongHeader bool
- KeyPhase int
- PayloadLen protocol.ByteCount
- Token []byte
- }
- var errInvalidPacketNumberLen = errors.New("invalid packet number length")
- // Write writes the Header.
- func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
- if !ver.UsesIETFHeaderFormat() {
- h.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
- return h.writePublicHeader(b, pers, ver)
- }
- // write an IETF QUIC header
- if h.IsLongHeader {
- return h.writeLongHeader(b, ver)
- }
- return h.writeShortHeader(b, ver)
- }
- // TODO: add support for the key phase
- func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
- b.WriteByte(byte(0x80 | h.Type))
- utils.BigEndian.WriteUint32(b, uint32(h.Version))
- connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
- if err != nil {
- return err
- }
- b.WriteByte(connIDLen)
- b.Write(h.DestConnectionID.Bytes())
- b.Write(h.SrcConnectionID.Bytes())
- if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
- utils.WriteVarInt(b, uint64(len(h.Token)))
- b.Write(h.Token)
- }
- if h.Type == protocol.PacketTypeRetry {
- odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
- if err != nil {
- return err
- }
- // randomize the first 4 bits
- odcilByte := make([]byte, 1)
- _, _ = rand.Read(odcilByte) // it's safe to ignore the error here
- odcilByte[0] = (odcilByte[0] & 0xf0) | odcil
- b.Write(odcilByte)
- b.Write(h.OrigDestConnectionID.Bytes())
- b.Write(h.Token)
- return nil
- }
- if v.UsesLengthInHeader() {
- utils.WriteVarInt(b, uint64(h.PayloadLen))
- }
- if v.UsesVarintPacketNumbers() {
- return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
- }
- utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
- if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
- if len(h.DiversificationNonce) != 32 {
- return errors.New("invalid diversification nonce length")
- }
- b.Write(h.DiversificationNonce)
- }
- return nil
- }
- func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
- typeByte := byte(0x30)
- typeByte |= byte(h.KeyPhase << 6)
- if !v.UsesVarintPacketNumbers() {
- switch h.PacketNumberLen {
- case protocol.PacketNumberLen1:
- case protocol.PacketNumberLen2:
- typeByte |= 0x1
- case protocol.PacketNumberLen4:
- typeByte |= 0x2
- default:
- return errInvalidPacketNumberLen
- }
- }
- b.WriteByte(typeByte)
- b.Write(h.DestConnectionID.Bytes())
- if !v.UsesVarintPacketNumbers() {
- switch h.PacketNumberLen {
- case protocol.PacketNumberLen1:
- b.WriteByte(uint8(h.PacketNumber))
- case protocol.PacketNumberLen2:
- utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
- case protocol.PacketNumberLen4:
- utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
- }
- return nil
- }
- return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
- }
- // writePublicHeader writes a Public Header.
- func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error {
- if h.ResetFlag || (h.VersionFlag && pers == protocol.PerspectiveServer) {
- return errors.New("PublicHeader: Can only write regular packets")
- }
- if h.SrcConnectionID.Len() != 0 {
- return errors.New("PublicHeader: SrcConnectionID must not be set")
- }
- if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 {
- return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID))
- }
- publicFlagByte := uint8(0x00)
- if h.VersionFlag {
- publicFlagByte |= 0x01
- }
- if h.DestConnectionID.Len() > 0 {
- publicFlagByte |= 0x08
- }
- if len(h.DiversificationNonce) > 0 {
- if len(h.DiversificationNonce) != 32 {
- return errors.New("invalid diversification nonce length")
- }
- publicFlagByte |= 0x04
- }
- switch h.PacketNumberLen {
- case protocol.PacketNumberLen1:
- publicFlagByte |= 0x00
- case protocol.PacketNumberLen2:
- publicFlagByte |= 0x10
- case protocol.PacketNumberLen4:
- publicFlagByte |= 0x20
- }
- b.WriteByte(publicFlagByte)
- if h.DestConnectionID.Len() > 0 {
- b.Write(h.DestConnectionID)
- }
- if h.VersionFlag && pers == protocol.PerspectiveClient {
- utils.BigEndian.WriteUint32(b, uint32(h.Version))
- }
- if len(h.DiversificationNonce) > 0 {
- b.Write(h.DiversificationNonce)
- }
- switch h.PacketNumberLen {
- case protocol.PacketNumberLen1:
- b.WriteByte(uint8(h.PacketNumber))
- case protocol.PacketNumberLen2:
- utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
- case protocol.PacketNumberLen4:
- utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
- case protocol.PacketNumberLen6:
- return errInvalidPacketNumberLen
- default:
- return errors.New("PublicHeader: PacketNumberLen not set")
- }
- return nil
- }
- // GetLength determines the length of the Header.
- func (h *Header) GetLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
- if !v.UsesIETFHeaderFormat() {
- return h.getPublicHeaderLength()
- }
- return h.getHeaderLength(v)
- }
- func (h *Header) getHeaderLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
- if h.IsLongHeader {
- length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen)
- if v.UsesLengthInHeader() {
- length += utils.VarIntLen(uint64(h.PayloadLen))
- }
- if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
- length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
- }
- if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
- length += protocol.ByteCount(len(h.DiversificationNonce))
- }
- return length, nil
- }
- length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
- if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
- return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
- }
- length += protocol.ByteCount(h.PacketNumberLen)
- return length, nil
- }
- // getPublicHeaderLength gets the length of the publicHeader in bytes.
- // It can only be called for regular packets.
- func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) {
- length := protocol.ByteCount(1) // 1 byte for public flags
- if h.PacketNumberLen == protocol.PacketNumberLen6 {
- return 0, errInvalidPacketNumberLen
- }
- if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
- return 0, errPacketNumberLenNotSet
- }
- length += protocol.ByteCount(h.PacketNumberLen)
- length += protocol.ByteCount(h.DestConnectionID.Len())
- // Version Number in packets sent by the client
- if h.VersionFlag {
- length += 4
- }
- length += protocol.ByteCount(len(h.DiversificationNonce))
- return length, nil
- }
- // Log logs the Header
- func (h *Header) Log(logger utils.Logger) {
- if h.IsPublicHeader {
- h.logPublicHeader(logger)
- } else {
- h.logHeader(logger)
- }
- }
- func (h *Header) logHeader(logger utils.Logger) {
- if h.IsLongHeader {
- if h.Version == 0 {
- logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
- } else {
- var token string
- if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
- if len(h.Token) == 0 {
- token = "Token: (empty), "
- } else {
- token = fmt.Sprintf("Token: %#x, ", h.Token)
- }
- }
- if h.Type == protocol.PacketTypeRetry {
- logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
- return
- }
- if h.Version == protocol.Version44 {
- var divNonce string
- if h.Type == protocol.PacketType0RTT {
- divNonce = fmt.Sprintf("Diversification Nonce: %#x, ", h.DiversificationNonce)
- }
- logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, divNonce, h.Version)
- return
- }
- logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
- }
- } else {
- logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
- }
- }
- func (h *Header) logPublicHeader(logger utils.Logger) {
- ver := "(unset)"
- if h.Version != 0 {
- ver = h.Version.String()
- }
- logger.Debugf("\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
- }
- func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
- dcil, err := encodeSingleConnIDLen(dest)
- if err != nil {
- return 0, err
- }
- scil, err := encodeSingleConnIDLen(src)
- if err != nil {
- return 0, err
- }
- return scil | dcil<<4, nil
- }
- func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
- len := id.Len()
- if len == 0 {
- return 0, nil
- }
- if len < 4 || len > 18 {
- return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
- }
- return byte(len - 3), nil
- }
- func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
- return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf)
- }
- func decodeSingleConnIDLen(enc uint8) int {
- if enc == 0 {
- return 0
- }
- return int(enc) + 3
- }
|