diff --git a/tests/benchmarks/api.py b/tests/benchmarks/api.py index 157b18b556..4584d944e2 100644 --- a/tests/benchmarks/api.py +++ b/tests/benchmarks/api.py @@ -73,6 +73,11 @@ class Subscription: async def something(self) -> AsyncIterator[str]: yield "Hello World!" + @strawberry.subscription + async def long_running(self, count: int) -> AsyncIterator[int]: + for i in range(count): + yield i + @strawberry.directive(locations=[DirectiveLocation.FIELD]) def uppercase(value: str) -> str: diff --git a/tests/benchmarks/test_subscriptions.py b/tests/benchmarks/test_subscriptions.py index 075c993904..0bbd512835 100644 --- a/tests/benchmarks/test_subscriptions.py +++ b/tests/benchmarks/test_subscriptions.py @@ -1,6 +1,8 @@ import asyncio +from typing import AsyncIterator import pytest +from graphql import ExecutionResult from pytest_codspeed.plugin import BenchmarkFixture from .api import schema @@ -24,3 +26,25 @@ async def _run(): assert value.data["something"] == "Hello World!" benchmark(lambda: asyncio.run(_run())) + + +@pytest.mark.benchmark +@pytest.mark.parametrize("count", [1000, 20000]) +def test_subscription_long_run(benchmark: BenchmarkFixture, count: int) -> None: + s = """#graphql + subscription LongRunning($count: Int!) { + longRunning(count: $count) + } + """ + + async def _run(): + i = 0 + aiterator: AsyncIterator[ExecutionResult] = await schema.subscribe( + s, variable_values={"count": count} + ) # type: ignore[assignment] + async for res in aiterator: + assert res.data is not None + assert res.data["longRunning"] == i + i += 1 + + benchmark(lambda: asyncio.run(_run()))