packet_packer.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. package quic
  2. import (
  3. "bytes"
  4. "errors"
  5. "fmt"
  6. "net"
  7. "time"
  8. "github.com/lucas-clemente/quic-go/internal/ackhandler"
  9. "github.com/lucas-clemente/quic-go/internal/handshake"
  10. "github.com/lucas-clemente/quic-go/internal/protocol"
  11. "github.com/lucas-clemente/quic-go/internal/utils"
  12. "github.com/lucas-clemente/quic-go/internal/wire"
  13. )
  14. type packer interface {
  15. PackPacket() (*packedPacket, error)
  16. MaybePackAckPacket() (*packedPacket, error)
  17. PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error)
  18. PackConnectionClose(*wire.ConnectionCloseFrame) (*packedPacket, error)
  19. HandleTransportParameters(*handshake.TransportParameters)
  20. ChangeDestConnectionID(protocol.ConnectionID)
  21. }
  22. type packedPacket struct {
  23. header *wire.ExtendedHeader
  24. raw []byte
  25. frames []wire.Frame
  26. encryptionLevel protocol.EncryptionLevel
  27. }
  28. func (p *packedPacket) ToAckHandlerPacket() *ackhandler.Packet {
  29. return &ackhandler.Packet{
  30. PacketNumber: p.header.PacketNumber,
  31. PacketType: p.header.Type,
  32. Frames: p.frames,
  33. Length: protocol.ByteCount(len(p.raw)),
  34. EncryptionLevel: p.encryptionLevel,
  35. SendTime: time.Now(),
  36. }
  37. }
  38. func getMaxPacketSize(addr net.Addr) protocol.ByteCount {
  39. maxSize := protocol.ByteCount(protocol.MinInitialPacketSize)
  40. // If this is not a UDP address, we don't know anything about the MTU.
  41. // Use the minimum size of an Initial packet as the max packet size.
  42. if udpAddr, ok := addr.(*net.UDPAddr); ok {
  43. // If ip is not an IPv4 address, To4 returns nil.
  44. // Note that there might be some corner cases, where this is not correct.
  45. // See https://stackoverflow.com/questions/22751035/golang-distinguish-ipv4-ipv6.
  46. if udpAddr.IP.To4() == nil {
  47. maxSize = protocol.MaxPacketSizeIPv6
  48. } else {
  49. maxSize = protocol.MaxPacketSizeIPv4
  50. }
  51. }
  52. return maxSize
  53. }
  54. type packetNumberManager interface {
  55. PeekPacketNumber() (protocol.PacketNumber, protocol.PacketNumberLen)
  56. PopPacketNumber() protocol.PacketNumber
  57. }
  58. type sealingManager interface {
  59. GetSealer() (protocol.EncryptionLevel, handshake.Sealer)
  60. GetSealerWithEncryptionLevel(protocol.EncryptionLevel) (handshake.Sealer, error)
  61. }
  62. type frameSource interface {
  63. AppendStreamFrames([]wire.Frame, protocol.ByteCount) []wire.Frame
  64. AppendControlFrames([]wire.Frame, protocol.ByteCount) ([]wire.Frame, protocol.ByteCount)
  65. }
  66. type ackFrameSource interface {
  67. GetAckFrame() *wire.AckFrame
  68. }
  69. type packetPacker struct {
  70. destConnID protocol.ConnectionID
  71. srcConnID protocol.ConnectionID
  72. perspective protocol.Perspective
  73. version protocol.VersionNumber
  74. cryptoSetup sealingManager
  75. initialStream cryptoStream
  76. handshakeStream cryptoStream
  77. token []byte
  78. pnManager packetNumberManager
  79. framer frameSource
  80. acks ackFrameSource
  81. maxPacketSize protocol.ByteCount
  82. hasSentPacket bool // has the packetPacker already sent a packet
  83. numNonRetransmittableAcks int
  84. }
  85. var _ packer = &packetPacker{}
  86. func newPacketPacker(
  87. destConnID protocol.ConnectionID,
  88. srcConnID protocol.ConnectionID,
  89. initialStream cryptoStream,
  90. handshakeStream cryptoStream,
  91. packetNumberManager packetNumberManager,
  92. remoteAddr net.Addr, // only used for determining the max packet size
  93. token []byte,
  94. cryptoSetup sealingManager,
  95. framer frameSource,
  96. acks ackFrameSource,
  97. perspective protocol.Perspective,
  98. version protocol.VersionNumber,
  99. ) *packetPacker {
  100. return &packetPacker{
  101. cryptoSetup: cryptoSetup,
  102. token: token,
  103. destConnID: destConnID,
  104. srcConnID: srcConnID,
  105. initialStream: initialStream,
  106. handshakeStream: handshakeStream,
  107. perspective: perspective,
  108. version: version,
  109. framer: framer,
  110. acks: acks,
  111. pnManager: packetNumberManager,
  112. maxPacketSize: getMaxPacketSize(remoteAddr),
  113. }
  114. }
  115. // PackConnectionClose packs a packet that ONLY contains a ConnectionCloseFrame
  116. func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*packedPacket, error) {
  117. frames := []wire.Frame{ccf}
  118. encLevel, sealer := p.cryptoSetup.GetSealer()
  119. header := p.getHeader(encLevel)
  120. raw, err := p.writeAndSealPacket(header, frames, sealer)
  121. return &packedPacket{
  122. header: header,
  123. raw: raw,
  124. frames: frames,
  125. encryptionLevel: encLevel,
  126. }, err
  127. }
  128. func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
  129. ack := p.acks.GetAckFrame()
  130. if ack == nil {
  131. return nil, nil
  132. }
  133. // TODO(#1534): only pack ACKs with the right encryption level
  134. encLevel, sealer := p.cryptoSetup.GetSealer()
  135. header := p.getHeader(encLevel)
  136. frames := []wire.Frame{ack}
  137. raw, err := p.writeAndSealPacket(header, frames, sealer)
  138. return &packedPacket{
  139. header: header,
  140. raw: raw,
  141. frames: frames,
  142. encryptionLevel: encLevel,
  143. }, err
  144. }
  145. // PackRetransmission packs a retransmission
  146. // For packets sent after completion of the handshake, it might happen that 2 packets have to be sent.
  147. // This can happen e.g. when a longer packet number is used in the header.
  148. func (p *packetPacker) PackRetransmission(packet *ackhandler.Packet) ([]*packedPacket, error) {
  149. if packet.EncryptionLevel != protocol.Encryption1RTT {
  150. p, err := p.packHandshakeRetransmission(packet)
  151. return []*packedPacket{p}, err
  152. }
  153. var controlFrames []wire.Frame
  154. var streamFrames []*wire.StreamFrame
  155. for _, f := range packet.Frames {
  156. if sf, ok := f.(*wire.StreamFrame); ok {
  157. sf.DataLenPresent = true
  158. streamFrames = append(streamFrames, sf)
  159. } else {
  160. controlFrames = append(controlFrames, f)
  161. }
  162. }
  163. var packets []*packedPacket
  164. encLevel := packet.EncryptionLevel
  165. sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel)
  166. if err != nil {
  167. return nil, err
  168. }
  169. for len(controlFrames) > 0 || len(streamFrames) > 0 {
  170. var frames []wire.Frame
  171. var length protocol.ByteCount
  172. header := p.getHeader(encLevel)
  173. headerLen := header.GetLength(p.version)
  174. maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLen
  175. for len(controlFrames) > 0 {
  176. frame := controlFrames[0]
  177. frameLen := frame.Length(p.version)
  178. if length+frameLen > maxSize {
  179. break
  180. }
  181. length += frameLen
  182. frames = append(frames, frame)
  183. controlFrames = controlFrames[1:]
  184. }
  185. for len(streamFrames) > 0 && length+protocol.MinStreamFrameSize < maxSize {
  186. frame := streamFrames[0]
  187. frame.DataLenPresent = false
  188. frameToAdd := frame
  189. sf, err := frame.MaybeSplitOffFrame(maxSize-length, p.version)
  190. if err != nil {
  191. return nil, err
  192. }
  193. if sf != nil {
  194. frameToAdd = sf
  195. } else {
  196. streamFrames = streamFrames[1:]
  197. }
  198. frame.DataLenPresent = true
  199. length += frameToAdd.Length(p.version)
  200. frames = append(frames, frameToAdd)
  201. }
  202. if sf, ok := frames[len(frames)-1].(*wire.StreamFrame); ok {
  203. sf.DataLenPresent = false
  204. }
  205. raw, err := p.writeAndSealPacket(header, frames, sealer)
  206. if err != nil {
  207. return nil, err
  208. }
  209. packets = append(packets, &packedPacket{
  210. header: header,
  211. raw: raw,
  212. frames: frames,
  213. encryptionLevel: encLevel,
  214. })
  215. }
  216. return packets, nil
  217. }
  218. // packHandshakeRetransmission retransmits a handshake packet
  219. func (p *packetPacker) packHandshakeRetransmission(packet *ackhandler.Packet) (*packedPacket, error) {
  220. sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(packet.EncryptionLevel)
  221. if err != nil {
  222. return nil, err
  223. }
  224. // make sure that the retransmission for an Initial packet is sent as an Initial packet
  225. if packet.PacketType == protocol.PacketTypeInitial {
  226. p.hasSentPacket = false
  227. }
  228. header := p.getHeader(packet.EncryptionLevel)
  229. header.Type = packet.PacketType
  230. raw, err := p.writeAndSealPacket(header, packet.Frames, sealer)
  231. return &packedPacket{
  232. header: header,
  233. raw: raw,
  234. frames: packet.Frames,
  235. encryptionLevel: packet.EncryptionLevel,
  236. }, err
  237. }
  238. // PackPacket packs a new packet
  239. // the other controlFrames are sent in the next packet, but might be queued and sent in the next packet if the packet would overflow MaxPacketSize otherwise
  240. func (p *packetPacker) PackPacket() (*packedPacket, error) {
  241. packet, err := p.maybePackCryptoPacket()
  242. if err != nil {
  243. return nil, err
  244. }
  245. if packet != nil {
  246. return packet, nil
  247. }
  248. // if this is the first packet to be send, make sure it contains stream data
  249. if !p.hasSentPacket && packet == nil {
  250. return nil, nil
  251. }
  252. encLevel, sealer := p.cryptoSetup.GetSealer()
  253. header := p.getHeader(encLevel)
  254. headerLen := header.GetLength(p.version)
  255. if err != nil {
  256. return nil, err
  257. }
  258. maxSize := p.maxPacketSize - protocol.ByteCount(sealer.Overhead()) - headerLen
  259. frames, err := p.composeNextPacket(maxSize, p.canSendData(encLevel))
  260. if err != nil {
  261. return nil, err
  262. }
  263. // Check if we have enough frames to send
  264. if len(frames) == 0 {
  265. return nil, nil
  266. }
  267. // check if this packet only contains an ACK
  268. if !ackhandler.HasRetransmittableFrames(frames) {
  269. if p.numNonRetransmittableAcks >= protocol.MaxNonRetransmittableAcks {
  270. frames = append(frames, &wire.PingFrame{})
  271. p.numNonRetransmittableAcks = 0
  272. } else {
  273. p.numNonRetransmittableAcks++
  274. }
  275. } else {
  276. p.numNonRetransmittableAcks = 0
  277. }
  278. raw, err := p.writeAndSealPacket(header, frames, sealer)
  279. if err != nil {
  280. return nil, err
  281. }
  282. return &packedPacket{
  283. header: header,
  284. raw: raw,
  285. frames: frames,
  286. encryptionLevel: encLevel,
  287. }, nil
  288. }
  289. func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
  290. var s cryptoStream
  291. var encLevel protocol.EncryptionLevel
  292. if p.initialStream.HasData() {
  293. s = p.initialStream
  294. encLevel = protocol.EncryptionInitial
  295. } else if p.handshakeStream.HasData() {
  296. s = p.handshakeStream
  297. encLevel = protocol.EncryptionHandshake
  298. }
  299. if s == nil {
  300. return nil, nil
  301. }
  302. hdr := p.getHeader(encLevel)
  303. hdrLen := hdr.GetLength(p.version)
  304. sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel)
  305. if err != nil {
  306. return nil, err
  307. }
  308. var length protocol.ByteCount
  309. frames := make([]wire.Frame, 0, 2)
  310. if ack := p.acks.GetAckFrame(); ack != nil {
  311. frames = append(frames, ack)
  312. length += ack.Length(p.version)
  313. }
  314. cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length)
  315. frames = append(frames, cf)
  316. raw, err := p.writeAndSealPacket(hdr, frames, sealer)
  317. if err != nil {
  318. return nil, err
  319. }
  320. return &packedPacket{
  321. header: hdr,
  322. raw: raw,
  323. frames: frames,
  324. encryptionLevel: encLevel,
  325. }, nil
  326. }
  327. func (p *packetPacker) composeNextPacket(
  328. maxFrameSize protocol.ByteCount,
  329. canSendStreamFrames bool,
  330. ) ([]wire.Frame, error) {
  331. var length protocol.ByteCount
  332. var frames []wire.Frame
  333. // ACKs need to go first, so that the sentPacketHandler will recognize them
  334. if ack := p.acks.GetAckFrame(); ack != nil {
  335. frames = append(frames, ack)
  336. length += ack.Length(p.version)
  337. }
  338. var lengthAdded protocol.ByteCount
  339. frames, lengthAdded = p.framer.AppendControlFrames(frames, maxFrameSize-length)
  340. length += lengthAdded
  341. if !canSendStreamFrames {
  342. return frames, nil
  343. }
  344. // temporarily increase the maxFrameSize by the (minimum) length of the DataLen field
  345. // this leads to a properly sized packet in all cases, since we do all the packet length calculations with STREAM frames that have the DataLen set
  346. // however, for the last STREAM frame in the packet, we can omit the DataLen, thus yielding a packet of exactly the correct size
  347. // the length is encoded to either 1 or 2 bytes
  348. maxFrameSize++
  349. frames = p.framer.AppendStreamFrames(frames, maxFrameSize-length)
  350. if len(frames) > 0 {
  351. lastFrame := frames[len(frames)-1]
  352. if sf, ok := lastFrame.(*wire.StreamFrame); ok {
  353. sf.DataLenPresent = false
  354. }
  355. }
  356. return frames, nil
  357. }
  358. func (p *packetPacker) getHeader(encLevel protocol.EncryptionLevel) *wire.ExtendedHeader {
  359. pn, pnLen := p.pnManager.PeekPacketNumber()
  360. header := &wire.ExtendedHeader{}
  361. header.PacketNumber = pn
  362. header.PacketNumberLen = pnLen
  363. header.Version = p.version
  364. header.DestConnectionID = p.destConnID
  365. if encLevel != protocol.Encryption1RTT {
  366. header.IsLongHeader = true
  367. header.SrcConnectionID = p.srcConnID
  368. // Set the length to the maximum packet size.
  369. // Since it is encoded as a varint, this guarantees us that the header will end up at most as big as GetLength() returns.
  370. header.Length = p.maxPacketSize
  371. switch encLevel {
  372. case protocol.EncryptionInitial:
  373. header.Type = protocol.PacketTypeInitial
  374. case protocol.EncryptionHandshake:
  375. header.Type = protocol.PacketTypeHandshake
  376. }
  377. }
  378. return header
  379. }
  380. func (p *packetPacker) writeAndSealPacket(
  381. header *wire.ExtendedHeader, frames []wire.Frame,
  382. sealer handshake.Sealer,
  383. ) ([]byte, error) {
  384. raw := *getPacketBuffer()
  385. buffer := bytes.NewBuffer(raw[:0])
  386. addPadding := p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial && !p.hasSentPacket
  387. // the length is only needed for Long Headers
  388. if header.IsLongHeader {
  389. if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial {
  390. header.Token = p.token
  391. }
  392. if addPadding {
  393. headerLen := header.GetLength(p.version)
  394. header.Length = protocol.ByteCount(header.PacketNumberLen) + protocol.MinInitialPacketSize - headerLen
  395. } else {
  396. length := protocol.ByteCount(sealer.Overhead()) + protocol.ByteCount(header.PacketNumberLen)
  397. for _, frame := range frames {
  398. length += frame.Length(p.version)
  399. }
  400. header.Length = length
  401. }
  402. }
  403. if err := header.Write(buffer, p.version); err != nil {
  404. return nil, err
  405. }
  406. payloadStartIndex := buffer.Len()
  407. // the Initial packet needs to be padded, so the last STREAM frame must have the data length present
  408. if p.perspective == protocol.PerspectiveClient && header.Type == protocol.PacketTypeInitial {
  409. lastFrame := frames[len(frames)-1]
  410. if sf, ok := lastFrame.(*wire.StreamFrame); ok {
  411. sf.DataLenPresent = true
  412. }
  413. }
  414. for _, frame := range frames {
  415. if err := frame.Write(buffer, p.version); err != nil {
  416. return nil, err
  417. }
  418. }
  419. if addPadding {
  420. paddingLen := protocol.MinInitialPacketSize - sealer.Overhead() - buffer.Len()
  421. if paddingLen > 0 {
  422. buffer.Write(bytes.Repeat([]byte{0}, paddingLen))
  423. }
  424. }
  425. if size := protocol.ByteCount(buffer.Len() + sealer.Overhead()); size > p.maxPacketSize {
  426. return nil, fmt.Errorf("PacketPacker BUG: packet too large (%d bytes, allowed %d bytes)", size, p.maxPacketSize)
  427. }
  428. raw = raw[0:buffer.Len()]
  429. _ = sealer.Seal(raw[payloadStartIndex:payloadStartIndex], raw[payloadStartIndex:], header.PacketNumber, raw[:payloadStartIndex])
  430. raw = raw[0 : buffer.Len()+sealer.Overhead()]
  431. num := p.pnManager.PopPacketNumber()
  432. if num != header.PacketNumber {
  433. return nil, errors.New("packetPacker BUG: Peeked and Popped packet numbers do not match")
  434. }
  435. p.hasSentPacket = true
  436. return raw, nil
  437. }
  438. func (p *packetPacker) canSendData(encLevel protocol.EncryptionLevel) bool {
  439. return encLevel == protocol.Encryption1RTT
  440. }
  441. func (p *packetPacker) ChangeDestConnectionID(connID protocol.ConnectionID) {
  442. p.destConnID = connID
  443. }
  444. func (p *packetPacker) HandleTransportParameters(params *handshake.TransportParameters) {
  445. if params.MaxPacketSize != 0 {
  446. p.maxPacketSize = utils.MinByteCount(p.maxPacketSize, params.MaxPacketSize)
  447. }
  448. }