From 8feabc2d0fea01b5601566e09de535f675e7ae66 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Thu, 15 Aug 2019 23:00:16 +0200 Subject: [PATCH 1/9] Rely on concurrency backend everywhere --- httpx/concurrency.py | 2 +- tests/client/test_async_client.py | 65 +++++++++++++------------------ tests/client/test_client.py | 65 ++++++++++++++----------------- tests/conftest.py | 12 ++++++ 4 files changed, 69 insertions(+), 75 deletions(-) diff --git a/httpx/concurrency.py b/httpx/concurrency.py index f07eef4c64..ff9e02bfd6 100644 --- a/httpx/concurrency.py +++ b/httpx/concurrency.py @@ -254,7 +254,7 @@ def background_manager( class BackgroundManager(BaseBackgroundManager): - def __init__(self, coroutine: typing.Callable, args: typing.Any) -> None: + def __init__(self, coroutine: typing.Callable, args: typing.Sequence) -> None: self.coroutine = coroutine self.args = args diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index b037085f25..08b9d778c0 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -3,10 +3,9 @@ import httpx -@pytest.mark.asyncio -async def test_get(server): +async def test_get(backend, server): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(backend=backend) as client: response = await client.get(url) assert response.status_code == 200 assert response.text == "Hello, world!" @@ -15,25 +14,22 @@ async def test_get(server): assert repr(response) == "" -@pytest.mark.asyncio -async def test_post(server): +async def test_post(backend, server): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(backend=backend) as client: response = await client.post(url, data=b"Hello, world!") assert response.status_code == 200 -@pytest.mark.asyncio -async def test_post_json(server): +async def test_post_json(backend, server): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(backend=backend) as client: response = await client.post(url, json={"text": "Hello, world!"}) assert response.status_code == 200 -@pytest.mark.asyncio -async def test_stream_response(server): - async with httpx.AsyncClient() as client: +async def test_stream_response(backend, server): + async with httpx.AsyncClient(backend=backend) as client: response = await client.request("GET", "http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 body = await response.read() @@ -41,31 +37,28 @@ async def test_stream_response(server): assert response.content == b"Hello, world!" -@pytest.mark.asyncio -async def test_access_content_stream_response(server): - async with httpx.AsyncClient() as client: +async def test_access_content_stream_response(backend, server): + async with httpx.AsyncClient(backend=backend) as client: response = await client.request("GET", "http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 with pytest.raises(httpx.ResponseNotRead): response.content -@pytest.mark.asyncio -async def test_stream_request(server): +async def test_stream_request(backend, server): async def hello_world(): yield b"Hello, " yield b"world!" - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(backend=backend) as client: response = await client.request( "POST", "http://127.0.0.1:8000/", data=hello_world() ) assert response.status_code == 200 -@pytest.mark.asyncio -async def test_raise_for_status(server): - async with httpx.AsyncClient() as client: +async def test_raise_for_status(backend, server): + async with httpx.AsyncClient(backend=backend) as client: for status_code in (200, 400, 404, 500, 505): response = await client.request( "GET", f"http://127.0.0.1:8000/status/{status_code}" @@ -79,56 +72,50 @@ async def test_raise_for_status(server): assert response.raise_for_status() is None -@pytest.mark.asyncio -async def test_options(server): +async def test_options(backend, server): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(backend=backend) as client: response = await client.options(url) assert response.status_code == 200 assert response.text == "Hello, world!" -@pytest.mark.asyncio -async def test_head(server): +async def test_head(backend, server): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(backend=backend) as client: response = await client.head(url) assert response.status_code == 200 assert response.text == "" -@pytest.mark.asyncio -async def test_put(server): +async def test_put(backend, server): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(backend=backend) as client: response = await client.put(url, data=b"Hello, world!") assert response.status_code == 200 -@pytest.mark.asyncio -async def test_patch(server): +async def test_patch(backend, server): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(backend=backend) as client: response = await client.patch(url, data=b"Hello, world!") assert response.status_code == 200 -@pytest.mark.asyncio -async def test_delete(server): +async def test_delete(backend, server): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(backend=backend) as client: response = await client.delete(url) assert response.status_code == 200 assert response.text == "Hello, world!" -@pytest.mark.asyncio -async def test_100_continue(server): +async def test_100_continue(backend, server): url = "http://127.0.0.1:8000/echo_body" headers = {"Expect": "100-continue"} data = b"Echo request body" - async with httpx.AsyncClient() as client: + async with httpx.AsyncClient(backend=backend) as client: response = await client.post(url, headers=headers, data=data) assert response.status_code == 200 diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 97ae0277dd..0c7c49e608 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,4 +1,3 @@ -import asyncio import functools import pytest @@ -12,21 +11,17 @@ def threadpool(func): """ @functools.wraps(func) - async def wrapped(*args, **kwargs): - nonlocal func + async def wrapped(backend, *args, **kwargs): + backend_for_thread = type(backend)() + await backend.run_in_threadpool(func, backend_for_thread, *args, **kwargs) - loop = asyncio.get_event_loop() - if kwargs: - func = functools.partial(func, **kwargs) - await loop.run_in_executor(None, func, *args) - - return pytest.mark.asyncio(wrapped) + return wrapped @threadpool -def test_get(server): +def test_get(backend, server): url = "http://127.0.0.1:8000/" - with httpx.Client() as http: + with httpx.Client(backend=backend) as http: response = http.get(url) assert response.status_code == 200 assert response.url == httpx.URL(url) @@ -41,24 +36,24 @@ def test_get(server): @threadpool -def test_post(server): - with httpx.Client() as http: +def test_post(backend, server): + with httpx.Client(backend=backend) as http: response = http.post("http://127.0.0.1:8000/", data=b"Hello, world!") assert response.status_code == 200 assert response.reason_phrase == "OK" @threadpool -def test_post_json(server): - with httpx.Client() as http: +def test_post_json(backend, server): + with httpx.Client(backend=backend) as http: response = http.post("http://127.0.0.1:8000/", json={"text": "Hello, world!"}) assert response.status_code == 200 assert response.reason_phrase == "OK" @threadpool -def test_stream_response(server): - with httpx.Client() as http: +def test_stream_response(backend, server): + with httpx.Client(backend=backend) as http: response = http.get("http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 content = response.read() @@ -66,8 +61,8 @@ def test_stream_response(server): @threadpool -def test_stream_iterator(server): - with httpx.Client() as http: +def test_stream_iterator(backend, server): + with httpx.Client(backend=backend) as http: response = http.get("http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 body = b"" @@ -77,8 +72,8 @@ def test_stream_iterator(server): @threadpool -def test_raw_iterator(server): - with httpx.Client() as http: +def test_raw_iterator(backend, server): + with httpx.Client(backend=backend) as http: response = http.get("http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 body = b"" @@ -89,8 +84,8 @@ def test_raw_iterator(server): @threadpool -def test_raise_for_status(server): - with httpx.Client() as client: +def test_raise_for_status(backend, server): + with httpx.Client(backend=backend) as client: for status_code in (200, 400, 404, 500, 505): response = client.request( "GET", "http://127.0.0.1:8000/status/{}".format(status_code) @@ -104,49 +99,49 @@ def test_raise_for_status(server): @threadpool -def test_options(server): - with httpx.Client() as http: +def test_options(backend, server): + with httpx.Client(backend=backend) as http: response = http.options("http://127.0.0.1:8000/") assert response.status_code == 200 assert response.reason_phrase == "OK" @threadpool -def test_head(server): - with httpx.Client() as http: +def test_head(backend, server): + with httpx.Client(backend=backend) as http: response = http.head("http://127.0.0.1:8000/") assert response.status_code == 200 assert response.reason_phrase == "OK" @threadpool -def test_put(server): - with httpx.Client() as http: +def test_put(backend, server): + with httpx.Client(backend=backend) as http: response = http.put("http://127.0.0.1:8000/", data=b"Hello, world!") assert response.status_code == 200 assert response.reason_phrase == "OK" @threadpool -def test_patch(server): - with httpx.Client() as http: +def test_patch(backend, server): + with httpx.Client(backend=backend) as http: response = http.patch("http://127.0.0.1:8000/", data=b"Hello, world!") assert response.status_code == 200 assert response.reason_phrase == "OK" @threadpool -def test_delete(server): - with httpx.Client() as http: +def test_delete(backend, server): + with httpx.Client(backend=backend) as http: response = http.delete("http://127.0.0.1:8000/") assert response.status_code == 200 assert response.reason_phrase == "OK" @threadpool -def test_base_url(server): +def test_base_url(backend, server): base_url = "http://127.0.0.1:8000/" - with httpx.Client(base_url=base_url) as http: + with httpx.Client(base_url=base_url, backend=backend) as http: response = http.get("/") assert response.status_code == 200 assert str(response.url) == base_url diff --git a/tests/conftest.py b/tests/conftest.py index f8a6d0dc04..fde7c4915f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,18 @@ from uvicorn.config import Config from uvicorn.main import Server +from httpx.concurrency import AsyncioBackend + + +@pytest.fixture( + params=[ + pytest.param(AsyncioBackend, marks=pytest.mark.asyncio) # type: ignore + ] +) +def backend(request): + backend_cls = request.param + return backend_cls() + async def app(scope, receive, send): assert scope["type"] == "http" From 602837c6021f3a89394a74da4218cf0ff73496a7 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Thu, 15 Aug 2019 20:28:46 +0200 Subject: [PATCH 2/9] Rely on concurrency backend in ASGIDispatch --- httpx/client.py | 2 +- httpx/concurrency.py | 42 ++++++++++++++++++++++++++++------------ httpx/dispatch/asgi.py | 34 +++++++++++++------------------- httpx/dispatch/http11.py | 3 ++- httpx/dispatch/http2.py | 3 ++- httpx/interfaces.py | 38 +++++++++++++++++++++++++++++++++--- 6 files changed, 83 insertions(+), 39 deletions(-) diff --git a/httpx/client.py b/httpx/client.py index 9d704b33e9..5b195679c6 100644 --- a/httpx/client.py +++ b/httpx/client.py @@ -79,7 +79,7 @@ def __init__( ) else: dispatch = ASGIDispatch( - app=app, raise_app_exceptions=raise_app_exceptions + app=app, raise_app_exceptions=raise_app_exceptions, backend=backend ) if dispatch is None: diff --git a/httpx/concurrency.py b/httpx/concurrency.py index ff9e02bfd6..dd1d386a76 100644 --- a/httpx/concurrency.py +++ b/httpx/concurrency.py @@ -14,11 +14,15 @@ import typing from types import TracebackType +from async_generator import asynccontextmanager + from .config import PoolLimits, TimeoutConfig from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout from .interfaces import ( BaseBackgroundManager, + BaseEvent, BasePoolSemaphore, + BaseQueue, BaseReader, BaseWriter, ConcurrencyBackend, @@ -244,31 +248,45 @@ def run( finally: self._loop = loop + def create_event(self) -> BaseEvent: + return asyncio.Event() # type: ignore + + def create_queue(self, max_size: int) -> BaseQueue: + return asyncio.Queue(maxsize=max_size) # type: ignore + def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: return PoolSemaphore(limits) - def background_manager( - self, coroutine: typing.Callable, args: typing.Any - ) -> "BackgroundManager": - return BackgroundManager(coroutine, args) + def background_manager(self) -> "BackgroundManager": + return BackgroundManager() class BackgroundManager(BaseBackgroundManager): - def __init__(self, coroutine: typing.Callable, args: typing.Sequence) -> None: - self.coroutine = coroutine - self.args = args + def __init__(self) -> None: + self.tasks: typing.Set[asyncio.Task] = set() async def __aenter__(self) -> "BackgroundManager": - loop = asyncio.get_event_loop() - self.task = loop.create_task(self.coroutine(*self.args)) return self + def start_soon(self, coroutine: typing.Callable, *args: typing.Any) -> None: + loop = asyncio.get_event_loop() + self.tasks.add(loop.create_task(coroutine(*args))) + + @asynccontextmanager # type: ignore + async def will_wait_for_first_completed(self) -> typing.AsyncContextManager: + initial_tasks = self.tasks + self.tasks = set() + yield + await asyncio.wait(self.tasks, return_when=asyncio.FIRST_COMPLETED) + self.tasks = initial_tasks.union(self.tasks) + async def __aexit__( self, exc_type: typing.Type[BaseException] = None, exc_value: BaseException = None, traceback: TracebackType = None, ) -> None: - await self.task - if exc_type is None: - self.task.result() + for task in self.tasks: + await task + if exc_type is None: + task.result() diff --git a/httpx/dispatch/asgi.py b/httpx/dispatch/asgi.py index 8164b18e22..4773f0ad97 100644 --- a/httpx/dispatch/asgi.py +++ b/httpx/dispatch/asgi.py @@ -1,8 +1,7 @@ -import asyncio import typing from ..config import CertTypes, TimeoutTypes, VerifyTypes -from ..interfaces import AsyncDispatcher +from ..interfaces import AsyncDispatcher, ConcurrencyBackend from ..models import AsyncRequest, AsyncResponse @@ -35,6 +34,7 @@ class ASGIDispatch(AsyncDispatcher): def __init__( self, app: typing.Callable, + backend: ConcurrencyBackend, raise_app_exceptions: bool = True, root_path: str = "", client: typing.Tuple[str, int] = ("127.0.0.1", 123), @@ -43,6 +43,7 @@ def __init__( self.raise_app_exceptions = raise_app_exceptions self.root_path = root_path self.client = client + self.backend = backend async def send( self, @@ -69,8 +70,8 @@ async def send( app_exc = None status_code = None headers = None - response_started = asyncio.Event() - response_body = BodyIterator() + response_started = self.backend.create_event() + response_body = BodyIterator(self.backend) request_stream = request.stream() async def receive() -> dict: @@ -106,18 +107,11 @@ async def run_app() -> None: finally: await response_body.done() - # Really we'd like to push all `asyncio` logic into concurrency.py, - # with a standardized interface, so that we can support other event - # loop implementations, such as Trio and Curio. - # That's a bit fiddly here, so we're not yet supporting using a custom - # `ConcurrencyBackend` with the `Client(app=asgi_app)` case. - loop = asyncio.get_event_loop() - app_task = loop.create_task(run_app()) - response_task = loop.create_task(response_started.wait()) + background = self.backend.background_manager() - tasks = {app_task, response_task} # type: typing.Set[asyncio.Task] - - await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + async with background.will_wait_for_first_completed(): + background.start_soon(run_app) + background.start_soon(response_started.wait) if app_exc is not None and self.raise_app_exceptions: raise app_exc @@ -127,9 +121,8 @@ async def run_app() -> None: assert headers is not None async def on_close() -> None: - nonlocal app_task, response_body await response_body.drain() - await app_task + await background.close() if app_exc is not None and self.raise_app_exceptions: raise app_exc @@ -149,10 +142,9 @@ class BodyIterator: ingest the response content from. """ - def __init__(self) -> None: - self._queue = asyncio.Queue( - maxsize=1 - ) # type: asyncio.Queue[typing.Union[bytes, object]] + def __init__(self, backend: ConcurrencyBackend) -> None: + self._backend = backend + self._queue = self._backend.create_queue(max_size=1) self._done = object() async def iterate(self) -> typing.AsyncIterator[bytes]: diff --git a/httpx/dispatch/http11.py b/httpx/dispatch/http11.py index 0f34191e97..e6cd8f064d 100644 --- a/httpx/dispatch/http11.py +++ b/httpx/dispatch/http11.py @@ -48,7 +48,8 @@ async def send( await self._send_request(request, timeout) task, args = self._send_request_data, [request.stream(), timeout] - async with self.backend.background_manager(task, args=args): + async with self.backend.background_manager() as background: + background.start_soon(task, *args) http_version, status_code, headers = await self._receive_response(timeout) content = self._receive_response_data(timeout) diff --git a/httpx/dispatch/http2.py b/httpx/dispatch/http2.py index 980b07b25c..a9980f7227 100644 --- a/httpx/dispatch/http2.py +++ b/httpx/dispatch/http2.py @@ -44,7 +44,8 @@ async def send( self.timeout_flags[stream_id] = TimeoutFlag() task, args = self.send_request_data, [stream_id, request.stream(), timeout] - async with self.backend.background_manager(task, args=args): + async with self.backend.background_manager() as background: + background.start_soon(task, *args) status_code, headers = await self.receive_response(stream_id, timeout) content = self.body_iter(stream_id, timeout) on_close = functools.partial(self.response_closed, stream_id=stream_id) diff --git a/httpx/interfaces.py b/httpx/interfaces.py index 6c79c9e22e..6c386858f1 100644 --- a/httpx/interfaces.py +++ b/httpx/interfaces.py @@ -151,6 +151,25 @@ async def close(self) -> None: raise NotImplementedError() # pragma: no cover +class BaseEvent: + def set(self) -> None: + raise NotImplementedError() # pragma: no cover + + def is_set(self) -> bool: + raise NotImplementedError() # pragma: no cover + + async def wait(self) -> None: + raise NotImplementedError() # pragma: no cover + + +class BaseQueue: + async def get(self) -> typing.Any: + raise NotImplementedError() # pragma: no cover + + async def put(self, value: typing.Any) -> None: + raise NotImplementedError() # pragma: no cover + + class BasePoolSemaphore: """ A semaphore for use with connection pooling. @@ -204,6 +223,12 @@ def run( ) -> typing.Any: raise NotImplementedError() # pragma: no cover + def create_event(self) -> BaseEvent: + raise NotImplementedError() # pragma: no cover + + def create_queue(self, max_size: int) -> BaseQueue: + raise NotImplementedError() # pragma: no cover + def iterate(self, async_iterator): # type: ignore while True: try: @@ -211,13 +236,17 @@ def iterate(self, async_iterator): # type: ignore except StopAsyncIteration: break - def background_manager( - self, coroutine: typing.Callable, args: typing.Any - ) -> "BaseBackgroundManager": + def background_manager(self) -> "BaseBackgroundManager": raise NotImplementedError() # pragma: no cover class BaseBackgroundManager: + def start_soon(self, coroutine: typing.Callable, *args: typing.Any) -> None: + raise NotImplementedError() # pragma: no cover + + def will_wait_for_first_completed(self) -> typing.AsyncContextManager: + raise NotImplementedError() # pragma: no cover + async def __aenter__(self) -> "BaseBackgroundManager": raise NotImplementedError() # pragma: no cover @@ -228,3 +257,6 @@ async def __aexit__( traceback: TracebackType = None, ) -> None: raise NotImplementedError() # pragma: no cover + + async def close(self) -> None: + await self.__aexit__(None, None, None) From ad6a2abcfaeef2f16d2891d87db0321245b24feb Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Thu, 15 Aug 2019 21:00:34 +0200 Subject: [PATCH 3/9] Refactor: body iterator interface --- httpx/concurrency.py | 28 +++++++++++++++++++---- httpx/dispatch/asgi.py | 45 +------------------------------------ httpx/interfaces.py | 50 ++++++++++++++++++++++++++++++++---------- 3 files changed, 64 insertions(+), 59 deletions(-) diff --git a/httpx/concurrency.py b/httpx/concurrency.py index dd1d386a76..41e0ae745c 100644 --- a/httpx/concurrency.py +++ b/httpx/concurrency.py @@ -20,9 +20,9 @@ from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout from .interfaces import ( BaseBackgroundManager, + BaseBodyIterator, BaseEvent, BasePoolSemaphore, - BaseQueue, BaseReader, BaseWriter, ConcurrencyBackend, @@ -251,15 +251,35 @@ def run( def create_event(self) -> BaseEvent: return asyncio.Event() # type: ignore - def create_queue(self, max_size: int) -> BaseQueue: - return asyncio.Queue(maxsize=max_size) # type: ignore - def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: return PoolSemaphore(limits) def background_manager(self) -> "BackgroundManager": return BackgroundManager() + def body_iterator(self) -> "BodyIterator": + return BodyIterator() + + +class BodyIterator(BaseBodyIterator): + def __init__(self) -> None: + self._queue: asyncio.Queue = asyncio.Queue(maxsize=1) + self._done = object() + + async def iterate(self) -> typing.AsyncIterator[bytes]: + while True: + data = await self._queue.get() + if data is self._done: + break + assert isinstance(data, bytes) + yield data + + async def put(self, data: bytes) -> None: + await self._queue.put(data) + + async def done(self) -> None: + await self._queue.put(self._done) + class BackgroundManager(BaseBackgroundManager): def __init__(self) -> None: diff --git a/httpx/dispatch/asgi.py b/httpx/dispatch/asgi.py index 4773f0ad97..752f1830b8 100644 --- a/httpx/dispatch/asgi.py +++ b/httpx/dispatch/asgi.py @@ -71,7 +71,7 @@ async def send( status_code = None headers = None response_started = self.backend.create_event() - response_body = BodyIterator(self.backend) + response_body = self.backend.body_iterator() request_stream = request.stream() async def receive() -> dict: @@ -134,46 +134,3 @@ async def on_close() -> None: on_close=on_close, request=request, ) - - -class BodyIterator: - """ - Provides a byte-iterator interface that the client can use to - ingest the response content from. - """ - - def __init__(self, backend: ConcurrencyBackend) -> None: - self._backend = backend - self._queue = self._backend.create_queue(max_size=1) - self._done = object() - - async def iterate(self) -> typing.AsyncIterator[bytes]: - """ - A byte-iterator, used by the client to consume the response body. - """ - while True: - data = await self._queue.get() - if data is self._done: - break - assert isinstance(data, bytes) - yield data - - async def drain(self) -> None: - """ - Drain any remaining body, in order to allow any blocked `put()` calls - to complete. - """ - async for chunk in self.iterate(): - pass # pragma: no cover - - async def put(self, data: bytes) -> None: - """ - Used by the server to add data to the response body. - """ - await self._queue.put(data) - - async def done(self) -> None: - """ - Used by the server to signal the end of the response body. - """ - await self._queue.put(self._done) diff --git a/httpx/interfaces.py b/httpx/interfaces.py index 6c386858f1..42fcb6fc88 100644 --- a/httpx/interfaces.py +++ b/httpx/interfaces.py @@ -162,14 +162,6 @@ async def wait(self) -> None: raise NotImplementedError() # pragma: no cover -class BaseQueue: - async def get(self) -> typing.Any: - raise NotImplementedError() # pragma: no cover - - async def put(self, value: typing.Any) -> None: - raise NotImplementedError() # pragma: no cover - - class BasePoolSemaphore: """ A semaphore for use with connection pooling. @@ -226,9 +218,6 @@ def run( def create_event(self) -> BaseEvent: raise NotImplementedError() # pragma: no cover - def create_queue(self, max_size: int) -> BaseQueue: - raise NotImplementedError() # pragma: no cover - def iterate(self, async_iterator): # type: ignore while True: try: @@ -239,6 +228,45 @@ def iterate(self, async_iterator): # type: ignore def background_manager(self) -> "BaseBackgroundManager": raise NotImplementedError() # pragma: no cover + def body_iterator(self) -> "BaseBodyIterator": + raise NotImplementedError() # pragma: no cover + + +class BaseBodyIterator: + """ + Provides a byte-iterator interface that the client can use to + ingest the response content from. + """ + + def __init__(self, backend: ConcurrencyBackend) -> None: + self.backend = backend + + def iterate(self) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator, used by the client to consume the response body. + """ + raise NotImplementedError() # pragma: no cover + + async def drain(self) -> None: + """ + Drain any remaining body, in order to allow any blocked `put()` calls + to complete. + """ + async for chunk in self.iterate(): + pass # pragma: no cover + + async def put(self, data: bytes) -> None: + """ + Used by the server to add data to the response body. + """ + raise NotImplementedError() # pragma: no cover + + async def done(self) -> None: + """ + Used by the server to signal the end of the response body. + """ + raise NotImplementedError() # pragma: no cover + class BaseBackgroundManager: def start_soon(self, coroutine: typing.Callable, *args: typing.Any) -> None: From 158799d5858c333371c6719b83c38840189024da Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Thu, 15 Aug 2019 21:36:03 +0200 Subject: [PATCH 4/9] Refactor server fixtures, improve test coverage --- httpx/interfaces.py | 3 --- tests/client/test_async_client.py | 39 ++++++++++++++++++++----------- tests/client/test_client.py | 39 ++++++++++++++++++++----------- 3 files changed, 52 insertions(+), 29 deletions(-) diff --git a/httpx/interfaces.py b/httpx/interfaces.py index 42fcb6fc88..ca82a62b62 100644 --- a/httpx/interfaces.py +++ b/httpx/interfaces.py @@ -238,9 +238,6 @@ class BaseBodyIterator: ingest the response content from. """ - def __init__(self, backend: ConcurrencyBackend) -> None: - self.backend = backend - def iterate(self) -> typing.AsyncIterator[bytes]: """ A byte-iterator, used by the client to consume the response body. diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index 08b9d778c0..ba6971ca84 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -3,7 +3,8 @@ import httpx -async def test_get(backend, server): +@pytest.mark.usefixtures("server") +async def test_get(backend): url = "http://127.0.0.1:8000/" async with httpx.AsyncClient(backend=backend) as client: response = await client.get(url) @@ -14,21 +15,24 @@ async def test_get(backend, server): assert repr(response) == "" -async def test_post(backend, server): +@pytest.mark.usefixtures("server") +async def test_post(backend): url = "http://127.0.0.1:8000/" async with httpx.AsyncClient(backend=backend) as client: response = await client.post(url, data=b"Hello, world!") assert response.status_code == 200 -async def test_post_json(backend, server): +@pytest.mark.usefixtures("server") +async def test_post_json(backend): url = "http://127.0.0.1:8000/" async with httpx.AsyncClient(backend=backend) as client: response = await client.post(url, json={"text": "Hello, world!"}) assert response.status_code == 200 -async def test_stream_response(backend, server): +@pytest.mark.usefixtures("server") +async def test_stream_response(backend): async with httpx.AsyncClient(backend=backend) as client: response = await client.request("GET", "http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 @@ -37,7 +41,8 @@ async def test_stream_response(backend, server): assert response.content == b"Hello, world!" -async def test_access_content_stream_response(backend, server): +@pytest.mark.usefixtures("server") +async def test_access_content_stream_response(backend): async with httpx.AsyncClient(backend=backend) as client: response = await client.request("GET", "http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 @@ -45,7 +50,8 @@ async def test_access_content_stream_response(backend, server): response.content -async def test_stream_request(backend, server): +@pytest.mark.usefixtures("server") +async def test_stream_request(backend): async def hello_world(): yield b"Hello, " yield b"world!" @@ -57,7 +63,8 @@ async def hello_world(): assert response.status_code == 200 -async def test_raise_for_status(backend, server): +@pytest.mark.usefixtures("server") +async def test_raise_for_status(backend): async with httpx.AsyncClient(backend=backend) as client: for status_code in (200, 400, 404, 500, 505): response = await client.request( @@ -72,7 +79,8 @@ async def test_raise_for_status(backend, server): assert response.raise_for_status() is None -async def test_options(backend, server): +@pytest.mark.usefixtures("server") +async def test_options(backend): url = "http://127.0.0.1:8000/" async with httpx.AsyncClient(backend=backend) as client: response = await client.options(url) @@ -80,7 +88,8 @@ async def test_options(backend, server): assert response.text == "Hello, world!" -async def test_head(backend, server): +@pytest.mark.usefixtures("server") +async def test_head(backend): url = "http://127.0.0.1:8000/" async with httpx.AsyncClient(backend=backend) as client: response = await client.head(url) @@ -88,21 +97,24 @@ async def test_head(backend, server): assert response.text == "" -async def test_put(backend, server): +@pytest.mark.usefixtures("server") +async def test_put(backend): url = "http://127.0.0.1:8000/" async with httpx.AsyncClient(backend=backend) as client: response = await client.put(url, data=b"Hello, world!") assert response.status_code == 200 -async def test_patch(backend, server): +@pytest.mark.usefixtures("server") +async def test_patch(backend): url = "http://127.0.0.1:8000/" async with httpx.AsyncClient(backend=backend) as client: response = await client.patch(url, data=b"Hello, world!") assert response.status_code == 200 -async def test_delete(backend, server): +@pytest.mark.usefixtures("server") +async def test_delete(backend): url = "http://127.0.0.1:8000/" async with httpx.AsyncClient(backend=backend) as client: response = await client.delete(url) @@ -110,7 +122,8 @@ async def test_delete(backend, server): assert response.text == "Hello, world!" -async def test_100_continue(backend, server): +@pytest.mark.usefixtures("server") +async def test_100_continue(backend): url = "http://127.0.0.1:8000/echo_body" headers = {"Expect": "100-continue"} data = b"Echo request body" diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 0c7c49e608..9334a1ea14 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -19,7 +19,8 @@ async def wrapped(backend, *args, **kwargs): @threadpool -def test_get(backend, server): +@pytest.mark.usefixtures("server") +def test_get(backend): url = "http://127.0.0.1:8000/" with httpx.Client(backend=backend) as http: response = http.get(url) @@ -36,7 +37,8 @@ def test_get(backend, server): @threadpool -def test_post(backend, server): +@pytest.mark.usefixtures("server") +def test_post(backend): with httpx.Client(backend=backend) as http: response = http.post("http://127.0.0.1:8000/", data=b"Hello, world!") assert response.status_code == 200 @@ -44,7 +46,8 @@ def test_post(backend, server): @threadpool -def test_post_json(backend, server): +@pytest.mark.usefixtures("server") +def test_post_json(backend): with httpx.Client(backend=backend) as http: response = http.post("http://127.0.0.1:8000/", json={"text": "Hello, world!"}) assert response.status_code == 200 @@ -52,7 +55,8 @@ def test_post_json(backend, server): @threadpool -def test_stream_response(backend, server): +@pytest.mark.usefixtures("server") +def test_stream_response(backend): with httpx.Client(backend=backend) as http: response = http.get("http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 @@ -61,7 +65,8 @@ def test_stream_response(backend, server): @threadpool -def test_stream_iterator(backend, server): +@pytest.mark.usefixtures("server") +def test_stream_iterator(backend): with httpx.Client(backend=backend) as http: response = http.get("http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 @@ -72,7 +77,8 @@ def test_stream_iterator(backend, server): @threadpool -def test_raw_iterator(backend, server): +@pytest.mark.usefixtures("server") +def test_raw_iterator(backend): with httpx.Client(backend=backend) as http: response = http.get("http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 @@ -84,7 +90,8 @@ def test_raw_iterator(backend, server): @threadpool -def test_raise_for_status(backend, server): +@pytest.mark.usefixtures("server") +def test_raise_for_status(backend): with httpx.Client(backend=backend) as client: for status_code in (200, 400, 404, 500, 505): response = client.request( @@ -99,7 +106,8 @@ def test_raise_for_status(backend, server): @threadpool -def test_options(backend, server): +@pytest.mark.usefixtures("server") +def test_options(backend): with httpx.Client(backend=backend) as http: response = http.options("http://127.0.0.1:8000/") assert response.status_code == 200 @@ -107,7 +115,8 @@ def test_options(backend, server): @threadpool -def test_head(backend, server): +@pytest.mark.usefixtures("server") +def test_head(backend): with httpx.Client(backend=backend) as http: response = http.head("http://127.0.0.1:8000/") assert response.status_code == 200 @@ -115,7 +124,8 @@ def test_head(backend, server): @threadpool -def test_put(backend, server): +@pytest.mark.usefixtures("server") +def test_put(backend): with httpx.Client(backend=backend) as http: response = http.put("http://127.0.0.1:8000/", data=b"Hello, world!") assert response.status_code == 200 @@ -123,7 +133,8 @@ def test_put(backend, server): @threadpool -def test_patch(backend, server): +@pytest.mark.usefixtures("server") +def test_patch(backend): with httpx.Client(backend=backend) as http: response = http.patch("http://127.0.0.1:8000/", data=b"Hello, world!") assert response.status_code == 200 @@ -131,7 +142,8 @@ def test_patch(backend, server): @threadpool -def test_delete(backend, server): +@pytest.mark.usefixtures("server") +def test_delete(backend): with httpx.Client(backend=backend) as http: response = http.delete("http://127.0.0.1:8000/") assert response.status_code == 200 @@ -139,7 +151,8 @@ def test_delete(backend, server): @threadpool -def test_base_url(backend, server): +@pytest.mark.usefixtures("server") +def test_base_url(backend): base_url = "http://127.0.0.1:8000/" with httpx.Client(base_url=base_url, backend=backend) as http: response = http.get("/") From 803a31911a0e81ea2e5b30189ea5faaf0182179b Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Thu, 15 Aug 2019 21:49:52 +0200 Subject: [PATCH 5/9] Fix compatibility issues with <3.7 --- httpx/concurrency.py | 23 +++++++++++++++-------- httpx/interfaces.py | 18 ++++++++++-------- setup.py | 1 + test-requirements.txt | 8 +------- 4 files changed, 27 insertions(+), 23 deletions(-) diff --git a/httpx/concurrency.py b/httpx/concurrency.py index 41e0ae745c..eecaf9f09f 100644 --- a/httpx/concurrency.py +++ b/httpx/concurrency.py @@ -14,11 +14,15 @@ import typing from types import TracebackType -from async_generator import asynccontextmanager +try: + from contextlib import asynccontextmanager +except ImportError: # pragma: no cover + from async_generator import asynccontextmanager # type: ignore from .config import PoolLimits, TimeoutConfig from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout from .interfaces import ( + BaseAsyncContextManager, BaseBackgroundManager, BaseBodyIterator, BaseEvent, @@ -292,13 +296,16 @@ def start_soon(self, coroutine: typing.Callable, *args: typing.Any) -> None: loop = asyncio.get_event_loop() self.tasks.add(loop.create_task(coroutine(*args))) - @asynccontextmanager # type: ignore - async def will_wait_for_first_completed(self) -> typing.AsyncContextManager: - initial_tasks = self.tasks - self.tasks = set() - yield - await asyncio.wait(self.tasks, return_when=asyncio.FIRST_COMPLETED) - self.tasks = initial_tasks.union(self.tasks) + def will_wait_for_first_completed(self) -> BaseAsyncContextManager: + @asynccontextmanager # type: ignore + async def context() -> None: + initial_tasks = self.tasks + self.tasks = set() + yield + await asyncio.wait(self.tasks, return_when=asyncio.FIRST_COMPLETED) + self.tasks = initial_tasks.union(self.tasks) + + return context() # type: ignore async def __aexit__( self, diff --git a/httpx/interfaces.py b/httpx/interfaces.py index ca82a62b62..83b8d95aa4 100644 --- a/httpx/interfaces.py +++ b/httpx/interfaces.py @@ -265,14 +265,8 @@ async def done(self) -> None: raise NotImplementedError() # pragma: no cover -class BaseBackgroundManager: - def start_soon(self, coroutine: typing.Callable, *args: typing.Any) -> None: - raise NotImplementedError() # pragma: no cover - - def will_wait_for_first_completed(self) -> typing.AsyncContextManager: - raise NotImplementedError() # pragma: no cover - - async def __aenter__(self) -> "BaseBackgroundManager": +class BaseAsyncContextManager: + async def __aenter__(self: typing.T) -> typing.T: raise NotImplementedError() # pragma: no cover async def __aexit__( @@ -283,5 +277,13 @@ async def __aexit__( ) -> None: raise NotImplementedError() # pragma: no cover + +class BaseBackgroundManager(BaseAsyncContextManager): + def start_soon(self, coroutine: typing.Callable, *args: typing.Any) -> None: + raise NotImplementedError() # pragma: no cover + + def will_wait_for_first_completed(self) -> BaseAsyncContextManager: + raise NotImplementedError() # pragma: no cover + async def close(self) -> None: await self.__aexit__(None, None, None) diff --git a/setup.py b/setup.py index dee79a46e2..11939f7232 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ def get_packages(package): "hstspreload", "idna==2.*", "rfc3986==1.*", + "async_generator==1.*;python_version<'3.7'" ], classifiers=[ "Development Status :: 3 - Alpha", diff --git a/test-requirements.txt b/test-requirements.txt index 089189b389..d799f50399 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,10 +1,4 @@ -certifi -chardet==3.* -h11==0.8.* -h2==3.* -hstspreload -idna==2.* -rfc3986==1.* +-e . # Optional brotlipy==0.7.* From 369761125c823429d92926b63dc87ab5a91a9005 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Thu, 15 Aug 2019 21:58:31 +0200 Subject: [PATCH 6/9] Fix lint --- setup.py | 2 +- tests/conftest.py | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 11939f7232..083e268cf2 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,7 @@ def get_packages(package): "hstspreload", "idna==2.*", "rfc3986==1.*", - "async_generator==1.*;python_version<'3.7'" + "async_generator==1.*;python_version<'3.7'", ], classifiers=[ "Development Status :: 3 - Alpha", diff --git a/tests/conftest.py b/tests/conftest.py index fde7c4915f..e89e4f0b7c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,11 +13,7 @@ from httpx.concurrency import AsyncioBackend -@pytest.fixture( - params=[ - pytest.param(AsyncioBackend, marks=pytest.mark.asyncio) # type: ignore - ] -) +@pytest.fixture(params=[pytest.param(AsyncioBackend, marks=pytest.mark.asyncio)]) def backend(request): backend_cls = request.param return backend_cls() From 387c5a3dee6b86a9a884f1be0fdcbccae83de710 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Thu, 15 Aug 2019 23:18:12 +0200 Subject: [PATCH 7/9] Remove dependency on async_generator --- httpx/concurrency.py | 29 +++++++++++++++-------------- setup.py | 1 - 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/httpx/concurrency.py b/httpx/concurrency.py index eecaf9f09f..f4551ce8c8 100644 --- a/httpx/concurrency.py +++ b/httpx/concurrency.py @@ -14,11 +14,6 @@ import typing from types import TracebackType -try: - from contextlib import asynccontextmanager -except ImportError: # pragma: no cover - from async_generator import asynccontextmanager # type: ignore - from .config import PoolLimits, TimeoutConfig from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout from .interfaces import ( @@ -297,15 +292,7 @@ def start_soon(self, coroutine: typing.Callable, *args: typing.Any) -> None: self.tasks.add(loop.create_task(coroutine(*args))) def will_wait_for_first_completed(self) -> BaseAsyncContextManager: - @asynccontextmanager # type: ignore - async def context() -> None: - initial_tasks = self.tasks - self.tasks = set() - yield - await asyncio.wait(self.tasks, return_when=asyncio.FIRST_COMPLETED) - self.tasks = initial_tasks.union(self.tasks) - - return context() # type: ignore + return WillWaitForFirstCompleted(self) async def __aexit__( self, @@ -317,3 +304,17 @@ async def __aexit__( await task if exc_type is None: task.result() + + +class WillWaitForFirstCompleted(BaseAsyncContextManager): + def __init__(self, background: BackgroundManager): + self.background = background + self.initial_tasks: typing.Set[asyncio.Task] = set() + + async def __aenter__(self) -> None: + self.initial_tasks = self.background.tasks + self.background.tasks = set() + + async def __aexit__(self, *args: typing.Any) -> None: + await asyncio.wait(self.background.tasks, return_when=asyncio.FIRST_COMPLETED) + self.background.tasks = self.initial_tasks.union(self.background.tasks) diff --git a/setup.py b/setup.py index 083e268cf2..dee79a46e2 100644 --- a/setup.py +++ b/setup.py @@ -51,7 +51,6 @@ def get_packages(package): "hstspreload", "idna==2.*", "rfc3986==1.*", - "async_generator==1.*;python_version<'3.7'", ], classifiers=[ "Development Status :: 3 - Alpha", From 8fbe15932a1bf23a86ef39c8742da5bdfa707e15 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Thu, 15 Aug 2019 00:21:46 +0200 Subject: [PATCH 8/9] Trio concurency backend PoC --- httpx/contrib/__init__.py | 0 httpx/contrib/trio.py | 259 ++++++++++++++++++++++++++++++ test-requirements.txt | 2 + tests/client/test_async_client.py | 31 ++++ tests/conftest.py | 14 +- 5 files changed, 305 insertions(+), 1 deletion(-) create mode 100644 httpx/contrib/__init__.py create mode 100644 httpx/contrib/trio.py diff --git a/httpx/contrib/__init__.py b/httpx/contrib/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/httpx/contrib/trio.py b/httpx/contrib/trio.py new file mode 100644 index 0000000000..37b1b75608 --- /dev/null +++ b/httpx/contrib/trio.py @@ -0,0 +1,259 @@ +import functools +import ssl +import typing +from types import TracebackType + +import trio +import trio.abc + +from httpx.config import PoolLimits, TimeoutConfig +from httpx.exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout +from httpx.interfaces import ( + BaseAsyncContextManager, + BaseBackgroundManager, + BaseBodyIterator, + BaseEvent, + BasePoolSemaphore, + BaseReader, + BaseWriter, + ConcurrencyBackend, + Protocol, +) +from httpx.concurrency import TimeoutFlag + + +class Reader(BaseReader): + def __init__( + self, receive_stream: trio.abc.ReceiveStream, timeout: TimeoutConfig + ) -> None: + self.receive_stream = receive_stream + self.timeout = timeout + self.is_eof = False + + async def read( + self, n: int, timeout: TimeoutConfig = None, flag: TimeoutFlag = None + ) -> bytes: + if timeout is None: + timeout = self.timeout + + while True: + # Check our flag at the first possible moment, and use a fine + # grained retry loop if we're not yet in read-timeout mode. + should_raise = flag is None or flag.raise_on_read_timeout + read_timeout = timeout.read_timeout if should_raise else 0.01 + with trio.move_on_after(read_timeout) as cancel_scope: + data = await self.receive_stream.receive_some(max_bytes=n) + if cancel_scope.cancelled_caught: + if should_raise: + raise ReadTimeout() from None + else: + if data == b"": + self.is_eof = True + return data + + return data + + def is_connection_dropped(self) -> bool: + return self.is_eof + + +class Writer(BaseWriter): + def __init__(self, send_stream: trio.abc.SendStream, timeout: TimeoutConfig): + self.send_stream = send_stream + self.timeout = timeout + + def write_no_block(self, data: bytes) -> None: + self.send_stream.send_all(data) # pragma: nocover + + async def write( + self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None + ) -> None: + if not data: + return + + if timeout is None: + timeout = self.timeout + + while True: + with trio.move_on_after(timeout.write_timeout) as cancel_scope: + await self.send_stream.wait_send_all_might_not_block() + await self.send_stream.send_all(data) + break + if cancel_scope.cancelled_caught: + # We check our flag at the possible moment, in order to + # allow us to suppress write timeouts, if we've since + # switched over to read-timeout mode. + should_raise = flag is None or flag.raise_on_write_timeout + if should_raise: + raise WriteTimeout() from None + + async def close(self) -> None: + await self.send_stream.aclose() + + +class PoolSemaphore(BasePoolSemaphore): + def __init__(self, pool_limits: PoolLimits): + self.pool_limits = pool_limits + + @property + def semaphore(self) -> typing.Optional[trio.Semaphore]: + if not hasattr(self, "_semaphore"): + max_connections = self.pool_limits.hard_limit + if max_connections is None: + self._semaphore = None + else: + self._semaphore = trio.Semaphore( + initial_value=1, max_value=max_connections + ) + return self._semaphore + + async def acquire(self) -> None: + if self.semaphore is None: + return + + timeout = self.pool_limits.pool_timeout + with trio.move_on_after(timeout) as cancel_scope: + await self.semaphore.acquire() + if cancel_scope.cancelled_caught: + raise PoolTimeout() + + def release(self) -> None: + if self.semaphore is None: + return + + self.semaphore.release() + + +class TrioBackend(ConcurrencyBackend): + async def connect( + self, + hostname: str, + port: int, + ssl_context: typing.Optional[ssl.SSLContext], + timeout: TimeoutConfig, + ) -> typing.Tuple[BaseReader, BaseWriter, Protocol]: + with trio.move_on_after(timeout.connect_timeout) as cancel_scope: + if ssl_context is None: + stream = await trio.open_tcp_stream(hostname, port) + else: + stream = await trio.open_ssl_over_tcp_stream( + hostname, port, ssl_context=ssl_context + ) + await stream.do_handshake() + if cancel_scope.cancelled_caught: + raise ConnectTimeout() + + if ssl_context is None: + ident = "http/1.1" # TODO + else: + ident = stream.selected_alpn_protocol() + if ident is None: + ident = stream.selected_npn_protocol() + + reader = Reader(receive_stream=stream, timeout=timeout) + writer = Writer(send_stream=stream, timeout=timeout) + protocol = Protocol.HTTP_2 if ident == "h2" else Protocol.HTTP_11 + + return reader, writer, protocol + + def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: + return PoolSemaphore(limits) + + async def run_in_threadpool( + self, func: typing.Callable, *args: typing.Any, **kwargs: typing.Any + ) -> typing.Any: + if kwargs: + # trio.to_thread.run_async doesn't accept 'kwargs', so bind them in here + func = functools.partial(func, **kwargs) + return await trio.to_thread.run_sync(func, *args) + + def run( + self, coroutine: typing.Callable, *args: typing.Any, **kwargs: typing.Any + ) -> typing.Any: + if kwargs: + coroutine = functools.partial(coroutine, **kwargs) + return trio.run(coroutine, *args) + + async def sleep(self, seconds: float) -> None: + await trio.sleep(seconds) + + def create_event(self) -> BaseEvent: + return trio.Event() # type: ignore + + def background_manager(self) -> "BackgroundManager": + return BackgroundManager() + + def body_iterator(self) -> "BodyIterator": + return BodyIterator() + + +class BackgroundManager(BaseBackgroundManager): + nursery: trio.Nursery + + def __init__(self) -> None: + self.nursery_manager = trio.open_nursery() + self.convert = lambda coroutine: coroutine + + def start_soon(self, coroutine: typing.Callable, *args: typing.Any) -> None: + self.nursery.start_soon(self.convert(coroutine), *args) + + def will_wait_for_first_completed(self) -> BaseAsyncContextManager: + return WillWaitForFirstCompleted(self) + + async def __aenter__(self) -> "BackgroundManager": + self.nursery = await self.nursery_manager.__aenter__() + return self + + async def __aexit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.nursery_manager.__aexit__(exc_type, exc_value, traceback) + + +class BodyIterator(BaseBodyIterator): + def __init__(self) -> None: + self.send_channel, self.receive_channel = trio.open_memory_channel() + + async def iterate(self) -> typing.AsyncIterator[bytes]: + async with self.receive_channel: + async for data in self.receive_channel: + assert isinstance(data, bytes) + yield data + + async def put(self, data: bytes) -> None: + await self.send_channel.send(data) + + async def done(self) -> None: + await self.send_channel.aclose() + + +class WillWaitForFirstCompleted(BaseAsyncContextManager): + nursery: trio.Nursery + + def __init__(self, background: BackgroundManager): + self.background = background + self.send_channel, self.receive_channel = trio.open_memory_channel(0) + self.initial_convert = self.background.convert + self.initial_nursery = self.background.nursery + self.nursery_manager = trio.open_nursery() + + def convert(self, coroutine: typing.Callable) -> typing.Callable: + async def wrapped(*args: typing.Any) -> None: + await self.send_channel.send(await coroutine(*args)) + + return wrapped + + async def __aenter__(self) -> None: + self.background.convert = self.convert + self.nursery = await self.nursery_manager.__aenter__() + self.background.nursery = self.nursery + + async def __aexit__(self, *args: typing.Any) -> None: + await self.receive_channel.receive() + self.nursery.cancel_scope.cancel() + await self.nursery.__aexit__(*args) + self.background.convert = self.initial_convert + self.background.nursery = self.initial_nursery diff --git a/test-requirements.txt b/test-requirements.txt index d799f50399..e3917e53ba 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -11,6 +11,8 @@ isort mypy pytest pytest-asyncio +trio +pytest-trio pytest-cov trustme uvicorn diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index ba6971ca84..2c45f0e42d 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -1,8 +1,27 @@ +import functools + import pytest import httpx +def threadpool(func): + """ + Async tests should run in a separate thread to the uvicorn server to prevent event + loop clashes (e.g. asyncio for uvicorn, trio for tests). + """ + + @functools.wraps(func) + async def wrapped(backend, *args, **kwargs): + backend_for_thread = type(backend)() + await backend.run_in_threadpool( + backend_for_thread.run, func, backend_for_thread, *args, **kwargs + ) + + return wrapped + + +@threadpool @pytest.mark.usefixtures("server") async def test_get(backend): url = "http://127.0.0.1:8000/" @@ -15,6 +34,7 @@ async def test_get(backend): assert repr(response) == "" +@threadpool @pytest.mark.usefixtures("server") async def test_post(backend): url = "http://127.0.0.1:8000/" @@ -23,6 +43,7 @@ async def test_post(backend): assert response.status_code == 200 +@threadpool @pytest.mark.usefixtures("server") async def test_post_json(backend): url = "http://127.0.0.1:8000/" @@ -31,6 +52,7 @@ async def test_post_json(backend): assert response.status_code == 200 +@threadpool @pytest.mark.usefixtures("server") async def test_stream_response(backend): async with httpx.AsyncClient(backend=backend) as client: @@ -41,6 +63,7 @@ async def test_stream_response(backend): assert response.content == b"Hello, world!" +@threadpool @pytest.mark.usefixtures("server") async def test_access_content_stream_response(backend): async with httpx.AsyncClient(backend=backend) as client: @@ -50,6 +73,7 @@ async def test_access_content_stream_response(backend): response.content +@threadpool @pytest.mark.usefixtures("server") async def test_stream_request(backend): async def hello_world(): @@ -63,6 +87,7 @@ async def hello_world(): assert response.status_code == 200 +@threadpool @pytest.mark.usefixtures("server") async def test_raise_for_status(backend): async with httpx.AsyncClient(backend=backend) as client: @@ -79,6 +104,7 @@ async def test_raise_for_status(backend): assert response.raise_for_status() is None +@threadpool @pytest.mark.usefixtures("server") async def test_options(backend): url = "http://127.0.0.1:8000/" @@ -88,6 +114,7 @@ async def test_options(backend): assert response.text == "Hello, world!" +@threadpool @pytest.mark.usefixtures("server") async def test_head(backend): url = "http://127.0.0.1:8000/" @@ -97,6 +124,7 @@ async def test_head(backend): assert response.text == "" +@threadpool @pytest.mark.usefixtures("server") async def test_put(backend): url = "http://127.0.0.1:8000/" @@ -105,6 +133,7 @@ async def test_put(backend): assert response.status_code == 200 +@threadpool @pytest.mark.usefixtures("server") async def test_patch(backend): url = "http://127.0.0.1:8000/" @@ -113,6 +142,7 @@ async def test_patch(backend): assert response.status_code == 200 +@threadpool @pytest.mark.usefixtures("server") async def test_delete(backend): url = "http://127.0.0.1:8000/" @@ -122,6 +152,7 @@ async def test_delete(backend): assert response.text == "Hello, world!" +@threadpool @pytest.mark.usefixtures("server") async def test_100_continue(backend): url = "http://127.0.0.1:8000/echo_body" diff --git a/tests/conftest.py b/tests/conftest.py index e89e4f0b7c..13d2e63d01 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,10 +12,22 @@ from httpx.concurrency import AsyncioBackend +try: + from httpx.contrib.trio import TrioBackend +except ImportError: + TrioBackend = None # type: ignore -@pytest.fixture(params=[pytest.param(AsyncioBackend, marks=pytest.mark.asyncio)]) + +@pytest.fixture( + params=[ + pytest.param(AsyncioBackend, marks=pytest.mark.asyncio), + pytest.param(TrioBackend, marks=pytest.mark.trio), + ] +) def backend(request): backend_cls = request.param + if backend_cls is None: + pytest.skip() return backend_cls() From 8fc2648923f91b14fa83433ccea1f3baac1658e0 Mon Sep 17 00:00:00 2001 From: florimondmanca Date: Fri, 16 Aug 2019 22:59:36 +0200 Subject: [PATCH 9/9] Resolve event loop clashes in async client tests --- tests/client/test_async_client.py | 31 ---------------------- tests/client/test_client.py | 28 -------------------- tests/conftest.py | 44 +++++++++++++++++++++++++++++-- 3 files changed, 42 insertions(+), 61 deletions(-) diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index 2c45f0e42d..ba6971ca84 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -1,27 +1,8 @@ -import functools - import pytest import httpx -def threadpool(func): - """ - Async tests should run in a separate thread to the uvicorn server to prevent event - loop clashes (e.g. asyncio for uvicorn, trio for tests). - """ - - @functools.wraps(func) - async def wrapped(backend, *args, **kwargs): - backend_for_thread = type(backend)() - await backend.run_in_threadpool( - backend_for_thread.run, func, backend_for_thread, *args, **kwargs - ) - - return wrapped - - -@threadpool @pytest.mark.usefixtures("server") async def test_get(backend): url = "http://127.0.0.1:8000/" @@ -34,7 +15,6 @@ async def test_get(backend): assert repr(response) == "" -@threadpool @pytest.mark.usefixtures("server") async def test_post(backend): url = "http://127.0.0.1:8000/" @@ -43,7 +23,6 @@ async def test_post(backend): assert response.status_code == 200 -@threadpool @pytest.mark.usefixtures("server") async def test_post_json(backend): url = "http://127.0.0.1:8000/" @@ -52,7 +31,6 @@ async def test_post_json(backend): assert response.status_code == 200 -@threadpool @pytest.mark.usefixtures("server") async def test_stream_response(backend): async with httpx.AsyncClient(backend=backend) as client: @@ -63,7 +41,6 @@ async def test_stream_response(backend): assert response.content == b"Hello, world!" -@threadpool @pytest.mark.usefixtures("server") async def test_access_content_stream_response(backend): async with httpx.AsyncClient(backend=backend) as client: @@ -73,7 +50,6 @@ async def test_access_content_stream_response(backend): response.content -@threadpool @pytest.mark.usefixtures("server") async def test_stream_request(backend): async def hello_world(): @@ -87,7 +63,6 @@ async def hello_world(): assert response.status_code == 200 -@threadpool @pytest.mark.usefixtures("server") async def test_raise_for_status(backend): async with httpx.AsyncClient(backend=backend) as client: @@ -104,7 +79,6 @@ async def test_raise_for_status(backend): assert response.raise_for_status() is None -@threadpool @pytest.mark.usefixtures("server") async def test_options(backend): url = "http://127.0.0.1:8000/" @@ -114,7 +88,6 @@ async def test_options(backend): assert response.text == "Hello, world!" -@threadpool @pytest.mark.usefixtures("server") async def test_head(backend): url = "http://127.0.0.1:8000/" @@ -124,7 +97,6 @@ async def test_head(backend): assert response.text == "" -@threadpool @pytest.mark.usefixtures("server") async def test_put(backend): url = "http://127.0.0.1:8000/" @@ -133,7 +105,6 @@ async def test_put(backend): assert response.status_code == 200 -@threadpool @pytest.mark.usefixtures("server") async def test_patch(backend): url = "http://127.0.0.1:8000/" @@ -142,7 +113,6 @@ async def test_patch(backend): assert response.status_code == 200 -@threadpool @pytest.mark.usefixtures("server") async def test_delete(backend): url = "http://127.0.0.1:8000/" @@ -152,7 +122,6 @@ async def test_delete(backend): assert response.text == "Hello, world!" -@threadpool @pytest.mark.usefixtures("server") async def test_100_continue(backend): url = "http://127.0.0.1:8000/echo_body" diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 9334a1ea14..8d09ad5a3c 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -1,24 +1,8 @@ -import functools - import pytest import httpx -def threadpool(func): - """ - Our sync tests should run in seperate thread to the uvicorn server. - """ - - @functools.wraps(func) - async def wrapped(backend, *args, **kwargs): - backend_for_thread = type(backend)() - await backend.run_in_threadpool(func, backend_for_thread, *args, **kwargs) - - return wrapped - - -@threadpool @pytest.mark.usefixtures("server") def test_get(backend): url = "http://127.0.0.1:8000/" @@ -36,7 +20,6 @@ def test_get(backend): assert repr(response) == "" -@threadpool @pytest.mark.usefixtures("server") def test_post(backend): with httpx.Client(backend=backend) as http: @@ -45,7 +28,6 @@ def test_post(backend): assert response.reason_phrase == "OK" -@threadpool @pytest.mark.usefixtures("server") def test_post_json(backend): with httpx.Client(backend=backend) as http: @@ -54,7 +36,6 @@ def test_post_json(backend): assert response.reason_phrase == "OK" -@threadpool @pytest.mark.usefixtures("server") def test_stream_response(backend): with httpx.Client(backend=backend) as http: @@ -64,7 +45,6 @@ def test_stream_response(backend): assert content == b"Hello, world!" -@threadpool @pytest.mark.usefixtures("server") def test_stream_iterator(backend): with httpx.Client(backend=backend) as http: @@ -76,7 +56,6 @@ def test_stream_iterator(backend): assert body == b"Hello, world!" -@threadpool @pytest.mark.usefixtures("server") def test_raw_iterator(backend): with httpx.Client(backend=backend) as http: @@ -89,7 +68,6 @@ def test_raw_iterator(backend): response.close() # TODO: should Response be available as context managers? -@threadpool @pytest.mark.usefixtures("server") def test_raise_for_status(backend): with httpx.Client(backend=backend) as client: @@ -105,7 +83,6 @@ def test_raise_for_status(backend): assert response.raise_for_status() is None -@threadpool @pytest.mark.usefixtures("server") def test_options(backend): with httpx.Client(backend=backend) as http: @@ -114,7 +91,6 @@ def test_options(backend): assert response.reason_phrase == "OK" -@threadpool @pytest.mark.usefixtures("server") def test_head(backend): with httpx.Client(backend=backend) as http: @@ -123,7 +99,6 @@ def test_head(backend): assert response.reason_phrase == "OK" -@threadpool @pytest.mark.usefixtures("server") def test_put(backend): with httpx.Client(backend=backend) as http: @@ -132,7 +107,6 @@ def test_put(backend): assert response.reason_phrase == "OK" -@threadpool @pytest.mark.usefixtures("server") def test_patch(backend): with httpx.Client(backend=backend) as http: @@ -141,7 +115,6 @@ def test_patch(backend): assert response.reason_phrase == "OK" -@threadpool @pytest.mark.usefixtures("server") def test_delete(backend): with httpx.Client(backend=backend) as http: @@ -150,7 +123,6 @@ def test_delete(backend): assert response.reason_phrase == "OK" -@threadpool @pytest.mark.usefixtures("server") def test_base_url(backend): base_url = "http://127.0.0.1:8000/" diff --git a/tests/conftest.py b/tests/conftest.py index 13d2e63d01..d0fdbd3fa0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,7 @@ import asyncio +import functools +import inspect +import threading import pytest import trustme @@ -18,10 +21,15 @@ TrioBackend = None # type: ignore +# All backends should cause tests to be marked (and run under) asyncio, +# because that is the only I/O implementation uvicorn can run on. +MARK_ASYNC = pytest.mark.asyncio + + @pytest.fixture( params=[ - pytest.param(AsyncioBackend, marks=pytest.mark.asyncio), - pytest.param(TrioBackend, marks=pytest.mark.trio), + pytest.param(AsyncioBackend, marks=MARK_ASYNC), + pytest.param(TrioBackend, marks=MARK_ASYNC), ] ) def backend(request): @@ -31,6 +39,38 @@ def backend(request): return backend_cls() +@pytest.hookimpl(tryfirst=True, hookwrapper=True) +def pytest_pyfunc_call(pyfuncitem): + """ + Run test functions parametrized by the concurrency `backend` in the asyncio + threadpool instead of a normal function call. + + We do this to prevent the backend-specific event loop from clashing with asyncio. + """ + if "backend" in pyfuncitem.fixturenames: + func = pyfuncitem.obj + + if inspect.iscoroutinefunction(func): + + @functools.wraps(func) + async def wrapped(backend, *args, **kwargs): + asyncio_backend = AsyncioBackend() + await asyncio_backend.run_in_threadpool( + backend.run, func, backend, *args, **kwargs + ) + + else: + + @functools.wraps(func) + async def wrapped(backend, *args, **kwargs): + asyncio_backend = AsyncioBackend() + await asyncio_backend.run_in_threadpool(func, backend, *args, **kwargs) + + pyfuncitem.obj = wrapped + + yield + + async def app(scope, receive, send): assert scope["type"] == "http" if scope["path"] == "/slow_response":