Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore async concurrency safety to websocket compressor (#7865) #7890

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/7865.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Restore async concurrency safety to websocket compressor
22 changes: 14 additions & 8 deletions aiohttp/compression_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,25 @@ def __init__(
self._compressor = zlib.compressobj(
wbits=self._mode, strategy=strategy, level=level
)
self._compress_lock = asyncio.Lock()

def compress_sync(self, data: bytes) -> bytes:
return self._compressor.compress(data)

async def compress(self, data: bytes) -> bytes:
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_event_loop().run_in_executor(
self._executor, self.compress_sync, data
)
return self.compress_sync(data)
async with self._compress_lock:
# To ensure the stream is consistent in the event
# there are multiple writers, we need to lock
# the compressor so that only one writer can
# compress at a time.
if (
self._max_sync_chunk_size is not None
and len(data) > self._max_sync_chunk_size
):
return await asyncio.get_event_loop().run_in_executor(
self._executor, self.compress_sync, data
)
return self.compress_sync(data)

def flush(self, mode: int = zlib.Z_FINISH) -> bytes:
return self._compressor.flush(mode)
Expand Down
26 changes: 16 additions & 10 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,21 +635,17 @@ async def _send_frame(
if (compress or self.compress) and opcode < 8:
if compress:
# Do not set self._compress if compressing is for this frame
compressobj = ZLibCompressor(
level=zlib.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
compressobj = self._make_compress_obj(compress)
else: # self.compress
if not self._compressobj:
self._compressobj = ZLibCompressor(
level=zlib.Z_BEST_SPEED,
wbits=-self.compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)
self._compressobj = self._make_compress_obj(self.compress)
compressobj = self._compressobj

message = await compressobj.compress(message)
# Its critical that we do not return control to the event
# loop until we have finished sending all the compressed
# data. Otherwise we could end up mixing compressed frames
# if there are multiple coroutines compressing data.
message += compressobj.flush(
zlib.Z_FULL_FLUSH if self.notakeover else zlib.Z_SYNC_FLUSH
)
Expand Down Expand Up @@ -687,10 +683,20 @@ async def _send_frame(

self._output_size += len(header) + len(message)

# It is safe to return control to the event loop when using compression
# after this point as we have already sent or buffered all the data.

if self._output_size > self._limit:
self._output_size = 0
await self.protocol._drain_helper()

def _make_compress_obj(self, compress: int) -> ZLibCompressor:
return ZLibCompressor(
level=zlib.Z_BEST_SPEED,
wbits=-compress,
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
)

def _write(self, data: bytes) -> None:
if self.transport is None or self.transport.is_closing():
raise ConnectionResetError("Cannot write to closing transport")
Expand Down
67 changes: 66 additions & 1 deletion tests/test_websocket_writer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import random
from typing import Any, Callable
from unittest import mock

import pytest

from aiohttp.http import WebSocketWriter
from aiohttp import DataQueue, WSMessage
from aiohttp.http import WebSocketReader, WebSocketWriter
from aiohttp.test_utils import make_mocked_coro


Expand Down Expand Up @@ -104,3 +107,65 @@ async def test_send_compress_text_per_message(protocol, transport) -> None:
writer.transport.write.assert_called_with(b"\x81\x04text")
await writer.send(b"text", compress=15)
writer.transport.write.assert_called_with(b"\xc1\x06*I\xad(\x01\x00")


@pytest.mark.parametrize(
("max_sync_chunk_size", "payload_point_generator"),
(
(16, lambda count: count),
(4096, lambda count: count),
(32, lambda count: 64 + count if count % 2 else count),
),
)
async def test_concurrent_messages(
protocol: Any,
transport: Any,
max_sync_chunk_size: int,
payload_point_generator: Callable[[int], int],
) -> None:
"""Ensure messages are compressed correctly when there are multiple concurrent writers.

This test generates is parametrized to

- Generate messages that are larger than patch
WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 16
where compression will run in the executor

- Generate messages that are smaller than patch
WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 4096
where compression will run in the event loop

- Interleave generated messages with a
WEBSOCKET_MAX_SYNC_CHUNK_SIZE of 32
where compression will run in the event loop
and in the executor
"""
with mock.patch(
"aiohttp.http_websocket.WEBSOCKET_MAX_SYNC_CHUNK_SIZE", max_sync_chunk_size
):
writer = WebSocketWriter(protocol, transport, compress=15)
queue: DataQueue[WSMessage] = DataQueue(asyncio.get_running_loop())
reader = WebSocketReader(queue, 50000)
writers = []
payloads = []
for count in range(1, 64 + 1):
point = payload_point_generator(count)
payload = bytes((point,)) * point
payloads.append(payload)
writers.append(writer.send(payload, binary=True))
await asyncio.gather(*writers)

for call in writer.transport.write.call_args_list:
call_bytes = call[0][0]
result, _ = reader.feed_data(call_bytes)
assert result is False
msg = await queue.read()
bytes_data: bytes = msg.data
first_char = bytes_data[0:1]
char_val = ord(first_char)
assert len(bytes_data) == char_val
# If we have a concurrency problem, the data
# tends to get mixed up between messages so
# we want to validate that all the bytes are
# the same value
assert bytes_data == bytes_data[0:1] * char_val
Loading