Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalized concurrency backend usage #217

Closed
wants to merge 9 commits into from
2 changes: 1 addition & 1 deletion httpx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .__version__ import __description__, __title__, __version__
from .api import delete, get, head, options, patch, post, put, request
from .client import AsyncClient, Client
from .concurrency import AsyncioBackend
from .concurrency.asyncio import AsyncioBackend
from .config import (
USER_AGENT,
CertTypes,
Expand Down
4 changes: 2 additions & 2 deletions httpx/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import hstspreload

from .auth import HTTPBasicAuth
from .concurrency import AsyncioBackend
from .concurrency.asyncio import AsyncioBackend
from .config import (
DEFAULT_MAX_REDIRECTS,
DEFAULT_POOL_LIMITS,
Expand Down Expand Up @@ -77,7 +77,7 @@ def __init__(
if param_count == 2:
dispatch = WSGIDispatch(app=app)
else:
dispatch = ASGIDispatch(app=app)
dispatch = ASGIDispatch(app=app, backend=backend)

if dispatch is None:
async_dispatch: AsyncDispatcher = ConnectionPool(
Expand Down
Empty file added httpx/concurrency/__init__.py
Empty file.
120 changes: 72 additions & 48 deletions httpx/concurrency.py → httpx/concurrency/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,20 @@
import typing
from types import TracebackType

from .config import PoolLimits, TimeoutConfig
from .exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .interfaces import (
from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from ..interfaces import (
BaseAsyncContextManager,
BaseBackgroundManager,
BaseBodyIterator,
BaseEvent,
BasePoolSemaphore,
BaseReader,
BaseWriter,
ConcurrencyBackend,
Protocol,
)
from .utils import TimeoutFlag

SSL_MONKEY_PATCH_APPLIED = False

Expand All @@ -49,38 +53,6 @@ def _fixed_write(self, data: bytes) -> None: # type: ignore
MonkeyPatch.write = _fixed_write


class TimeoutFlag:
"""
A timeout flag holds a state of either read-timeout or write-timeout mode.

We use this so that we can attempt both reads and writes concurrently, while
only enforcing timeouts in one direction.

During a request/response cycle we start in write-timeout mode.

Once we've sent a request fully, or once we start seeing a response,
then we switch to read-timeout mode instead.
"""

def __init__(self) -> None:
self.raise_on_read_timeout = False
self.raise_on_write_timeout = True

def set_read_timeouts(self) -> None:
"""
Set the flag to read-timeout mode.
"""
self.raise_on_read_timeout = True
self.raise_on_write_timeout = False

def set_write_timeouts(self) -> None:
"""
Set the flag to write-timeout mode.
"""
self.raise_on_read_timeout = False
self.raise_on_write_timeout = True


class Reader(BaseReader):
def __init__(
self, stream_reader: asyncio.StreamReader, timeout: TimeoutConfig
Expand Down Expand Up @@ -247,28 +219,80 @@ def run(
def get_semaphore(self, limits: PoolLimits) -> BasePoolSemaphore:
return PoolSemaphore(limits)

def background_manager(
self, coroutine: typing.Callable, args: typing.Any
) -> "BackgroundManager":
return BackgroundManager(coroutine, args)
def create_event(self) -> BaseEvent:
return typing.cast(BaseEvent, asyncio.Event())

def background_manager(self) -> "BackgroundManager":
return BackgroundManager()

def body_iterator(self) -> "BodyIterator":
return BodyIterator()


class BackgroundManager(BaseBackgroundManager):
def __init__(self, coroutine: typing.Callable, args: typing.Any) -> None:
self.coroutine = coroutine
self.args = args
def __init__(self) -> None:
self.tasks: typing.Set[asyncio.Task] = set()

async def __aenter__(self) -> "BackgroundManager":
def start_soon(self, coroutine: typing.Callable, *args: typing.Any) -> None:
loop = asyncio.get_event_loop()
self.task = loop.create_task(self.coroutine(*self.args))
return self
self.tasks.add(loop.create_task(coroutine(*args)))

def wait_first_completed(self) -> "WaitFirstCompleted":
return WaitFirstCompleted(self)

async def __aexit__(
self,
exc_type: typing.Type[BaseException] = None,
exc_value: BaseException = None,
traceback: TracebackType = None,
) -> None:
await self.task
if exc_type is None:
self.task.result()
done, pending = await asyncio.wait(self.tasks, timeout=1e-3)

for task in pending:
task.cancel()

for task in done:
await task
if exc_type is None:
task.result()


class WaitFirstCompleted(BaseAsyncContextManager):
def __init__(self, background: BackgroundManager):
self.background = background
self.initial_tasks: typing.Set[asyncio.Task] = set()

async def __aenter__(self) -> "WaitFirstCompleted":
self.initial_tasks = self.background.tasks
self.background.tasks = set()
return self

async def __aexit__(self, *args: typing.Any) -> None:
_, pending = await asyncio.wait(
self.background.tasks, return_when=asyncio.FIRST_COMPLETED
)
self.background.tasks = self.initial_tasks | typing.cast(
typing.Set[asyncio.Task], pending
)


class BodyIterator(BaseBodyIterator):
def __init__(self) -> None:
self._queue: asyncio.Queue[typing.Union[bytes, object]] = asyncio.Queue(
maxsize=1
)
self._done = object()

async def iterate(self) -> typing.AsyncIterator[bytes]:
while True:
data = await self._queue.get()
if data is self._done:
break
assert isinstance(data, bytes)
yield data

async def put(self, data: bytes) -> None:
await self._queue.put(data)

async def done(self) -> None:
await self._queue.put(self._done)
30 changes: 30 additions & 0 deletions httpx/concurrency/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
class TimeoutFlag:
"""
A timeout flag holds a state of either read-timeout or write-timeout mode.

We use this so that we can attempt both reads and writes concurrently, while
only enforcing timeouts in one direction.

During a request/response cycle we start in write-timeout mode.

Once we've sent a request fully, or once we start seeing a response,
then we switch to read-timeout mode instead.
"""

def __init__(self) -> None:
self.raise_on_read_timeout = False
self.raise_on_write_timeout = True

def set_read_timeouts(self) -> None:
"""
Set the flag to read-timeout mode.
"""
self.raise_on_read_timeout = True
self.raise_on_write_timeout = False

def set_write_timeouts(self) -> None:
"""
Set the flag to write-timeout mode.
"""
self.raise_on_read_timeout = False
self.raise_on_write_timeout = True
75 changes: 14 additions & 61 deletions httpx/dispatch/asgi.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import typing

from ..concurrency.asyncio import AsyncioBackend
from ..config import CertTypes, TimeoutTypes, VerifyTypes
from ..interfaces import AsyncDispatcher
from ..interfaces import AsyncDispatcher, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse


Expand Down Expand Up @@ -44,11 +44,13 @@ class ASGIDispatch(AsyncDispatcher):
def __init__(
self,
app: typing.Callable,
backend: ConcurrencyBackend = None,
raise_app_exceptions: bool = True,
root_path: str = "",
client: typing.Tuple[str, int] = ("127.0.0.1", 123),
) -> None:
self.app = app
self.backend = AsyncioBackend() if backend is None else backend
self.raise_app_exceptions = raise_app_exceptions
self.root_path = root_path
self.client = client
Expand Down Expand Up @@ -78,8 +80,8 @@ async def send(
app_exc = None
status_code = None
headers = None
response_started = asyncio.Event()
response_body = BodyIterator()
response_started = self.backend.create_event()
response_body = self.backend.body_iterator()
request_stream = request.stream()

async def receive() -> dict:
Expand Down Expand Up @@ -115,30 +117,25 @@ async def run_app() -> None:
finally:
await response_body.done()

# Really we'd like to push all `asyncio` logic into concurrency.py,
# with a standardized interface, so that we can support other event
# loop implementations, such as Trio and Curio.
# That's a bit fiddly here, so we're not yet supporting using a custom
# `ConcurrencyBackend` with the `Client(app=asgi_app)` case.
loop = asyncio.get_event_loop()
app_task = loop.create_task(run_app())
response_task = loop.create_task(response_started.wait())
background = self.backend.background_manager()
await background.__aenter__()

tasks = {app_task, response_task} # type: typing.Set[asyncio.Task]

await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
async with background.wait_first_completed():
background.start_soon(run_app)
background.start_soon(response_started.wait)

if app_exc is not None and self.raise_app_exceptions:
await background.close()
raise app_exc

assert response_started.is_set(), "application did not return a response."
assert status_code is not None
assert headers is not None

async def on_close() -> None:
nonlocal app_task, response_body
nonlocal background, response_body
await response_body.drain()
await app_task
await background.close()
if app_exc is not None and self.raise_app_exceptions:
raise app_exc

Expand All @@ -150,47 +147,3 @@ async def on_close() -> None:
on_close=on_close,
request=request,
)


class BodyIterator:
"""
Provides a byte-iterator interface that the client can use to
ingest the response content from.
"""

def __init__(self) -> None:
self._queue = asyncio.Queue(
maxsize=1
) # type: asyncio.Queue[typing.Union[bytes, object]]
self._done = object()

async def iterate(self) -> typing.AsyncIterator[bytes]:
"""
A byte-iterator, used by the client to consume the response body.
"""
while True:
data = await self._queue.get()
if data is self._done:
break
assert isinstance(data, bytes)
yield data

async def drain(self) -> None:
"""
Drain any remaining body, in order to allow any blocked `put()` calls
to complete.
"""
async for chunk in self.iterate():
pass # pragma: no cover

async def put(self, data: bytes) -> None:
"""
Used by the server to add data to the response body.
"""
await self._queue.put(data)

async def done(self) -> None:
"""
Used by the server to signal the end of the response body.
"""
await self._queue.put(self._done)
6 changes: 4 additions & 2 deletions httpx/dispatch/connection.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import functools
import typing

from ..concurrency import AsyncioBackend
from ..concurrency.asyncio import AsyncioBackend
from ..config import (
DEFAULT_TIMEOUT_CONFIG,
CertTypes,
Expand Down Expand Up @@ -32,7 +32,9 @@ def __init__(
self.origin = Origin(origin) if isinstance(origin, str) else origin
self.ssl = SSLConfig(cert=cert, verify=verify)
self.timeout = TimeoutConfig(timeout)
self.backend = AsyncioBackend() if backend is None else backend
self.backend = typing.cast(
ConcurrencyBackend, AsyncioBackend() if backend is None else backend
)
self.release_func = release_func
self.h11_connection = None # type: typing.Optional[HTTP11Connection]
self.h2_connection = None # type: typing.Optional[HTTP2Connection]
Expand Down
6 changes: 4 additions & 2 deletions httpx/dispatch/connection_pool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import typing

from ..concurrency import AsyncioBackend
from ..concurrency.asyncio import AsyncioBackend
from ..config import (
DEFAULT_POOL_LIMITS,
DEFAULT_TIMEOUT_CONFIG,
Expand Down Expand Up @@ -91,7 +91,9 @@ def __init__(
self.keepalive_connections = ConnectionStore()
self.active_connections = ConnectionStore()

self.backend = AsyncioBackend() if backend is None else backend
self.backend = typing.cast(
ConcurrencyBackend, AsyncioBackend() if backend is None else backend
)
self.max_connections = self.backend.get_semaphore(pool_limits)

@property
Expand Down
5 changes: 3 additions & 2 deletions httpx/dispatch/http11.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import h11

from ..concurrency import TimeoutFlag
from ..concurrency.utils import TimeoutFlag
from ..config import TimeoutConfig, TimeoutTypes
from ..interfaces import BaseReader, BaseWriter, ConcurrencyBackend
from ..models import AsyncRequest, AsyncResponse
Expand Down Expand Up @@ -48,7 +48,8 @@ async def send(
await self._send_request(request, timeout)

task, args = self._send_request_data, [request.stream(), timeout]
async with self.backend.background_manager(task, args=args):
async with self.backend.background_manager() as background:
background.start_soon(task, *args)
http_version, status_code, headers = await self._receive_response(timeout)
content = self._receive_response_data(timeout)

Expand Down
Loading