Skip to content

Commit

Permalink
Restore async concurrency safety to websocket compressor (aio-libs#7865)
Browse files Browse the repository at this point in the history
Fixes aio-libs#7859

(cherry picked from commit 86a2396)
  • Loading branch information
bdraco committed Nov 24, 2023
1 parent 41a9f1f commit b548b4e
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGES/7865.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Restore async concurrency safety to websocket compressor
22 changes: 14 additions & 8 deletions aiohttp/compression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,25 @@ def __init__(
self._compressor = zlib.compressobj(
wbits=self._mode, strategy=strategy, level=level
)
self._compress_lock = asyncio.Lock()

def compress_sync(self, data: bytes) -> bytes:
return self._compressor.compress(data)

async def compress(self, data: bytes) -> bytes:
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_event_loop().run_in_executor(
self._executor, self.compress_sync, data
)
return self.compress_sync(data)
async with self._compress_lock:
# To ensure the stream is consistent in the event
# there are multiple writers, we need to lock
# the compressor so that only one writer can
# compress at a time.
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_event_loop().run_in_executor(
self._executor, self.compress_sync, data
)
return self.compress_sync(data)

def flush(self, mode: int = zlib.Z_FINISH) -> bytes:
return self._compressor.flush(mode)
Expand Down
26 changes: 16 additions & 10 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,21 +635,17 @@ async def _send_frame(
if (compress or self.compress) and opcode < 8:
if compress:
# Do not set self._compress if compressing is for this frame
compressobj = ZLibCompressor(
level=zlib.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
compressobj = self._make_compress_obj(compress)
else: # self.compress
if not self._compressobj:
self._compressobj = ZLibCompressor(
level=zlib.Z_BEST_SPEED,
wbits=-self.compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
self._compressobj = self._make_compress_obj(self.compress)
compressobj = self._compressobj

message = await compressobj.compress(message)
# Its critical that we do not return control to the event
# loop until we have finished sending all the compressed
# data. Otherwise we could end up mixing compressed frames
# if there are multiple coroutines compressing data.
message += compressobj.flush(
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
)
Expand Down Expand Up @@ -687,10 +683,20 @@ async def _send_frame(

self._output_size += len(header) + len(message)

# It is safe to return control to the event loop when using compression
# after this point as we have already sent or buffered all the data.

if self._output_size > self._limit:
self._output_size = 0
await self.protocol._drain_helper()

def _make_compress_obj(self, compress: int) -> ZLibCompressor:
return ZLibCompressor(
level=zlib.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)

def _write(self, data: bytes) -> None:
if self.transport is None or self.transport.is_closing():
raise ConnectionResetError("Cannot write to closing transport")
Expand Down
68 changes: 67 additions & 1 deletion tests/test_websocket_writer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# type: ignore
import asyncio
import random
from typing import Any, Callable
from unittest import mock

import pytest

from aiohttp.http import WebSocketWriter
from aiohttp import DataQueue, WSMessage
from aiohttp.http import WebSocketReader, WebSocketWriter
from aiohttp.test_utils import make_mocked_coro


Expand Down Expand Up @@ -104,3 +108,65 @@ async def test_send_compress_text_per_message(protocol, transport) -> None:
writer.transport.write.assert_called_with(b"\x81\x04text")
await writer.send(b"text", compress=15)
writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00")


@pytest.mark.parametrize(
("max_sync_chunk_size", "payload_point_generator"),
(
(16, lambda count: count),
(4096, lambda count: count),
(32, lambda count: 64 + count if count % 2 else count),
),
)
async def test_concurrent_messages(
protocol: Any,
transport: Any,
max_sync_chunk_size: int,
payload_point_generator: Callable[[int], int],
) -> None:
"""Ensure messages are compressed correctly when there are multiple concurrent writers.
This test generates is parametrized to
- Generate messages that are larger than patch
WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 16
where compression will run in the executor
- Generate messages that are smaller than patch
WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 4096
where compression will run in the event loop
- Interleave generated messages with a
WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 32
where compression will run in the event loop
and in the executor
"""
with mock.patch(
"aiohttp.http_websocket.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", max_sync_chunk_size
):
writer = WebSocketWriter(protocol, transport, compress=15)
queue: DataQueue[WSMessage] = DataQueue(asyncio.get_running_loop())
reader = WebSocketReader(queue, 50000)
writers = []
payloads = []
for count in range(1, 64 + 1):
point = payload_point_generator(count)
payload = bytes((point,)) * point
payloads.append(payload)
writers.append(writer.send(payload, binary=True))
await asyncio.gather(*writers)

for call in writer.transport.write.call_args_list:
call_bytes = call[0][0]
result, _ = reader.feed_data(call_bytes)
assert result is False
msg = await queue.read()
bytes_data: bytes = msg.data
first_char = bytes_data[0:1]
char_val = ord(first_char)
assert len(bytes_data) == char_val
# If we have a concurrency problem, the data
# tends to get mixed up between messages so
# we want to validate that all the bytes are
# the same value
assert bytes_data == bytes_data[0:1] * char_val

0 comments on commit b548b4e

Please sign in to comment.