diff --git a/httpx/__init__.py b/httpx/__init__.py index dec5cff19c..c2925d5c81 100644 --- a/httpx/__init__.py +++ b/httpx/__init__.py @@ -1,7 +1,7 @@ from .__version__ import __description__, __title__, __version__ from .api import delete, get, head, options, patch, post, put, request from .client import AsyncClient, Client -from .concurrency import AsyncioBackend +from .concurrency.asyncio import AsyncioBackend from .config import ( USER_AGENT, CertTypes, diff --git a/httpx/client.py b/httpx/client.py index fd86fb0c83..fcb8937048 100644 --- a/httpx/client.py +++ b/httpx/client.py @@ -5,7 +5,7 @@ import hstspreload from .auth import HTTPBasicAuth -from .concurrency import AsyncioBackend +from .concurrency.asyncio import AsyncioBackend from .config import ( DEFAULT_MAX_REDIRECTS, DEFAULT_POOL_LIMITS, @@ -77,7 +77,7 @@ def __init__( if param_count == 2: dispatch = WSGIDispatch(app=app) else: - dispatch = ASGIDispatch(app=app) + dispatch = ASGIDispatch(app=app, backend=backend) if dispatch is None: async_dispatch: AsyncDispatcher = ConnectionPool( diff --git a/httpx/concurrency/__init__.py b/httpx/concurrency/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/httpx/concurrency.py b/httpx/concurrency/asyncio.py similarity index 75% rename from httpx/concurrency.py rename to httpx/concurrency/asyncio.py index f1bf585448..2d7c4fd88c 100644 --- a/httpx/concurrency.py +++ b/httpx/concurrency/asyncio.py @@ -14,16 +14,20 @@ import typing from types import TracebackType -from .config import PoolLimits, TimeoutConfig -from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout -from .interfaces import ( +from ..config import PoolLimits, TimeoutConfig +from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout +from ..interfaces import ( + BaseAsyncContextManager, BaseBackgroundManager, + BaseBodyIterator, + BaseEvent, BasePoolSemaphore, BaseReader, BaseWriter, ConcurrencyBackend, Protocol, ) +from .utils import TimeoutFlag SSL_MONKEY_PATCH_APPLIED = False @@ -49,38 +53,6 @@ def _fixed_write(self, data: bytes) -> None: # type: ignore MonkeyPatch.write = _fixed_write -class TimeoutFlag: - """ - A timeout flag holds a state of either read-timeout or write-timeout mode. - - We use this so that we can attempt both reads and writes concurrently, while - only enforcing timeouts in one direction. - - During a request/response cycle we start in write-timeout mode. - - Once we've sent a request fully, or once we start seeing a response, - then we switch to read-timeout mode instead. - """ - - def __init__(self) -> None: - self.raise_on_read_timeout = False - self.raise_on_write_timeout = True - - def set_read_timeouts(self) -> None: - """ - Set the flag to read-timeout mode. - """ - self.raise_on_read_timeout = True - self.raise_on_write_timeout = False - - def set_write_timeouts(self) -> None: - """ - Set the flag to write-timeout mode. - """ - self.raise_on_read_timeout = False - self.raise_on_write_timeout = True - - class Reader(BaseReader): def __init__( self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig @@ -247,21 +219,26 @@ def run( def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore: return PoolSemaphore(limits) - def background_manager( - self, coroutine: typing.Callable, args: typing.Any - ) -> "BackgroundManager": - return BackgroundManager(coroutine, args) + def create_event(self) -> BaseEvent: + return typing.cast(BaseEvent, asyncio.Event()) + + def background_manager(self) -> "BackgroundManager": + return BackgroundManager() + + def body_iterator(self) -> "BodyIterator": + return BodyIterator() class BackgroundManager(BaseBackgroundManager): - def __init__(self, coroutine: typing.Callable, args: typing.Any) -> None: - self.coroutine = coroutine - self.args = args + def __init__(self) -> None: + self.tasks: typing.Set[asyncio.Task] = set() - async def __aenter__(self) -> "BackgroundManager": + def start_soon(self, coroutine: typing.Callable, *args: typing.Any) -> None: loop = asyncio.get_event_loop() - self.task = loop.create_task(self.coroutine(*self.args)) - return self + self.tasks.add(loop.create_task(coroutine(*args))) + + def wait_first_completed(self) -> "WaitFirstCompleted": + return WaitFirstCompleted(self) async def __aexit__( self, @@ -269,6 +246,53 @@ async def __aexit__( exc_value: BaseException = None, traceback: TracebackType = None, ) -> None: - await self.task - if exc_type is None: - self.task.result() + done, pending = await asyncio.wait(self.tasks, timeout=1e-3) + + for task in pending: + task.cancel() + + for task in done: + await task + if exc_type is None: + task.result() + + +class WaitFirstCompleted(BaseAsyncContextManager): + def __init__(self, background: BackgroundManager): + self.background = background + self.initial_tasks: typing.Set[asyncio.Task] = set() + + async def __aenter__(self) -> "WaitFirstCompleted": + self.initial_tasks = self.background.tasks + self.background.tasks = set() + return self + + async def __aexit__(self, *args: typing.Any) -> None: + _, pending = await asyncio.wait( + self.background.tasks, return_when=asyncio.FIRST_COMPLETED + ) + self.background.tasks = self.initial_tasks | typing.cast( + typing.Set[asyncio.Task], pending + ) + + +class BodyIterator(BaseBodyIterator): + def __init__(self) -> None: + self._queue: asyncio.Queue[typing.Union[bytes, object]] = asyncio.Queue( + maxsize=1 + ) + self._done = object() + + async def iterate(self) -> typing.AsyncIterator[bytes]: + while True: + data = await self._queue.get() + if data is self._done: + break + assert isinstance(data, bytes) + yield data + + async def put(self, data: bytes) -> None: + await self._queue.put(data) + + async def done(self) -> None: + await self._queue.put(self._done) diff --git a/httpx/concurrency/utils.py b/httpx/concurrency/utils.py new file mode 100644 index 0000000000..a10e9d2827 --- /dev/null +++ b/httpx/concurrency/utils.py @@ -0,0 +1,30 @@ +class TimeoutFlag: + """ + A timeout flag holds a state of either read-timeout or write-timeout mode. + + We use this so that we can attempt both reads and writes concurrently, while + only enforcing timeouts in one direction. + + During a request/response cycle we start in write-timeout mode. + + Once we've sent a request fully, or once we start seeing a response, + then we switch to read-timeout mode instead. + """ + + def __init__(self) -> None: + self.raise_on_read_timeout = False + self.raise_on_write_timeout = True + + def set_read_timeouts(self) -> None: + """ + Set the flag to read-timeout mode. + """ + self.raise_on_read_timeout = True + self.raise_on_write_timeout = False + + def set_write_timeouts(self) -> None: + """ + Set the flag to write-timeout mode. + """ + self.raise_on_read_timeout = False + self.raise_on_write_timeout = True diff --git a/httpx/dispatch/asgi.py b/httpx/dispatch/asgi.py index 23eebb0fc9..64a0344685 100644 --- a/httpx/dispatch/asgi.py +++ b/httpx/dispatch/asgi.py @@ -1,8 +1,8 @@ -import asyncio import typing +from ..concurrency.asyncio import AsyncioBackend from ..config import CertTypes, TimeoutTypes, VerifyTypes -from ..interfaces import AsyncDispatcher +from ..interfaces import AsyncDispatcher, ConcurrencyBackend from ..models import AsyncRequest, AsyncResponse @@ -44,11 +44,13 @@ class ASGIDispatch(AsyncDispatcher): def __init__( self, app: typing.Callable, + backend: ConcurrencyBackend = None, raise_app_exceptions: bool = True, root_path: str = "", client: typing.Tuple[str, int] = ("127.0.0.1", 123), ) -> None: self.app = app + self.backend = AsyncioBackend() if backend is None else backend self.raise_app_exceptions = raise_app_exceptions self.root_path = root_path self.client = client @@ -78,8 +80,8 @@ async def send( app_exc = None status_code = None headers = None - response_started = asyncio.Event() - response_body = BodyIterator() + response_started = self.backend.create_event() + response_body = self.backend.body_iterator() request_stream = request.stream() async def receive() -> dict: @@ -115,20 +117,15 @@ async def run_app() -> None: finally: await response_body.done() - # Really we'd like to push all `asyncio` logic into concurrency.py, - # with a standardized interface, so that we can support other event - # loop implementations, such as Trio and Curio. - # That's a bit fiddly here, so we're not yet supporting using a custom - # `ConcurrencyBackend` with the `Client(app=asgi_app)` case. - loop = asyncio.get_event_loop() - app_task = loop.create_task(run_app()) - response_task = loop.create_task(response_started.wait()) + background = self.backend.background_manager() + await background.__aenter__() - tasks = {app_task, response_task} # type: typing.Set[asyncio.Task] - - await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + async with background.wait_first_completed(): + background.start_soon(run_app) + background.start_soon(response_started.wait) if app_exc is not None and self.raise_app_exceptions: + await background.close() raise app_exc assert response_started.is_set(), "application did not return a response." @@ -136,9 +133,9 @@ async def run_app() -> None: assert headers is not None async def on_close() -> None: - nonlocal app_task, response_body + nonlocal background, response_body await response_body.drain() - await app_task + await background.close() if app_exc is not None and self.raise_app_exceptions: raise app_exc @@ -150,47 +147,3 @@ async def on_close() -> None: on_close=on_close, request=request, ) - - -class BodyIterator: - """ - Provides a byte-iterator interface that the client can use to - ingest the response content from. - """ - - def __init__(self) -> None: - self._queue = asyncio.Queue( - maxsize=1 - ) # type: asyncio.Queue[typing.Union[bytes, object]] - self._done = object() - - async def iterate(self) -> typing.AsyncIterator[bytes]: - """ - A byte-iterator, used by the client to consume the response body. - """ - while True: - data = await self._queue.get() - if data is self._done: - break - assert isinstance(data, bytes) - yield data - - async def drain(self) -> None: - """ - Drain any remaining body, in order to allow any blocked `put()` calls - to complete. - """ - async for chunk in self.iterate(): - pass # pragma: no cover - - async def put(self, data: bytes) -> None: - """ - Used by the server to add data to the response body. - """ - await self._queue.put(data) - - async def done(self) -> None: - """ - Used by the server to signal the end of the response body. - """ - await self._queue.put(self._done) diff --git a/httpx/dispatch/connection.py b/httpx/dispatch/connection.py index b51fec688b..fa68966a5c 100644 --- a/httpx/dispatch/connection.py +++ b/httpx/dispatch/connection.py @@ -1,7 +1,7 @@ import functools import typing -from ..concurrency import AsyncioBackend +from ..concurrency.asyncio import AsyncioBackend from ..config import ( DEFAULT_TIMEOUT_CONFIG, CertTypes, @@ -32,7 +32,9 @@ def __init__( self.origin = Origin(origin) if isinstance(origin, str) else origin self.ssl = SSLConfig(cert=cert, verify=verify) self.timeout = TimeoutConfig(timeout) - self.backend = AsyncioBackend() if backend is None else backend + self.backend = typing.cast( + ConcurrencyBackend, AsyncioBackend() if backend is None else backend + ) self.release_func = release_func self.h11_connection = None # type: typing.Optional[HTTP11Connection] self.h2_connection = None # type: typing.Optional[HTTP2Connection] diff --git a/httpx/dispatch/connection_pool.py b/httpx/dispatch/connection_pool.py index 3090a9a514..563d709ffd 100644 --- a/httpx/dispatch/connection_pool.py +++ b/httpx/dispatch/connection_pool.py @@ -1,6 +1,6 @@ import typing -from ..concurrency import AsyncioBackend +from ..concurrency.asyncio import AsyncioBackend from ..config import ( DEFAULT_POOL_LIMITS, DEFAULT_TIMEOUT_CONFIG, @@ -91,7 +91,9 @@ def __init__( self.keepalive_connections = ConnectionStore() self.active_connections = ConnectionStore() - self.backend = AsyncioBackend() if backend is None else backend + self.backend = typing.cast( + ConcurrencyBackend, AsyncioBackend() if backend is None else backend + ) self.max_connections = self.backend.get_semaphore(pool_limits) @property diff --git a/httpx/dispatch/http11.py b/httpx/dispatch/http11.py index 623ac1df6b..97b6b81b7e 100644 --- a/httpx/dispatch/http11.py +++ b/httpx/dispatch/http11.py @@ -2,7 +2,7 @@ import h11 -from ..concurrency import TimeoutFlag +from ..concurrency.utils import TimeoutFlag from ..config import TimeoutConfig, TimeoutTypes from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend from ..models import AsyncRequest, AsyncResponse @@ -48,7 +48,8 @@ async def send( await self._send_request(request, timeout) task, args = self._send_request_data, [request.stream(), timeout] - async with self.backend.background_manager(task, args=args): + async with self.backend.background_manager() as background: + background.start_soon(task, *args) http_version, status_code, headers = await self._receive_response(timeout) content = self._receive_response_data(timeout) diff --git a/httpx/dispatch/http2.py b/httpx/dispatch/http2.py index 980b07b25c..c118b54a6e 100644 --- a/httpx/dispatch/http2.py +++ b/httpx/dispatch/http2.py @@ -4,7 +4,7 @@ import h2.connection import h2.events -from ..concurrency import TimeoutFlag +from ..concurrency.utils import TimeoutFlag from ..config import TimeoutConfig, TimeoutTypes from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend from ..models import AsyncRequest, AsyncResponse @@ -44,7 +44,8 @@ async def send( self.timeout_flags[stream_id] = TimeoutFlag() task, args = self.send_request_data, [stream_id, request.stream(), timeout] - async with self.backend.background_manager(task, args=args): + async with self.backend.background_manager() as background: + background.start_soon(task, *args) status_code, headers = await self.receive_response(stream_id, timeout) content = self.body_iter(stream_id, timeout) on_close = functools.partial(self.response_closed, stream_id=stream_id) diff --git a/httpx/interfaces.py b/httpx/interfaces.py index 2b4edf4d3c..f5919b2fc6 100644 --- a/httpx/interfaces.py +++ b/httpx/interfaces.py @@ -118,6 +118,17 @@ def __exit__( self.close() +class BaseEvent: + def set(self) -> None: + raise NotImplementedError() # pragma: no cover + + def is_set(self) -> bool: + raise NotImplementedError() # pragma: no cover + + async def wait(self) -> None: + raise NotImplementedError() # pragma: no cover + + class BaseReader: """ A stream reader. Abstracts away any asyncio-specific interfaces @@ -211,20 +222,71 @@ def iterate(self, async_iterator): # type: ignore except StopAsyncIteration: break - def background_manager( - self, coroutine: typing.Callable, args: typing.Any - ) -> "BaseBackgroundManager": + def create_event(self) -> BaseEvent: raise NotImplementedError() # pragma: no cover + def background_manager(self) -> "BaseBackgroundManager": + raise NotImplementedError() # pragma: no cover -class BaseBackgroundManager: - async def __aenter__(self) -> "BaseBackgroundManager": + def body_iterator(self) -> "BaseBodyIterator": raise NotImplementedError() # pragma: no cover + +class BaseAsyncContextManager: + async def __aenter__(self: typing.T) -> typing.T: + return self # pragma: no cover + async def __aexit__( self, exc_type: typing.Type[BaseException] = None, exc_value: BaseException = None, traceback: TracebackType = None, ) -> None: + pass # pragma: no cover + + +class BaseBackgroundManager(BaseAsyncContextManager): + def start_soon(self, coroutine: typing.Callable, *args: typing.Any) -> None: + raise NotImplementedError() # pragma: no cover + + def wait_first_completed(self) -> BaseAsyncContextManager: + """ + On exit, wait until at least one task started within the block has completed. + """ + raise NotImplementedError() # pragma: no cover + + async def close(self) -> None: + await self.__aexit__(None, None, None) + + +class BaseBodyIterator: + """ + Provides a byte-iterator interface that the client can use to + ingest the response content from. + """ + + def iterate(self) -> typing.AsyncIterator[bytes]: + """ + A byte-iterator, used by the client to consume the response body. + """ + raise NotImplementedError() # pragma: no cover + + async def drain(self) -> None: + """ + Drain any remaining body, in order to allow any blocked `put()` calls + to complete. + """ + async for chunk in self.iterate(): + pass # pragma: no cover + + async def put(self, data: bytes) -> None: + """ + Used by the server to add data to the response body. + """ + raise NotImplementedError() # pragma: no cover + + async def done(self) -> None: + """ + Used by the server to signal the end of the response body. + """ raise NotImplementedError() # pragma: no cover diff --git a/tests/client/test_async_client.py b/tests/client/test_async_client.py index b037085f25..3bcf64b6f2 100644 --- a/tests/client/test_async_client.py +++ b/tests/client/test_async_client.py @@ -3,10 +3,14 @@ import httpx -@pytest.mark.asyncio -async def test_get(server): +@pytest.fixture +def client(backend): + return httpx.AsyncClient(backend=backend) + + +async def test_get(server, client: httpx.AsyncClient): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with client: response = await client.get(url) assert response.status_code == 200 assert response.text == "Hello, world!" @@ -15,25 +19,22 @@ async def test_get(server): assert repr(response) == "" -@pytest.mark.asyncio -async def test_post(server): +async def test_post(server, client: httpx.AsyncClient): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with client: response = await client.post(url, data=b"Hello, world!") assert response.status_code == 200 -@pytest.mark.asyncio -async def test_post_json(server): +async def test_post_json(server, client: httpx.AsyncClient): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with client: response = await client.post(url, json={"text": "Hello, world!"}) assert response.status_code == 200 -@pytest.mark.asyncio -async def test_stream_response(server): - async with httpx.AsyncClient() as client: +async def test_stream_response(server, client: httpx.AsyncClient): + async with client: response = await client.request("GET", "http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 body = await response.read() @@ -41,31 +42,28 @@ async def test_stream_response(server): assert response.content == b"Hello, world!" -@pytest.mark.asyncio -async def test_access_content_stream_response(server): - async with httpx.AsyncClient() as client: +async def test_access_content_stream_response(server, client: httpx.AsyncClient): + async with client: response = await client.request("GET", "http://127.0.0.1:8000/", stream=True) assert response.status_code == 200 with pytest.raises(httpx.ResponseNotRead): response.content -@pytest.mark.asyncio -async def test_stream_request(server): +async def test_stream_request(server, client: httpx.AsyncClient): async def hello_world(): yield b"Hello, " yield b"world!" - async with httpx.AsyncClient() as client: + async with client: response = await client.request( "POST", "http://127.0.0.1:8000/", data=hello_world() ) assert response.status_code == 200 -@pytest.mark.asyncio -async def test_raise_for_status(server): - async with httpx.AsyncClient() as client: +async def test_raise_for_status(server, client: httpx.AsyncClient): + async with client: for status_code in (200, 400, 404, 500, 505): response = await client.request( "GET", f"http://127.0.0.1:8000/status/{status_code}" @@ -79,56 +77,50 @@ async def test_raise_for_status(server): assert response.raise_for_status() is None -@pytest.mark.asyncio -async def test_options(server): +async def test_options(server, client: httpx.AsyncClient): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with client: response = await client.options(url) assert response.status_code == 200 assert response.text == "Hello, world!" -@pytest.mark.asyncio -async def test_head(server): +async def test_head(server, client: httpx.AsyncClient): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with client: response = await client.head(url) assert response.status_code == 200 assert response.text == "" -@pytest.mark.asyncio -async def test_put(server): +async def test_put(server, client: httpx.AsyncClient): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with client: response = await client.put(url, data=b"Hello, world!") assert response.status_code == 200 -@pytest.mark.asyncio -async def test_patch(server): +async def test_patch(server, client: httpx.AsyncClient): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with client: response = await client.patch(url, data=b"Hello, world!") assert response.status_code == 200 -@pytest.mark.asyncio -async def test_delete(server): +async def test_delete(server, client: httpx.AsyncClient): url = "http://127.0.0.1:8000/" - async with httpx.AsyncClient() as client: + async with client: response = await client.delete(url) assert response.status_code == 200 assert response.text == "Hello, world!" -@pytest.mark.asyncio -async def test_100_continue(server): +async def test_100_continue(server, client: httpx.AsyncClient): url = "http://127.0.0.1:8000/echo_body" headers = {"Expect": "100-continue"} data = b"Echo request body" - async with httpx.AsyncClient() as client: + async with client: response = await client.post(url, headers=headers, data=data) assert response.status_code == 200 diff --git a/tests/client/test_redirects.py b/tests/client/test_redirects.py index 3062733a73..8f5335d6e5 100644 --- a/tests/client/test_redirects.py +++ b/tests/client/test_redirects.py @@ -100,36 +100,33 @@ async def send( return AsyncResponse(codes.OK, content=b"Hello, world!", request=request) -@pytest.mark.asyncio -async def test_redirect_301(): - client = AsyncClient(dispatch=MockDispatch()) +@pytest.fixture +def client(backend): + return AsyncClient(dispatch=MockDispatch(), backend=backend) + + +async def test_redirect_301(client: AsyncClient): response = await client.post("https://example.org/redirect_301") assert response.status_code == codes.OK assert response.url == URL("https://example.org/") assert len(response.history) == 1 -@pytest.mark.asyncio -async def test_redirect_302(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_redirect_302(client: AsyncClient): response = await client.post("https://example.org/redirect_302") assert response.status_code == codes.OK assert response.url == URL("https://example.org/") assert len(response.history) == 1 -@pytest.mark.asyncio -async def test_redirect_303(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_redirect_303(client: AsyncClient): response = await client.get("https://example.org/redirect_303") assert response.status_code == codes.OK assert response.url == URL("https://example.org/") assert len(response.history) == 1 -@pytest.mark.asyncio -async def test_disallow_redirects(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_disallow_redirects(client: AsyncClient): response = await client.post( "https://example.org/redirect_303", allow_redirects=False ) @@ -145,36 +142,28 @@ async def test_disallow_redirects(): assert len(response.history) == 1 -@pytest.mark.asyncio -async def test_relative_redirect(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_relative_redirect(client: AsyncClient): response = await client.get("https://example.org/relative_redirect") assert response.status_code == codes.OK assert response.url == URL("https://example.org/") assert len(response.history) == 1 -@pytest.mark.asyncio -async def test_no_scheme_redirect(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_no_scheme_redirect(client: AsyncClient): response = await client.get("https://example.org/no_scheme_redirect") assert response.status_code == codes.OK assert response.url == URL("https://example.org/") assert len(response.history) == 1 -@pytest.mark.asyncio -async def test_fragment_redirect(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_fragment_redirect(client: AsyncClient): response = await client.get("https://example.org/relative_redirect#fragment") assert response.status_code == codes.OK assert response.url == URL("https://example.org/#fragment") assert len(response.history) == 1 -@pytest.mark.asyncio -async def test_multiple_redirects(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_multiple_redirects(client: AsyncClient): response = await client.get("https://example.org/multiple_redirects?count=20") assert response.status_code == codes.OK assert response.url == URL("https://example.org/multiple_redirects") @@ -189,16 +178,12 @@ async def test_multiple_redirects(): assert len(response.history[1].history) == 1 -@pytest.mark.asyncio -async def test_too_many_redirects(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_too_many_redirects(client: AsyncClient): with pytest.raises(TooManyRedirects): await client.get("https://example.org/multiple_redirects?count=21") -@pytest.mark.asyncio -async def test_too_many_redirects_calling_next(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_too_many_redirects_calling_next(client: AsyncClient): url = "https://example.org/multiple_redirects?count=21" response = await client.get(url, allow_redirects=False) with pytest.raises(TooManyRedirects): @@ -206,16 +191,12 @@ async def test_too_many_redirects_calling_next(): response = await response.next() -@pytest.mark.asyncio -async def test_redirect_loop(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_redirect_loop(client: AsyncClient): with pytest.raises(RedirectLoop): await client.get("https://example.org/redirect_loop") -@pytest.mark.asyncio -async def test_redirect_loop_calling_next(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_redirect_loop_calling_next(client: AsyncClient): url = "https://example.org/redirect_loop" response = await client.get(url, allow_redirects=False) with pytest.raises(RedirectLoop): @@ -223,9 +204,7 @@ async def test_redirect_loop_calling_next(): response = await response.next() -@pytest.mark.asyncio -async def test_cross_domain_redirect(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_cross_domain_redirect(client: AsyncClient): url = "https://example.com/cross_domain" headers = {"Authorization": "abc"} response = await client.get(url, headers=headers) @@ -233,9 +212,7 @@ async def test_cross_domain_redirect(): assert "authorization" not in response.json()["headers"] -@pytest.mark.asyncio -async def test_same_domain_redirect(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_same_domain_redirect(client: AsyncClient): url = "https://example.org/cross_domain" headers = {"Authorization": "abc"} response = await client.get(url, headers=headers) @@ -243,9 +220,7 @@ async def test_same_domain_redirect(): assert response.json()["headers"]["authorization"] == "abc" -@pytest.mark.asyncio -async def test_body_redirect(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_body_redirect(client: AsyncClient): url = "https://example.org/redirect_body" data = b"Example request body" response = await client.post(url, data=data) @@ -253,9 +228,7 @@ async def test_body_redirect(): assert response.json() == {"body": "Example request body"} -@pytest.mark.asyncio -async def test_cannot_redirect_streaming_body(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_cannot_redirect_streaming_body(client: AsyncClient): url = "https://example.org/redirect_body" async def streaming_body(): @@ -265,9 +238,7 @@ async def streaming_body(): await client.post(url, data=streaming_body()) -@pytest.mark.asyncio -async def test_cross_dubdomain_redirect(): - client = AsyncClient(dispatch=MockDispatch()) +async def test_cross_dubdomain_redirect(client: AsyncClient): url = "https://example.com/cross_subdomain" response = await client.get(url) assert response.url == URL("https://www.example.org/cross_subdomain") diff --git a/tests/conftest.py b/tests/conftest.py index f8a6d0dc04..5f3e4f2b7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,15 @@ from uvicorn.config import Config from uvicorn.main import Server +from httpx.concurrency.asyncio import AsyncioBackend + + +@pytest.fixture(params=[pytest.param(AsyncioBackend, marks=pytest.mark.asyncio)]) +def backend(request): + backend_cls = request.param + backend = backend_cls() + return backend + async def app(scope, receive, send): assert scope["type"] == "http" diff --git a/tests/dispatch/test_connection_pools.py b/tests/dispatch/test_connection_pools.py index 1bd564b030..7cee8449b5 100644 --- a/tests/dispatch/test_connection_pools.py +++ b/tests/dispatch/test_connection_pools.py @@ -3,12 +3,11 @@ import httpx -@pytest.mark.asyncio -async def test_keepalive_connections(server): +async def test_keepalive_connections(server, backend): """ Connections should default to staying in a keep-alive state. """ - async with httpx.ConnectionPool() as http: + async with httpx.ConnectionPool(backend=backend) as http: response = await http.request("GET", "http://127.0.0.1:8000/") await response.read() assert len(http.active_connections) == 0 @@ -20,12 +19,11 @@ async def test_keepalive_connections(server): assert len(http.keepalive_connections) == 1 -@pytest.mark.asyncio -async def test_differing_connection_keys(server): +async def test_differing_connection_keys(server, backend): """ Connections to differing connection keys should result in multiple connections. """ - async with httpx.ConnectionPool() as http: + async with httpx.ConnectionPool(backend=backend) as http: response = await http.request("GET", "http://127.0.0.1:8000/") await response.read() assert len(http.active_connections) == 0 @@ -37,14 +35,13 @@ async def test_differing_connection_keys(server): assert len(http.keepalive_connections) == 2 -@pytest.mark.asyncio -async def test_soft_limit(server): +async def test_soft_limit(server, backend): """ The soft_limit config should limit the maximum number of keep-alive connections. """ pool_limits = httpx.PoolLimits(soft_limit=1) - async with httpx.ConnectionPool(pool_limits=pool_limits) as http: + async with httpx.ConnectionPool(pool_limits=pool_limits, backend=backend) as http: response = await http.request("GET", "http://127.0.0.1:8000/") await response.read() assert len(http.active_connections) == 0 @@ -56,12 +53,11 @@ async def test_soft_limit(server): assert len(http.keepalive_connections) == 1 -@pytest.mark.asyncio -async def test_streaming_response_holds_connection(server): +async def test_streaming_response_holds_connection(server, backend): """ A streaming request should hold the connection open until the response is read. """ - async with httpx.ConnectionPool() as http: + async with httpx.ConnectionPool(backend=backend) as http: response = await http.request("GET", "http://127.0.0.1:8000/") assert len(http.active_connections) == 1 assert len(http.keepalive_connections) == 0 @@ -72,12 +68,11 @@ async def test_streaming_response_holds_connection(server): assert len(http.keepalive_connections) == 1 -@pytest.mark.asyncio -async def test_multiple_concurrent_connections(server): +async def test_multiple_concurrent_connections(server, backend): """ Multiple conncurrent requests should open multiple conncurrent connections. """ - async with httpx.ConnectionPool() as http: + async with httpx.ConnectionPool(backend=backend) as http: response_a = await http.request("GET", "http://127.0.0.1:8000/") assert len(http.active_connections) == 1 assert len(http.keepalive_connections) == 0 @@ -95,25 +90,23 @@ async def test_multiple_concurrent_connections(server): assert len(http.keepalive_connections) == 2 -@pytest.mark.asyncio -async def test_close_connections(server): +async def test_close_connections(server, backend): """ Using a `Connection: close` header should close the connection. """ headers = [(b"connection", b"close")] - async with httpx.ConnectionPool() as http: + async with httpx.ConnectionPool(backend=backend) as http: response = await http.request("GET", "http://127.0.0.1:8000/", headers=headers) await response.read() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 0 -@pytest.mark.asyncio -async def test_standard_response_close(server): +async def test_standard_response_close(server, backend): """ A standard close should keep the connection open. """ - async with httpx.ConnectionPool() as http: + async with httpx.ConnectionPool(backend=backend) as http: response = await http.request("GET", "http://127.0.0.1:8000/") await response.read() await response.close() @@ -121,25 +114,23 @@ async def test_standard_response_close(server): assert len(http.keepalive_connections) == 1 -@pytest.mark.asyncio -async def test_premature_response_close(server): +async def test_premature_response_close(server, backend): """ A premature close should close the connection. """ - async with httpx.ConnectionPool() as http: + async with httpx.ConnectionPool(backend=backend) as http: response = await http.request("GET", "http://127.0.0.1:8000/") await response.close() assert len(http.active_connections) == 0 assert len(http.keepalive_connections) == 0 -@pytest.mark.asyncio -async def test_keepalive_connection_closed_by_server_is_reestablished(server): +async def test_keepalive_connection_closed_by_server_is_reestablished(server, backend): """ Upon keep-alive connection closed by remote a new connection should be reestablished. """ - async with httpx.ConnectionPool() as http: + async with httpx.ConnectionPool(backend=backend) as http: response = await http.request("GET", "http://127.0.0.1:8000/") await response.read() @@ -153,13 +144,14 @@ async def test_keepalive_connection_closed_by_server_is_reestablished(server): assert len(http.keepalive_connections) == 1 -@pytest.mark.asyncio -async def test_keepalive_http2_connection_closed_by_server_is_reestablished(server): +async def test_keepalive_http2_connection_closed_by_server_is_reestablished( + server, backend +): """ Upon keep-alive connection closed by remote a new connection should be reestablished. """ - async with httpx.ConnectionPool() as http: + async with httpx.ConnectionPool(backend=backend) as http: response = await http.request("GET", "http://127.0.0.1:8000/") await response.read() diff --git a/tests/dispatch/test_connections.py b/tests/dispatch/test_connections.py index 5273536b21..59e9bb1412 100644 --- a/tests/dispatch/test_connections.py +++ b/tests/dispatch/test_connections.py @@ -1,44 +1,40 @@ -import pytest - from httpx import HTTPConnection -@pytest.mark.asyncio -async def test_get(server): - conn = HTTPConnection(origin="http://127.0.0.1:8000/") +async def test_get(server, backend): + conn = HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend) response = await conn.request("GET", "http://127.0.0.1:8000/") await response.read() assert response.status_code == 200 assert response.content == b"Hello, world!" -@pytest.mark.asyncio -async def test_post(server): - conn = HTTPConnection(origin="http://127.0.0.1:8000/") +async def test_post(server, backend): + conn = HTTPConnection(origin="http://127.0.0.1:8000/", backend=backend) response = await conn.request( "GET", "http://127.0.0.1:8000/", data=b"Hello, world!" ) assert response.status_code == 200 -@pytest.mark.asyncio -async def test_https_get_with_ssl_defaults(https_server): +async def test_https_get_with_ssl_defaults(https_server, backend): """ An HTTPS request, with default SSL configuration set on the client. """ - conn = HTTPConnection(origin="https://127.0.0.1:8001/", verify=False) + conn = HTTPConnection( + origin="https://127.0.0.1:8001/", verify=False, backend=backend + ) response = await conn.request("GET", "https://127.0.0.1:8001/") await response.read() assert response.status_code == 200 assert response.content == b"Hello, world!" -@pytest.mark.asyncio -async def test_https_get_with_sll_overrides(https_server): +async def test_https_get_with_sll_overrides(https_server, backend): """ An HTTPS request, with SSL configuration set on the request. """ - conn = HTTPConnection(origin="https://127.0.0.1:8001/") + conn = HTTPConnection(origin="https://127.0.0.1:8001/", backend=backend) response = await conn.request("GET", "https://127.0.0.1:8001/", verify=False) await response.read() assert response.status_code == 200 diff --git a/tests/test_asgi.py b/tests/test_asgi.py index ca3d02ddf3..d681067b9a 100644 --- a/tests/test_asgi.py +++ b/tests/test_asgi.py @@ -46,6 +46,13 @@ def test_asgi(): assert response.text == "Hello, World!" +async def test_asgi_async(backend): + client = httpx.AsyncClient(app=hello_world, backend=backend) + response = await client.get("http://www.example.org/") + assert response.status_code == 200 + assert response.text == "Hello, World!" + + def test_asgi_upload(): client = httpx.Client(app=echo_body) response = client.post("http://www.example.org/", data=b"example") diff --git a/tests/test_timeouts.py b/tests/test_timeouts.py index ab69e690c6..ac8efa130f 100644 --- a/tests/test_timeouts.py +++ b/tests/test_timeouts.py @@ -11,40 +11,36 @@ ) -@pytest.mark.asyncio -async def test_read_timeout(server): +async def test_read_timeout(server, backend): timeout = TimeoutConfig(read_timeout=0.000001) - async with AsyncClient(timeout=timeout) as client: + async with AsyncClient(timeout=timeout, backend=backend) as client: with pytest.raises(ReadTimeout): await client.get("http://127.0.0.1:8000/slow_response") -@pytest.mark.asyncio -async def test_write_timeout(server): +async def test_write_timeout(server, backend): timeout = TimeoutConfig(write_timeout=0.000001) - async with AsyncClient(timeout=timeout) as client: + async with AsyncClient(timeout=timeout, backend=backend) as client: with pytest.raises(WriteTimeout): data = b"*" * 1024 * 1024 * 100 await client.put("http://127.0.0.1:8000/slow_response", data=data) -@pytest.mark.asyncio -async def test_connect_timeout(server): +async def test_connect_timeout(server, backend): timeout = TimeoutConfig(connect_timeout=0.000001) - async with AsyncClient(timeout=timeout) as client: + async with AsyncClient(timeout=timeout, backend=backend) as client: with pytest.raises(ConnectTimeout): # See https://stackoverflow.com/questions/100841/ await client.get("http://10.255.255.1/") -@pytest.mark.asyncio -async def test_pool_timeout(server): +async def test_pool_timeout(server, backend): pool_limits = PoolLimits(hard_limit=1, pool_timeout=0.000001) - async with AsyncClient(pool_limits=pool_limits) as client: + async with AsyncClient(pool_limits=pool_limits, backend=backend) as client: response = await client.get("http://127.0.0.1:8000/", stream=True) with pytest.raises(PoolTimeout):