From 62422d1189e73af9d02d4ae8d9e444c1ccdd065d Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Tue, 28 Mar 2023 16:04:10 +0200 Subject: [PATCH 1/3] Fix stream replay in validators --- connexion/validators/abstract.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/connexion/validators/abstract.py b/connexion/validators/abstract.py index d911fc947..8d31b1720 100644 --- a/connexion/validators/abstract.py +++ b/connexion/validators/abstract.py @@ -99,11 +99,14 @@ def _insert_messages( receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]] ) -> Receive: """Insert messages at the start of the `receive` channel.""" + # Ensure that messages in an iterator + messages = iter(messages) async def receive_() -> t.MutableMapping[str, t.Any]: - for message in messages: - return message - return await receive() + try: + return next(iter(messages)) + except StopIteration: + return await receive() return receive_ From 607c1c614efbda213df36d8466ddbb2cf0ed7911 Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Tue, 28 Mar 2023 16:49:11 +0200 Subject: [PATCH 2/3] Add test for wrapped receive --- connexion/validators/abstract.py | 6 +++--- tests/test_validation.py | 29 ++++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/connexion/validators/abstract.py b/connexion/validators/abstract.py index 8d31b1720..bb6599eac 100644 --- a/connexion/validators/abstract.py +++ b/connexion/validators/abstract.py @@ -99,12 +99,12 @@ def _insert_messages( receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]] ) -> Receive: """Insert messages at the start of the `receive` channel.""" - # Ensure that messages in an iterator - messages = iter(messages) + # Ensure that messages is an iterator so each message is replayed once. + message_iterator = iter(messages) async def receive_() -> t.MutableMapping[str, t.Any]: try: - return next(iter(messages)) + return next(message_iterator) except StopIteration: return await receive() diff --git a/tests/test_validation.py b/tests/test_validation.py index b62eca32e..3750dd494 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -4,7 +4,7 @@ import pytest from connexion.exceptions import BadRequestProblem from connexion.uri_parsing import Swagger2URIParser -from connexion.validators.parameter import ParameterValidator +from connexion.validators import AbstractRequestBodyValidator, ParameterValidator from starlette.datastructures import QueryParams @@ -140,3 +140,30 @@ def test_parameter_validator(monkeypatch): with pytest.raises(BadRequestProblem) as exc: validator.validate_request(request) assert exc.value.detail.startswith("'x' is not one of ['a', 'b']") + + +async def test_stream_replay(): + messages = [ + {"body": b"message 1", "more_body": True}, + {"body": b"message 2", "more_body": False}, + ] + + async def receive(): + return b"" + + wrapped_receive = AbstractRequestBodyValidator._insert_messages( + receive, messages=messages + ) + + replay = [] + more_body = True + while more_body: + message = await wrapped_receive() + replay.append(message) + more_body = message.get("more_body", False) + + assert len(replay) <= len(messages), ( + "Replayed more messages than received, " "break out of while loop" + ) + + assert messages == replay From 9fe589393715540d7aac4a2d65b81d22c4db1d9d Mon Sep 17 00:00:00 2001 From: Robbe Sneyders Date: Thu, 30 Mar 2023 21:47:33 +0200 Subject: [PATCH 3/3] Fix black --- tests/test_validation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_validation.py b/tests/test_validation.py index 3750dd494..c15501f6c 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -162,8 +162,8 @@ async def receive(): replay.append(message) more_body = message.get("more_body", False) - assert len(replay) <= len(messages), ( - "Replayed more messages than received, " "break out of while loop" - ) + assert len(replay) <= len( + messages + ), "Replayed more messages than received, break out of while loop" assert messages == replay