Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Python 3.8 asyncio.Stream where possible #369

Merged
merged 6 commits into from
Sep 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions httpx/concurrency/asyncio/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .backend import AsyncioBackend, BackgroundManager, PoolSemaphore, TCPStream

__all__ = ["AsyncioBackend", "BackgroundManager", "PoolSemaphore", "TCPStream"]
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import typing
from types import TracebackType

from ..config import PoolLimits, TimeoutConfig
from ..exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout
from .base import (
from httpx.config import PoolLimits, TimeoutConfig
from httpx.exceptions import ConnectTimeout, PoolTimeout, ReadTimeout, WriteTimeout

from ..base import (
BaseBackgroundManager,
BaseEvent,
BasePoolSemaphore,
Expand All @@ -15,6 +16,7 @@
ConcurrencyBackend,
TimeoutFlag,
)
from .compat import Stream, connect_compat

SSL_MONKEY_PATCH_APPLIED = False

Expand All @@ -41,18 +43,12 @@ def _fixed_write(self, data: bytes) -> None: # type: ignore


class TCPStream(BaseTCPStream):
def __init__(
self,
stream_reader: asyncio.StreamReader,
stream_writer: asyncio.StreamWriter,
timeout: TimeoutConfig,
):
self.stream_reader = stream_reader
self.stream_writer = stream_writer
def __init__(self, stream: Stream, timeout: TimeoutConfig):
self.stream = stream
self.timeout = timeout

def get_http_version(self) -> str:
ssl_object = self.stream_writer.get_extra_info("ssl_object")
ssl_object = self.stream.get_extra_info("ssl_object")

if ssl_object is None:
return "HTTP/1.1"
Expand All @@ -76,7 +72,7 @@ async def read(
should_raise = flag is None or flag.raise_on_read_timeout
read_timeout = timeout.read_timeout if should_raise else 0.01
try:
data = await asyncio.wait_for(self.stream_reader.read(n), read_timeout)
data = await asyncio.wait_for(self.stream.read(n), read_timeout)
break
except asyncio.TimeoutError:
if should_raise:
Expand All @@ -91,7 +87,7 @@ async def read(
return data

def write_no_block(self, data: bytes) -> None:
self.stream_writer.write(data) # pragma: nocover
self.stream.write(data) # pragma: nocover

async def write(
self, data: bytes, timeout: TimeoutConfig = None, flag: TimeoutFlag = None
Expand All @@ -102,11 +98,11 @@ async def write(
if timeout is None:
timeout = self.timeout

self.stream_writer.write(data)
self.stream.write(data)
while True:
try:
await asyncio.wait_for( # type: ignore
self.stream_writer.drain(), timeout.write_timeout
self.stream.drain(), timeout.write_timeout
)
break
except asyncio.TimeoutError:
Expand All @@ -132,10 +128,12 @@ def is_connection_dropped(self) -> bool:
# (For a solution that uses private asyncio APIs, see:
# https://github.com/encode/httpx/pull/143#issuecomment-515202982)

return self.stream_reader.at_eof()
return self.stream.at_eof()

async def close(self) -> None:
self.stream_writer.close()
# FIXME: We should await on this call, but need a workaround for this first:
# https://github.com/aio-libs/aiohttp/issues/3535
Copy link
Member Author

@JayH5 JayH5 Sep 21, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the await, the test_start_tls_on_socket_stream test fails transiently with the error mentioned in this thread in about 1 out of 5-30 test runs (thanks, pytest-repeat) for me on Python 3.7.2, 3.7.4, and 3.8.0b4.

There is a possible fix (a custom exception handler set on the event loop) in the thread, but it's fairly complicated and I'm not sure how to integrate it nicely. This is no worse than we had so I think it can be fixed in a separate PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, that thread doesn't look good. 😨

Copy link
Member

@florimondmanca florimondmanca Sep 22, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it OK to not await this .close() call, though (since we did await it before)? Will it close itself in the background in pure asyncio magic?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@florimondmanca

since we did await it before

afaics, we don't/never have await-ed on the close/wait_closed?

We should probably await the close, though. I've looked at this a bit more and read the aiohttp thread a bit more thoroughly and...

  • The SSL error occurs sporadically when await-ing on StreamWriter.wait_closed() (under the hood, Python 3.8 (for now) calls Stream.wait_closed() when you await Stream.close())
  • Python 3.6 doesn't have wait_closed so this doesn't apply.
  • This seems to only apply to OpenSSL 1.1.1+ since that is when this particular type of error was added.

The best I've come up with is something like:

async def close(self) -> None:
    try:
        await self.stream.close()
    except ssl.SSLError as e:  # pragma: no cover
        if e.reason == "KRB5_S_INIT":
            logger.debug("Ignoring asyncio SSL KRB5_S_INIT error on stream close")
            return

        raise

Thoughts?

I'm not sure exactly what state the socket is in when that error is raised 😕. Once we try to close and get that error there doesn't seem to be anything we can do without just getting that error again.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, all these details are very helpful. :)

Since the 3.8 docs it’s only possible to await on .close() (which I find kind of weird, having in mind all those « coroutine was never awaited » exceptions that usually causes), I’m okay with keeping it as it is currently. As you said, we actually never waited for the socket to close before (and if we did, we’d have probably only encountered this issue earlier?).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One (not very good) reason to include this try/except is that without it I don't think this change has 100% test coverage since wait_closed is never called.

self.stream.close()


class PoolSemaphore(BasePoolSemaphore):
Expand Down Expand Up @@ -194,16 +192,13 @@ async def open_tcp_stream(
timeout: TimeoutConfig,
) -> BaseTCPStream:
try:
stream_reader, stream_writer = await asyncio.wait_for( # type: ignore
asyncio.open_connection(hostname, port, ssl=ssl_context),
timeout.connect_timeout,
stream = await asyncio.wait_for( # type: ignore
connect_compat(hostname, port, ssl=ssl_context), timeout.connect_timeout
)
except asyncio.TimeoutError:
raise ConnectTimeout()

return TCPStream(
stream_reader=stream_reader, stream_writer=stream_writer, timeout=timeout
)
return TCPStream(stream=stream, timeout=timeout)

async def start_tls(
self,
Expand All @@ -212,35 +207,13 @@ async def start_tls(
ssl_context: ssl.SSLContext,
timeout: TimeoutConfig,
) -> BaseTCPStream:

loop = self.loop
if not hasattr(loop, "start_tls"): # pragma: no cover
raise NotImplementedError(
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
)

assert isinstance(stream, TCPStream)

stream_reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(stream_reader)
transport = stream.stream_writer.transport

loop_start_tls = loop.start_tls # type: ignore
transport = await asyncio.wait_for(
loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=ssl_context,
server_hostname=hostname,
),
await asyncio.wait_for(
stream.stream.start_tls(ssl_context, server_hostname=hostname),
timeout=timeout.connect_timeout,
)

stream_reader.set_transport(transport)
stream.stream_reader = stream_reader
stream.stream_writer = asyncio.StreamWriter(
transport=transport, protocol=protocol, reader=stream_reader, loop=loop
)
return stream

async def run_in_threadpool(
Expand Down
137 changes: 137 additions & 0 deletions httpx/concurrency/asyncio/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import asyncio
import ssl
import sys
import typing

if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol


class Stream(Protocol): # pragma: no cover
"""Protocol defining just the methods we use from asyncio.Stream."""

def at_eof(self) -> bool:
...

def close(self) -> typing.Awaitable[None]:
...

async def drain(self) -> None:
...

def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
...

async def read(self, n: int = -1) -> bytes:
...

async def start_tls(
self,
sslContext: ssl.SSLContext,
*,
server_hostname: typing.Optional[str] = None,
ssl_handshake_timeout: typing.Optional[float] = None,
) -> None:
...

def write(self, data: bytes) -> typing.Awaitable[None]:
...


async def connect_compat(*args: typing.Any, **kwargs: typing.Any) -> Stream:
if sys.version_info >= (3, 8):
return await asyncio.connect(*args, **kwargs)
else:
reader, writer = await asyncio.open_connection(*args, **kwargs)
return StreamCompat(reader, writer)


class StreamCompat:
"""
Thin wrapper around asyncio.StreamReader/StreamWriter to make them look and
behave similarly to an asyncio.Stream.
"""

def __init__(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter):
self.reader = reader
self.writer = writer

def at_eof(self) -> bool:
return self.reader.at_eof()

def close(self) -> typing.Awaitable[None]:
self.writer.close()
return _OptionalAwait(self.wait_closed)

async def drain(self) -> None:
await self.writer.drain()

def get_extra_info(self, name: str, default: typing.Any = None) -> typing.Any:
return self.writer.get_extra_info(name, default)

async def read(self, n: int = -1) -> bytes:
return await self.reader.read(n)

async def start_tls(
self,
sslContext: ssl.SSLContext,
*,
server_hostname: typing.Optional[str] = None,
ssl_handshake_timeout: typing.Optional[float] = None,
) -> None:
if not sys.version_info >= (3, 7): # pragma: no cover
raise NotImplementedError(
"asyncio.AbstractEventLoop.start_tls() is only available in Python 3.7+"
)
else:
# This code is in an else branch to appease mypy on Python < 3.7

reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
transport = self.writer.transport

loop = asyncio.get_event_loop()
loop_start_tls = loop.start_tls # type: ignore
tls_transport = await loop_start_tls(
transport=transport,
protocol=protocol,
sslcontext=sslContext,
server_hostname=server_hostname,
ssl_handshake_timeout=ssl_handshake_timeout,
)

reader.set_transport(tls_transport)
self.reader = reader
self.writer = asyncio.StreamWriter(
transport=tls_transport, protocol=protocol, reader=reader, loop=loop
)

def write(self, data: bytes) -> typing.Awaitable[None]:
self.writer.write(data)
return _OptionalAwait(self.drain)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

await-ing on the drain call is not exactly the same as how things work in Python 3.8 when you await on Stream.write() it first tries a fast-path to flush data and then falls back to a full drain.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently drain() was deprecated in 3.8 too, and it is now recommended to await stream.write(). Should we switch to this latter API? This means we'd need to refactor AsyncioBackend.write() a bit, and potentially use a manual buffering mechanism for .write_no_block(). (There's an example of this in the trio backend: #276.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hoping this could tie into #341 quite nicely maybe?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, actually it does. :)


async def wait_closed(self) -> None:
if sys.version_info >= (3, 7):
await self.writer.wait_closed()
# else not much we can do to wait for the connection to close


# This code is copied from cPython 3.8 but with type annotations added:
# https://github.com/python/cpython/blob/v3.8.0b4/Lib/asyncio/streams.py#L1262-L1273
_T = typing.TypeVar("_T")


class _OptionalAwait(typing.Generic[_T]):
# The class doesn't create a coroutine
# if not awaited
# It prevents "coroutine is never awaited" message

__slots___ = ("_method",)

def __init__(self, method: typing.Callable[[], typing.Awaitable[_T]]):
self._method = method

def __await__(self) -> typing.Generator[typing.Any, None, _T]:
return self._method().__await__()
4 changes: 2 additions & 2 deletions tests/test_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ async def test_start_tls_on_socket_stream(https_server):

try:
assert stream.is_connection_dropped() is False
assert stream.stream_writer.get_extra_info("cipher", default=None) is None
assert stream.stream.get_extra_info("cipher", default=None) is None

stream = await backend.start_tls(stream, https_server.url.host, ctx, timeout)
assert stream.is_connection_dropped() is False
assert stream.stream_writer.get_extra_info("cipher", default=None) is not None
assert stream.stream.get_extra_info("cipher", default=None) is not None

await stream.write(b"GET / HTTP/1.1\r\n\r\n")
assert (await stream.read(8192, timeout)).startswith(b"HTTP/1.1 200 OK\r\n")
Expand Down