From 5a564bd9ec8f3ced5e0a750f701f5b7374971178 Mon Sep 17 00:00:00 2001 From: Khafra Date: Mon, 20 May 2024 01:21:52 -0400 Subject: [PATCH] permessage-deflate decompression support in websocket (#3263) * handshake * fixup * +85 autobahn test passes * fixup * fixup * fixup * fixup * fixup * fixup * fixup * fixup * fixup * fixup * fixup * fixup * add basic send queue * fixup * fixup * fixup * fix EVERY FAILURE!!!! --- lib/web/fetch/data-url.js | 1 + lib/web/websocket/connection.js | 27 +++++--- lib/web/websocket/constants.js | 10 ++- lib/web/websocket/permessage-deflate.js | 70 ++++++++++++++++++++ lib/web/websocket/receiver.js | 79 ++++++++++++++++++----- lib/web/websocket/sender.js | 85 +++++++++++++++++++++++++ lib/web/websocket/util.js | 47 +++++++++++++- lib/web/websocket/websocket.js | 75 ++++++++-------------- test/autobahn/client.js | 1 + test/websocket/send-mutable.js | 34 ++++++++++ 10 files changed, 355 insertions(+), 74 deletions(-) create mode 100644 lib/web/websocket/permessage-deflate.js create mode 100644 lib/web/websocket/sender.js create mode 100644 test/websocket/send-mutable.js diff --git a/lib/web/fetch/data-url.js b/lib/web/fetch/data-url.js index 3f42e2eb6b2..7a74db6bde8 100644 --- a/lib/web/fetch/data-url.js +++ b/lib/web/fetch/data-url.js @@ -737,6 +737,7 @@ module.exports = { collectAnHTTPQuotedString, serializeAMimeType, removeChars, + removeHTTPWhitespace, minimizeSupportedMimeType, HTTP_TOKEN_CODEPOINTS, isomorphicDecode diff --git a/lib/web/websocket/connection.js b/lib/web/websocket/connection.js index 664fc3f0780..bb87d361e4b 100644 --- a/lib/web/websocket/connection.js +++ b/lib/web/websocket/connection.js @@ -8,7 +8,7 @@ const { kReceivedClose, kResponse } = require('./symbols') -const { fireEvent, failWebsocketConnection, isClosing, isClosed, isEstablished } = require('./util') +const { fireEvent, failWebsocketConnection, isClosing, isClosed, isEstablished, parseExtensions } = require('./util') const { channels } = require('../../core/diagnostics') const { CloseEvent } = require('./events') const { makeRequest } = require('../fetch/request') @@ -31,7 +31,7 @@ try { * @param {URL} url * @param {string|string[]} protocols * @param {import('./websocket').WebSocket} ws - * @param {(response: any) => void} onEstablish + * @param {(response: any, extensions: string[] | undefined) => void} onEstablish * @param {Partial} options */ function establishWebSocketConnection (url, protocols, client, ws, onEstablish, options) { @@ -91,12 +91,11 @@ function establishWebSocketConnection (url, protocols, client, ws, onEstablish, // 9. Let permessageDeflate be a user-agent defined // "permessage-deflate" extension header value. // https://github.com/mozilla/gecko-dev/blob/ce78234f5e653a5d3916813ff990f053510227bc/netwerk/protocol/websocket/WebSocketChannel.cpp#L2673 - // TODO: enable once permessage-deflate is supported - const permessageDeflate = '' // 'permessage-deflate; 15' + const permessageDeflate = 'permessage-deflate; client_max_window_bits' // 10. Append (`Sec-WebSocket-Extensions`, permessageDeflate) to // request’s header list. - // request.headersList.append('sec-websocket-extensions', permessageDeflate) + request.headersList.append('sec-websocket-extensions', permessageDeflate) // 11. Fetch request with useParallelQueue set to true, and // processResponse given response being these steps: @@ -167,10 +166,15 @@ function establishWebSocketConnection (url, protocols, client, ws, onEstablish, // header field to determine which extensions are requested is // discussed in Section 9.1.) const secExtension = response.headersList.get('Sec-WebSocket-Extensions') + let extensions - if (secExtension !== null && secExtension !== permessageDeflate) { - failWebsocketConnection(ws, 'Received different permessage-deflate than the one set.') - return + if (secExtension !== null) { + extensions = parseExtensions(secExtension) + + if (!extensions.has('permessage-deflate')) { + failWebsocketConnection(ws, 'Sec-WebSocket-Extensions header does not match.') + return + } } // 6. If the response includes a |Sec-WebSocket-Protocol| header field @@ -206,7 +210,7 @@ function establishWebSocketConnection (url, protocols, client, ws, onEstablish, }) } - onEstablish(response) + onEstablish(response, extensions) } }) @@ -290,6 +294,11 @@ function onSocketData (chunk) { */ function onSocketClose () { const { ws } = this + const { [kResponse]: response } = ws + + response.socket.off('data', onSocketData) + response.socket.off('close', onSocketClose) + response.socket.off('error', onSocketError) // If the TCP connection was closed after the // WebSocket closing handshake was completed, the WebSocket connection diff --git a/lib/web/websocket/constants.js b/lib/web/websocket/constants.js index d5de91460f5..2019b5b67a7 100644 --- a/lib/web/websocket/constants.js +++ b/lib/web/websocket/constants.js @@ -46,6 +46,13 @@ const parserStates = { const emptyBuffer = Buffer.allocUnsafe(0) +const sendHints = { + string: 1, + typedArray: 2, + arrayBuffer: 3, + blob: 4 +} + module.exports = { uid, sentCloseFrameState, @@ -54,5 +61,6 @@ module.exports = { opcodes, maxUnsigned16Bit, parserStates, - emptyBuffer + emptyBuffer, + sendHints } diff --git a/lib/web/websocket/permessage-deflate.js b/lib/web/websocket/permessage-deflate.js new file mode 100644 index 00000000000..76cb366d5e5 --- /dev/null +++ b/lib/web/websocket/permessage-deflate.js @@ -0,0 +1,70 @@ +'use strict' + +const { createInflateRaw, Z_DEFAULT_WINDOWBITS } = require('node:zlib') +const { isValidClientWindowBits } = require('./util') + +const tail = Buffer.from([0x00, 0x00, 0xff, 0xff]) +const kBuffer = Symbol('kBuffer') +const kLength = Symbol('kLength') + +class PerMessageDeflate { + /** @type {import('node:zlib').InflateRaw} */ + #inflate + + #options = {} + + constructor (extensions) { + this.#options.serverNoContextTakeover = extensions.has('server_no_context_takeover') + this.#options.serverMaxWindowBits = extensions.get('server_max_window_bits') + } + + decompress (chunk, fin, callback) { + // An endpoint uses the following algorithm to decompress a message. + // 1. Append 4 octets of 0x00 0x00 0xff 0xff to the tail end of the + // payload of the message. + // 2. Decompress the resulting data using DEFLATE. + + if (!this.#inflate) { + let windowBits = Z_DEFAULT_WINDOWBITS + + if (this.#options.serverMaxWindowBits) { // empty values default to Z_DEFAULT_WINDOWBITS + if (!isValidClientWindowBits(this.#options.serverMaxWindowBits)) { + callback(new Error('Invalid server_max_window_bits')) + return + } + + windowBits = Number.parseInt(this.#options.serverMaxWindowBits) + } + + this.#inflate = createInflateRaw({ windowBits }) + this.#inflate[kBuffer] = [] + this.#inflate[kLength] = 0 + + this.#inflate.on('data', (data) => { + this.#inflate[kBuffer].push(data) + this.#inflate[kLength] += data.length + }) + + this.#inflate.on('error', (err) => { + this.#inflate = null + callback(err) + }) + } + + this.#inflate.write(chunk) + if (fin) { + this.#inflate.write(tail) + } + + this.#inflate.flush(() => { + const full = Buffer.concat(this.#inflate[kBuffer], this.#inflate[kLength]) + + this.#inflate[kBuffer].length = 0 + this.#inflate[kLength] = 0 + + callback(null, full) + }) + } +} + +module.exports = { PerMessageDeflate } diff --git a/lib/web/websocket/receiver.js b/lib/web/websocket/receiver.js index 85b6edf649c..3a8b2abb611 100644 --- a/lib/web/websocket/receiver.js +++ b/lib/web/websocket/receiver.js @@ -17,6 +17,7 @@ const { } = require('./util') const { WebsocketFrameSend } = require('./frame') const { closeWebSocketConnection } = require('./connection') +const { PerMessageDeflate } = require('./permessage-deflate') // This code was influenced by ws released under the MIT license. // Copyright (c) 2011 Einar Otto Stangvik @@ -33,10 +34,18 @@ class ByteParser extends Writable { #info = {} #fragments = [] - constructor (ws) { + /** @type {Map} */ + #extensions + + constructor (ws, extensions) { super() this.ws = ws + this.#extensions = extensions == null ? new Map() : extensions + + if (this.#extensions.has('permessage-deflate')) { + this.#extensions.set('permessage-deflate', new PerMessageDeflate(extensions)) + } } /** @@ -91,7 +100,16 @@ class ByteParser extends Writable { // the negotiated extensions defines the meaning of such a nonzero // value, the receiving endpoint MUST _Fail the WebSocket // Connection_. - if (rsv1 !== 0 || rsv2 !== 0 || rsv3 !== 0) { + // This document allocates the RSV1 bit of the WebSocket header for + // PMCEs and calls the bit the "Per-Message Compressed" bit. On a + // WebSocket connection where a PMCE is in use, this bit indicates + // whether a message is compressed or not. + if (rsv1 !== 0 && !this.#extensions.has('permessage-deflate')) { + failWebsocketConnection(this.ws, 'Expected RSV1 to be clear.') + return + } + + if (rsv2 !== 0 || rsv3 !== 0) { failWebsocketConnection(this.ws, 'RSV1, RSV2, RSV3 must be clear') return } @@ -122,7 +140,7 @@ class ByteParser extends Writable { return } - if (isContinuationFrame(opcode) && this.#fragments.length === 0) { + if (isContinuationFrame(opcode) && this.#fragments.length === 0 && !this.#info.compressed) { failWebsocketConnection(this.ws, 'Unexpected continuation frame') return } @@ -138,6 +156,7 @@ class ByteParser extends Writable { if (isTextBinaryFrame(opcode)) { this.#info.binaryType = opcode + this.#info.compressed = rsv1 !== 0 } this.#info.opcode = opcode @@ -185,21 +204,50 @@ class ByteParser extends Writable { if (isControlFrame(this.#info.opcode)) { this.#loop = this.parseControlFrame(body) + this.#state = parserStates.INFO } else { - this.#fragments.push(body) - - // If the frame is not fragmented, a message has been received. - // If the frame is fragmented, it will terminate with a fin bit set - // and an opcode of 0 (continuation), therefore we handle that when - // parsing continuation frames, not here. - if (!this.#info.fragmented && this.#info.fin) { - const fullMessage = Buffer.concat(this.#fragments) - websocketMessageReceived(this.ws, this.#info.binaryType, fullMessage) - this.#fragments.length = 0 + if (!this.#info.compressed) { + this.#fragments.push(body) + + // If the frame is not fragmented, a message has been received. + // If the frame is fragmented, it will terminate with a fin bit set + // and an opcode of 0 (continuation), therefore we handle that when + // parsing continuation frames, not here. + if (!this.#info.fragmented && this.#info.fin) { + const fullMessage = Buffer.concat(this.#fragments) + websocketMessageReceived(this.ws, this.#info.binaryType, fullMessage) + this.#fragments.length = 0 + } + + this.#state = parserStates.INFO + } else { + this.#extensions.get('permessage-deflate').decompress(body, this.#info.fin, (error, data) => { + if (error) { + closeWebSocketConnection(this.ws, 1007, error.message, error.message.length) + return + } + + this.#fragments.push(data) + + if (!this.#info.fin) { + this.#state = parserStates.INFO + this.#loop = true + this.run(callback) + return + } + + websocketMessageReceived(this.ws, this.#info.binaryType, Buffer.concat(this.#fragments)) + + this.#loop = true + this.#state = parserStates.INFO + this.run(callback) + this.#fragments.length = 0 + }) + + this.#loop = false + break } } - - this.#state = parserStates.INFO } } } @@ -333,7 +381,6 @@ class ByteParser extends Writable { this.ws[kReadyState] = states.CLOSING this.ws[kReceivedClose] = true - this.end() return false } else if (opcode === opcodes.PING) { // Upon receipt of a Ping frame, an endpoint MUST send a Pong frame in diff --git a/lib/web/websocket/sender.js b/lib/web/websocket/sender.js new file mode 100644 index 00000000000..b9fc7a72364 --- /dev/null +++ b/lib/web/websocket/sender.js @@ -0,0 +1,85 @@ +'use strict' + +const { WebsocketFrameSend } = require('./frame') +const { opcodes, sendHints } = require('./constants') + +/** @type {Uint8Array} */ +const FastBuffer = Buffer[Symbol.species] + +class SendQueue { + #queued = new Set() + #size = 0 + + /** @type {import('net').Socket} */ + #socket + + constructor (socket) { + this.#socket = socket + } + + add (item, cb, hint) { + if (hint !== sendHints.blob) { + const data = clone(item, hint) + + if (this.#size === 0) { + this.#dispatch(data, cb, hint) + } else { + this.#queued.add([data, cb, true, hint]) + this.#size++ + + this.#run() + } + + return + } + + const promise = item.arrayBuffer() + const queue = [null, cb, false, hint] + promise.then((ab) => { + queue[0] = clone(ab, hint) + queue[2] = true + + this.#run() + }) + + this.#queued.add(queue) + this.#size++ + } + + #run () { + for (const queued of this.#queued) { + const [data, cb, done, hint] = queued + + if (!done) return + + this.#queued.delete(queued) + this.#size-- + + this.#dispatch(data, cb, hint) + } + } + + #dispatch (data, cb, hint) { + const frame = new WebsocketFrameSend() + const opcode = hint === sendHints.string ? opcodes.TEXT : opcodes.BINARY + + frame.frameData = data + const buffer = frame.createFrame(opcode) + + this.#socket.write(buffer, cb) + } +} + +function clone (data, hint) { + switch (hint) { + case sendHints.string: + return Buffer.from(data) + case sendHints.arrayBuffer: + case sendHints.blob: + return new FastBuffer(data) + case sendHints.typedArray: + return Buffer.copyBytesFrom(data) + } +} + +module.exports = { SendQueue } diff --git a/lib/web/websocket/util.js b/lib/web/websocket/util.js index ea5b29d3549..e5ce7899752 100644 --- a/lib/web/websocket/util.js +++ b/lib/web/websocket/util.js @@ -4,6 +4,7 @@ const { kReadyState, kController, kResponse, kBinaryType, kWebSocketURL } = requ const { states, opcodes } = require('./constants') const { ErrorEvent, createFastMessageEvent } = require('./events') const { isUtf8 } = require('node:buffer') +const { collectASequenceOfCodePointsFast, removeHTTPWhitespace } = require('../fetch/data-url') /* globals Blob */ @@ -234,6 +235,48 @@ function isValidOpcode (opcode) { return isTextBinaryFrame(opcode) || isContinuationFrame(opcode) || isControlFrame(opcode) } +/** + * Parses a Sec-WebSocket-Extensions header value. + * @param {string} extensions + * @returns {Map} + */ +// TODO(@Uzlopak, @KhafraDev): make compliant https://datatracker.ietf.org/doc/html/rfc6455#section-9.1 +function parseExtensions (extensions) { + const position = { position: 0 } + const extensionList = new Map() + + while (position.position < extensions.length) { + const pair = collectASequenceOfCodePointsFast(';', extensions, position) + const [name, value = ''] = pair.split('=') + + extensionList.set( + removeHTTPWhitespace(name, true, false), + removeHTTPWhitespace(value, false, true) + ) + + position.position++ + } + + return extensionList +} + +/** + * @see https://www.rfc-editor.org/rfc/rfc7692#section-7.1.2.2 + * @description "client-max-window-bits = 1*DIGIT" + * @param {string} value + */ +function isValidClientWindowBits (value) { + for (let i = 0; i < value.length; i++) { + const byte = value.charCodeAt(i) + + if (byte < 0x30 || byte > 0x39) { + return false + } + } + + return true +} + // https://nodejs.org/api/intl.html#detecting-internationalization-support const hasIntl = typeof process.versions.icu === 'string' const fatalDecoder = hasIntl ? new TextDecoder('utf-8', { fatal: true }) : undefined @@ -265,5 +308,7 @@ module.exports = { isControlFrame, isContinuationFrame, isTextBinaryFrame, - isValidOpcode + isValidOpcode, + parseExtensions, + isValidClientWindowBits } diff --git a/lib/web/websocket/websocket.js b/lib/web/websocket/websocket.js index 7b62dde43c6..83d4ee94e30 100644 --- a/lib/web/websocket/websocket.js +++ b/lib/web/websocket/websocket.js @@ -3,7 +3,7 @@ const { webidl } = require('../fetch/webidl') const { URLSerializer } = require('../fetch/data-url') const { environmentSettingsObject } = require('../fetch/util') -const { staticPropertyDescriptors, states, sentCloseFrameState, opcodes } = require('./constants') +const { staticPropertyDescriptors, states, sentCloseFrameState, sendHints } = require('./constants') const { kWebSocketURL, kReadyState, @@ -21,17 +21,15 @@ const { fireEvent } = require('./util') const { establishWebSocketConnection, closeWebSocketConnection } = require('./connection') -const { WebsocketFrameSend } = require('./frame') const { ByteParser } = require('./receiver') const { kEnumerableProperty, isBlobLike } = require('../../core/util') const { getGlobalDispatcher } = require('../../global') const { types } = require('node:util') const { ErrorEvent, CloseEvent } = require('./events') +const { SendQueue } = require('./sender') let experimentalWarned = false -const FastBuffer = Buffer[Symbol.species] - // https://websockets.spec.whatwg.org/#interface-definition class WebSocket extends EventTarget { #events = { @@ -45,6 +43,9 @@ class WebSocket extends EventTarget { #protocol = '' #extensions = '' + /** @type {SendQueue} */ + #sendQueue + /** * @param {string} url * @param {string|string[]} protocols @@ -135,7 +136,7 @@ class WebSocket extends EventTarget { protocols, client, this, - (response) => this.#onConnectionEstablished(response), + (response, extensions) => this.#onConnectionEstablished(response, extensions), options ) @@ -229,9 +230,6 @@ class WebSocket extends EventTarget { return } - /** @type {import('stream').Duplex} */ - const socket = this[kResponse].socket - // If data is a string if (typeof data === 'string') { // If the WebSocket connection is established and the WebSocket @@ -245,14 +243,12 @@ class WebSocket extends EventTarget { // the bufferedAmount attribute by the number of bytes needed to // express the argument as UTF-8. - const value = Buffer.from(data) - const frame = new WebsocketFrameSend(value) - const buffer = frame.createFrame(opcodes.TEXT) + const length = Buffer.byteLength(data) - this.#bufferedAmount += value.byteLength - socket.write(buffer, () => { - this.#bufferedAmount -= value.byteLength - }) + this.#bufferedAmount += length + this.#sendQueue.add(data, () => { + this.#bufferedAmount -= length + }, sendHints.string) } else if (types.isArrayBuffer(data)) { // If the WebSocket connection is established, and the WebSocket // closing handshake has not yet started, then the user agent must @@ -266,14 +262,10 @@ class WebSocket extends EventTarget { // increase the bufferedAmount attribute by the length of the // ArrayBuffer in bytes. - const value = new FastBuffer(data) - const frame = new WebsocketFrameSend(value) - const buffer = frame.createFrame(opcodes.BINARY) - - this.#bufferedAmount += value.byteLength - socket.write(buffer, () => { - this.#bufferedAmount -= value.byteLength - }) + this.#bufferedAmount += data.byteLength + this.#sendQueue.add(data, () => { + this.#bufferedAmount -= data.byteLength + }, sendHints.arrayBuffer) } else if (ArrayBuffer.isView(data)) { // If the WebSocket connection is established, and the WebSocket // closing handshake has not yet started, then the user agent must @@ -287,15 +279,10 @@ class WebSocket extends EventTarget { // not throw an exception must increase the bufferedAmount attribute // by the length of data’s buffer in bytes. - const ab = new FastBuffer(data.buffer, data.byteOffset, data.byteLength) - - const frame = new WebsocketFrameSend(ab) - const buffer = frame.createFrame(opcodes.BINARY) - - this.#bufferedAmount += ab.byteLength - socket.write(buffer, () => { - this.#bufferedAmount -= ab.byteLength - }) + this.#bufferedAmount += data.byteLength + this.#sendQueue.add(data, () => { + this.#bufferedAmount -= data.byteLength + }, sendHints.typedArray) } else if (isBlobLike(data)) { // If the WebSocket connection is established, and the WebSocket // closing handshake has not yet started, then the user agent must @@ -308,18 +295,10 @@ class WebSocket extends EventTarget { // an exception must increase the bufferedAmount attribute by the size // of the Blob object’s raw data, in bytes. - const frame = new WebsocketFrameSend() - - data.arrayBuffer().then((ab) => { - const value = new FastBuffer(ab) - frame.frameData = value - const buffer = frame.createFrame(opcodes.BINARY) - - this.#bufferedAmount += value.byteLength - socket.write(buffer, () => { - this.#bufferedAmount -= value.byteLength - }) - }) + this.#bufferedAmount += data.size + this.#sendQueue.add(data, () => { + this.#bufferedAmount -= data.size + }, sendHints.blob) } } @@ -458,18 +437,20 @@ class WebSocket extends EventTarget { /** * @see https://websockets.spec.whatwg.org/#feedback-from-the-protocol */ - #onConnectionEstablished (response) { + #onConnectionEstablished (response, parsedExtensions) { // processResponse is called when the "response’s header list has been received and initialized." // once this happens, the connection is open this[kResponse] = response - const parser = new ByteParser(this) + const parser = new ByteParser(this, parsedExtensions) parser.on('drain', onParserDrain) parser.on('error', onParserError.bind(this)) response.socket.ws = this this[kByteParser] = parser + this.#sendQueue = new SendQueue(response.socket) + // 1. Change the ready state to OPEN (1). this[kReadyState] = states.OPEN @@ -558,7 +539,7 @@ webidl.converters.WebSocketInit = webidl.dictionaryConverter([ }, { key: 'dispatcher', - converter: (V) => V, + converter: webidl.converters.any, defaultValue: () => getGlobalDispatcher() }, { diff --git a/test/autobahn/client.js b/test/autobahn/client.js index 41bf1d61063..53bc17e722c 100644 --- a/test/autobahn/client.js +++ b/test/autobahn/client.js @@ -12,6 +12,7 @@ function nextTest () { if (currentTest > testCount) { ws = new WebSocket(`${autobahnFuzzingserverUrl}/updateReports?agent=undici`) + ws.addEventListener('close', () => require('./report')) return } diff --git a/test/websocket/send-mutable.js b/test/websocket/send-mutable.js new file mode 100644 index 00000000000..fa1cc86aecc --- /dev/null +++ b/test/websocket/send-mutable.js @@ -0,0 +1,34 @@ +'use strict' + +const { test } = require('node:test') +const { WebSocketServer } = require('ws') +const { WebSocket } = require('../..') +const { tspl } = require('@matteo.collina/tspl') + +test('check cloned', async (t) => { + const assert = tspl(t, { plan: 2 }) + + const server = new WebSocketServer({ port: 0 }) + const buffer = new Uint8Array([0x61]) + + server.on('connection', (ws) => { + ws.on('message', (data) => { + assert.deepStrictEqual(data, Buffer.from([0x61])) + }) + }) + + const ws = new WebSocket(`ws://localhost:${server.address().port}`) + + ws.addEventListener('open', () => { + ws.send(new Blob([buffer])) + ws.send(buffer) + buffer[0] = 1 + }) + + t.after(() => { + server.close() + ws.close() + }) + + await assert.completed +})