Browse Source

Merge pull request #1950 from keepalivesrc/patch-2

Websocket Read Limit Fix
Kslr 6 years ago
parent
commit
31a647bcf0
1 changed files with 48 additions and 10 deletions
  1. 48 10
      external/github.com/gorilla/websocket/conn.go

+ 48 - 10
external/github.com/gorilla/websocket/conn.go

@@ -260,10 +260,12 @@ type Conn struct {
 	newCompressionWriter   func(io.WriteCloser, int) io.WriteCloser
 
 	// Read fields
-	reader        io.ReadCloser // the current reader returned to the application
-	readErr       error
-	br            *bufio.Reader
-	readRemaining int64 // bytes remaining in current frame.
+	reader  io.ReadCloser // the current reader returned to the application
+	readErr error
+	br      *bufio.Reader
+	// bytes remaining in current frame.
+	// set setReadRemaining to safely update this value and prevent overflow
+	readRemaining int64
 	readFinal     bool  // true the current message has more frames.
 	readLength    int64 // Message size.
 	readLimit     int64 // Maximum message size.
@@ -320,6 +322,17 @@ func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int,
 	return c
 }
 
+// setReadRemaining tracks the number of bytes remaining on the connection. If n
+// overflows, an ErrReadLimit is returned.
+func (c *Conn) setReadRemaining(n int64) error {
+	if n < 0 {
+		return ErrReadLimit
+	}
+
+	c.readRemaining = n
+	return nil
+}
+
 // Subprotocol returns the negotiated protocol for the connection.
 func (c *Conn) Subprotocol() string {
 	return c.subprotocol
@@ -770,7 +783,7 @@ func (c *Conn) advanceFrame() (int, error) {
 	final := p[0]&finalBit != 0
 	frameType := int(p[0] & 0xf)
 	mask := p[1]&maskBit != 0
-	c.readRemaining = int64(p[1] & 0x7f)
+	c.setReadRemaining(int64(p[1] & 0x7f))
 
 	c.readDecompress = false
 	if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
@@ -804,7 +817,17 @@ func (c *Conn) advanceFrame() (int, error) {
 		return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
 	}
 
-	// 3. Read and parse frame length.
+	// 3. Read and parse frame length as per
+	// https://tools.ietf.org/html/rfc6455#section-5.2
+	//
+	// The length of the "Payload data", in bytes: if 0-125, that is the payload
+	// length.
+	// - If 126, the following 2 bytes interpreted as a 16-bit unsigned
+	// integer are the payload length.
+	// - If 127, the following 8 bytes interpreted as
+	// a 64-bit unsigned integer (the most significant bit MUST be 0) are the
+	// payload length. Multibyte length quantities are expressed in network byte
+	// order.
 
 	switch c.readRemaining {
 	case 126:
@@ -812,13 +835,19 @@ func (c *Conn) advanceFrame() (int, error) {
 		if err != nil {
 			return noFrame, err
 		}
-		c.readRemaining = int64(binary.BigEndian.Uint16(p))
+
+		if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil {
+			return noFrame, err
+		}
 	case 127:
 		p, err := c.read(8)
 		if err != nil {
 			return noFrame, err
 		}
-		c.readRemaining = int64(binary.BigEndian.Uint64(p))
+
+		if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil {
+			return noFrame, err
+		}
 	}
 
 	// 4. Handle frame masking.
@@ -841,6 +870,12 @@ func (c *Conn) advanceFrame() (int, error) {
 	if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage {
 
 		c.readLength += c.readRemaining
+		// Don't allow readLength to overflow in the presence of a large readRemaining
+		// counter.
+		if c.readLength < 0 {
+			return noFrame, ErrReadLimit
+		}
+
 		if c.readLimit > 0 && c.readLength > c.readLimit {
 			c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait))
 			return noFrame, ErrReadLimit
@@ -854,7 +889,7 @@ func (c *Conn) advanceFrame() (int, error) {
 	var payload []byte
 	if c.readRemaining > 0 {
 		payload, err = c.read(int(c.readRemaining))
-		c.readRemaining = 0
+		c.setReadRemaining(0)
 		if err != nil {
 			return noFrame, err
 		}
@@ -927,6 +962,7 @@ func (c *Conn) NextReader() (messageType int, r io.Reader, err error) {
 			c.readErr = hideTempErr(err)
 			break
 		}
+
 		if frameType == TextMessage || frameType == BinaryMessage {
 			c.messageReader = &messageReader{c}
 			c.reader = c.messageReader
@@ -967,7 +1003,9 @@ func (r *messageReader) Read(b []byte) (int, error) {
 			if c.isServer {
 				c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n])
 			}
-			c.readRemaining -= int64(n)
+			rem := c.readRemaining
+			rem -= int64(n)
+			c.setReadRemaining(rem)
 			if c.readRemaining > 0 && c.readErr == io.EOF {
 				c.readErr = errUnexpectedEOF
 			}