Skip to content

Commit

Permalink
Drop messages on the client for closed streams/subscriptions (#133)
Browse files Browse the repository at this point in the history
Why
===

Each stream/subscription has a messages channel with a capacity of 128
messages. In our main receive loop, we push messages into the channel,
blocking until the channel has room. This adds some backpressure, but
becomes problematic if the stream is not making any progress. For
example, the client could start a stream and then decide to cancel it
and not read any of the messages. If the server sends >128 messages, it
will fill up the stream's channel leading to a deadlock for the session.

This will be more correctly fixed when river v2 support is landed, as
that adds support for proper cancellation. In the meantime, we can close
the channel when we know we are not going to be reading from it anymore,
and then drop any messages destiined for a closed channel.

rpc/upload are not affected because the server is only allowed to send 1
payload, and the channel has a buffer size of 1, so there will always be
room.

> [!Note]
> A deadlock can still occur if the client holds a reference to the
`AsyncGenerator` but doesn't service it. This is likely a client bug if
it happens, but we can probably add some timeouts to putting messages in
the stream channel for defense in depth. I'll do this as a followup as I
need to think a little bit more about how to properly handle that case.
This PR as-is should be a quick win for our usage since we shouldn't be
holding references to async generators that we aren't also actively
servicing.


What changed
============

- Close stream aiochannel in a finalizer in stream/subscription impls
- Ignore `ChannelClosed` errors when adding message to stream
- Fix tests to use correct method kinds, this was causing
subscription/upload RPCs to not work correctly in tests. Luckily things
are fine on the server codegen side.

Test plan
=========

- Added a test which caused a deadlock on the client before this change,
but works properly after this change.
  • Loading branch information
cbrewster authored Jan 8, 2025
1 parent 7d011c7 commit ecd1b17
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 6 deletions.
4 changes: 4 additions & 0 deletions replit_river/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ async def send_subscription(
) from e
except Exception as e:
raise e
finally:
output.close()

async def send_stream(
self,
Expand Down Expand Up @@ -335,6 +337,8 @@ async def _encode_stream() -> None:
) from e
except Exception as e:
raise e
finally:
output.close()

async def send_close_stream(
self,
Expand Down
6 changes: 5 additions & 1 deletion replit_river/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,11 @@ async def _add_msg_to_stream(
return
try:
await stream.put(msg.payload)
except (RuntimeError, ChannelClosed) as e:
except ChannelClosed:
# The client is no longer interested in this stream,
# just drop the message.
pass
except RuntimeError as e:
raise InvalidMessageException(e) from e

async def _remove_acked_messages_in_buffer(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/common_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ async def upload_handler(

basic_upload: HandlerMapping = {
("test_service", "upload_method"): (
"upload",
"upload-stream",
upload_method_handler(upload_handler, deserialize_request, serialize_response),
),
}
Expand All @@ -54,7 +54,7 @@ async def subscription_handler(

basic_subscription: HandlerMapping = {
("test_service", "subscription_method"): (
"subscription",
"subscription-stream",
subscription_method_handler(
subscription_handler, deserialize_request, serialize_response
),
Expand Down
5 changes: 3 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Mapping
from typing import Any, Literal, Mapping

import nanoid
import pytest
Expand All @@ -16,7 +16,8 @@
# Modular fixtures
pytest_plugins = ["tests.river_fixtures.logging", "tests.river_fixtures.clientserver"]

HandlerMapping = Mapping[tuple[str, str], tuple[str, GenericRpcHandler]]
HandlerKind = Literal["rpc", "subscription-stream", "upload-stream", "stream"]
HandlerMapping = Mapping[tuple[str, str], tuple[HandlerKind, GenericRpcHandler]]


def transport_message(
Expand Down
57 changes: 56 additions & 1 deletion tests/test_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from typing import AsyncGenerator

import pytest
from grpc.aio import grpc

from replit_river.client import Client
from replit_river.error_schema import RiverError
from replit_river.rpc import subscription_method_handler
from replit_river.transport_options import MAX_MESSAGE_BUFFER_SIZE
from tests.common_handlers import (
basic_rpc_method,
Expand All @@ -14,9 +16,12 @@
basic_upload,
)
from tests.conftest import (
HandlerMapping,
deserialize_error,
deserialize_request,
deserialize_response,
serialize_request,
serialize_response,
)


Expand Down Expand Up @@ -101,6 +106,7 @@ async def upload_data(enabled: bool = False) -> AsyncGenerator[str, None]:
@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_subscription}])
async def test_subscription_method(client: Client) -> None:
messages = []
async for response in client.send_subscription(
"test_service",
"subscription_method",
Expand All @@ -110,7 +116,8 @@ async def test_subscription_method(client: Client) -> None:
deserialize_error,
):
assert isinstance(response, str)
assert "Subscription message" in response
messages.append(response)
assert messages == [f"Subscription message {i} for Bob" for i in range(5)]


@pytest.mark.asyncio
Expand Down Expand Up @@ -213,3 +220,51 @@ async def stream_data() -> AsyncGenerator[str, None]:
"Stream response for Stream Data 1",
"Stream response for Stream Data 2",
]


async def flood_subscription_handler(
request: str, context: grpc.aio.ServicerContext
) -> AsyncGenerator[str, None]:
for i in range(1024):
yield f"Subscription message {i} for {request}"


flood_subscription: HandlerMapping = {
("test_service", "flood_subscription_method"): (
"subscription-stream",
subscription_method_handler(
flood_subscription_handler, deserialize_request, serialize_response
),
),
}


@pytest.mark.asyncio
@pytest.mark.parametrize("handlers", [{**basic_rpc_method, **flood_subscription}])
async def test_ignore_flood_subscription(client: Client) -> None:
sub = client.send_subscription(
"test_service",
"flood_subscription_method",
"Initial Subscription Data",
serialize_request,
deserialize_response,
deserialize_error,
)

# read one entry to start the subscription
await sub.__anext__()
# close the subscription so we can signal that we're not
# interested in the rest of the subscription.
await sub.aclose()

# ensure that subsequent RPCs still work
response = await client.send_rpc(
"test_service",
"rpc_method",
"Alice",
serialize_request,
deserialize_response,
deserialize_error,
timedelta(seconds=20),
)
assert response == "Hello, Alice!"

0 comments on commit ecd1b17

Please sign in to comment.