Browse Source

sync quic lib

Darien Raymond 6 years ago
parent
commit
1cf07c3379

+ 3 - 2
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/interfaces.go

@@ -27,6 +27,7 @@ type SentPacketHandler interface {
 	// Before sending any packet, SendingAllowed() must be called to learn if we can actually send it.
 	ShouldSendNumPackets() int
 
+	// only to be called once the handshake is complete
 	GetLowestPacketNotConfirmedAcked() protocol.PacketNumber
 	DequeuePacketForRetransmission() *Packet
 	DequeueProbePacket() (*Packet, error)
@@ -40,9 +41,9 @@ type SentPacketHandler interface {
 
 // ReceivedPacketHandler handles ACKs needed to send for incoming packets
 type ReceivedPacketHandler interface {
-	ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error
+	ReceivedPacket(pn protocol.PacketNumber, encLevel protocol.EncryptionLevel, rcvTime time.Time, shouldInstigateAck bool) error
 	IgnoreBelow(protocol.PacketNumber)
 
 	GetAlarmTimeout() time.Time
-	GetAckFrame() *wire.AckFrame
+	GetAckFrame(protocol.EncryptionLevel) *wire.AckFrame
 }

+ 44 - 161
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_handler.go

@@ -1,6 +1,7 @@
 package ackhandler
 
 import (
+	"fmt"
 	"time"
 
 	"github.com/lucas-clemente/quic-go/internal/congestion"
@@ -9,27 +10,6 @@ import (
 	"github.com/lucas-clemente/quic-go/internal/wire"
 )
 
-type receivedPacketHandler struct {
-	largestObserved             protocol.PacketNumber
-	ignoreBelow                 protocol.PacketNumber
-	largestObservedReceivedTime time.Time
-
-	packetHistory *receivedPacketHistory
-
-	ackSendDelay time.Duration
-	rttStats     *congestion.RTTStats
-
-	packetsReceivedSinceLastAck                int
-	retransmittablePacketsReceivedSinceLastAck int
-	ackQueued                                  bool
-	ackAlarm                                   time.Time
-	lastAck                                    *wire.AckFrame
-
-	logger utils.Logger
-
-	version protocol.VersionNumber
-}
-
 const (
 	// maximum delay that can be applied to an ACK for a retransmittable packet
 	ackSendDelay = 25 * time.Millisecond
@@ -53,6 +33,14 @@ const (
 	maxPacketsAfterNewMissing = 4
 )
 
+type receivedPacketHandler struct {
+	initialPackets   *receivedPacketTracker
+	handshakePackets *receivedPacketTracker
+	oneRTTPackets    *receivedPacketTracker
+}
+
+var _ ReceivedPacketHandler = &receivedPacketHandler{}
+
 // NewReceivedPacketHandler creates a new receivedPacketHandler
 func NewReceivedPacketHandler(
 	rttStats *congestion.RTTStats,
@@ -60,156 +48,51 @@ func NewReceivedPacketHandler(
 	version protocol.VersionNumber,
 ) ReceivedPacketHandler {
 	return &receivedPacketHandler{
-		packetHistory: newReceivedPacketHistory(),
-		ackSendDelay:  ackSendDelay,
-		rttStats:      rttStats,
-		logger:        logger,
-		version:       version,
+		initialPackets:   newReceivedPacketTracker(rttStats, logger, version),
+		handshakePackets: newReceivedPacketTracker(rttStats, logger, version),
+		oneRTTPackets:    newReceivedPacketTracker(rttStats, logger, version),
 	}
 }
 
-func (h *receivedPacketHandler) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
-	if packetNumber < h.ignoreBelow {
-		return nil
-	}
-
-	isMissing := h.isMissing(packetNumber)
-	if packetNumber >= h.largestObserved {
-		h.largestObserved = packetNumber
-		h.largestObservedReceivedTime = rcvTime
+func (h *receivedPacketHandler) ReceivedPacket(
+	pn protocol.PacketNumber,
+	encLevel protocol.EncryptionLevel,
+	rcvTime time.Time,
+	shouldInstigateAck bool,
+) error {
+	switch encLevel {
+	case protocol.EncryptionInitial:
+		return h.initialPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck)
+	case protocol.EncryptionHandshake:
+		return h.handshakePackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck)
+	case protocol.Encryption1RTT:
+		return h.oneRTTPackets.ReceivedPacket(pn, rcvTime, shouldInstigateAck)
+	default:
+		return fmt.Errorf("received packet with unknown encryption level: %s", encLevel)
 	}
-
-	if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
-		return err
-	}
-	h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
-	return nil
 }
 
-// IgnoreBelow sets a lower limit for acking packets.
-// Packets with packet numbers smaller than p will not be acked.
-func (h *receivedPacketHandler) IgnoreBelow(p protocol.PacketNumber) {
-	if p <= h.ignoreBelow {
-		return
-	}
-	h.ignoreBelow = p
-	h.packetHistory.DeleteBelow(p)
-	if h.logger.Debug() {
-		h.logger.Debugf("\tIgnoring all packets below %#x.", p)
-	}
+// only to be used with 1-RTT packets
+func (h *receivedPacketHandler) IgnoreBelow(pn protocol.PacketNumber) {
+	h.oneRTTPackets.IgnoreBelow(pn)
 }
 
-// isMissing says if a packet was reported missing in the last ACK.
-func (h *receivedPacketHandler) isMissing(p protocol.PacketNumber) bool {
-	if h.lastAck == nil || p < h.ignoreBelow {
-		return false
-	}
-	return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
+func (h *receivedPacketHandler) GetAlarmTimeout() time.Time {
+	initialAlarm := h.initialPackets.GetAlarmTimeout()
+	handshakeAlarm := h.handshakePackets.GetAlarmTimeout()
+	oneRTTAlarm := h.oneRTTPackets.GetAlarmTimeout()
+	return utils.MinNonZeroTime(utils.MinNonZeroTime(initialAlarm, handshakeAlarm), oneRTTAlarm)
 }
 
-func (h *receivedPacketHandler) hasNewMissingPackets() bool {
-	if h.lastAck == nil {
-		return false
-	}
-	highestRange := h.packetHistory.GetHighestAckRange()
-	return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing
-}
-
-// maybeQueueAck queues an ACK, if necessary.
-// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
-// in ACK_DECIMATION_WITH_REORDERING mode.
-func (h *receivedPacketHandler) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
-	h.packetsReceivedSinceLastAck++
-
-	// always ack the first packet
-	if h.lastAck == nil {
-		h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
-		h.ackQueued = true
-		return
-	}
-
-	// Send an ACK if this packet was reported missing in an ACK sent before.
-	// Ack decimation with reordering relies on the timer to send an ACK, but if
-	// missing packets we reported in the previous ack, send an ACK immediately.
-	if wasMissing {
-		if h.logger.Debug() {
-			h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber)
-		}
-		h.ackQueued = true
-	}
-
-	if !h.ackQueued && shouldInstigateAck {
-		h.retransmittablePacketsReceivedSinceLastAck++
-
-		if packetNumber > minReceivedBeforeAckDecimation {
-			// ack up to 10 packets at once
-			if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
-				h.ackQueued = true
-				if h.logger.Debug() {
-					h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck)
-				}
-			} else if h.ackAlarm.IsZero() {
-				// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
-				ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
-				h.ackAlarm = rcvTime.Add(ackDelay)
-				if h.logger.Debug() {
-					h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
-				}
-			}
-		} else {
-			// send an ACK every 2 retransmittable packets
-			if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
-				if h.logger.Debug() {
-					h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck)
-				}
-				h.ackQueued = true
-			} else if h.ackAlarm.IsZero() {
-				if h.logger.Debug() {
-					h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay)
-				}
-				h.ackAlarm = rcvTime.Add(ackSendDelay)
-			}
-		}
-		// If there are new missing packets to report, set a short timer to send an ACK.
-		if h.hasNewMissingPackets() {
-			// wait the minimum of 1/8 min RTT and the existing ack time
-			ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay))
-			ackTime := rcvTime.Add(ackDelay)
-			if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
-				h.ackAlarm = ackTime
-				if h.logger.Debug() {
-					h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
-				}
-			}
-		}
-	}
-
-	if h.ackQueued {
-		// cancel the ack alarm
-		h.ackAlarm = time.Time{}
-	}
-}
-
-func (h *receivedPacketHandler) GetAckFrame() *wire.AckFrame {
-	now := time.Now()
-	if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
+func (h *receivedPacketHandler) GetAckFrame(encLevel protocol.EncryptionLevel) *wire.AckFrame {
+	switch encLevel {
+	case protocol.EncryptionInitial:
+		return h.initialPackets.GetAckFrame()
+	case protocol.EncryptionHandshake:
+		return h.handshakePackets.GetAckFrame()
+	case protocol.Encryption1RTT:
+		return h.oneRTTPackets.GetAckFrame()
+	default:
 		return nil
 	}
-	if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
-		h.logger.Debugf("Sending ACK because the ACK timer expired.")
-	}
-
-	ack := &wire.AckFrame{
-		AckRanges: h.packetHistory.GetAckRanges(),
-		DelayTime: now.Sub(h.largestObservedReceivedTime),
-	}
-
-	h.lastAck = ack
-	h.ackAlarm = time.Time{}
-	h.ackQueued = false
-	h.packetsReceivedSinceLastAck = 0
-	h.retransmittablePacketsReceivedSinceLastAck = 0
-	return ack
 }
-
-func (h *receivedPacketHandler) GetAlarmTimeout() time.Time { return h.ackAlarm }

+ 1 - 0
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_history.go

@@ -8,6 +8,7 @@ import (
 )
 
 // The receivedPacketHistory stores if a packet number has already been received.
+// It generates ACK ranges which can be used to assemble an ACK frame.
 // It does not store packet contents.
 type receivedPacketHistory struct {
 	ranges *utils.PacketIntervalList

+ 191 - 0
vendor/github.com/lucas-clemente/quic-go/internal/ackhandler/received_packet_tracker.go

@@ -0,0 +1,191 @@
+package ackhandler
+
+import (
+	"time"
+
+	"github.com/lucas-clemente/quic-go/internal/congestion"
+	"github.com/lucas-clemente/quic-go/internal/protocol"
+	"github.com/lucas-clemente/quic-go/internal/utils"
+	"github.com/lucas-clemente/quic-go/internal/wire"
+)
+
+type receivedPacketTracker struct {
+	largestObserved             protocol.PacketNumber
+	ignoreBelow                 protocol.PacketNumber
+	largestObservedReceivedTime time.Time
+
+	packetHistory *receivedPacketHistory
+
+	ackSendDelay time.Duration
+	rttStats     *congestion.RTTStats
+
+	packetsReceivedSinceLastAck                int
+	retransmittablePacketsReceivedSinceLastAck int
+	ackQueued                                  bool
+	ackAlarm                                   time.Time
+	lastAck                                    *wire.AckFrame
+
+	logger utils.Logger
+
+	version protocol.VersionNumber
+}
+
+func newReceivedPacketTracker(
+	rttStats *congestion.RTTStats,
+	logger utils.Logger,
+	version protocol.VersionNumber,
+) *receivedPacketTracker {
+	return &receivedPacketTracker{
+		packetHistory: newReceivedPacketHistory(),
+		ackSendDelay:  ackSendDelay,
+		rttStats:      rttStats,
+		logger:        logger,
+		version:       version,
+	}
+}
+
+func (h *receivedPacketTracker) ReceivedPacket(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck bool) error {
+	if packetNumber < h.ignoreBelow {
+		return nil
+	}
+
+	isMissing := h.isMissing(packetNumber)
+	if packetNumber >= h.largestObserved {
+		h.largestObserved = packetNumber
+		h.largestObservedReceivedTime = rcvTime
+	}
+
+	if err := h.packetHistory.ReceivedPacket(packetNumber); err != nil {
+		return err
+	}
+	h.maybeQueueAck(packetNumber, rcvTime, shouldInstigateAck, isMissing)
+	return nil
+}
+
+// IgnoreBelow sets a lower limit for acking packets.
+// Packets with packet numbers smaller than p will not be acked.
+func (h *receivedPacketTracker) IgnoreBelow(p protocol.PacketNumber) {
+	if p <= h.ignoreBelow {
+		return
+	}
+	h.ignoreBelow = p
+	h.packetHistory.DeleteBelow(p)
+	if h.logger.Debug() {
+		h.logger.Debugf("\tIgnoring all packets below %#x.", p)
+	}
+}
+
+// isMissing says if a packet was reported missing in the last ACK.
+func (h *receivedPacketTracker) isMissing(p protocol.PacketNumber) bool {
+	if h.lastAck == nil || p < h.ignoreBelow {
+		return false
+	}
+	return p < h.lastAck.LargestAcked() && !h.lastAck.AcksPacket(p)
+}
+
+func (h *receivedPacketTracker) hasNewMissingPackets() bool {
+	if h.lastAck == nil {
+		return false
+	}
+	highestRange := h.packetHistory.GetHighestAckRange()
+	return highestRange.Smallest >= h.lastAck.LargestAcked() && highestRange.Len() <= maxPacketsAfterNewMissing
+}
+
+// maybeQueueAck queues an ACK, if necessary.
+// It is implemented analogously to Chrome's QuicConnection::MaybeQueueAck()
+// in ACK_DECIMATION_WITH_REORDERING mode.
+func (h *receivedPacketTracker) maybeQueueAck(packetNumber protocol.PacketNumber, rcvTime time.Time, shouldInstigateAck, wasMissing bool) {
+	h.packetsReceivedSinceLastAck++
+
+	// always ack the first packet
+	if h.lastAck == nil {
+		h.logger.Debugf("\tQueueing ACK because the first packet should be acknowledged.")
+		h.ackQueued = true
+		return
+	}
+
+	// Send an ACK if this packet was reported missing in an ACK sent before.
+	// Ack decimation with reordering relies on the timer to send an ACK, but if
+	// missing packets we reported in the previous ack, send an ACK immediately.
+	if wasMissing {
+		if h.logger.Debug() {
+			h.logger.Debugf("\tQueueing ACK because packet %#x was missing before.", packetNumber)
+		}
+		h.ackQueued = true
+	}
+
+	if !h.ackQueued && shouldInstigateAck {
+		h.retransmittablePacketsReceivedSinceLastAck++
+
+		if packetNumber > minReceivedBeforeAckDecimation {
+			// ack up to 10 packets at once
+			if h.retransmittablePacketsReceivedSinceLastAck >= retransmittablePacketsBeforeAck {
+				h.ackQueued = true
+				if h.logger.Debug() {
+					h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, retransmittablePacketsBeforeAck)
+				}
+			} else if h.ackAlarm.IsZero() {
+				// wait for the minimum of the ack decimation delay or the delayed ack time before sending an ack
+				ackDelay := utils.MinDuration(ackSendDelay, time.Duration(float64(h.rttStats.MinRTT())*float64(ackDecimationDelay)))
+				h.ackAlarm = rcvTime.Add(ackDelay)
+				if h.logger.Debug() {
+					h.logger.Debugf("\tSetting ACK timer to min(1/4 min-RTT, max ack delay): %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
+				}
+			}
+		} else {
+			// send an ACK every 2 retransmittable packets
+			if h.retransmittablePacketsReceivedSinceLastAck >= initialRetransmittablePacketsBeforeAck {
+				if h.logger.Debug() {
+					h.logger.Debugf("\tQueueing ACK because packet %d packets were received after the last ACK (using initial threshold: %d).", h.retransmittablePacketsReceivedSinceLastAck, initialRetransmittablePacketsBeforeAck)
+				}
+				h.ackQueued = true
+			} else if h.ackAlarm.IsZero() {
+				if h.logger.Debug() {
+					h.logger.Debugf("\tSetting ACK timer to max ack delay: %s", ackSendDelay)
+				}
+				h.ackAlarm = rcvTime.Add(ackSendDelay)
+			}
+		}
+		// If there are new missing packets to report, set a short timer to send an ACK.
+		if h.hasNewMissingPackets() {
+			// wait the minimum of 1/8 min RTT and the existing ack time
+			ackDelay := time.Duration(float64(h.rttStats.MinRTT()) * float64(shortAckDecimationDelay))
+			ackTime := rcvTime.Add(ackDelay)
+			if h.ackAlarm.IsZero() || h.ackAlarm.After(ackTime) {
+				h.ackAlarm = ackTime
+				if h.logger.Debug() {
+					h.logger.Debugf("\tSetting ACK timer to 1/8 min-RTT: %s (%s from now)", ackDelay, time.Until(h.ackAlarm))
+				}
+			}
+		}
+	}
+
+	if h.ackQueued {
+		// cancel the ack alarm
+		h.ackAlarm = time.Time{}
+	}
+}
+
+func (h *receivedPacketTracker) GetAckFrame() *wire.AckFrame {
+	now := time.Now()
+	if !h.ackQueued && (h.ackAlarm.IsZero() || h.ackAlarm.After(now)) {
+		return nil
+	}
+	if h.logger.Debug() && !h.ackQueued && !h.ackAlarm.IsZero() {
+		h.logger.Debugf("Sending ACK because the ACK timer expired.")
+	}
+
+	ack := &wire.AckFrame{
+		AckRanges: h.packetHistory.GetAckRanges(),
+		DelayTime: now.Sub(h.largestObservedReceivedTime),
+	}
+
+	h.lastAck = ack
+	h.ackAlarm = time.Time{}
+	h.ackQueued = false
+	h.packetsReceivedSinceLastAck = 0
+	h.retransmittablePacketsReceivedSinceLastAck = 0
+	return ack
+}
+
+func (h *receivedPacketTracker) GetAlarmTimeout() time.Time { return h.ackAlarm }

+ 6 - 6
vendor/github.com/lucas-clemente/quic-go/internal/handshake/crypto_setup.go

@@ -359,6 +359,12 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
 		case <-h.handshakeErrChan:
 			return false
 		}
+		// get the handshake write key
+		select {
+		case <-h.receivedWriteKey:
+		case <-h.handshakeErrChan:
+			return false
+		}
 		return true
 	case typeEncryptedExtensions:
 		select {
@@ -372,12 +378,6 @@ func (h *cryptoSetup) handleMessageForClient(msgType messageType) bool {
 		// nothing to do
 		return false
 	case typeFinished:
-		// get the handshake write key
-		select {
-		case <-h.receivedWriteKey:
-		case <-h.handshakeErrChan:
-			return false
-		}
 		// While the order of these two is not defined by the TLS spec,
 		// we have to do it on the same order as our TLS library does it.
 		// get the handshake write key

+ 8 - 8
vendor/github.com/lucas-clemente/quic-go/internal/mocks/ackhandler/received_packet_handler.go

@@ -37,15 +37,15 @@ func (m *MockReceivedPacketHandler) EXPECT() *MockReceivedPacketHandlerMockRecor
 }
 
 // GetAckFrame mocks base method
-func (m *MockReceivedPacketHandler) GetAckFrame() *wire.AckFrame {
-	ret := m.ctrl.Call(m, "GetAckFrame")
+func (m *MockReceivedPacketHandler) GetAckFrame(arg0 protocol.EncryptionLevel) *wire.AckFrame {
+	ret := m.ctrl.Call(m, "GetAckFrame", arg0)
 	ret0, _ := ret[0].(*wire.AckFrame)
 	return ret0
 }
 
 // GetAckFrame indicates an expected call of GetAckFrame
-func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame() *gomock.Call {
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame))
+func (mr *MockReceivedPacketHandlerMockRecorder) GetAckFrame(arg0 interface{}) *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAckFrame", reflect.TypeOf((*MockReceivedPacketHandler)(nil).GetAckFrame), arg0)
 }
 
 // GetAlarmTimeout mocks base method
@@ -71,13 +71,13 @@ func (mr *MockReceivedPacketHandlerMockRecorder) IgnoreBelow(arg0 interface{}) *
 }
 
 // ReceivedPacket mocks base method
-func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 time.Time, arg2 bool) error {
-	ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2)
+func (m *MockReceivedPacketHandler) ReceivedPacket(arg0 protocol.PacketNumber, arg1 protocol.EncryptionLevel, arg2 time.Time, arg3 bool) error {
+	ret := m.ctrl.Call(m, "ReceivedPacket", arg0, arg1, arg2, arg3)
 	ret0, _ := ret[0].(error)
 	return ret0
 }
 
 // ReceivedPacket indicates an expected call of ReceivedPacket
-func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2 interface{}) *gomock.Call {
-	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2)
+func (mr *MockReceivedPacketHandlerMockRecorder) ReceivedPacket(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
+	return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReceivedPacket", reflect.TypeOf((*MockReceivedPacketHandler)(nil).ReceivedPacket), arg0, arg1, arg2, arg3)
 }

+ 12 - 0
vendor/github.com/lucas-clemente/quic-go/internal/utils/minmax.go

@@ -122,6 +122,18 @@ func MinTime(a, b time.Time) time.Time {
 	return a
 }
 
+// MinNonZeroTime returns the earlist time that is not time.Time{}
+// If both a and b are time.Time{}, it returns time.Time{}
+func MinNonZeroTime(a, b time.Time) time.Time {
+	if a.IsZero() {
+		return b
+	}
+	if b.IsZero() {
+		return a
+	}
+	return MinTime(a, b)
+}
+
 // MaxTime returns the later time
 func MaxTime(a, b time.Time) time.Time {
 	if a.After(b) {

+ 13 - 0
vendor/github.com/lucas-clemente/quic-go/multiplexer.go

@@ -15,6 +15,7 @@ var (
 
 type multiplexer interface {
 	AddConn(net.PacketConn, int) (packetHandlerManager, error)
+	RemoveConn(net.PacketConn) error
 }
 
 type connManager struct {
@@ -61,3 +62,15 @@ func (m *connMultiplexer) AddConn(c net.PacketConn, connIDLen int) (packetHandle
 	}
 	return p.manager, nil
 }
+
+func (m *connMultiplexer) RemoveConn(c net.PacketConn) error {
+	m.mutex.Lock()
+	defer m.mutex.Unlock()
+
+	if _, ok := m.conns[c]; !ok {
+		return fmt.Errorf("cannote remove connection, connection is unknown")
+	}
+
+	delete(m.conns, c)
+	return nil
+}

+ 1 - 1
vendor/github.com/lucas-clemente/quic-go/packet_handler_map.go

@@ -139,7 +139,7 @@ func (h *packetHandlerMap) close(e error) error {
 	}
 	h.mutex.Unlock()
 	wg.Wait()
-	return nil
+	return getMultiplexer().RemoveConn(h.conn)
 }
 
 func (h *packetHandlerMap) listen() {

+ 23 - 12
vendor/github.com/lucas-clemente/quic-go/packet_packer.go

@@ -90,7 +90,7 @@ type frameSource interface {
 }
 
 type ackFrameSource interface {
-	GetAckFrame() *wire.AckFrame
+	GetAckFrame(protocol.EncryptionLevel) *wire.AckFrame
 }
 
 type packetPacker struct {
@@ -155,7 +155,7 @@ func (p *packetPacker) PackConnectionClose(ccf *wire.ConnectionCloseFrame) (*pac
 }
 
 func (p *packetPacker) MaybePackAckPacket() (*packedPacket, error) {
-	ack := p.acks.GetAckFrame()
+	ack := p.acks.GetAckFrame(protocol.Encryption1RTT)
 	if ack == nil {
 		return nil, nil
 	}
@@ -285,30 +285,41 @@ func (p *packetPacker) PackPacket() (*packedPacket, error) {
 func (p *packetPacker) maybePackCryptoPacket() (*packedPacket, error) {
 	var s cryptoStream
 	var encLevel protocol.EncryptionLevel
-	if p.initialStream.HasData() {
+
+	hasData := p.initialStream.HasData()
+	ack := p.acks.GetAckFrame(protocol.EncryptionInitial)
+	if hasData || ack != nil {
 		s = p.initialStream
 		encLevel = protocol.EncryptionInitial
-	} else if p.handshakeStream.HasData() {
-		s = p.handshakeStream
-		encLevel = protocol.EncryptionHandshake
+	} else {
+		hasData = p.handshakeStream.HasData()
+		ack = p.acks.GetAckFrame(protocol.EncryptionHandshake)
+		if hasData || ack != nil {
+			s = p.handshakeStream
+			encLevel = protocol.EncryptionHandshake
+		}
 	}
 	if s == nil {
 		return nil, nil
 	}
-	hdr := p.getHeader(encLevel)
-	hdrLen := hdr.GetLength(p.version)
 	sealer, err := p.cryptoSetup.GetSealerWithEncryptionLevel(encLevel)
 	if err != nil {
+		// The sealer
 		return nil, err
 	}
+
+	hdr := p.getHeader(encLevel)
+	hdrLen := hdr.GetLength(p.version)
 	var length protocol.ByteCount
 	frames := make([]wire.Frame, 0, 2)
-	if ack := p.acks.GetAckFrame(); ack != nil {
+	if ack != nil {
 		frames = append(frames, ack)
 		length += ack.Length(p.version)
 	}
-	cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length)
-	frames = append(frames, cf)
+	if hasData {
+		cf := s.PopCryptoFrame(p.maxPacketSize - hdrLen - protocol.ByteCount(sealer.Overhead()) - length)
+		frames = append(frames, cf)
+	}
 	return p.writeAndSealPacket(hdr, frames, sealer)
 }
 
@@ -317,7 +328,7 @@ func (p *packetPacker) composeNextPacket(maxFrameSize protocol.ByteCount) ([]wir
 	var frames []wire.Frame
 
 	// ACKs need to go first, so that the sentPacketHandler will recognize them
-	if ack := p.acks.GetAckFrame(); ack != nil {
+	if ack := p.acks.GetAckFrame(protocol.Encryption1RTT); ack != nil {
 		frames = append(frames, ack)
 		length += ack.Length(p.version)
 	}

+ 4 - 2
vendor/github.com/lucas-clemente/quic-go/session.go

@@ -566,7 +566,7 @@ func (s *session) handleUnpackedPacket(packet *unpackedPacket, rcvTime time.Time
 		}
 	}
 
-	if err := s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, rcvTime, isRetransmittable); err != nil {
+	if err := s.receivedPacketHandler.ReceivedPacket(packet.packetNumber, packet.encryptionLevel, rcvTime, isRetransmittable); err != nil {
 		return err
 	}
 	return nil
@@ -726,7 +726,9 @@ func (s *session) handleAckFrame(frame *wire.AckFrame, pn protocol.PacketNumber,
 	if err := s.sentPacketHandler.ReceivedAck(frame, pn, encLevel, s.lastNetworkActivityTime); err != nil {
 		return err
 	}
-	s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
+	if encLevel == protocol.Encryption1RTT {
+		s.receivedPacketHandler.IgnoreBelow(s.sentPacketHandler.GetLowestPacketNotConfirmedAcked())
+	}
 	return nil
 }
 

+ 28 - 12
vendor/github.com/lucas-clemente/quic-go/vendor/github.com/marten-seemann/qtls/13.go

@@ -116,9 +116,9 @@ func (ks *keySchedule13) setSecret(secret []byte) {
 	salt := ks.secret
 	if salt != nil {
 		h0 := hash.New().Sum(nil)
-		salt = HkdfExpandLabel(hash, salt, h0, "derived", hash.Size())
+		salt = hkdfExpandLabel(hash, salt, h0, "derived", hash.Size())
 	}
-	ks.secret = HkdfExtract(hash, secret, salt)
+	ks.secret = hkdfExtract(hash, secret, salt)
 }
 
 // Depending on role returns pair of key variant to be used by
@@ -168,7 +168,7 @@ func (ks *keySchedule13) deriveSecret(secretLabel secretLabel) []byte {
 		ks.handshakeCtx = ks.transcriptHash.Sum(nil)
 	}
 	hash := hashForSuite(ks.suite)
-	secret := HkdfExpandLabel(hash, ks.secret, ks.handshakeCtx, label, hash.Size())
+	secret := hkdfExpandLabel(hash, ks.secret, ks.handshakeCtx, label, hash.Size())
 	if keylogType != "" && ks.config != nil {
 		ks.config.writeKeyLog(keylogType, ks.clientRandom, secret)
 	}
@@ -177,8 +177,8 @@ func (ks *keySchedule13) deriveSecret(secretLabel secretLabel) []byte {
 
 func (ks *keySchedule13) prepareCipher(trafficSecret []byte) cipher.AEAD {
 	hash := hashForSuite(ks.suite)
-	key := HkdfExpandLabel(hash, trafficSecret, nil, "key", ks.suite.keyLen)
-	iv := HkdfExpandLabel(hash, trafficSecret, nil, "iv", ks.suite.ivLen)
+	key := hkdfExpandLabel(hash, trafficSecret, nil, "key", ks.suite.keyLen)
+	iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", ks.suite.ivLen)
 	return ks.suite.aead(key, iv)
 }
 
@@ -254,10 +254,11 @@ CurvePreferenceLoop:
 	hs.keySchedule.setSecret(ecdheSecret)
 	hs.hsClientTrafficSecret = hs.keySchedule.deriveSecret(secretHandshakeClient)
 	hsServerTrafficSecret := hs.keySchedule.deriveSecret(secretHandshakeServer)
+	c.out.exportKey(hs.keySchedule.suite, hsServerTrafficSecret)
 	c.out.setKey(c.vers, hs.keySchedule.suite, hsServerTrafficSecret)
 
-	serverFinishedKey := HkdfExpandLabel(hash, hsServerTrafficSecret, nil, "finished", hashSize)
-	hs.clientFinishedKey = HkdfExpandLabel(hash, hs.hsClientTrafficSecret, nil, "finished", hashSize)
+	serverFinishedKey := hkdfExpandLabel(hash, hsServerTrafficSecret, nil, "finished", hashSize)
+	hs.clientFinishedKey = hkdfExpandLabel(hash, hs.hsClientTrafficSecret, nil, "finished", hashSize)
 
 	// EncryptedExtensions
 	hs.keySchedule.write(hs.hello13Enc.marshal())
@@ -296,6 +297,7 @@ CurvePreferenceLoop:
 
 	hs.keySchedule.setSecret(nil) // derive master secret
 	serverAppTrafficSecret := hs.keySchedule.deriveSecret(secretApplicationServer)
+	c.out.exportKey(hs.keySchedule.suite, serverAppTrafficSecret)
 	c.out.setKey(c.vers, hs.keySchedule.suite, serverAppTrafficSecret)
 
 	if c.hand.Len() > 0 {
@@ -303,9 +305,11 @@ CurvePreferenceLoop:
 	}
 	hs.appClientTrafficSecret = hs.keySchedule.deriveSecret(secretApplicationClient)
 	if hs.hello13Enc.earlyData {
+		c.in.exportKey(hs.keySchedule.suite, earlyClientTrafficSecret)
 		c.in.setKey(c.vers, hs.keySchedule.suite, earlyClientTrafficSecret)
 		c.phase = readingEarlyData
 	} else {
+		c.in.exportKey(hs.keySchedule.suite, hs.hsClientTrafficSecret)
 		c.in.setKey(c.vers, hs.keySchedule.suite, hs.hsClientTrafficSecret)
 		if hs.clientHello.earlyData {
 			c.phase = discardingEarlyData
@@ -418,6 +422,7 @@ func (hs *serverHandshakeState) readClientFinished13(hasConfirmLock bool) error
 	if c.hand.Len() > 0 {
 		return c.sendAlert(alertUnexpectedMessage)
 	}
+	c.in.exportKey(hs.keySchedule.suite, hs.appClientTrafficSecret)
 	c.in.setKey(c.vers, hs.keySchedule.suite, hs.appClientTrafficSecret)
 	c.in.traceErr, c.out.traceErr = nil, nil
 	c.phase = handshakeConfirmed
@@ -514,6 +519,7 @@ func (c *Conn) handleEndOfEarlyData() error {
 	}
 	c.hs.keySchedule.write(endOfEarlyData.marshal())
 	c.phase = waitingClientFinished
+	c.in.exportKey(c.hs.keySchedule.suite, c.hs.hsClientTrafficSecret)
 	c.in.setKey(c.vers, c.hs.keySchedule.suite, c.hs.hsClientTrafficSecret)
 	return nil
 }
@@ -618,6 +624,10 @@ func (c *Conn) deriveDHESecret(ks keyShare, secretKey []byte) []byte {
 
 // HkdfExpandLabel HKDF expands a label
 func HkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte {
+	return hkdfExpandLabel(hash, secret, hashValue, label, L)
+}
+
+func hkdfExpandLabel(hash crypto.Hash, secret, hashValue []byte, label string, L int) []byte {
 	prefix := "tls13 "
 	hkdfLabel := make([]byte, 4+len(prefix)+len(label)+len(hashValue))
 	hkdfLabel[0] = byte(L >> 8)
@@ -710,7 +720,7 @@ func (hs *serverHandshakeState) checkPSK() (isResumed bool, alert alert) {
 
 		hs.keySchedule.setSecret(s.pskSecret)
 		binderKey := hs.keySchedule.deriveSecret(secretResumptionPskBinder)
-		binderFinishedKey := HkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
+		binderFinishedKey := hkdfExpandLabel(hash, binderKey, nil, "finished", hashSize)
 		chHash := hash.New()
 		chHash.Write(hs.clientHello.rawTruncated)
 		expectedBinder := hmacOfSum(hash, chHash, binderFinishedKey)
@@ -781,7 +791,7 @@ func (hs *serverHandshakeState) sendSessionTicket13() error {
 		// tickets might have the same PSK which could be a problem if
 		// one of them is compromised.
 		ticketNonce := []byte{byte(i)}
-		sessionState.pskSecret = HkdfExpandLabel(hash, resumptionMasterSecret, ticketNonce, "resumption", hash.Size())
+		sessionState.pskSecret = hkdfExpandLabel(hash, resumptionMasterSecret, ticketNonce, "resumption", hash.Size())
 		ticket := sessionState.marshal()
 		var err error
 		if c.config.SessionTicketSealer != nil {
@@ -1006,13 +1016,17 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
 		c.sendAlert(alertUnexpectedMessage)
 		return errors.New("tls: unexpected data after Server Hello")
 	}
-	// Do not change the sender key yet, the server must authenticate first.
 	serverHandshakeSecret := hs.keySchedule.deriveSecret(secretHandshakeServer)
+	c.in.exportKey(hs.keySchedule.suite, serverHandshakeSecret)
+	// Already the sender key yet, when using an alternative record layer.
+	// QUIC needs the handshake write key in order to acknowlege Handshake packets.
+	c.out.exportKey(hs.keySchedule.suite, clientHandshakeSecret)
+	// Do not change the sender key yet, the server must authenticate first.
 	c.in.setKey(c.vers, hs.keySchedule.suite, serverHandshakeSecret)
 
 	// Calculate MAC key for Finished messages.
-	serverFinishedKey := HkdfExpandLabel(hash, serverHandshakeSecret, nil, "finished", hashSize)
-	clientFinishedKey := HkdfExpandLabel(hash, clientHandshakeSecret, nil, "finished", hashSize)
+	serverFinishedKey := hkdfExpandLabel(hash, serverHandshakeSecret, nil, "finished", hashSize)
+	clientFinishedKey := hkdfExpandLabel(hash, clientHandshakeSecret, nil, "finished", hashSize)
 
 	msg, err := c.readHandshake()
 	if err != nil {
@@ -1155,11 +1169,13 @@ func (hs *clientHandshakeState) doTLS13Handshake() error {
 
 	// Handshake done, set application traffic secret
 	// TODO store initial traffic secret key for KeyUpdate GH #85
+	c.out.exportKey(hs.keySchedule.suite, clientAppTrafficSecret)
 	c.out.setKey(c.vers, hs.keySchedule.suite, clientAppTrafficSecret)
 	if c.hand.Len() > 0 {
 		c.sendAlert(alertUnexpectedMessage)
 		return errors.New("tls: unexpected data after handshake")
 	}
+	c.in.exportKey(hs.keySchedule.suite, serverAppTrafficSecret)
 	c.in.setKey(c.vers, hs.keySchedule.suite, serverAppTrafficSecret)
 	return nil
 }

+ 8 - 3
vendor/github.com/lucas-clemente/quic-go/vendor/github.com/marten-seemann/qtls/conn.go

@@ -234,15 +234,20 @@ func (hc *halfConn) changeCipherSpec() error {
 	return nil
 }
 
-func (hc *halfConn) setKey(version uint16, suite *cipherSuite, trafficSecret []byte) {
+func (hc *halfConn) exportKey(suite *cipherSuite, trafficSecret []byte) {
 	if hc.setKeyCallback != nil {
 		hc.setKeyCallback(&CipherSuite{*suite}, trafficSecret)
+	}
+}
+
+func (hc *halfConn) setKey(version uint16, suite *cipherSuite, trafficSecret []byte) {
+	if hc.setKeyCallback != nil {
 		return
 	}
 	hc.version = version
 	hash := hashForSuite(suite)
-	key := HkdfExpandLabel(hash, trafficSecret, nil, "key", suite.keyLen)
-	iv := HkdfExpandLabel(hash, trafficSecret, nil, "iv", suite.ivLen)
+	key := hkdfExpandLabel(hash, trafficSecret, nil, "key", suite.keyLen)
+	iv := hkdfExpandLabel(hash, trafficSecret, nil, "iv", suite.ivLen)
 	hc.cipher = suite.aead(key, iv)
 	for i := range hc.seq {
 		hc.seq[i] = 0

+ 4 - 0
vendor/github.com/lucas-clemente/quic-go/vendor/github.com/marten-seemann/qtls/hkdf.go

@@ -47,6 +47,10 @@ func hkdfExpand(hash crypto.Hash, prk, info []byte, l int) []byte {
 
 // HkdfExtract generates a pseudorandom key for use with Expand from an input secret and an optional independent salt.
 func HkdfExtract(hash crypto.Hash, secret, salt []byte) []byte {
+	return hkdfExtract(hash, secret, salt)
+}
+
+func hkdfExtract(hash crypto.Hash, secret, salt []byte) []byte {
 	if salt == nil {
 		salt = make([]byte, hash.Size())
 	}