|
|
@@ -260,10 +260,12 @@ type Conn struct {
|
|
|
newCompressionWriter func(io.WriteCloser, int) io.WriteCloser
|
|
|
|
|
|
// Read fields
|
|
|
- reader io.ReadCloser // the current reader returned to the application
|
|
|
- readErr error
|
|
|
- br *bufio.Reader
|
|
|
- readRemaining int64 // bytes remaining in current frame.
|
|
|
+ reader io.ReadCloser // the current reader returned to the application
|
|
|
+ readErr error
|
|
|
+ br *bufio.Reader
|
|
|
+ // bytes remaining in current frame.
|
|
|
+ // set setReadRemaining to safely update this value and prevent overflow
|
|
|
+ readRemaining int64
|
|
|
readFinal bool // true the current message has more frames.
|
|
|
readLength int64 // Message size.
|
|
|
readLimit int64 // Maximum message size.
|
|
|
@@ -320,6 +322,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
|
|
|
return c
|
|
|
}
|
|
|
|
|
|
+// setReadRemaining tracks the number of bytes remaining on the connection. If n
|
|
|
+// overflows, an ErrReadLimit is returned.
|
|
|
+func (c *Conn) setReadRemaining(n int64) error {
|
|
|
+ if n < 0 {
|
|
|
+ return ErrReadLimit
|
|
|
+ }
|
|
|
+
|
|
|
+ c.readRemaining = n
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
// Subprotocol returns the negotiated protocol for the connection.
|
|
|
func (c *Conn) Subprotocol() string {
|
|
|
return c.subprotocol
|
|
|
@@ -770,7 +783,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
final := p[0]&finalBit != 0
|
|
|
frameType := int(p[0] & 0xf)
|
|
|
mask := p[1]&maskBit != 0
|
|
|
- c.readRemaining = int64(p[1] & 0x7f)
|
|
|
+ c.setReadRemaining(int64(p[1] & 0x7f))
|
|
|
|
|
|
c.readDecompress = false
|
|
|
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
|
|
|
@@ -804,7 +817,17 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
|
|
|
}
|
|
|
|
|
|
- // 3. Read and parse frame length.
|
|
|
+ // 3. Read and parse frame length as per
|
|
|
+ // https://tools.ietf.org/html/rfc6455#section-5.2
|
|
|
+ //
|
|
|
+ // The length of the "Payload data", in bytes: if 0-125, that is the payload
|
|
|
+ // length.
|
|
|
+ // - If 126, the following 2 bytes interpreted as a 16-bit unsigned
|
|
|
+ // integer are the payload length.
|
|
|
+ // - If 127, the following 8 bytes interpreted as
|
|
|
+ // a 64-bit unsigned integer (the most significant bit MUST be 0) are the
|
|
|
+ // payload length. Multibyte length quantities are expressed in network byte
|
|
|
+ // order.
|
|
|
|
|
|
switch c.readRemaining {
|
|
|
case 126:
|
|
|
@@ -812,13 +835,19 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
if err != nil {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
- c.readRemaining = int64(binary.BigEndian.Uint16(p))
|
|
|
+
|
|
|
+ if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
|
|
|
+ return noFrame, err
|
|
|
+ }
|
|
|
case 127:
|
|
|
p, err := c.read(8)
|
|
|
if err != nil {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
- c.readRemaining = int64(binary.BigEndian.Uint64(p))
|
|
|
+
|
|
|
+ if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
|
|
|
+ return noFrame, err
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
// 4. Handle frame masking.
|
|
|
@@ -841,6 +870,12 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
|
|
|
|
|
|
c.readLength += c.readRemaining
|
|
|
+ // Don't allow readLength to overflow in the presence of a large readRemaining
|
|
|
+ // counter.
|
|
|
+ if c.readLength < 0 {
|
|
|
+ return noFrame, ErrReadLimit
|
|
|
+ }
|
|
|
+
|
|
|
if c.readLimit > 0 && c.readLength > c.readLimit {
|
|
|
c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
|
|
|
return noFrame, ErrReadLimit
|
|
|
@@ -854,7 +889,7 @@ func (c *Conn) advanceFrame() (int, error) {
|
|
|
var payload []byte
|
|
|
if c.readRemaining > 0 {
|
|
|
payload, err = c.read(int(c.readRemaining))
|
|
|
- c.readRemaining = 0
|
|
|
+ c.setReadRemaining(0)
|
|
|
if err != nil {
|
|
|
return noFrame, err
|
|
|
}
|
|
|
@@ -927,6 +962,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
|
|
|
c.readErr = hideTempErr(err)
|
|
|
break
|
|
|
}
|
|
|
+
|
|
|
if frameType == TextMessage || frameType == BinaryMessage {
|
|
|
c.messageReader = &messageReader{c}
|
|
|
c.reader = c.messageReader
|
|
|
@@ -967,7 +1003,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
|
|
|
if c.isServer {
|
|
|
c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
|
|
|
}
|
|
|
- c.readRemaining -= int64(n)
|
|
|
+ rem := c.readRemaining
|
|
|
+ rem -= int64(n)
|
|
|
+ c.setReadRemaining(rem)
|
|
|
if c.readRemaining > 0 && c.readErr == io.EOF {
|
|
|
c.readErr = errUnexpectedEOF
|
|
|
}
|