From a4413cb73925ebae5da29a8424fd05a69c56e63a Mon Sep 17 00:00:00 2001 From: Jamie Hewland Date: Sat, 21 Sep 2019 17:13:05 +0200 Subject: [PATCH 1/3] Use Python 3.8 asyncio.Stream where possible --- httpx/concurrency/asyncio/__init__.py | 3 + .../{asyncio.py => asyncio/backend.py} | 66 +++------ httpx/concurrency/asyncio/compat.py | 138 ++++++++++++++++++ tests/test_concurrency.py | 4 +- 4 files changed, 161 insertions(+), 50 deletions(-) create mode 100644 httpx/concurrency/asyncio/__init__.py rename httpx/concurrency/{asyncio.py => asyncio/backend.py} (76%) create mode 100644 httpx/concurrency/asyncio/compat.py diff --git a/httpx/concurrency/asyncio/__init__.py b/httpx/concurrency/asyncio/__init__.py new file mode 100644 index 0000000000..3543542c17 --- /dev/null +++ b/httpx/concurrency/asyncio/__init__.py @@ -0,0 +1,3 @@ +from .backend import AsyncioBackend, BackgroundManager, PoolSemaphore, TCPStream + +__all__ = ["AsyncioBackend", "BackgroundManager", "PoolSemaphore", "TCPStream"] diff --git a/httpx/concurrency/asyncio.py b/httpx/concurrency/asyncio/backend.py similarity index 76% rename from httpx/concurrency/asyncio.py rename to httpx/concurrency/asyncio/backend.py index 1a145bed90..84f4dd62d5 100644 --- a/httpx/concurrency/asyncio.py +++ b/httpx/concurrency/asyncio/backend.py @@ -4,9 +4,9 @@ import typing from types import TracebackType -from ..config import PoolLimits, TimeoutConfig -from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout -from .base import ( +from ...config import PoolLimits, TimeoutConfig +from ...exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout +from ..base import ( BaseBackgroundManager, BaseEvent, BasePoolSemaphore, @@ -15,6 +15,7 @@ ConcurrencyBackend, TimeoutFlag, ) +from .compat import Stream, connect_compat SSL_MONKEY_PATCH_APPLIED = False @@ -41,18 +42,12 @@ def _fixed_write(self, data: bytes) -> None: # type: ignore class TCPStream(BaseTCPStream): - def __init__( - self, - stream_reader: asyncio.StreamReader, - stream_writer: asyncio.StreamWriter, - timeout: TimeoutConfig, - ): - self.stream_reader = stream_reader - self.stream_writer = stream_writer + def __init__(self, stream: Stream, timeout: TimeoutConfig): + self.stream = stream self.timeout = timeout def get_http_version(self) -> str: - ssl_object = self.stream_writer.get_extra_info("ssl_object") + ssl_object = self.stream.get_extra_info("ssl_object") if ssl_object is None: return "HTTP/1.1" @@ -76,7 +71,7 @@ async def read( should_raise = flag is None or flag.raise_on_read_timeout read_timeout = timeout.read_timeout if should_raise else 0.01 try: - data = await asyncio.wait_for(self.stream_reader.read(n), read_timeout) + data = await asyncio.wait_for(self.stream.read(n), read_timeout) break except asyncio.TimeoutError: if should_raise: @@ -85,7 +80,7 @@ async def read( return data def write_no_block(self, data: bytes) -> None: - self.stream_writer.write(data) # pragma: nocover + self.stream.write(data) # pragma: nocover async def write( self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None @@ -96,11 +91,11 @@ async def write( if timeout is None: timeout = self.timeout - self.stream_writer.write(data) + self.stream.write(data) while True: try: await asyncio.wait_for( # type: ignore - self.stream_writer.drain(), timeout.write_timeout + self.stream.drain(), timeout.write_timeout ) break except asyncio.TimeoutError: @@ -112,10 +107,10 @@ async def write( raise WriteTimeout() from None def is_connection_dropped(self) -> bool: - return self.stream_reader.at_eof() + return self.stream.at_eof() async def close(self) -> None: - self.stream_writer.close() + await self.stream.close() class PoolSemaphore(BasePoolSemaphore): @@ -174,16 +169,13 @@ async def open_tcp_stream( timeout: TimeoutConfig, ) -> BaseTCPStream: try: - stream_reader, stream_writer = await asyncio.wait_for( # type: ignore - asyncio.open_connection(hostname, port, ssl=ssl_context), - timeout.connect_timeout, + stream = await asyncio.wait_for( # type: ignore + connect_compat(hostname, port, ssl=ssl_context), timeout.connect_timeout ) except asyncio.TimeoutError: raise ConnectTimeout() - return TCPStream( - stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout - ) + return TCPStream(stream=stream, timeout=timeout) async def start_tls( self, @@ -192,35 +184,13 @@ async def start_tls( ssl_context: ssl.SSLContext, timeout: TimeoutConfig, ) -> BaseTCPStream: - - loop = self.loop - if not hasattr(loop, "start_tls"): # pragma: no cover - raise NotImplementedError( - "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+" - ) - assert isinstance(stream, TCPStream) - stream_reader = asyncio.StreamReader() - protocol = asyncio.StreamReaderProtocol(stream_reader) - transport = stream.stream_writer.transport - - loop_start_tls = loop.start_tls # type: ignore - transport = await asyncio.wait_for( - loop_start_tls( - transport=transport, - protocol=protocol, - sslcontext=ssl_context, - server_hostname=hostname, - ), + await asyncio.wait_for( + stream.stream.start_tls(ssl_context, server_hostname=hostname), timeout=timeout.connect_timeout, ) - stream_reader.set_transport(transport) - stream.stream_reader = stream_reader - stream.stream_writer = asyncio.StreamWriter( - transport=transport, protocol=protocol, reader=stream_reader, loop=loop - ) return stream async def run_in_threadpool( diff --git a/httpx/concurrency/asyncio/compat.py b/httpx/concurrency/asyncio/compat.py new file mode 100644 index 0000000000..3c8b74d026 --- /dev/null +++ b/httpx/concurrency/asyncio/compat.py @@ -0,0 +1,138 @@ +import asyncio +import ssl +import sys +import typing + + +if sys.version_info >= (3, 8): + from typing import Protocol +else: + from typing_extensions import Protocol + + +class Stream(Protocol): # pragma: no cover + """Protocol defining just the methods we use from asyncio.Stream.""" + + def at_eof(self) -> bool: + ... + + def close(self) -> typing.Awaitable[None]: + ... + + async def drain(self) -> None: + ... + + def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any: + ... + + async def read(self, n: int = -1) -> bytes: + ... + + async def start_tls( + self, + sslContext: ssl.SSLContext, + *, + server_hostname: typing.Optional[str] = None, + ssl_handshake_timeout: typing.Optional[float] = None, + ) -> None: + ... + + def write(self, data: bytes) -> typing.Awaitable[None]: + ... + + +async def connect_compat(*args: typing.Any, **kwargs: typing.Any) -> Stream: + if sys.version_info >= (3, 8): + return await asyncio.connect(*args, **kwargs) + else: + reader, writer = await asyncio.open_connection(*args, **kwargs) + return StreamCompat(reader, writer) + + +class StreamCompat: + """ + Thin wrapper around asyncio.StreamReader/StreamWriter to make them look and + behave similarly to an asyncio.Stream. + """ + + def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + self.reader = reader + self.writer = writer + + def at_eof(self) -> bool: + return self.reader.at_eof() + + def close(self) -> typing.Awaitable[None]: + self.writer.close() + return _OptionalAwait(self.wait_closed) + + async def drain(self) -> None: + await self.writer.drain() + + def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any: + return self.writer.get_extra_info(name, default) + + async def read(self, n: int = -1) -> bytes: + return await self.reader.read(n) + + async def start_tls( + self, + sslContext: ssl.SSLContext, + *, + server_hostname: typing.Optional[str] = None, + ssl_handshake_timeout: typing.Optional[float] = None, + ) -> None: + if not sys.version_info >= (3, 7): # pragma: no cover + raise NotImplementedError( + "asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+" + ) + else: + # This code is in an else branch to appease mypy on Python < 3.7 + + reader = asyncio.StreamReader() + protocol = asyncio.StreamReaderProtocol(reader) + transport = self.writer.transport + + loop = asyncio.get_event_loop() + loop_start_tls = loop.start_tls # type: ignore + tls_transport = await loop_start_tls( + transport=transport, + protocol=protocol, + sslcontext=sslContext, + server_hostname=server_hostname, + ssl_handshake_timeout=ssl_handshake_timeout, + ) + + reader.set_transport(tls_transport) + self.reader = reader + self.writer = asyncio.StreamWriter( + transport=tls_transport, protocol=protocol, reader=reader, loop=loop + ) + + def write(self, data: bytes) -> typing.Awaitable[None]: + self.writer.write(data) + return _OptionalAwait(self.drain) + + async def wait_closed(self) -> None: + if sys.version_info >= (3, 7): + await self.writer.wait_closed() + # else not much we can do to wait for the connection to close + + +# This code is copied from cPython 3.8 but with type annotations added: +# https://github.com/python/cpython/blob/v3.8.0b4/Lib/asyncio/streams.py#L1262-L1273 +_T = typing.TypeVar("_T") + + +class _OptionalAwait(typing.Generic[_T]): + # The class doesn't create a coroutine + # if not awaited + # It prevents "coroutine is never awaited" message + + __slots___ = ("_method",) + + def __init__(self, method: typing.Callable[[], typing.Awaitable[_T]]): + self._method = method + + def __await__(self) -> typing.Generator[typing.Any, None, _T]: + return self._method().__await__() diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index ab93b30282..03f8c0aaea 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -25,11 +25,11 @@ async def test_start_tls_on_socket_stream(https_server): try: assert stream.is_connection_dropped() is False - assert stream.stream_writer.get_extra_info("cipher", default=None) is None + assert stream.stream.get_extra_info("cipher", default=None) is None stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout) assert stream.is_connection_dropped() is False - assert stream.stream_writer.get_extra_info("cipher", default=None) is not None + assert stream.stream.get_extra_info("cipher", default=None) is not None await stream.write(b"GET / HTTP/1.1\r\n\r\n") assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n") From 1d2e5251e2b327e3da3926af57d22a7aad7982d3 Mon Sep 17 00:00:00 2001 From: Jamie Hewland Date: Sat, 21 Sep 2019 21:09:01 +0200 Subject: [PATCH 2/3] Clean up imports --- httpx/concurrency/asyncio/backend.py | 5 +++-- httpx/concurrency/asyncio/compat.py | 1 - 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/httpx/concurrency/asyncio/backend.py b/httpx/concurrency/asyncio/backend.py index 84f4dd62d5..ddac291b13 100644 --- a/httpx/concurrency/asyncio/backend.py +++ b/httpx/concurrency/asyncio/backend.py @@ -4,8 +4,9 @@ import typing from types import TracebackType -from ...config import PoolLimits, TimeoutConfig -from ...exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout +from httpx.config import PoolLimits, TimeoutConfig +from httpx.exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout + from ..base import ( BaseBackgroundManager, BaseEvent, diff --git a/httpx/concurrency/asyncio/compat.py b/httpx/concurrency/asyncio/compat.py index 3c8b74d026..d83b2098e2 100644 --- a/httpx/concurrency/asyncio/compat.py +++ b/httpx/concurrency/asyncio/compat.py @@ -3,7 +3,6 @@ import sys import typing - if sys.version_info >= (3, 8): from typing import Protocol else: From d629af9d6ca13736c176c0c048bda5b3d06e5fd2 Mon Sep 17 00:00:00 2001 From: Jamie Hewland Date: Sat, 21 Sep 2019 21:45:28 +0200 Subject: [PATCH 3/3] Give up on await-ing on the stream closing (for now) --- httpx/concurrency/asyncio/backend.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/httpx/concurrency/asyncio/backend.py b/httpx/concurrency/asyncio/backend.py index ddac291b13..607da37fba 100644 --- a/httpx/concurrency/asyncio/backend.py +++ b/httpx/concurrency/asyncio/backend.py @@ -111,7 +111,9 @@ def is_connection_dropped(self) -> bool: return self.stream.at_eof() async def close(self) -> None: - await self.stream.close() + # FIXME: We should await on this call, but need a workaround for this first: + # https://github.com/aio-libs/aiohttp/issues/3535 + self.stream.close() class PoolSemaphore(BasePoolSemaphore):