|
@@ -2,150 +2,183 @@ package wire
|
|
|
|
|
|
|
|
import (
|
|
import (
|
|
|
"bytes"
|
|
"bytes"
|
|
|
- "crypto/rand"
|
|
|
|
|
- "fmt"
|
|
|
|
|
|
|
+ "errors"
|
|
|
|
|
+ "io"
|
|
|
|
|
|
|
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
"github.com/lucas-clemente/quic-go/internal/protocol"
|
|
|
|
|
+ "github.com/lucas-clemente/quic-go/internal/qerr"
|
|
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
"github.com/lucas-clemente/quic-go/internal/utils"
|
|
|
)
|
|
)
|
|
|
|
|
|
|
|
-// Header is the header of a QUIC packet.
|
|
|
|
|
|
|
+// The Header is the version independent part of the header
|
|
|
type Header struct {
|
|
type Header struct {
|
|
|
- Raw []byte
|
|
|
|
|
|
|
+ Version protocol.VersionNumber
|
|
|
|
|
+ SrcConnectionID protocol.ConnectionID
|
|
|
|
|
+ DestConnectionID protocol.ConnectionID
|
|
|
|
|
|
|
|
- Version protocol.VersionNumber
|
|
|
|
|
-
|
|
|
|
|
- DestConnectionID protocol.ConnectionID
|
|
|
|
|
- SrcConnectionID protocol.ConnectionID
|
|
|
|
|
- OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet
|
|
|
|
|
|
|
+ IsLongHeader bool
|
|
|
|
|
+ Type protocol.PacketType
|
|
|
|
|
+ Length protocol.ByteCount
|
|
|
|
|
|
|
|
- PacketNumberLen protocol.PacketNumberLen
|
|
|
|
|
- PacketNumber protocol.PacketNumber
|
|
|
|
|
|
|
+ Token []byte
|
|
|
|
|
+ SupportedVersions []protocol.VersionNumber // sent in a Version Negotiation Packet
|
|
|
|
|
+ OrigDestConnectionID protocol.ConnectionID // sent in the Retry packet
|
|
|
|
|
|
|
|
- IsVersionNegotiation bool
|
|
|
|
|
- SupportedVersions []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
|
|
|
|
|
|
|
+ typeByte byte
|
|
|
|
|
+ len int // how many bytes were read while parsing this header
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
- Type protocol.PacketType
|
|
|
|
|
- IsLongHeader bool
|
|
|
|
|
- KeyPhase int
|
|
|
|
|
- Length protocol.ByteCount
|
|
|
|
|
- Token []byte
|
|
|
|
|
|
|
+// ParseHeader parses the header.
|
|
|
|
|
+// For short header packets: up to the packet number.
|
|
|
|
|
+// For long header packets:
|
|
|
|
|
+// * if we understand the version: up to the packet number
|
|
|
|
|
+// * if not, only the invariant part of the header
|
|
|
|
|
+func ParseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
|
|
|
|
|
+ startLen := b.Len()
|
|
|
|
|
+ h, err := parseHeaderImpl(b, shortHeaderConnIDLen)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ h.len = startLen - b.Len()
|
|
|
|
|
+ return h, nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// Write writes the Header.
|
|
|
|
|
-func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
|
|
|
|
|
- if h.IsLongHeader {
|
|
|
|
|
- return h.writeLongHeader(b, ver)
|
|
|
|
|
|
|
+func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
|
|
|
|
|
+ typeByte, err := b.ReadByte()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ h := &Header{
|
|
|
|
|
+ typeByte: typeByte,
|
|
|
|
|
+ IsLongHeader: typeByte&0x80 > 0,
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ if !h.IsLongHeader {
|
|
|
|
|
+ if h.typeByte&0x40 == 0 {
|
|
|
|
|
+ return nil, errors.New("not a QUIC packet")
|
|
|
|
|
+ }
|
|
|
|
|
+ if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ return h, nil
|
|
|
}
|
|
}
|
|
|
- return h.writeShortHeader(b, ver)
|
|
|
|
|
|
|
+ if err := h.parseLongHeader(b); err != nil {
|
|
|
|
|
+ return nil, err
|
|
|
|
|
+ }
|
|
|
|
|
+ return h, nil
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error {
|
|
|
|
|
+ var err error
|
|
|
|
|
+ h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
|
|
|
|
|
+ return err
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// TODO: add support for the key phase
|
|
|
|
|
-func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
|
|
|
|
|
- b.WriteByte(byte(0x80 | h.Type))
|
|
|
|
|
- utils.BigEndian.WriteUint32(b, uint32(h.Version))
|
|
|
|
|
- connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
|
|
|
|
|
|
|
+func (h *Header) parseLongHeader(b *bytes.Reader) error {
|
|
|
|
|
+ v, err := utils.BigEndian.ReadUint32(b)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+ h.Version = protocol.VersionNumber(v)
|
|
|
|
|
+ if !h.IsVersionNegotiation() && h.typeByte&0x40 == 0 {
|
|
|
|
|
+ return errors.New("not a QUIC packet")
|
|
|
|
|
+ }
|
|
|
|
|
+ connIDLenByte, err := b.ReadByte()
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+ dcil, scil := decodeConnIDLen(connIDLenByte)
|
|
|
|
|
+ h.DestConnectionID, err = protocol.ReadConnectionID(b, dcil)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+ h.SrcConnectionID, err = protocol.ReadConnectionID(b, scil)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
|
- b.WriteByte(connIDLen)
|
|
|
|
|
- b.Write(h.DestConnectionID.Bytes())
|
|
|
|
|
- b.Write(h.SrcConnectionID.Bytes())
|
|
|
|
|
|
|
+ if h.Version == 0 {
|
|
|
|
|
+ return h.parseVersionNegotiationPacket(b)
|
|
|
|
|
+ }
|
|
|
|
|
+ // If we don't understand the version, we have no idea how to interpret the rest of the bytes
|
|
|
|
|
+ if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
|
|
|
|
|
+ return nil
|
|
|
|
|
+ }
|
|
|
|
|
|
|
|
- if h.Type == protocol.PacketTypeInitial {
|
|
|
|
|
- utils.WriteVarInt(b, uint64(len(h.Token)))
|
|
|
|
|
- b.Write(h.Token)
|
|
|
|
|
|
|
+ switch (h.typeByte & 0x30) >> 4 {
|
|
|
|
|
+ case 0x0:
|
|
|
|
|
+ h.Type = protocol.PacketTypeInitial
|
|
|
|
|
+ case 0x1:
|
|
|
|
|
+ h.Type = protocol.PacketType0RTT
|
|
|
|
|
+ case 0x2:
|
|
|
|
|
+ h.Type = protocol.PacketTypeHandshake
|
|
|
|
|
+ case 0x3:
|
|
|
|
|
+ h.Type = protocol.PacketTypeRetry
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
if h.Type == protocol.PacketTypeRetry {
|
|
if h.Type == protocol.PacketTypeRetry {
|
|
|
- odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
|
|
|
|
|
|
|
+ odcil := decodeSingleConnIDLen(h.typeByte & 0xf)
|
|
|
|
|
+ h.OrigDestConnectionID, err = protocol.ReadConnectionID(b, odcil)
|
|
|
if err != nil {
|
|
if err != nil {
|
|
|
return err
|
|
return err
|
|
|
}
|
|
}
|
|
|
- // randomize the first 4 bits
|
|
|
|
|
- odcilByte := make([]byte, 1)
|
|
|
|
|
- _, _ = rand.Read(odcilByte) // it's safe to ignore the error here
|
|
|
|
|
- odcilByte[0] = (odcilByte[0] & 0xf0) | odcil
|
|
|
|
|
- b.Write(odcilByte)
|
|
|
|
|
- b.Write(h.OrigDestConnectionID.Bytes())
|
|
|
|
|
- b.Write(h.Token)
|
|
|
|
|
|
|
+ h.Token = make([]byte, b.Len())
|
|
|
|
|
+ if _, err := io.ReadFull(b, h.Token); err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
return nil
|
|
return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- utils.WriteVarInt(b, uint64(h.Length))
|
|
|
|
|
- return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
|
|
|
|
|
- typeByte := byte(0x30)
|
|
|
|
|
- typeByte |= byte(h.KeyPhase << 6)
|
|
|
|
|
-
|
|
|
|
|
- b.WriteByte(typeByte)
|
|
|
|
|
- b.Write(h.DestConnectionID.Bytes())
|
|
|
|
|
- return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
|
|
|
|
|
-}
|
|
|
|
|
-
|
|
|
|
|
-// GetLength determines the length of the Header.
|
|
|
|
|
-func (h *Header) GetLength(v protocol.VersionNumber) protocol.ByteCount {
|
|
|
|
|
- if h.IsLongHeader {
|
|
|
|
|
- length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + utils.VarIntLen(uint64(h.Length))
|
|
|
|
|
- if h.Type == protocol.PacketTypeInitial {
|
|
|
|
|
- length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
|
|
|
|
|
|
|
+ if h.Type == protocol.PacketTypeInitial {
|
|
|
|
|
+ tokenLen, err := utils.ReadVarInt(b)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+ if tokenLen > uint64(b.Len()) {
|
|
|
|
|
+ return io.EOF
|
|
|
|
|
+ }
|
|
|
|
|
+ h.Token = make([]byte, tokenLen)
|
|
|
|
|
+ if _, err := io.ReadFull(b, h.Token); err != nil {
|
|
|
|
|
+ return err
|
|
|
}
|
|
}
|
|
|
- return length
|
|
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
- length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
|
|
|
|
|
- length += protocol.ByteCount(h.PacketNumberLen)
|
|
|
|
|
- return length
|
|
|
|
|
|
|
+ pl, err := utils.ReadVarInt(b)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return err
|
|
|
|
|
+ }
|
|
|
|
|
+ h.Length = protocol.ByteCount(pl)
|
|
|
|
|
+ return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-// Log logs the Header
|
|
|
|
|
-func (h *Header) Log(logger utils.Logger) {
|
|
|
|
|
- if h.IsLongHeader {
|
|
|
|
|
- if h.Version == 0 {
|
|
|
|
|
- logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
|
|
|
|
|
- } else {
|
|
|
|
|
- var token string
|
|
|
|
|
- if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
|
|
|
|
|
- if len(h.Token) == 0 {
|
|
|
|
|
- token = "Token: (empty), "
|
|
|
|
|
- } else {
|
|
|
|
|
- token = fmt.Sprintf("Token: %#x, ", h.Token)
|
|
|
|
|
- }
|
|
|
|
|
- }
|
|
|
|
|
- if h.Type == protocol.PacketTypeRetry {
|
|
|
|
|
- logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
|
|
|
|
|
- return
|
|
|
|
|
- }
|
|
|
|
|
- logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
|
|
|
|
|
|
|
+func (h *Header) parseVersionNegotiationPacket(b *bytes.Reader) error {
|
|
|
|
|
+ if b.Len() == 0 {
|
|
|
|
|
+ return qerr.Error(qerr.InvalidVersionNegotiationPacket, "empty version list")
|
|
|
|
|
+ }
|
|
|
|
|
+ h.SupportedVersions = make([]protocol.VersionNumber, b.Len()/4)
|
|
|
|
|
+ for i := 0; b.Len() > 0; i++ {
|
|
|
|
|
+ v, err := utils.BigEndian.ReadUint32(b)
|
|
|
|
|
+ if err != nil {
|
|
|
|
|
+ return qerr.InvalidVersionNegotiationPacket
|
|
|
}
|
|
}
|
|
|
- } else {
|
|
|
|
|
- logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
|
|
|
|
|
|
|
+ h.SupportedVersions[i] = protocol.VersionNumber(v)
|
|
|
}
|
|
}
|
|
|
|
|
+ return nil
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
|
|
|
|
|
- dcil, err := encodeSingleConnIDLen(dest)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return 0, err
|
|
|
|
|
- }
|
|
|
|
|
- scil, err := encodeSingleConnIDLen(src)
|
|
|
|
|
- if err != nil {
|
|
|
|
|
- return 0, err
|
|
|
|
|
- }
|
|
|
|
|
- return scil | dcil<<4, nil
|
|
|
|
|
|
|
+// IsVersionNegotiation says if this a version negotiation packet
|
|
|
|
|
+func (h *Header) IsVersionNegotiation() bool {
|
|
|
|
|
+ return h.IsLongHeader && h.Version == 0
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
-func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
|
|
|
|
|
- len := id.Len()
|
|
|
|
|
- if len == 0 {
|
|
|
|
|
- return 0, nil
|
|
|
|
|
- }
|
|
|
|
|
- if len < 4 || len > 18 {
|
|
|
|
|
- return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
|
|
|
|
|
- }
|
|
|
|
|
- return byte(len - 3), nil
|
|
|
|
|
|
|
+// ParseExtended parses the version dependent part of the header.
|
|
|
|
|
+// The Reader has to be set such that it points to the first byte of the header.
|
|
|
|
|
+func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
|
|
|
|
|
+ return h.toExtendedHeader().parse(b, ver)
|
|
|
|
|
+}
|
|
|
|
|
+
|
|
|
|
|
+func (h *Header) toExtendedHeader() *ExtendedHeader {
|
|
|
|
|
+ return &ExtendedHeader{Header: *h}
|
|
|
}
|
|
}
|
|
|
|
|
|
|
|
func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
|
|
func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
|