Skip to content

Commit

Permalink
Added support for wait_readable() and wait_writable() on ProactorEven…
Browse files Browse the repository at this point in the history
…tLoop (#831)
  • Loading branch information
agronholm authored Dec 3, 2024
1 parent 97d5fe6 commit 0f80611
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 65 deletions.
4 changes: 3 additions & 1 deletion docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.
- Added the ``wait_readable()`` and ``wait_writable()`` functions which will accept
an object with a ``.fileno()`` method or an integer handle, and deprecated
their now obsolete versions (``wait_socket_readable()`` and
``wait_socket_writable()`` (PR by @davidbrochart)
``wait_socket_writable()``) (PR by @davidbrochart)
- Added support for ``wait_readable()`` and ``wait_writable()`` on ``ProactorEventLoop``
(used on asyncio + Windows by default)
- Fixed the return type annotations of ``readinto()`` and ``readinto1()`` methods in the
``anyio.AsyncFile`` class
(`#825 <https://github.com/agronholm/anyio/issues/825>`_)
Expand Down
62 changes: 37 additions & 25 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

if TYPE_CHECKING:
from _typeshed import HasFileno
from _typeshed import FileDescriptorLike
else:
FileDescriptorLike = object

if sys.version_info >= (3, 10):
from typing import ParamSpec
Expand Down Expand Up @@ -2734,7 +2736,7 @@ async def getnameinfo(
return await get_running_loop().getnameinfo(sockaddr, flags)

@classmethod
async def wait_readable(cls, obj: HasFileno | int) -> None:
async def wait_readable(cls, obj: FileDescriptorLike) -> None:
await cls.checkpoint()
try:
read_events = _read_events.get()
Expand All @@ -2746,25 +2748,30 @@ async def wait_readable(cls, obj: HasFileno | int) -> None:
obj = obj.fileno()

if read_events.get(obj):
raise BusyResourceError("reading from") from None
raise BusyResourceError("reading from")

loop = get_running_loop()
event = read_events[obj] = asyncio.Event()
loop.add_reader(obj, event.set)
event = asyncio.Event()
try:
loop.add_reader(obj, event.set)
except NotImplementedError:
from anyio._core._asyncio_selector_thread import get_selector

selector = get_selector()
selector.add_reader(obj, event.set)
remove_reader = selector.remove_reader
else:
remove_reader = loop.remove_reader

read_events[obj] = event
try:
await event.wait()
finally:
if read_events.pop(obj, None) is not None:
loop.remove_reader(obj)
readable = True
else:
readable = False

if not readable:
raise ClosedResourceError
remove_reader(obj)
del read_events[obj]

@classmethod
async def wait_writable(cls, obj: HasFileno | int) -> None:
async def wait_writable(cls, obj: FileDescriptorLike) -> None:
await cls.checkpoint()
try:
write_events = _write_events.get()
Expand All @@ -2776,22 +2783,27 @@ async def wait_writable(cls, obj: HasFileno | int) -> None:
obj = obj.fileno()

if write_events.get(obj):
raise BusyResourceError("writing to") from None
raise BusyResourceError("writing to")

loop = get_running_loop()
event = write_events[obj] = asyncio.Event()
loop.add_writer(obj, event.set)
event = asyncio.Event()
try:
loop.add_writer(obj, event.set)
except NotImplementedError:
from anyio._core._asyncio_selector_thread import get_selector

selector = get_selector()
selector.add_writer(obj, event.set)
remove_writer = selector.remove_writer
else:
remove_writer = loop.remove_writer

write_events[obj] = event
try:
await event.wait()
finally:
if write_events.pop(obj, None) is not None:
loop.remove_writer(obj)
writable = True
else:
writable = False

if not writable:
raise ClosedResourceError
del write_events[obj]
remove_writer(obj)

@classmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
Expand Down
150 changes: 150 additions & 0 deletions src/anyio/_core/_asyncio_selector_thread.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
from __future__ import annotations

import asyncio
import socket
import threading
from collections.abc import Callable
from selectors import EVENT_READ, EVENT_WRITE, DefaultSelector
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from _typeshed import FileDescriptorLike

_selector_lock = threading.Lock()
_selector: Selector | None = None


class Selector:
def __init__(self) -> None:
self._thread = threading.Thread(target=self.run, name="AnyIO socket selector")
self._selector = DefaultSelector()
self._send, self._receive = socket.socketpair()
self._send.setblocking(False)
self._receive.setblocking(False)
self._selector.register(self._receive, EVENT_READ)
self._closed = False

def start(self) -> None:
self._thread.start()
threading._register_atexit(self._stop) # type: ignore[attr-defined]

def _stop(self) -> None:
global _selector
self._closed = True
self._notify_self()
self._send.close()
self._thread.join()
self._selector.unregister(self._receive)
self._receive.close()
self._selector.close()
_selector = None
assert (
not self._selector.get_map()
), "selector still has registered file descriptors after shutdown"

def _notify_self(self) -> None:
try:
self._send.send(b"\x00")
except BlockingIOError:
pass

def add_reader(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None:
loop = asyncio.get_running_loop()
try:
key = self._selector.get_key(fd)
except KeyError:
self._selector.register(fd, EVENT_READ, {EVENT_READ: (loop, callback)})
else:
if EVENT_READ in key.data:
raise ValueError(
"this file descriptor is already registered for reading"
)

key.data[EVENT_READ] = loop, callback
self._selector.modify(fd, key.events | EVENT_READ, key.data)

self._notify_self()

def add_writer(self, fd: FileDescriptorLike, callback: Callable[[], Any]) -> None:
loop = asyncio.get_running_loop()
try:
key = self._selector.get_key(fd)
except KeyError:
self._selector.register(fd, EVENT_WRITE, {EVENT_WRITE: (loop, callback)})
else:
if EVENT_WRITE in key.data:
raise ValueError(
"this file descriptor is already registered for writing"
)

key.data[EVENT_WRITE] = loop, callback
self._selector.modify(fd, key.events | EVENT_WRITE, key.data)

self._notify_self()

def remove_reader(self, fd: FileDescriptorLike) -> bool:
try:
key = self._selector.get_key(fd)
except KeyError:
return False

if new_events := key.events ^ EVENT_READ:
del key.data[EVENT_READ]
self._selector.modify(fd, new_events, key.data)
else:
self._selector.unregister(fd)

return True

def remove_writer(self, fd: FileDescriptorLike) -> bool:
try:
key = self._selector.get_key(fd)
except KeyError:
return False

if new_events := key.events ^ EVENT_WRITE:
del key.data[EVENT_WRITE]
self._selector.modify(fd, new_events, key.data)
else:
self._selector.unregister(fd)

return True

def run(self) -> None:
while not self._closed:
for key, events in self._selector.select():
if key.fileobj is self._receive:
try:
while self._receive.recv(4096):
pass
except BlockingIOError:
pass

continue

if events & EVENT_READ:
loop, callback = key.data[EVENT_READ]
self.remove_reader(key.fd)
try:
loop.call_soon_threadsafe(callback)
except RuntimeError:
pass # the loop was already closed

if events & EVENT_WRITE:
loop, callback = key.data[EVENT_WRITE]
self.remove_writer(key.fd)
try:
loop.call_soon_threadsafe(callback)
except RuntimeError:
pass # the loop was already closed


def get_selector() -> Selector:
global _selector

with _selector_lock:
if _selector is None:
_selector = Selector()
_selector.start()

return _selector
33 changes: 14 additions & 19 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
from ._tasks import create_task_group, move_on_after

if TYPE_CHECKING:
from _typeshed import HasFileno
from _typeshed import FileDescriptorLike
else:
HasFileno = object
FileDescriptorLike = object

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup
Expand Down Expand Up @@ -609,9 +609,6 @@ def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
Wait until the given socket has data to be read.
This does **NOT** work on Windows when using the asyncio backend with a proactor
event loop (default on py3.8+).
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!
Expand Down Expand Up @@ -649,7 +646,7 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
return get_async_backend().wait_writable(sock.fileno())


def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
def wait_readable(obj: FileDescriptorLike) -> Awaitable[None]:
"""
Wait until the given object has data to be read.
Expand All @@ -663,10 +660,11 @@ def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
descriptors aren't supported, and neither are handles that refer to anything besides
a ``SOCKET``.
This does **NOT** work on Windows when using the asyncio backend with a proactor
event loop (default on py3.8+).
On backends where this functionality is not natively provided (asyncio
``ProactorEventLoop`` on Windows), it is provided using a separate selector thread
which is set to shut down when the interpreter shuts down.
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
.. warning:: Don't use this on raw sockets that have been wrapped by any higher
level constructs like socket streams!
:param obj: an object with a ``.fileno()`` method or an integer handle
Expand All @@ -679,25 +677,22 @@ def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
return get_async_backend().wait_readable(obj)


def wait_writable(obj: HasFileno | int) -> Awaitable[None]:
def wait_writable(obj: FileDescriptorLike) -> Awaitable[None]:
"""
Wait until the given object can be written to.
This does **NOT** work on Windows when using the asyncio backend with a proactor
event loop (default on py3.8+).
.. seealso:: See the documentation of :func:`wait_readable` for the definition of
``obj``.
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!
:param obj: an object with a ``.fileno()`` method or an integer handle
:raises ~anyio.ClosedResourceError: if the object was closed while waiting for the
object to become writable
:raises ~anyio.BusyResourceError: if another task is already waiting for the object
to become writable
.. seealso:: See the documentation of :func:`wait_readable` for the definition of
``obj`` and notes on backend compatibility.
.. warning:: Don't use this on raw sockets that have been wrapped by any higher
level constructs like socket streams!
"""
return get_async_backend().wait_writable(obj)

Expand Down
26 changes: 6 additions & 20 deletions tests/test_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from exceptiongroup import ExceptionGroup

if TYPE_CHECKING:
from _typeshed import HasFileno
from _typeshed import FileDescriptorLike

AnyIPAddressFamily = Literal[
AddressFamily.AF_UNSPEC, AddressFamily.AF_INET, AddressFamily.AF_INET6
Expand Down Expand Up @@ -1858,16 +1858,7 @@ async def test_connect_tcp_getaddrinfo_context() -> None:

@pytest.mark.parametrize("socket_type", ["socket", "fd"])
@pytest.mark.parametrize("event", ["readable", "writable"])
async def test_wait_socket(
anyio_backend_name: str, event: str, socket_type: str
) -> None:
if anyio_backend_name == "asyncio" and platform.system() == "Windows":
import asyncio

policy = asyncio.get_event_loop_policy()
if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy":
pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop")

async def test_wait_socket(event: str, socket_type: str) -> None:
wait = wait_readable if event == "readable" else wait_writable

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as server_sock:
Expand All @@ -1880,20 +1871,15 @@ async def test_wait_socket(

conn, addr = server_sock.accept()
with conn:
sock_or_fd: HasFileno | int = conn.fileno() if socket_type == "fd" else conn
with fail_after(10):
sock_or_fd: FileDescriptorLike = (
conn.fileno() if socket_type == "fd" else conn
)
with fail_after(3):
await wait(sock_or_fd)
assert conn.recv(1024) == b"Hello, world"


async def test_deprecated_wait_socket(anyio_backend_name: str) -> None:
if anyio_backend_name == "asyncio" and platform.system() == "Windows":
import asyncio

policy = asyncio.get_event_loop_policy()
if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy":
pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop")

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
with pytest.warns(
DeprecationWarning,
Expand Down

0 comments on commit 0f80611

Please sign in to comment.