packet_packer.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. package quic
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "time"
  8. "github.com/lucas-clemente/quic-go/internal/ackhandler"
  9. "github.com/lucas-clemente/quic-go/internal/handshake"
  10. "github.com/lucas-clemente/quic-go/internal/protocol"
  11. "github.com/lucas-clemente/quic-go/internal/utils"
  12. "github.com/lucas-clemente/quic-go/internal/wire"
  13. )
  14. type packer interface {
  15. PackPacket() (*packedPacket, error)
  16. MaybePackAckPacket() (*packedPacket, error)
  17. PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error)
  18. PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error)
  19. HandleTransportParameters(*handshake.TransportParameters)
  20. ChangeDestConnectionID(protocol.ConnectionID)
  21. }
  22. type packedPacket struct {
  23. header *wire.ExtendedHeader
  24. raw []byte
  25. frames []wire.Frame
  26. buffer *packetBuffer
  27. }
  28. func (p *packedPacket) EncryptionLevel() protocol.EncryptionLevel {
  29. if !p.header.IsLongHeader {
  30. return protocol.Encryption1RTT
  31. }
  32. switch p.header.Type {
  33. case protocol.PacketTypeInitial:
  34. return protocol.EncryptionInitial
  35. case protocol.PacketTypeHandshake:
  36. return protocol.EncryptionHandshake
  37. default:
  38. return protocol.EncryptionUnspecified
  39. }
  40. }
  41. func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
  42. return &ackhandler.Packet{
  43. PacketNumber: p.header.PacketNumber,
  44. PacketType: p.header.Type,
  45. Frames: p.frames,
  46. Length: protocol.ByteCount(len(p.raw)),
  47. EncryptionLevel: p.EncryptionLevel(),
  48. SendTime: time.Now(),
  49. }
  50. }
  51. func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
  52. maxSize := protocol.ByteCount(protocol.MinInitialPacketSize)
  53. // If this is not a UDP address, we don't know anything about the MTU.
  54. // Use the minimum size of an Initial packet as the max packet size.
  55. if udpAddr, ok := addr.(*net.UDPAddr); ok {
  56. // If ip is not an IPv4 address, To4 returns nil.
  57. // Note that there might be some corner cases, where this is not correct.
  58. // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6.
  59. if udpAddr.IP.To4() == nil {
  60. maxSize = protocol.MaxPacketSizeIPv6
  61. } else {
  62. maxSize = protocol.MaxPacketSizeIPv4
  63. }
  64. }
  65. return maxSize
  66. }
  67. type packetNumberManager interface {
  68. PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen)
  69. PopPacketNumber() protocol.PacketNumber
  70. }
  71. type sealingManager interface {
  72. GetSealer() (protocol.EncryptionLevel, handshake.Sealer)
  73. GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error)
  74. }
  75. type frameSource interface {
  76. AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame
  77. AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount)
  78. }
  79. type ackFrameSource interface {
  80. GetAckFrame() *wire.AckFrame
  81. }
  82. type packetPacker struct {
  83. destConnID protocol.ConnectionID
  84. srcConnID protocol.ConnectionID
  85. perspective protocol.Perspective
  86. version protocol.VersionNumber
  87. cryptoSetup sealingManager
  88. initialStream cryptoStream
  89. handshakeStream cryptoStream
  90. token []byte
  91. pnManager packetNumberManager
  92. framer frameSource
  93. acks ackFrameSource
  94. maxPacketSize protocol.ByteCount
  95. numNonRetransmittableAcks int
  96. }
  97. var _ packer = &packetPacker{}
  98. func newPacketPacker(
  99. destConnID protocol.ConnectionID,
  100. srcConnID protocol.ConnectionID,
  101. initialStream cryptoStream,
  102. handshakeStream cryptoStream,
  103. packetNumberManager packetNumberManager,
  104. remoteAddr net.Addr, // only used for determining the max packet size
  105. token []byte,
  106. cryptoSetup sealingManager,
  107. framer frameSource,
  108. acks ackFrameSource,
  109. perspective protocol.Perspective,
  110. version protocol.VersionNumber,
  111. ) *packetPacker {
  112. return &packetPacker{
  113. cryptoSetup: cryptoSetup,
  114. token: token,
  115. destConnID: destConnID,
  116. srcConnID: srcConnID,
  117. initialStream: initialStream,
  118. handshakeStream: handshakeStream,
  119. perspective: perspective,
  120. version: version,
  121. framer: framer,
  122. acks: acks,
  123. pnManager: packetNumberManager,
  124. maxPacketSize: getMaxPacketSize(remoteAddr),
  125. }
  126. }
  127. // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame
  128. func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) {
  129. frames := []wire.Frame{ccf}
  130. encLevel, sealer := p.cryptoSetup.GetSealer()
  131. header := p.getHeader(encLevel)
  132. return p.writeAndSealPacket(header, frames, sealer)
  133. }
  134. func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
  135. ack := p.acks.GetAckFrame()
  136. if ack == nil {
  137. return nil, nil
  138. }
  139. // TODO(#1534): only pack ACKs with the right encryption level
  140. encLevel, sealer := p.cryptoSetup.GetSealer()
  141. header := p.getHeader(encLevel)
  142. frames := []wire.Frame{ack}
  143. return p.writeAndSealPacket(header, frames, sealer)
  144. }
  145. // PackRetransmission packs a retransmission
  146. // For packets sent after completion of the handshake, it might happen that 2 packets have to be sent.
  147. // This can happen e.g. when a longer packet number is used in the header.
  148. func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) {
  149. var controlFrames []wire.Frame
  150. var streamFrames []*wire.StreamFrame
  151. for _, f := range packet.Frames {
  152. // CRYPTO frames are treated as control frames here.
  153. // Since we're making sure that the header can never be larger for a retransmission,
  154. // we never have to split CRYPTO frames.
  155. if sf, ok := f.(*wire.StreamFrame); ok {
  156. sf.DataLenPresent = true
  157. streamFrames = append(streamFrames, sf)
  158. } else {
  159. controlFrames = append(controlFrames, f)
  160. }
  161. }
  162. var packets []*packedPacket
  163. encLevel := packet.EncryptionLevel
  164. sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel)
  165. if err != nil {
  166. return nil, err
  167. }
  168. for len(controlFrames) > 0 || len(streamFrames) > 0 {
  169. var frames []wire.Frame
  170. var length protocol.ByteCount
  171. header := p.getHeader(encLevel)
  172. headerLen := header.GetLength(p.version)
  173. maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLen
  174. for len(controlFrames) > 0 {
  175. frame := controlFrames[0]
  176. frameLen := frame.Length(p.version)
  177. if length+frameLen > maxSize {
  178. break
  179. }
  180. length += frameLen
  181. frames = append(frames, frame)
  182. controlFrames = controlFrames[1:]
  183. }
  184. for len(streamFrames) > 0 && length+protocol.MinStreamFrameSize < maxSize {
  185. frame := streamFrames[0]
  186. frame.DataLenPresent = false
  187. frameToAdd := frame
  188. sf, err := frame.MaybeSplitOffFrame(maxSize-length, p.version)
  189. if err != nil {
  190. return nil, err
  191. }
  192. if sf != nil {
  193. frameToAdd = sf
  194. } else {
  195. streamFrames = streamFrames[1:]
  196. }
  197. frame.DataLenPresent = true
  198. length += frameToAdd.Length(p.version)
  199. frames = append(frames, frameToAdd)
  200. }
  201. if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok {
  202. sf.DataLenPresent = false
  203. }
  204. p, err := p.writeAndSealPacket(header, frames, sealer)
  205. if err != nil {
  206. return nil, err
  207. }
  208. packets = append(packets, p)
  209. }
  210. return packets, nil
  211. }
  212. // PackPacket packs a new packet
  213. // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
  214. func (p *packetPacker) PackPacket() (*packedPacket, error) {
  215. packet, err := p.maybePackCryptoPacket()
  216. if err != nil {
  217. return nil, err
  218. }
  219. if packet != nil {
  220. return packet, nil
  221. }
  222. encLevel, sealer := p.cryptoSetup.GetSealer()
  223. header := p.getHeader(encLevel)
  224. headerLen := header.GetLength(p.version)
  225. if err != nil {
  226. return nil, err
  227. }
  228. maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLen
  229. frames, err := p.composeNextPacket(maxSize)
  230. if err != nil {
  231. return nil, err
  232. }
  233. // Check if we have enough frames to send
  234. if len(frames) == 0 {
  235. return nil, nil
  236. }
  237. // check if this packet only contains an ACK
  238. if !ackhandler.HasRetransmittableFrames(frames) {
  239. if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks {
  240. frames = append(frames, &wire.PingFrame{})
  241. p.numNonRetransmittableAcks = 0
  242. } else {
  243. p.numNonRetransmittableAcks++
  244. }
  245. } else {
  246. p.numNonRetransmittableAcks = 0
  247. }
  248. return p.writeAndSealPacket(header, frames, sealer)
  249. }
  250. func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
  251. var s cryptoStream
  252. var encLevel protocol.EncryptionLevel
  253. if p.initialStream.HasData() {
  254. s = p.initialStream
  255. encLevel = protocol.EncryptionInitial
  256. } else if p.handshakeStream.HasData() {
  257. s = p.handshakeStream
  258. encLevel = protocol.EncryptionHandshake
  259. }
  260. if s == nil {
  261. return nil, nil
  262. }
  263. hdr := p.getHeader(encLevel)
  264. hdrLen := hdr.GetLength(p.version)
  265. sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel)
  266. if err != nil {
  267. return nil, err
  268. }
  269. var length protocol.ByteCount
  270. frames := make([]wire.Frame, 0, 2)
  271. if ack := p.acks.GetAckFrame(); ack != nil {
  272. frames = append(frames, ack)
  273. length += ack.Length(p.version)
  274. }
  275. cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length)
  276. frames = append(frames, cf)
  277. return p.writeAndSealPacket(hdr, frames, sealer)
  278. }
  279. func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wire.Frame, error) {
  280. var length protocol.ByteCount
  281. var frames []wire.Frame
  282. // ACKs need to go first, so that the sentPacketHandler will recognize them
  283. if ack := p.acks.GetAckFrame(); ack != nil {
  284. frames = append(frames, ack)
  285. length += ack.Length(p.version)
  286. }
  287. var lengthAdded protocol.ByteCount
  288. frames, lengthAdded = p.framer.AppendControlFrames(frames, maxFrameSize-length)
  289. length += lengthAdded
  290. // temporarily increase the maxFrameSize by the (minimum) length of the DataLen field
  291. // this leads to a properly sized packet in all cases, since we do all the packet length calculations with STREAM frames that have the DataLen set
  292. // however, for the last STREAM frame in the packet, we can omit the DataLen, thus yielding a packet of exactly the correct size
  293. // the length is encoded to either 1 or 2 bytes
  294. maxFrameSize++
  295. frames = p.framer.AppendStreamFrames(frames, maxFrameSize-length)
  296. if len(frames) > 0 {
  297. lastFrame := frames[len(frames)-1]
  298. if sf, ok := lastFrame.(*wire.StreamFrame); ok {
  299. sf.DataLenPresent = false
  300. }
  301. }
  302. return frames, nil
  303. }
  304. func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
  305. pn, pnLen := p.pnManager.PeekPacketNumber()
  306. header := &wire.ExtendedHeader{}
  307. header.PacketNumber = pn
  308. header.PacketNumberLen = pnLen
  309. header.Version = p.version
  310. header.DestConnectionID = p.destConnID
  311. if encLevel != protocol.Encryption1RTT {
  312. header.IsLongHeader = true
  313. // Always send Initial and Handshake packets with the maximum packet number length.
  314. // This simplifies retransmissions: Since the header can't get any larger,
  315. // we don't need to split CRYPTO frames.
  316. header.PacketNumberLen = protocol.PacketNumberLen4
  317. header.SrcConnectionID = p.srcConnID
  318. // Set the length to the maximum packet size.
  319. // Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns.
  320. header.Length = p.maxPacketSize
  321. switch encLevel {
  322. case protocol.EncryptionInitial:
  323. header.Type = protocol.PacketTypeInitial
  324. case protocol.EncryptionHandshake:
  325. header.Type = protocol.PacketTypeHandshake
  326. }
  327. }
  328. return header
  329. }
  330. func (p *packetPacker) writeAndSealPacket(
  331. header *wire.ExtendedHeader,
  332. frames []wire.Frame,
  333. sealer handshake.Sealer,
  334. ) (*packedPacket, error) {
  335. packetBuffer := getPacketBuffer()
  336. buffer := bytes.NewBuffer(packetBuffer.Slice[:0])
  337. addPaddingForInitial := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial
  338. if header.IsLongHeader {
  339. if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial {
  340. header.Token = p.token
  341. }
  342. if addPaddingForInitial {
  343. headerLen := header.GetLength(p.version)
  344. header.Length = protocol.ByteCount(header.PacketNumberLen) + protocol.MinInitialPacketSize - headerLen
  345. } else {
  346. // long header packets always use 4 byte packet number, so we never need to pad short payloads
  347. length := protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen)
  348. for _, frame := range frames {
  349. length += frame.Length(p.version)
  350. }
  351. header.Length = length
  352. }
  353. }
  354. if err := header.Write(buffer, p.version); err != nil {
  355. return nil, err
  356. }
  357. payloadOffset := buffer.Len()
  358. // write all frames but the last one
  359. for _, frame := range frames[:len(frames)-1] {
  360. if err := frame.Write(buffer, p.version); err != nil {
  361. return nil, err
  362. }
  363. }
  364. lastFrame := frames[len(frames)-1]
  365. if addPaddingForInitial {
  366. // when appending padding, we need to make sure that the last STREAM frames has the data length set
  367. if sf, ok := lastFrame.(*wire.StreamFrame); ok {
  368. sf.DataLenPresent = true
  369. }
  370. } else {
  371. payloadLen := buffer.Len() - payloadOffset + int(lastFrame.Length(p.version))
  372. if paddingLen := 4 - int(header.PacketNumberLen) - payloadLen; paddingLen > 0 {
  373. // Pad the packet such that packet number length + payload length is 4 bytes.
  374. // This is needed to enable the peer to get a 16 byte sample for header protection.
  375. buffer.Write(bytes.Repeat([]byte{0}, paddingLen))
  376. }
  377. }
  378. if err := lastFrame.Write(buffer, p.version); err != nil {
  379. return nil, err
  380. }
  381. if addPaddingForInitial {
  382. paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len()
  383. if paddingLen > 0 {
  384. buffer.Write(bytes.Repeat([]byte{0}, paddingLen))
  385. }
  386. }
  387. if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize {
  388. return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
  389. }
  390. raw := buffer.Bytes()
  391. _ = sealer.Seal(raw[payloadOffset:payloadOffset], raw[payloadOffset:], header.PacketNumber, raw[:payloadOffset])
  392. raw = raw[0 : buffer.Len()+sealer.Overhead()]
  393. pnOffset := payloadOffset - int(header.PacketNumberLen)
  394. sealer.EncryptHeader(
  395. raw[pnOffset+4:pnOffset+4+16],
  396. &raw[0],
  397. raw[pnOffset:payloadOffset],
  398. )
  399. num := p.pnManager.PopPacketNumber()
  400. if num != header.PacketNumber {
  401. return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
  402. }
  403. return &packedPacket{
  404. header: header,
  405. raw: raw,
  406. frames: frames,
  407. buffer: packetBuffer,
  408. }, nil
  409. }
  410. func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) {
  411. p.destConnID = connID
  412. }
  413. func (p *packetPacker) HandleTransportParameters(params *handshake.TransportParameters) {
  414. if params.MaxPacketSize != 0 {
  415. p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, params.MaxPacketSize)
  416. }
  417. }