Skip to content

Commit

Permalink
Add test for wrapped receive
Browse files Browse the repository at this point in the history
  • Loading branch information
RobbeSneyders committed Mar 29, 2023
1 parent 62422d1 commit 607c1c6
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
6 changes: 3 additions & 3 deletions connexion/validators/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def _insert_messages(
receive: Receive, *, messages: t.Iterable[t.MutableMapping[str, t.Any]]
) -> Receive:
"""Insert messages at the start of the `receive` channel."""
# Ensure that messages in an iterator
messages = iter(messages)
# Ensure that messages is an iterator so each message is replayed once.
message_iterator = iter(messages)

async def receive_() -> t.MutableMapping[str, t.Any]:
try:
return next(iter(messages))
return next(message_iterator)
except StopIteration:
return await receive()

Expand Down
29 changes: 28 additions & 1 deletion tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from connexion.exceptions import BadRequestProblem
from connexion.uri_parsing import Swagger2URIParser
from connexion.validators.parameter import ParameterValidator
from connexion.validators import AbstractRequestBodyValidator, ParameterValidator
from starlette.datastructures import QueryParams


Expand Down Expand Up @@ -140,3 +140,30 @@ def test_parameter_validator(monkeypatch):
with pytest.raises(BadRequestProblem) as exc:
validator.validate_request(request)
assert exc.value.detail.startswith("'x' is not one of ['a', 'b']")


async def test_stream_replay():
messages = [
{"body": b"message 1", "more_body": True},
{"body": b"message 2", "more_body": False},
]

async def receive():
return b""

wrapped_receive = AbstractRequestBodyValidator._insert_messages(
receive, messages=messages
)

replay = []
more_body = True
while more_body:
message = await wrapped_receive()
replay.append(message)
more_body = message.get("more_body", False)

assert len(replay) <= len(messages), (
"Replayed more messages than received, " "break out of while loop"
)

assert messages == replay

0 comments on commit 607c1c6

Please sign in to comment.