diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index d1df910d..d112e3a9 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -11,6 +11,7 @@ This library adheres to `Semantic Versioning 2.0 `_.
- Changed asyncio task groups so that if the host and child tasks have only raised
``CancelledErrors``, just one ``CancelledError`` will now be raised instead of an
``ExceptionGroup``, allowing asyncio to ignore it when it propagates out of the task
+- Changed task names to be converted to ``str`` early on asyncio (PR by Thomas Grainger)
- Fixed ``sniffio._impl.AsyncLibraryNotFoundError: unknown async library, or not in async context``
on asyncio and Python 3.6 when ``to_thread.run_sync()`` is used from
``loop.run_until_complete()``
@@ -20,6 +21,17 @@ This library adheres to `Semantic Versioning 2.0 `_.
task is cancelled (PR by Thomas Grainger)
- Fixed declared return type of ``TaskGroup.start()`` (it was declared as ``None``, but anything
can be returned from it)
+- Fixed ``TextStream.extra_attributes`` raising ``AttributeError`` (PR by Thomas Grainger)
+- Fixed ``await maybe_async(current_task())`` returning ``None`` (PR by Thomas Grainger)
+- Fixed: ``pickle.dumps(current_task())`` now correctly raises ``TypeError`` instead of pickling to
+ ``None`` (PR by Thomas Grainger)
+- Fixed return type annotation of ``Event.wait()`` (``bool`` → ``None``) (PR by Thomas Grainger)
+- Fixed return type annotation of ``RunVar.get()`` to return either the type of the default value
+ or the type of the contained value (PR by Thomas Grainger)
+- Fixed a deprecation warning message to refer to ``maybe_async()`` and not ``maybe_awaitable()``
+ (PR by Thomas Grainger)
+- Filled in argument and return types for all functions and methods previously missing them
+ (PR by Thomas Grainger)
**3.0.1**
diff --git a/setup.cfg b/setup.cfg
index 353182af..931b930b 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -60,3 +60,4 @@ pytest11 =
[mypy]
ignore_missing_imports = true
+disallow_untyped_defs = true
diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py
index 5bcc6514..c8545e3f 100644
--- a/src/anyio/_backends/_asyncio.py
+++ b/src/anyio/_backends/_asyncio.py
@@ -100,8 +100,8 @@ def _cancel_all_tasks(loop):
events.set_event_loop(None)
loop.close()
- def create_task(coro: Union[Generator[Any, None, _T], Awaitable[_T]], *, # type: ignore
- name: Optional[str] = None) -> asyncio.Task:
+ def create_task(coro: Union[Generator[Any, None, _T], Awaitable[_T]], *,
+ name: object = None) -> asyncio.Task:
return get_running_loop().create_task(coro)
def get_running_loop() -> asyncio.AbstractEventLoop:
@@ -213,11 +213,12 @@ def _maybe_set_event_loop_policy(policy: Optional[asyncio.AbstractEventLoopPolic
asyncio.set_event_loop_policy(policy)
-def run(func: Callable[..., T_Retval], *args, debug: bool = False, use_uvloop: bool = True,
+def run(func: Callable[..., Awaitable[T_Retval]], *args: object,
+ debug: bool = False, use_uvloop: bool = True,
policy: Optional[asyncio.AbstractEventLoopPolicy] = None) -> T_Retval:
@wraps(func)
- async def wrapper():
- task = current_task()
+ async def wrapper() -> T_Retval:
+ task = cast(asyncio.Task, current_task())
task_state = TaskState(None, get_callable_name(func), None)
_task_states[task] = task_state
if _native_task_names:
@@ -247,7 +248,7 @@ async def wrapper():
class CancelScope(BaseCancelScope):
- def __new__(cls, *, deadline: float = math.inf, shield: bool = False):
+ def __new__(cls, *, deadline: float = math.inf, shield: bool = False) -> "CancelScope":
return object.__new__(cls)
def __init__(self, deadline: float = math.inf, shield: bool = False):
@@ -262,20 +263,20 @@ def __init__(self, deadline: float = math.inf, shield: bool = False):
self._host_task: Optional[asyncio.Task] = None
self._timeout_expired = False
- def __enter__(self):
+ def __enter__(self) -> "CancelScope":
if self._active:
raise RuntimeError(
"Each CancelScope may only be used for a single 'with' block"
)
- self._host_task = current_task()
- self._tasks.add(self._host_task)
+ self._host_task = host_task = cast(asyncio.Task, current_task())
+ self._tasks.add(host_task)
try:
- task_state = _task_states[self._host_task]
+ task_state = _task_states[host_task]
except KeyError:
- task_name = self._host_task.get_name() if _native_task_names else None
+ task_name = host_task.get_name() if _native_task_names else None
task_state = TaskState(None, task_name, self)
- _task_states[self._host_task] = task_state
+ _task_states[host_task] = task_state
else:
self._parent_scope = task_state.cancel_scope
task_state.cancel_scope = self
@@ -326,7 +327,7 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[Ba
return None
- def _timeout(self):
+ def _timeout(self) -> None:
if self._deadline != math.inf:
loop = get_running_loop()
if loop.time() >= self._deadline:
@@ -460,9 +461,9 @@ async def cancel_shielded_checkpoint() -> None:
await sleep(0)
-def current_effective_deadline():
+def current_effective_deadline() -> float:
try:
- cancel_scope = _task_states[current_task()].cancel_scope
+ cancel_scope = _task_states[current_task()].cancel_scope # type: ignore[index]
except KeyError:
return math.inf
@@ -477,7 +478,7 @@ def current_effective_deadline():
return deadline
-def current_time():
+def current_time() -> float:
return get_running_loop().time()
@@ -517,17 +518,17 @@ class _AsyncioTaskStatus(abc.TaskStatus):
def __init__(self, future: asyncio.Future):
self._future = future
- def started(self, value=None) -> None:
+ def started(self, value: object = None) -> None:
self._future.set_result(value)
class TaskGroup(abc.TaskGroup):
- def __init__(self):
+ def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
self._active = False
self._exceptions: List[BaseException] = []
- async def __aenter__(self):
+ async def __aenter__(self) -> "TaskGroup":
self.cancel_scope.__enter__()
self._active = True
return self
@@ -613,7 +614,7 @@ async def _run_wrapped_task(
self.cancel_scope._tasks.remove(task)
del _task_states[task]
- def _spawn(self, func: Callable[..., Coroutine], args: tuple, name,
+ def _spawn(self, func: Callable[..., Coroutine], args: tuple, name: object,
task_status_future: Optional[asyncio.Future] = None) -> asyncio.Task:
def task_done(_task: asyncio.Task) -> None:
# This is the code path for Python 3.8+
@@ -643,7 +644,7 @@ def task_done(_task: asyncio.Task) -> None:
raise RuntimeError('This task group is not active; no new tasks can be started.')
options = {}
- name = name or get_callable_name(func)
+ name = get_callable_name(func) if name is None else str(name)
if _native_task_names:
options['name'] = name
@@ -669,10 +670,12 @@ def task_done(_task: asyncio.Task) -> None:
self.cancel_scope._tasks.add(task)
return task
- def start_soon(self, func: Callable[..., Coroutine], *args, name=None) -> None:
+ def start_soon(self, func: Callable[..., Coroutine], *args: object,
+ name: object = None) -> None:
self._spawn(func, args, name)
- async def start(self, func: Callable[..., Coroutine], *args, name=None) -> None:
+ async def start(self, func: Callable[..., Coroutine], *args: object,
+ name: object = None) -> None:
future: asyncio.Future = asyncio.Future()
task = self._spawn(func, args, name, future)
@@ -745,7 +748,7 @@ def stop(self, f: Optional[asyncio.Task] = None) -> None:
async def run_sync_in_worker_thread(
- func: Callable[..., T_Retval], *args, cancellable: bool = False,
+ func: Callable[..., T_Retval], *args: object, cancellable: bool = False,
limiter: Optional['CapacityLimiter'] = None) -> T_Retval:
await checkpoint()
@@ -789,10 +792,10 @@ async def run_sync_in_worker_thread(
idle_workers.append(worker)
-def run_sync_from_thread(func: Callable[..., T_Retval], *args,
+def run_sync_from_thread(func: Callable[..., T_Retval], *args: object,
loop: Optional[asyncio.AbstractEventLoop] = None) -> T_Retval:
@wraps(func)
- def wrapper():
+ def wrapper() -> None:
try:
f.set_result(func(*args))
except BaseException as exc:
@@ -806,22 +809,24 @@ def wrapper():
return f.result()
-def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval:
+def run_async_from_thread(
+ func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object
+) -> T_Retval:
f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe(
func(*args), threadlocals.loop)
return f.result()
class BlockingPortal(abc.BlockingPortal):
- def __new__(cls):
+ def __new__(cls) -> "BlockingPortal":
return object.__new__(cls)
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
self._loop = get_running_loop()
def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any],
- name, future: Future) -> None:
+ name: object, future: Future) -> None:
run_sync_from_thread(
partial(self._task_group.start_soon, name=name), self._call_func, func, args, kwargs,
future, loop=self._loop)
@@ -908,13 +913,16 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]:
return self._stderr
-async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int,
+async def open_process(command: Union[str, Sequence[str]], *, shell: bool,
+ stdin: int, stdout: int, stderr: int,
cwd: Union[str, bytes, PathLike, None] = None,
env: Optional[Mapping[str, str]] = None) -> Process:
await checkpoint()
if shell:
- process = await asyncio.create_subprocess_shell(command, stdin=stdin, stdout=stdout,
- stderr=stderr, cwd=cwd, env=env)
+ process = await asyncio.create_subprocess_shell(
+ command, stdin=stdin, stdout=stdout, # type: ignore[arg-type]
+ stderr=stderr, cwd=cwd, env=env,
+ )
else:
process = await asyncio.create_subprocess_exec(*command, stdin=stdin, stdout=stdout,
stderr=stderr, cwd=cwd, env=env)
@@ -925,7 +933,7 @@ async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr:
return Process(process, stdin_stream, stdout_stream, stderr_stream)
-def _forcibly_shutdown_process_pool_on_exit(workers: Set[Process], _task) -> None:
+def _forcibly_shutdown_process_pool_on_exit(workers: Set[Process], _task: object) -> None:
"""
Forcibly shuts down worker processes belonging to this event loop."""
child_watcher: Optional[asyncio.AbstractChildWatcher]
@@ -1142,7 +1150,7 @@ def _raw_socket(self) -> SocketType:
return self.__raw_socket
def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
- def callback(f):
+ def callback(f: object) -> None:
del self._receive_future
loop.remove_reader(self.__raw_socket)
@@ -1152,7 +1160,7 @@ def callback(f):
return f
def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
- def callback(f):
+ def callback(f: object) -> None:
del self._send_future
loop.remove_writer(self.__raw_socket)
@@ -1594,10 +1602,10 @@ async def wait_socket_writable(sock: socket.SocketType) -> None:
#
class Event(BaseEvent):
- def __new__(cls):
+ def __new__(cls) -> "Event":
return object.__new__(cls)
- def __init__(self):
+ def __init__(self) -> None:
self._event = asyncio.Event()
def set(self) -> DeprecatedAwaitable:
@@ -1607,18 +1615,18 @@ def set(self) -> DeprecatedAwaitable:
def is_set(self) -> bool:
return self._event.is_set()
- async def wait(self):
+ async def wait(self) -> None:
if await self._event.wait():
await checkpoint()
def statistics(self) -> EventStatistics:
- return EventStatistics(len(self._event._waiters))
+ return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined]
class CapacityLimiter(BaseCapacityLimiter):
_total_tokens: float = 0
- def __new__(cls, total_tokens: float):
+ def __new__(cls, total_tokens: float) -> "CapacityLimiter":
return object.__new__(cls)
def __init__(self, total_tokens: float):
@@ -1626,7 +1634,7 @@ def __init__(self, total_tokens: float):
self._wait_queue: Dict[Any, asyncio.Event] = OrderedDict()
self.total_tokens = total_tokens
- async def __aenter__(self):
+ async def __aenter__(self) -> None:
await self.acquire()
async def __aexit__(self, exc_type: Optional[Type[BaseException]],
@@ -1671,7 +1679,7 @@ def acquire_nowait(self) -> DeprecatedAwaitable:
self.acquire_on_behalf_of_nowait(current_task())
return DeprecatedAwaitable(self.acquire_nowait)
- def acquire_on_behalf_of_nowait(self, borrower) -> DeprecatedAwaitable:
+ def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable:
if borrower in self._borrowers:
raise RuntimeError("this borrower is already holding one of this CapacityLimiter's "
"tokens")
@@ -1685,7 +1693,7 @@ def acquire_on_behalf_of_nowait(self, borrower) -> DeprecatedAwaitable:
async def acquire(self) -> None:
return await self.acquire_on_behalf_of(current_task())
- async def acquire_on_behalf_of(self, borrower) -> None:
+ async def acquire_on_behalf_of(self, borrower: object) -> None:
await checkpoint_if_cancelled()
try:
self.acquire_on_behalf_of_nowait(borrower)
@@ -1705,7 +1713,7 @@ async def acquire_on_behalf_of(self, borrower) -> None:
def release(self) -> None:
self.release_on_behalf_of(current_task())
- def release_on_behalf_of(self, borrower) -> None:
+ def release_on_behalf_of(self, borrower: object) -> None:
try:
self._borrowers.remove(borrower)
except KeyError:
@@ -1725,7 +1733,7 @@ def statistics(self) -> CapacityLimiterStatistics:
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar('_default_thread_limiter')
-def current_default_thread_limiter():
+def current_default_thread_limiter() -> CapacityLimiter:
try:
return _default_thread_limiter.get()
except LookupError:
@@ -1751,18 +1759,21 @@ def _deliver(self, signum: int) -> None:
if not self._future.done():
self._future.set_result(None)
- def __enter__(self):
+ def __enter__(self) -> "_SignalReceiver":
for sig in set(self._signals):
self._loop.add_signal_handler(sig, self._deliver, sig)
self._handled_signals.add(sig)
return self
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(self, exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType]) -> Optional[bool]:
for sig in self._handled_signals:
self._loop.remove_signal_handler(sig)
+ return None
- def __aiter__(self):
+ def __aiter__(self) -> "_SignalReceiver":
return self
async def __anext__(self) -> int:
@@ -1825,7 +1836,7 @@ def __init__(self, debug: bool = False, use_uvloop: bool = True,
self._loop.set_debug(debug)
asyncio.set_event_loop(self._loop)
- def _cancel_all_tasks(self):
+ def _cancel_all_tasks(self) -> None:
to_cancel = all_tasks(self._loop)
if not to_cancel:
return
@@ -1840,7 +1851,7 @@ def _cancel_all_tasks(self):
if task.cancelled():
continue
if task.exception() is not None:
- raise task.exception()
+ raise cast(BaseException, task.exception())
def close(self) -> None:
try:
@@ -1850,16 +1861,17 @@ def close(self) -> None:
asyncio.set_event_loop(None)
self._loop.close()
- def call(self, func: Callable[..., Awaitable], *args, **kwargs):
+ def call(self, func: Callable[..., Awaitable[T_Retval]],
+ *args: object, **kwargs: object) -> T_Retval:
def exception_handler(loop: asyncio.AbstractEventLoop, context: Dict[str, Any]) -> None:
exceptions.append(context['exception'])
exceptions: List[Exception] = []
self._loop.set_exception_handler(exception_handler)
try:
- retval = self._loop.run_until_complete(func(*args, **kwargs))
+ retval: T_Retval = self._loop.run_until_complete(func(*args, **kwargs))
except Exception as exc:
- retval = None
+ retval = None # type: ignore[assignment]
exceptions.append(exc)
finally:
self._loop.set_exception_handler(None)
diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py
index 48be7a51..42764907 100644
--- a/src/anyio/_backends/_trio.py
+++ b/src/anyio/_backends/_trio.py
@@ -8,11 +8,12 @@
from os import PathLike
from types import TracebackType
from typing import (
- Any, Awaitable, Callable, Collection, Coroutine, Dict, Generic, List, Mapping, NoReturn,
- Optional, Set, Tuple, Type, TypeVar, Union)
+ Any, Awaitable, Callable, Collection, ContextManager, Coroutine, Deque, Dict, Generic, List,
+ Mapping, NoReturn, Optional, Sequence, Set, Tuple, Type, TypeVar, Union)
import trio.from_thread
-from outcome import Error, Value
+from outcome import Error, Outcome, Value
+from trio.socket import SocketType as TrioSocketType
from trio.to_thread import run_sync
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
@@ -36,6 +37,7 @@
else:
from trio.lowlevel import wait_readable, wait_writable
+
T_Retval = TypeVar('T_Retval')
T_SockAddr = TypeVar('T_SockAddr', str, IPSockAddrType)
@@ -61,17 +63,20 @@
#
class CancelScope(BaseCancelScope):
- def __new__(cls, original: Optional[trio.CancelScope] = None, **kwargs):
+ def __new__(cls, original: Optional[trio.CancelScope] = None,
+ **kwargs: object) -> 'CancelScope':
return object.__new__(cls)
- def __init__(self, original: Optional[trio.CancelScope] = None, **kwargs):
+ def __init__(self, original: Optional[trio.CancelScope] = None, **kwargs: object) -> None:
self.__original = original or trio.CancelScope(**kwargs)
- def __enter__(self):
+ def __enter__(self) -> 'CancelScope':
self.__original.__enter__()
return self
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(self, exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType]) -> Optional[bool]:
return self.__original.__exit__(exc_type, exc_val, exc_tb)
def cancel(self) -> DeprecatedAwaitable:
@@ -116,12 +121,12 @@ class ExceptionGroup(BaseExceptionGroup, trio.MultiError):
class TaskGroup(abc.TaskGroup):
- def __init__(self):
+ def __init__(self) -> None:
self._active = False
self._nursery_manager = trio.open_nursery()
- self.cancel_scope = None
+ self.cancel_scope = None # type: ignore[assignment]
- async def __aenter__(self):
+ async def __aenter__(self) -> 'TaskGroup':
self._active = True
self._nursery = await self._nursery_manager.__aenter__()
self.cancel_scope = CancelScope(self._nursery.cancel_scope)
@@ -137,13 +142,14 @@ async def __aexit__(self, exc_type: Optional[Type[BaseException]],
finally:
self._active = False
- def start_soon(self, func: Callable, *args, name=None) -> None:
+ def start_soon(self, func: Callable, *args: object, name: object = None) -> None:
if not self._active:
raise RuntimeError('This task group is not active; no new tasks can be started.')
self._nursery.start_soon(func, *args, name=name)
- async def start(self, func: Callable[..., Coroutine], *args, name=None):
+ async def start(self, func: Callable[..., Coroutine],
+ *args: object, name: object = None) -> object:
if not self._active:
raise RuntimeError('This task group is not active; no new tasks can be started.')
@@ -155,9 +161,9 @@ async def start(self, func: Callable[..., Coroutine], *args, name=None):
async def run_sync_in_worker_thread(
- func: Callable[..., T_Retval], *args, cancellable: bool = False,
+ func: Callable[..., T_Retval], *args: object, cancellable: bool = False,
limiter: Optional[trio.CapacityLimiter] = None) -> T_Retval:
- def wrapper():
+ def wrapper() -> T_Retval:
with claim_worker_thread('trio'):
return func(*args)
@@ -168,15 +174,15 @@ def wrapper():
class BlockingPortal(abc.BlockingPortal):
- def __new__(cls):
+ def __new__(cls) -> 'BlockingPortal':
return object.__new__(cls)
- def __init__(self):
+ def __init__(self) -> None:
super().__init__()
self._token = trio.lowlevel.current_trio_token()
def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any],
- name, future: Future) -> None:
+ name: object, future: Future) -> None:
return trio.from_thread.run_sync(
partial(self._task_group.start_soon, name=name), self._call_func, func, args, kwargs,
future, trio_token=self._token)
@@ -273,7 +279,8 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]:
return self._stderr
-async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int,
+async def open_process(command: Union[str, Sequence[str]], *, shell: bool,
+ stdin: int, stdout: int, stderr: int,
cwd: Union[str, bytes, PathLike, None] = None,
env: Optional[Mapping[str, str]] = None) -> Process:
process = await trio.open_process(command, stdin=stdin, stdout=stdout, stderr=stderr,
@@ -285,7 +292,7 @@ async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr:
class _ProcessPoolShutdownInstrument(trio.abc.Instrument):
- def after_run(self):
+ def after_run(self) -> None:
super().after_run()
@@ -316,7 +323,7 @@ def setup_process_pool_exit_at_shutdown(workers: Set[Process]) -> None:
#
class _TrioSocketMixin(Generic[T_SockAddr]):
- def __init__(self, trio_socket):
+ def __init__(self, trio_socket: TrioSocketType) -> None:
self._trio_socket = trio_socket
self._closed = False
@@ -347,7 +354,7 @@ def _convert_socket_error(self, exc: BaseException) -> 'NoReturn':
class SocketStream(_TrioSocketMixin, abc.SocketStream):
- def __init__(self, trio_socket):
+ def __init__(self, trio_socket: TrioSocketType) -> None:
super().__init__(trio_socket)
self._receive_guard = ResourceGuard('reading from')
self._send_guard = ResourceGuard('writing to')
@@ -467,7 +474,7 @@ async def accept(self) -> UNIXSocketStream:
class UDPSocket(_TrioSocketMixin[IPSockAddrType], abc.UDPSocket):
- def __init__(self, trio_socket):
+ def __init__(self, trio_socket: TrioSocketType) -> None:
super().__init__(trio_socket)
self._receive_guard = ResourceGuard('reading from')
self._send_guard = ResourceGuard('writing to')
@@ -489,7 +496,7 @@ async def send(self, item: UDPPacketType) -> None:
class ConnectedUDPSocket(_TrioSocketMixin[IPSockAddrType], abc.ConnectedUDPSocket):
- def __init__(self, trio_socket):
+ def __init__(self, trio_socket: TrioSocketType) -> None:
super().__init__(trio_socket)
self._receive_guard = ResourceGuard('reading from')
self._send_guard = ResourceGuard('writing to')
@@ -562,7 +569,7 @@ async def create_udp_socket(
getnameinfo = trio.socket.getnameinfo
-async def wait_socket_readable(sock):
+async def wait_socket_readable(sock: socket.SocketType) -> None:
try:
await wait_readable(sock)
except trio.ClosedResourceError as exc:
@@ -571,7 +578,7 @@ async def wait_socket_readable(sock):
raise BusyResourceError('reading from') from None
-async def wait_socket_writable(sock):
+async def wait_socket_writable(sock: socket.SocketType) -> None:
try:
await wait_writable(sock)
except trio.ClosedResourceError as exc:
@@ -585,34 +592,34 @@ async def wait_socket_writable(sock):
#
class Event(BaseEvent):
- def __new__(cls):
+ def __new__(cls) -> 'Event':
return object.__new__(cls)
- def __init__(self):
+ def __init__(self) -> None:
self.__original = trio.Event()
def is_set(self) -> bool:
return self.__original.is_set()
- async def wait(self) -> bool:
+ async def wait(self) -> None:
return await self.__original.wait()
def statistics(self) -> EventStatistics:
return self.__original.statistics()
- def set(self):
+ def set(self) -> DeprecatedAwaitable:
self.__original.set()
return DeprecatedAwaitable(self.set)
class CapacityLimiter(BaseCapacityLimiter):
- def __new__(cls, *args, **kwargs):
+ def __new__(cls, *args: object, **kwargs: object) -> "CapacityLimiter":
return object.__new__(cls)
- def __init__(self, *args, original: Optional[trio.CapacityLimiter] = None):
+ def __init__(self, *args: object, original: Optional[trio.CapacityLimiter] = None) -> None:
self.__original = original or trio.CapacityLimiter(*args)
- async def __aenter__(self):
+ async def __aenter__(self) -> None:
return await self.__original.__aenter__()
async def __aexit__(self, exc_type: Optional[Type[BaseException]],
@@ -636,24 +643,24 @@ def borrowed_tokens(self) -> int:
def available_tokens(self) -> float:
return self.__original.available_tokens
- def acquire_nowait(self):
+ def acquire_nowait(self) -> DeprecatedAwaitable:
self.__original.acquire_nowait()
return DeprecatedAwaitable(self.acquire_nowait)
- def acquire_on_behalf_of_nowait(self, borrower):
+ def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable:
self.__original.acquire_on_behalf_of_nowait(borrower)
return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait)
async def acquire(self) -> None:
await self.__original.acquire()
- async def acquire_on_behalf_of(self, borrower) -> None:
+ async def acquire_on_behalf_of(self, borrower: object) -> None:
await self.__original.acquire_on_behalf_of(borrower)
def release(self) -> None:
return self.__original.release()
- def release_on_behalf_of(self, borrower) -> None:
+ def release_on_behalf_of(self, borrower: object) -> None:
return self.__original.release_on_behalf_of(borrower)
def statistics(self) -> CapacityLimiterStatistics:
@@ -677,17 +684,19 @@ def current_default_thread_limiter() -> CapacityLimiter:
#
class _SignalReceiver(DeprecatedAsyncContextManager):
- def __init__(self, cm):
+ def __init__(self, cm: ContextManager[T]):
self._cm = cm
def __enter__(self) -> T:
return self._cm.__enter__()
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(self, exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType]) -> Optional[bool]:
return self._cm.__exit__(exc_type, exc_val, exc_tb)
-def open_signal_receiver(*signals: int):
+def open_signal_receiver(*signals: int) -> _SignalReceiver:
cm = trio.open_signal_receiver(*signals)
return _SignalReceiver(cm)
@@ -723,18 +732,18 @@ def get_running_tasks() -> List[TaskInfo]:
return task_infos
-def wait_all_tasks_blocked():
+def wait_all_tasks_blocked() -> Awaitable[None]:
import trio.testing
return trio.testing.wait_all_tasks_blocked()
class TestRunner(abc.TestRunner):
- def __init__(self, **options):
+ def __init__(self, **options: object) -> None:
from collections import deque
from queue import Queue
- self._call_queue = Queue()
- self._result_queue = deque()
+ self._call_queue: "Queue[Callable[..., object]]" = Queue()
+ self._result_queue: Deque[Outcome] = deque()
self._stop_event: Optional[trio.Event] = None
self._nursery: Optional[trio.Nursery] = None
self._options = options
@@ -744,7 +753,8 @@ async def _trio_main(self) -> None:
async with trio.open_nursery() as self._nursery:
await self._stop_event.wait()
- async def _call_func(self, func, args, kwargs):
+ async def _call_func(self, func: Callable[..., Awaitable[object]],
+ args: tuple, kwargs: dict) -> None:
try:
retval = await func(*args, **kwargs)
except BaseException as exc:
@@ -752,7 +762,7 @@ async def _call_func(self, func, args, kwargs):
else:
self._result_queue.append(Value(retval))
- def _main_task_finished(self, outcome) -> None:
+ def _main_task_finished(self, outcome: object) -> None:
self._nursery = None
def close(self) -> None:
@@ -761,7 +771,8 @@ def close(self) -> None:
while self._nursery is not None:
self._call_queue.get()()
- def call(self, func: Callable[..., Awaitable], *args, **kwargs):
+ def call(self, func: Callable[..., Awaitable[T_Retval]],
+ *args: object, **kwargs: object) -> T_Retval:
if self._nursery is None:
trio.lowlevel.start_guest_run(
self._trio_main, run_sync_soon_threadsafe=self._call_queue.put,
diff --git a/src/anyio/_core/_compat.py b/src/anyio/_core/_compat.py
index 9f44ea95..0ebfeb2d 100644
--- a/src/anyio/_core/_compat.py
+++ b/src/anyio/_core/_compat.py
@@ -1,13 +1,24 @@
from abc import ABCMeta, abstractmethod
from contextlib import AbstractContextManager
+from types import TracebackType
from typing import (
- AsyncContextManager, Callable, ContextManager, Generic, List, Optional, TypeVar, Union,
- overload)
+ TYPE_CHECKING, AsyncContextManager, Callable, ContextManager, Generic, Iterable, List,
+ Optional, Tuple, Type, TypeVar, Union, overload)
from warnings import warn
+if TYPE_CHECKING:
+ from ._testing import TaskInfo
+else:
+ TaskInfo = object
+
T = TypeVar('T')
AnyDeprecatedAwaitable = Union['DeprecatedAwaitable', 'DeprecatedAwaitableFloat',
- 'DeprecatedAwaitableList']
+ 'DeprecatedAwaitableList', TaskInfo]
+
+
+@overload
+async def maybe_async(__obj: TaskInfo) -> TaskInfo:
+ ...
@overload
@@ -16,7 +27,7 @@ async def maybe_async(__obj: 'DeprecatedAwaitableFloat') -> float:
@overload
-async def maybe_async(__obj: 'DeprecatedAwaitableList') -> list:
+async def maybe_async(__obj: 'DeprecatedAwaitableList[T]') -> List[T]:
...
@@ -25,7 +36,7 @@ async def maybe_async(__obj: 'DeprecatedAwaitable') -> None:
...
-async def maybe_async(__obj: AnyDeprecatedAwaitable) -> Union[float, list, None]:
+async def maybe_async(__obj: AnyDeprecatedAwaitable) -> Union[TaskInfo, float, list, None]:
"""
Await on the given object if necessary.
@@ -49,7 +60,9 @@ def __init__(self, cm: ContextManager[T]):
async def __aenter__(self) -> T:
return self._cm.__enter__()
- async def __aexit__(self, exc_type, exc_val, exc_tb) -> Optional[bool]:
+ async def __aexit__(self, exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType]) -> Optional[bool]:
return self._cm.__exit__(exc_type, exc_val, exc_tb)
@@ -74,7 +87,7 @@ def maybe_async_cm(cm: Union[ContextManager[T], AsyncContextManager[T]]) -> Asyn
def _warn_deprecation(awaitable: AnyDeprecatedAwaitable, stacklevel: int = 1) -> None:
warn(f'Awaiting on {awaitable._name}() is deprecated. Use "await '
- f'anyio.maybe_awaitable({awaitable._name}(...)) if you have to support both AnyIO 2.x '
+ f'anyio.maybe_async({awaitable._name}(...)) if you have to support both AnyIO 2.x '
f'and 3.x, or just remove the "await" if you are completely migrating to AnyIO 3+.',
DeprecationWarning, stacklevel=stacklevel + 1)
@@ -83,33 +96,35 @@ class DeprecatedAwaitable:
def __init__(self, func: Callable[..., 'DeprecatedAwaitable']):
self._name = f'{func.__module__}.{func.__qualname__}'
- def __await__(self):
+ def __await__(self) -> Iterable[None]:
_warn_deprecation(self)
if False:
yield
- def __reduce__(self):
+ def __reduce__(self) -> Tuple[Type[None], Tuple]:
return type(None), ()
- def _unwrap(self):
+ def _unwrap(self) -> None:
return None
class DeprecatedAwaitableFloat(float):
- def __new__(cls, x, func):
+ def __new__(
+ cls, x: float, func: Callable[..., 'DeprecatedAwaitableFloat']
+ ) -> 'DeprecatedAwaitableFloat':
return super().__new__(cls, x)
def __init__(self, x: float, func: Callable[..., 'DeprecatedAwaitableFloat']):
self._name = f'{func.__module__}.{func.__qualname__}'
- def __await__(self):
+ def __await__(self) -> Iterable[float]:
_warn_deprecation(self)
if False:
yield
return float(self)
- def __reduce__(self):
+ def __reduce__(self) -> Tuple[Type[float], Tuple[float]]:
return float, (float(self),)
def _unwrap(self) -> float:
@@ -117,18 +132,18 @@ def _unwrap(self) -> float:
class DeprecatedAwaitableList(List[T]):
- def __init__(self, *args, func: Callable[..., 'DeprecatedAwaitableList']):
+ def __init__(self, *args: T, func: Callable[..., 'DeprecatedAwaitableList']):
super().__init__(*args)
self._name = f'{func.__module__}.{func.__qualname__}'
- def __await__(self):
+ def __await__(self) -> Iterable[List[T]]:
_warn_deprecation(self)
if False:
yield
- return self
+ return list(self)
- def __reduce__(self):
+ def __reduce__(self) -> Tuple[Type[list], Tuple[List[T]]]:
return list, (list(self),)
def _unwrap(self) -> List[T]:
@@ -141,7 +156,9 @@ def __enter__(self) -> T:
pass
@abstractmethod
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(self, exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType]) -> Optional[bool]:
pass
async def __aenter__(self) -> T:
@@ -151,5 +168,7 @@ async def __aenter__(self) -> T:
f'you are completely migrating to AnyIO 3+.', DeprecationWarning)
return self.__enter__()
- async def __aexit__(self, exc_type, exc_val, exc_tb) -> Optional[bool]:
+ async def __aexit__(self, exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType]) -> Optional[bool]:
return self.__exit__(exc_type, exc_val, exc_tb)
diff --git a/src/anyio/_core/_eventloop.py b/src/anyio/_core/_eventloop.py
index 6021ab99..f2364a3b 100644
--- a/src/anyio/_core/_eventloop.py
+++ b/src/anyio/_core/_eventloop.py
@@ -16,7 +16,7 @@
threadlocals = threading.local()
-def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args,
+def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object,
backend: str = 'asyncio', backend_options: Optional[Dict[str, Any]] = None) -> T_Retval:
"""
Run the given coroutine function in an asynchronous event loop.
@@ -120,7 +120,7 @@ def get_cancelled_exc_class() -> Type[BaseException]:
#
@contextmanager
-def claim_worker_thread(backend) -> Generator[Any, None, None]:
+def claim_worker_thread(backend: str) -> Generator[Any, None, None]:
module = sys.modules['anyio._backends._' + backend]
threadlocals.current_async_module = module
token = sniffio.current_async_library_cvar.set(backend)
@@ -131,7 +131,7 @@ def claim_worker_thread(backend) -> Generator[Any, None, None]:
del threadlocals.current_async_module
-def get_asynclib(asynclib_name: Optional[str] = None):
+def get_asynclib(asynclib_name: Optional[str] = None) -> Any:
if asynclib_name is None:
asynclib_name = sniffio.current_async_library()
diff --git a/src/anyio/_core/_exceptions.py b/src/anyio/_core/_exceptions.py
index 8025ef6f..52e59808 100644
--- a/src/anyio/_core/_exceptions.py
+++ b/src/anyio/_core/_exceptions.py
@@ -49,7 +49,7 @@ class ExceptionGroup(BaseException):
#: the sequence of exceptions raised together
exceptions: Sequence[BaseException]
- def __str__(self):
+ def __str__(self) -> str:
tracebacks = [''.join(format_exception(type(exc), exc, exc.__traceback__))
for exc in self.exceptions]
return f'{len(self.exceptions)} exceptions were raised in the task group:\n' \
@@ -67,7 +67,7 @@ class IncompleteRead(Exception):
connection is closed before the requested amount of bytes has been read.
"""
- def __init__(self):
+ def __init__(self) -> None:
super().__init__('The stream was closed before the read operation could be completed')
diff --git a/src/anyio/_core/_fileio.py b/src/anyio/_core/_fileio.py
index fe2b53d6..ebfd34e3 100644
--- a/src/anyio/_core/_fileio.py
+++ b/src/anyio/_core/_fileio.py
@@ -1,12 +1,14 @@
import os
from os import PathLike
-from typing import Callable, Optional, Union
+from typing import Any, AsyncIterator, Callable, Generic, Optional, TypeVar, Union
from .. import to_thread
from ..abc import AsyncResource
+T_Fp = TypeVar("T_Fp")
-class AsyncFile(AsyncResource):
+
+class AsyncFile(AsyncResource, Generic[T_Fp]):
"""
An asynchronous file object.
@@ -38,18 +40,18 @@ class AsyncFile(AsyncResource):
print(line)
"""
- def __init__(self, fp) -> None:
- self._fp = fp
+ def __init__(self, fp: T_Fp) -> None:
+ self._fp: Any = fp
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> object:
return getattr(self._fp, name)
@property
- def wrapped(self):
+ def wrapped(self) -> T_Fp:
"""The wrapped file object."""
return self._fp
- async def __aiter__(self):
+ async def __aiter__(self) -> AsyncIterator[bytes]:
while True:
line = await self.readline()
if line:
diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py
index e770861f..f58cdc3a 100644
--- a/src/anyio/_core/_sockets.py
+++ b/src/anyio/_core/_sockets.py
@@ -83,9 +83,11 @@ async def connect_tcp(
async def connect_tcp(
- remote_host, remote_port, *, local_host=None, tls=False, ssl_context=None,
- tls_standard_compatible=True, tls_hostname=None, happy_eyeballs_delay=0.25
-):
+ remote_host: IPAddressType, remote_port: int, *, local_host: Optional[IPAddressType] = None,
+ tls: bool = False, ssl_context: Optional[ssl.SSLContext] = None,
+ tls_standard_compatible: bool = True, tls_hostname: Optional[str] = None,
+ happy_eyeballs_delay: float = 0.25
+) -> Union[SocketStream, TLSStream]:
"""
Connect to a host using the TCP protocol.
@@ -119,7 +121,7 @@ async def connect_tcp(
# Placed here due to https://github.com/python/mypy/issues/7057
connected_stream: Optional[SocketStream] = None
- async def try_connect(remote_host: str, event: Event):
+ async def try_connect(remote_host: str, event: Event) -> None:
nonlocal connected_stream
try:
stream = await asynclib.connect_tcp(remote_host, remote_port, local_address)
@@ -184,7 +186,7 @@ async def try_connect(remote_host: str, event: Event):
if tls or tls_hostname or ssl_context:
try:
return await TLSStream.wrap(connected_stream, server_side=False,
- hostname=tls_hostname or remote_host,
+ hostname=tls_hostname or str(remote_host),
ssl_context=ssl_context,
standard_compatible=tls_standard_compatible)
except BaseException:
@@ -475,7 +477,9 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
# Private API
#
-def convert_ipv6_sockaddr(sockaddr):
+def convert_ipv6_sockaddr(
+ sockaddr: Union[Tuple[str, int, int, int], Tuple[str, int]]
+) -> Tuple[str, int]:
"""
Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format.
@@ -489,10 +493,11 @@ def convert_ipv6_sockaddr(sockaddr):
"""
# This is more complicated than it should be because of MyPy
if isinstance(sockaddr, tuple) and len(sockaddr) == 4:
- if sockaddr[3]:
- # Add scopeid to the address
- return sockaddr[0] + '%' + str(sockaddr[3]), sockaddr[1]
+ host, port, flowinfo, scope_id = cast(Tuple[str, int, int, int], sockaddr)
+ if scope_id:
+ # Add scope_id to the address
+ return f"{host}%{scope_id}", port
else:
- return sockaddr[:2]
+ return host, port
else:
- return sockaddr
+ return cast(Tuple[str, int], sockaddr)
diff --git a/src/anyio/_core/_streams.py b/src/anyio/_core/_streams.py
index 02eabd4d..4a003bea 100644
--- a/src/anyio/_core/_streams.py
+++ b/src/anyio/_core/_streams.py
@@ -1,5 +1,5 @@
import math
-from typing import Any, Tuple, Type, TypeVar, overload
+from typing import Optional, Tuple, Type, TypeVar, overload
from ..streams.memory import (
MemoryObjectReceiveStream, MemoryObjectSendStream, MemoryObjectStreamState)
@@ -17,11 +17,13 @@ def create_memory_object_stream(
@overload
def create_memory_object_stream(
max_buffer_size: float = 0
-) -> Tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]:
+) -> Tuple[MemoryObjectSendStream, MemoryObjectReceiveStream]:
...
-def create_memory_object_stream(max_buffer_size=0, item_type=None):
+def create_memory_object_stream(
+ max_buffer_size: float = 0, item_type: Optional[Type[T_Item]] = None
+) -> Tuple[MemoryObjectSendStream, MemoryObjectReceiveStream]:
"""
Create a memory object stream.
diff --git a/src/anyio/_core/_subprocesses.py b/src/anyio/_core/_subprocesses.py
index f7e0232c..1daf6c10 100644
--- a/src/anyio/_core/_subprocesses.py
+++ b/src/anyio/_core/_subprocesses.py
@@ -1,6 +1,6 @@
from os import PathLike
from subprocess import DEVNULL, PIPE, CalledProcessError, CompletedProcess
-from typing import Mapping, Optional, Sequence, Union, cast
+from typing import AsyncIterable, List, Mapping, Optional, Sequence, Union, cast
from ..abc import Process
from ._eventloop import get_asynclib
@@ -32,13 +32,13 @@ async def run_process(command: Union[str, Sequence[str]], *, input: Optional[byt
nonzero return code
"""
- async def drain_stream(stream, index):
+ async def drain_stream(stream: AsyncIterable[bytes], index: int) -> None:
chunks = [chunk async for chunk in stream]
stream_contents[index] = b''.join(chunks)
async with await open_process(command, stdin=PIPE if input else DEVNULL, stdout=stdout,
stderr=stderr, cwd=cwd, env=env) as process:
- stream_contents = [None, None]
+ stream_contents: List[Optional[bytes]] = [None, None]
try:
async with create_task_group() as tg:
if process.stdout:
diff --git a/src/anyio/_core/_synchronization.py b/src/anyio/_core/_synchronization.py
index f7070e98..6c691770 100644
--- a/src/anyio/_core/_synchronization.py
+++ b/src/anyio/_core/_synchronization.py
@@ -73,7 +73,7 @@ class SemaphoreStatistics:
class Event:
- def __new__(cls):
+ def __new__(cls) -> 'Event':
return get_asynclib().Event()
def set(self) -> DeprecatedAwaitable:
@@ -84,7 +84,7 @@ def is_set(self) -> bool:
"""Return ``True`` if the flag is set, ``False`` if not."""
raise NotImplementedError
- async def wait(self) -> bool:
+ async def wait(self) -> None:
"""
Wait until the flag has been set.
@@ -101,10 +101,10 @@ def statistics(self) -> EventStatistics:
class Lock:
_owner_task: Optional[TaskInfo] = None
- def __init__(self):
+ def __init__(self) -> None:
self._waiters: Deque[Tuple[TaskInfo, Event]] = deque()
- async def __aenter__(self):
+ async def __aenter__(self) -> None:
await self.acquire()
async def __aexit__(self, exc_type: Optional[Type[BaseException]],
@@ -183,7 +183,7 @@ def __init__(self, lock: Optional[Lock] = None):
self._lock = lock or Lock()
self._waiters: Deque[Event] = deque()
- async def __aenter__(self):
+ async def __aenter__(self) -> None:
await self.acquire()
async def __aexit__(self, exc_type: Optional[Type[BaseException]],
@@ -351,10 +351,10 @@ def statistics(self) -> SemaphoreStatistics:
class CapacityLimiter:
- def __new__(cls, total_tokens: float):
+ def __new__(cls, total_tokens: float) -> 'CapacityLimiter':
return get_asynclib().CapacityLimiter(total_tokens)
- async def __aenter__(self):
+ async def __aenter__(self) -> None:
raise NotImplementedError
async def __aexit__(self, exc_type: Optional[Type[BaseException]],
@@ -380,7 +380,7 @@ def total_tokens(self) -> float:
def total_tokens(self, value: float) -> None:
raise NotImplementedError
- async def set_total_tokens(self, value) -> None:
+ async def set_total_tokens(self, value: float) -> None:
warn('CapacityLimiter.set_total_tokens has been deprecated. Set the value of the'
'"total_tokens" attribute directly.', DeprecationWarning)
self.total_tokens = value
@@ -404,7 +404,7 @@ def acquire_nowait(self) -> DeprecatedAwaitable:
"""
raise NotImplementedError
- def acquire_on_behalf_of_nowait(self, borrower) -> DeprecatedAwaitable:
+ def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable:
"""
Acquire a token without waiting for one to become available.
@@ -421,7 +421,7 @@ async def acquire(self) -> None:
"""
raise NotImplementedError
- async def acquire_on_behalf_of(self, borrower) -> None:
+ async def acquire_on_behalf_of(self, borrower: object) -> None:
"""
Acquire a token, waiting if necessary for one to become available.
@@ -438,7 +438,7 @@ def release(self) -> None:
"""
raise NotImplementedError
- def release_on_behalf_of(self, borrower) -> None:
+ def release_on_behalf_of(self, borrower: object) -> None:
"""
Release the token held by the given borrower.
@@ -541,11 +541,14 @@ def __init__(self, action: str):
self.action = action
self._guarded = False
- def __enter__(self):
+ def __enter__(self) -> None:
if self._guarded:
raise BusyResourceError(self.action)
self._guarded = True
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(self, exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType]) -> Optional[bool]:
self._guarded = False
+ return None
diff --git a/src/anyio/_core/_tasks.py b/src/anyio/_core/_tasks.py
index ac4e41da..8bbad974 100644
--- a/src/anyio/_core/_tasks.py
+++ b/src/anyio/_core/_tasks.py
@@ -9,7 +9,7 @@
class _IgnoredTaskStatus(TaskStatus):
- def started(self, value=None) -> None:
+ def started(self, value: object = None) -> None:
pass
@@ -24,7 +24,7 @@ class CancelScope(DeprecatedAsyncContextManager['CancelScope']):
:param shield: ``True`` to shield the cancel scope from external cancellation
"""
- def __new__(cls, *, deadline: float = math.inf, shield: bool = False):
+ def __new__(cls, *, deadline: float = math.inf, shield: bool = False) -> 'CancelScope':
return get_asynclib().CancelScope(shield=shield, deadline=deadline)
def cancel(self) -> DeprecatedAwaitable:
@@ -64,7 +64,7 @@ def shield(self) -> bool:
def shield(self, value: bool) -> None:
raise NotImplementedError
- def __enter__(self):
+ def __enter__(self) -> 'CancelScope':
raise NotImplementedError
def __exit__(self, exc_type: Optional[Type[BaseException]],
@@ -92,10 +92,12 @@ class FailAfterContextManager(DeprecatedAsyncContextManager):
def __init__(self, cancel_scope: CancelScope):
self._cancel_scope = cancel_scope
- def __enter__(self):
+ def __enter__(self) -> CancelScope:
return self._cancel_scope.__enter__()
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(self, exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType]) -> Optional[bool]:
retval = self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
if self._cancel_scope.cancel_called:
raise TimeoutError
diff --git a/src/anyio/_core/_testing.py b/src/anyio/_core/_testing.py
index 269923da..c48bd45e 100644
--- a/src/anyio/_core/_testing.py
+++ b/src/anyio/_core/_testing.py
@@ -1,10 +1,10 @@
-from typing import Coroutine, Optional
+from typing import Coroutine, Generator, Optional
-from ._compat import DeprecatedAwaitable, DeprecatedAwaitableList
+from ._compat import DeprecatedAwaitableList, _warn_deprecation
from ._eventloop import get_asynclib
-class TaskInfo(DeprecatedAwaitable):
+class TaskInfo:
"""
Represents an asynchronous task.
@@ -15,31 +15,38 @@ class TaskInfo(DeprecatedAwaitable):
:ivar ~collections.abc.Coroutine coro: the coroutine object of the task
"""
- __slots__ = 'id', 'parent_id', 'name', 'coro'
+ __slots__ = '_name', 'id', 'parent_id', 'name', 'coro'
def __init__(self, id: int, parent_id: Optional[int], name: Optional[str], coro: Coroutine):
- super().__init__(get_current_task)
+ func = get_current_task
+ self._name = f'{func.__module__}.{func.__qualname__}'
self.id = id
self.parent_id = parent_id
self.name = name
self.coro = coro
- def __await__(self):
- yield from super().__await__()
- return self
-
- def __eq__(self, other):
+ def __eq__(self, other: object) -> bool:
if isinstance(other, TaskInfo):
return self.id == other.id
return NotImplemented
- def __hash__(self):
+ def __hash__(self) -> int:
return hash(self.id)
- def __repr__(self):
+ def __repr__(self) -> str:
return f'{self.__class__.__name__}(id={self.id!r}, name={self.name!r})'
+ def __await__(self) -> Generator[None, None, "TaskInfo"]:
+ _warn_deprecation(self)
+ if False:
+ yield
+
+ return self
+
+ def _unwrap(self) -> 'TaskInfo':
+ return self
+
def get_current_task() -> TaskInfo:
"""
diff --git a/src/anyio/_core/_typedattr.py b/src/anyio/_core/_typedattr.py
index 162426a1..e8377961 100644
--- a/src/anyio/_core/_typedattr.py
+++ b/src/anyio/_core/_typedattr.py
@@ -1,5 +1,5 @@
import sys
-from typing import Callable, Mapping, TypeVar, Union, overload
+from typing import Any, Callable, Mapping, TypeVar, Union, overload
from ._exceptions import TypedAttributeLookupError
@@ -13,7 +13,7 @@
undefined = object()
-def typed_attribute():
+def typed_attribute() -> Any:
"""Return a unique object, used to mark typed attributes."""
return object()
@@ -58,7 +58,7 @@ def extra(self, attribute: T_Attr, default: T_Default) -> Union[T_Attr, T_Defaul
...
@final
- def extra(self, attribute, default=undefined):
+ def extra(self, attribute: Any, default: object = undefined) -> object:
"""
extra(attribute, default=undefined)
diff --git a/src/anyio/abc/_resources.py b/src/anyio/abc/_resources.py
index d6ed168f..4594e6e9 100644
--- a/src/anyio/abc/_resources.py
+++ b/src/anyio/abc/_resources.py
@@ -1,6 +1,8 @@
from abc import ABCMeta, abstractmethod
from types import TracebackType
-from typing import Optional, Type
+from typing import Optional, Type, TypeVar
+
+T = TypeVar("T")
class AsyncResource(metaclass=ABCMeta):
@@ -11,7 +13,7 @@ class AsyncResource(metaclass=ABCMeta):
:meth:`aclose` on exit.
"""
- async def __aenter__(self):
+ async def __aenter__(self: T) -> T:
return self
async def __aexit__(self, exc_type: Optional[Type[BaseException]],
diff --git a/src/anyio/abc/_sockets.py b/src/anyio/abc/_sockets.py
index ef23cdac..2e460fbb 100644
--- a/src/anyio/abc/_sockets.py
+++ b/src/anyio/abc/_sockets.py
@@ -2,8 +2,10 @@
from io import IOBase
from ipaddress import IPv4Address, IPv6Address
from socket import AddressFamily, SocketType
+from types import TracebackType
from typing import (
- Any, AsyncContextManager, Callable, Collection, List, Optional, Tuple, TypeVar, Union)
+ Any, AsyncContextManager, Callable, Collection, List, Mapping, Optional, Tuple, Type, TypeVar,
+ Union)
from .._core._typedattr import TypedAttributeProvider, TypedAttributeSet, typed_attribute
from ._streams import ByteStream, Listener, T_Stream, UnreliableObjectStream
@@ -17,11 +19,13 @@
class _NullAsyncContextManager:
- async def __aenter__(self):
+ async def __aenter__(self) -> None:
pass
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- pass
+ async def __aexit__(self, exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType]) -> Optional[bool]:
+ return None
class SocketAttribute(TypedAttributeSet):
@@ -41,7 +45,7 @@ class SocketAttribute(TypedAttributeSet):
class _SocketProvider(TypedAttributeProvider):
@property
- def extra_attributes(self):
+ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
from .._core._sockets import convert_ipv6_sockaddr as convert
attributes = {
@@ -50,7 +54,7 @@ def extra_attributes(self):
SocketAttribute.raw_socket: lambda: self._raw_socket
}
try:
- peername = convert(self._raw_socket.getpeername())
+ peername: Optional[Tuple[str, int]] = convert(self._raw_socket.getpeername())
except OSError:
peername = None
@@ -62,7 +66,8 @@ def extra_attributes(self):
if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6):
attributes[SocketAttribute.local_port] = lambda: self._raw_socket.getsockname()[1]
if peername is not None:
- attributes[SocketAttribute.remote_port] = lambda: peername[1]
+ remote_port = peername[1]
+ attributes[SocketAttribute.remote_port] = lambda: remote_port
return attributes
diff --git a/src/anyio/abc/_streams.py b/src/anyio/abc/_streams.py
index 6747f543..3f5e783d 100644
--- a/src/anyio/abc/_streams.py
+++ b/src/anyio/abc/_streams.py
@@ -21,7 +21,7 @@ class UnreliableObjectReceiveStream(Generic[T_Item], AsyncResource, TypedAttribu
parameter.
"""
- def __aiter__(self):
+ def __aiter__(self) -> "UnreliableObjectReceiveStream[T_Item]":
return self
async def __anext__(self) -> T_Item:
diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py
index 3801d566..afa2d983 100644
--- a/src/anyio/abc/_tasks.py
+++ b/src/anyio/abc/_tasks.py
@@ -12,7 +12,7 @@
class TaskStatus(metaclass=ABCMeta):
@abstractmethod
- def started(self, value=None) -> None:
+ def started(self, value: object = None) -> None:
"""
Signal that the task has started.
@@ -30,7 +30,8 @@ class TaskGroup(metaclass=ABCMeta):
cancel_scope: 'CancelScope'
- async def spawn(self, func: Callable[..., Coroutine], *args, name=None) -> None:
+ async def spawn(self, func: Callable[..., Coroutine],
+ *args: object, name: object = None) -> None:
"""
Start a new task in this task group.
@@ -48,7 +49,8 @@ async def spawn(self, func: Callable[..., Coroutine], *args, name=None) -> None:
self.start_soon(func, *args, name=name)
@abstractmethod
- def start_soon(self, func: Callable[..., Coroutine], *args, name=None) -> None:
+ def start_soon(self, func: Callable[..., Coroutine],
+ *args: object, name: object = None) -> None:
"""
Start a new task in this task group.
@@ -60,7 +62,8 @@ def start_soon(self, func: Callable[..., Coroutine], *args, name=None) -> None:
"""
@abstractmethod
- async def start(self, func: Callable[..., Coroutine], *args, name=None):
+ async def start(self, func: Callable[..., Coroutine],
+ *args: object, name: object = None) -> object:
"""
Start a new task and wait until it signals for readiness.
diff --git a/src/anyio/abc/_testing.py b/src/anyio/abc/_testing.py
index ace2d9b7..68aeb00e 100644
--- a/src/anyio/abc/_testing.py
+++ b/src/anyio/abc/_testing.py
@@ -1,5 +1,8 @@
+import types
from abc import ABCMeta, abstractmethod
-from typing import Any, Awaitable, Callable, Dict
+from typing import Any, Awaitable, Callable, Dict, Optional, Type, TypeVar
+
+_T = TypeVar("_T")
class TestRunner(metaclass=ABCMeta):
@@ -11,15 +14,19 @@ class TestRunner(metaclass=ABCMeta):
def __enter__(self) -> 'TestRunner':
return self
- def __exit__(self, exc_type, exc_val, exc_tb):
+ def __exit__(self, exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[types.TracebackType]) -> Optional[bool]:
self.close()
+ return None
@abstractmethod
def close(self) -> None:
"""Close the event loop."""
@abstractmethod
- def call(self, func: Callable[..., Awaitable], *args: tuple, **kwargs: Dict[str, Any]):
+ def call(self, func: Callable[..., Awaitable[_T]],
+ *args: tuple, **kwargs: Dict[str, Any]) -> _T:
"""
Call the given function within the backend's event loop.
diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py
index 90d06cd6..29f7d03d 100644
--- a/src/anyio/from_thread.py
+++ b/src/anyio/from_thread.py
@@ -5,7 +5,7 @@
from types import TracebackType
from typing import (
Any, AsyncContextManager, Callable, ContextManager, Coroutine, Dict, Generator, Iterable,
- Optional, Tuple, Type, TypeVar, cast, overload)
+ Optional, Tuple, Type, TypeVar, Union, cast, overload)
from warnings import warn
from ._core import _eventloop
@@ -17,7 +17,7 @@
T_co = TypeVar('T_co')
-def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval:
+def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval:
"""
Call a coroutine function from a worker thread.
@@ -34,13 +34,14 @@ def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval:
return asynclib.run_async_from_thread(func, *args)
-def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval:
+def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]],
+ *args: object) -> T_Retval:
warn('run_async_from_thread() has been deprecated, use anyio.from_thread.run() instead',
DeprecationWarning)
return run(func, *args)
-def run_sync(func: Callable[..., T_Retval], *args) -> T_Retval:
+def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval:
"""
Call a function in the event loop thread from a worker thread.
@@ -57,7 +58,7 @@ def run_sync(func: Callable[..., T_Retval], *args) -> T_Retval:
return asynclib.run_sync_from_thread(func, *args)
-def run_sync_from_thread(func: Callable[..., T_Retval], *args) -> T_Retval:
+def run_sync_from_thread(func: Callable[..., T_Retval], *args: object) -> T_Retval:
warn('run_sync_from_thread() has been deprecated, use anyio.from_thread.run_sync() instead',
DeprecationWarning)
return run_sync(func, *args)
@@ -74,7 +75,7 @@ def __init__(self, async_cm: AsyncContextManager[T_co], portal: 'BlockingPortal'
self._async_cm = async_cm
self._portal = portal
- async def run_async_cm(self):
+ async def run_async_cm(self) -> Optional[bool]:
try:
self._exit_event = Event()
value = await self._async_cm.__aenter__()
@@ -105,18 +106,18 @@ class _BlockingPortalTaskStatus(TaskStatus):
def __init__(self, future: Future):
self._future = future
- def started(self, value=None) -> None:
+ def started(self, value: object = None) -> None:
self._future.set_result(value)
class BlockingPortal:
"""An object that lets external threads run code in an asynchronous event loop."""
- def __new__(cls):
+ def __new__(cls) -> 'BlockingPortal':
return get_asynclib().BlockingPortal()
- def __init__(self):
- self._event_loop_thread_id = threading.get_ident()
+ def __init__(self) -> None:
+ self._event_loop_thread_id: Optional[int] = threading.get_ident()
self._stop_event = Event()
self._task_group = create_task_group()
self._cancelled_exc_class = get_cancelled_exc_class()
@@ -125,7 +126,9 @@ async def __aenter__(self) -> 'BlockingPortal':
await self._task_group.__aenter__()
return self
- async def __aexit__(self, exc_type, exc_val, exc_tb):
+ async def __aexit__(self, exc_type: Optional[Type[BaseException]],
+ exc_val: Optional[BaseException],
+ exc_tb: Optional[TracebackType]) -> Optional[bool]:
await self.stop()
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
@@ -157,7 +160,7 @@ async def stop(self, cancel_remaining: bool = False) -> None:
async def _call_func(self, func: Callable, args: tuple, kwargs: Dict[str, Any],
future: Future) -> None:
- def callback(f: Future):
+ def callback(f: Future) -> None:
if f.cancelled():
self.call(scope.cancel)
@@ -184,10 +187,10 @@ def callback(f: Future):
if not future.cancelled():
future.set_result(retval)
finally:
- scope = None
+ scope = None # type: ignore[assignment]
def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any],
- name, future: Future) -> None:
+ name: object, future: Future) -> None:
"""
Spawn a new task using the given callable.
@@ -204,14 +207,15 @@ def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str,
raise NotImplementedError
@overload
- def call(self, func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval:
+ def call(self, func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval:
...
@overload
- def call(self, func: Callable[..., T_Retval], *args) -> T_Retval:
+ def call(self, func: Callable[..., T_Retval], *args: object) -> T_Retval:
...
- def call(self, func, *args):
+ def call(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]],
+ *args: object) -> T_Retval:
"""
Call the given function in the event loop thread.
@@ -224,7 +228,8 @@ def call(self, func, *args):
"""
return self.start_task_soon(func, *args).result()
- def spawn_task(self, func: Callable[..., Coroutine], *args, name=None) -> Future:
+ def spawn_task(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]],
+ *args: object, name: object = None) -> "Future[T_Retval]":
"""
Start a task in the portal's task group.
@@ -245,7 +250,8 @@ def spawn_task(self, func: Callable[..., Coroutine], *args, name=None) -> Future
warn('spawn_task() is deprecated -- use start_task_soon() instead', DeprecationWarning)
return self.start_task_soon(func, *args, name=name)
- def start_task_soon(self, func: Callable[..., Coroutine], *args, name=None) -> Future:
+ def start_task_soon(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]],
+ *args: object, name: object = None) -> "Future[T_Retval]":
"""
Start a task in the portal's task group.
@@ -268,7 +274,8 @@ def start_task_soon(self, func: Callable[..., Coroutine], *args, name=None) -> F
self._spawn_task_from_thread(func, args, {}, name, f)
return f
- def start_task(self, func: Callable[..., Coroutine], *args, name=None) -> Tuple[Future, Any]:
+ def start_task(self, func: Callable[..., Coroutine], *args: object,
+ name: object = None) -> Tuple[Future, Any]:
"""
Start a task in the portal's task group and wait until it signals for readiness.
@@ -350,7 +357,7 @@ def start_blocking_portal(
Usage as a context manager is now required.
"""
- async def run_portal():
+ async def run_portal() -> None:
async with BlockingPortal() as portal_:
if future.set_running_or_notify_cancel():
future.set_result(portal_)
diff --git a/src/anyio/lowlevel.py b/src/anyio/lowlevel.py
index 8a7d53fb..471b7e6b 100644
--- a/src/anyio/lowlevel.py
+++ b/src/anyio/lowlevel.py
@@ -1,8 +1,16 @@
-from typing import Any, Dict, Generic, Set, TypeVar, Union, cast
+import enum
+import sys
+from dataclasses import dataclass
+from typing import Any, Dict, Generic, Set, TypeVar, Union, overload
from weakref import WeakKeyDictionary
from ._core._eventloop import get_asynclib
+if sys.version_info >= (3, 8):
+ from typing import Literal
+else:
+ from typing_extensions import Literal
+
T = TypeVar('T')
D = TypeVar('D')
@@ -22,7 +30,7 @@ async def checkpoint() -> None:
await get_asynclib().checkpoint()
-async def checkpoint_if_cancelled():
+async def checkpoint_if_cancelled() -> None:
"""
Enter a checkpoint if the enclosing cancel scope has been cancelled.
@@ -58,25 +66,22 @@ def current_token() -> object:
_token_wrappers: Dict[Any, '_TokenWrapper'] = {}
+@dataclass(frozen=True)
class _TokenWrapper:
__slots__ = '_token', '__weakref__'
+ _token: object
- def __init__(self, token):
- self._token = token
-
- def __eq__(self, other):
- return self._token is other._token
- def __hash__(self):
- return hash(self._token)
+class _NoValueSet(enum.Enum):
+ NO_VALUE_SET = enum.auto()
-class RunvarToken:
+class RunvarToken(Generic[T]):
__slots__ = '_var', '_value', '_redeemed'
- def __init__(self, var: 'RunVar', value):
+ def __init__(self, var: 'RunVar', value: Union[T, Literal[_NoValueSet.NO_VALUE_SET]]):
self._var = var
- self._value = value
+ self._value: Union[T, Literal[_NoValueSet.NO_VALUE_SET]] = value
self._redeemed = False
@@ -84,11 +89,12 @@ class RunVar(Generic[T]):
"""Like a :class:`~contextvars.ContextVar`, expect scoped to the running event loop."""
__slots__ = '_name', '_default'
- NO_VALUE_SET = object()
+ NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET
_token_wrappers: Set[_TokenWrapper] = set()
- def __init__(self, name: str, default: Union[T, object] = NO_VALUE_SET):
+ def __init__(self, name: str,
+ default: Union[T, Literal[_NoValueSet.NO_VALUE_SET]] = NO_VALUE_SET):
self._name = name
self._default = default
@@ -108,31 +114,39 @@ def _current_vars(self) -> Dict[str, T]:
run_vars = _run_vars[token] = {}
return run_vars
- def get(self, default: Union[T, object] = NO_VALUE_SET) -> T:
+ @overload
+ def get(self, default: D) -> Union[T, D]: ...
+
+ @overload
+ def get(self) -> T: ...
+
+ def get(
+ self, default: Union[D, Literal[_NoValueSet.NO_VALUE_SET]] = NO_VALUE_SET
+ ) -> Union[T, D]:
try:
return self._current_vars[self._name]
except KeyError:
if default is not RunVar.NO_VALUE_SET:
- return cast(T, default)
+ return default
elif self._default is not RunVar.NO_VALUE_SET:
- return cast(T, self._default)
+ return self._default
raise LookupError(f'Run variable "{self._name}" has no value and no default set')
- def set(self, value: T) -> RunvarToken:
+ def set(self, value: T) -> RunvarToken[T]:
current_vars = self._current_vars
token = RunvarToken(self, current_vars.get(self._name, RunVar.NO_VALUE_SET))
current_vars[self._name] = value
return token
- def reset(self, token: RunvarToken) -> None:
+ def reset(self, token: RunvarToken[T]) -> None:
if token._var is not self:
raise ValueError('This token does not belong to this RunVar')
if token._redeemed:
raise ValueError('This token has already been used')
- if token._value is RunVar.NO_VALUE_SET:
+ if token._value is _NoValueSet.NO_VALUE_SET:
try:
del self._current_vars[self._name]
except KeyError:
@@ -142,5 +156,5 @@ def reset(self, token: RunvarToken) -> None:
token._redeemed = True
- def __repr__(self):
+ def __repr__(self) -> str:
return f''
diff --git a/src/anyio/pytest_plugin.py b/src/anyio/pytest_plugin.py
index 1db51764..0e99a456 100644
--- a/src/anyio/pytest_plugin.py
+++ b/src/anyio/pytest_plugin.py
@@ -1,7 +1,7 @@
import sys
from contextlib import contextmanager
from inspect import iscoroutinefunction
-from typing import Any, Dict, Iterator, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Tuple, cast
import pytest
import sniffio
@@ -14,10 +14,13 @@
else:
from async_generator import isasyncgenfunction
+if TYPE_CHECKING:
+ from _pytest.config import Config
+
_current_runner: Optional[TestRunner] = None
-def extract_backend_and_options(backend) -> Tuple[str, Dict[str, Any]]:
+def extract_backend_and_options(backend: object) -> Tuple[str, Dict[str, Any]]:
if isinstance(backend, str):
return backend, {}
elif isinstance(backend, tuple) and len(backend) == 2:
@@ -51,13 +54,13 @@ def get_runner(backend_name: str, backend_options: Dict[str, Any]) -> Iterator[T
sniffio.current_async_library_cvar.reset(token)
-def pytest_configure(config):
+def pytest_configure(config: "Config") -> None:
config.addinivalue_line('markers', 'anyio: mark the (coroutine function) test to be run '
'asynchronously via anyio.')
-def pytest_fixture_setup(fixturedef, request):
- def wrapper(*args, anyio_backend, **kwargs):
+def pytest_fixture_setup(fixturedef: Any, request: Any) -> None:
+ def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def]
backend_name, backend_options = extract_backend_and_options(anyio_backend)
if has_backend_arg:
kwargs['anyio_backend'] = anyio_backend
@@ -94,7 +97,7 @@ def wrapper(*args, anyio_backend, **kwargs):
@pytest.hookimpl(tryfirst=True)
-def pytest_pycollect_makeitem(collector, name, obj):
+def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None:
if collector.istestfunction(obj, name):
inner_func = obj.hypothesis.inner_test if hasattr(obj, 'hypothesis') else obj
if iscoroutinefunction(inner_func):
@@ -105,8 +108,8 @@ def pytest_pycollect_makeitem(collector, name, obj):
@pytest.hookimpl(tryfirst=True)
-def pytest_pyfunc_call(pyfuncitem):
- def run_with_hypothesis(**kwargs):
+def pytest_pyfunc_call(pyfuncitem: Any) -> Optional[bool]:
+ def run_with_hypothesis(**kwargs: Any) -> None:
with get_runner(backend_name, backend_options) as runner:
runner.call(original_func, **kwargs)
@@ -131,14 +134,16 @@ def run_with_hypothesis(**kwargs):
return True
+ return None
+
@pytest.fixture(params=get_all_backends())
-def anyio_backend(request):
+def anyio_backend(request: Any) -> Any:
return request.param
@pytest.fixture
-def anyio_backend_name(anyio_backend) -> str:
+def anyio_backend_name(anyio_backend: Any) -> str:
if isinstance(anyio_backend, str):
return anyio_backend
else:
@@ -146,7 +151,7 @@ def anyio_backend_name(anyio_backend) -> str:
@pytest.fixture
-def anyio_backend_options(anyio_backend) -> Dict[str, Any]:
+def anyio_backend_options(anyio_backend: Any) -> Dict[str, Any]:
if isinstance(anyio_backend, str):
return {}
else:
diff --git a/src/anyio/streams/buffered.py b/src/anyio/streams/buffered.py
index bed83246..53ef5176 100644
--- a/src/anyio/streams/buffered.py
+++ b/src/anyio/streams/buffered.py
@@ -1,4 +1,5 @@
from dataclasses import dataclass, field
+from typing import Any, Callable, Mapping
from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead
from ..abc import AnyByteReceiveStream, ByteReceiveStream
@@ -25,8 +26,8 @@ def buffer(self) -> bytes:
return bytes(self._buffer)
@property
- def extra_attributes(self):
- return self.receive_stream.extra_attributes
+ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
+ return self.receive_stream.extra_attributes # type: ignore[return-value]
async def receive(self, max_bytes: int = 65536) -> bytes:
if self._closed:
diff --git a/src/anyio/streams/file.py b/src/anyio/streams/file.py
index 12442ad3..d3e22bae 100644
--- a/src/anyio/streams/file.py
+++ b/src/anyio/streams/file.py
@@ -1,6 +1,6 @@
from io import SEEK_SET, UnsupportedOperation
from pathlib import Path
-from typing import BinaryIO, Union, cast
+from typing import Any, BinaryIO, Callable, Dict, Mapping, Union, cast
from .. import (
BrokenResourceError, ClosedResourceError, EndOfStream, TypedAttributeSet, to_thread,
@@ -25,8 +25,8 @@ async def aclose(self) -> None:
await to_thread.run_sync(self._file.close)
@property
- def extra_attributes(self):
- attributes = {
+ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
+ attributes: Dict[Any, Callable[[], Any]] = {
FileStreamAttribute.file: lambda: self._file,
}
diff --git a/src/anyio/streams/memory.py b/src/anyio/streams/memory.py
index 0d04ab46..47ff1f5d 100644
--- a/src/anyio/streams/memory.py
+++ b/src/anyio/streams/memory.py
@@ -41,7 +41,7 @@ class MemoryObjectReceiveStream(Generic[T_Item], ObjectReceiveStream[T_Item]):
_state: MemoryObjectStreamState[T_Item]
_closed: bool = field(init=False, default=False)
- def __post_init__(self):
+ def __post_init__(self) -> None:
self._state.open_receive_channels += 1
def receive_nowait(self) -> T_Item:
@@ -135,7 +135,7 @@ class MemoryObjectSendStream(Generic[T_Item], ObjectSendStream[T_Item]):
_state: MemoryObjectStreamState[T_Item]
_closed: bool = field(init=False, default=False)
- def __post_init__(self):
+ def __post_init__(self) -> None:
self._state.open_send_channels += 1
def send_nowait(self, item: T_Item) -> DeprecatedAwaitable:
diff --git a/src/anyio/streams/stapled.py b/src/anyio/streams/stapled.py
index 372cac23..0d5e7fb2 100644
--- a/src/anyio/streams/stapled.py
+++ b/src/anyio/streams/stapled.py
@@ -1,5 +1,5 @@
from dataclasses import dataclass
-from typing import Any, Callable, Generic, Optional, Sequence, TypeVar
+from typing import Any, Callable, Generic, List, Mapping, Optional, Sequence, TypeVar
from ..abc import (
ByteReceiveStream, ByteSendStream, ByteStream, Listener, ObjectReceiveStream, ObjectSendStream,
@@ -38,8 +38,8 @@ async def aclose(self) -> None:
await self.receive_stream.aclose()
@property
- def extra_attributes(self):
- return dict(**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes)
+ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
+ return {**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes}
@dataclass(eq=False)
@@ -71,8 +71,8 @@ async def aclose(self) -> None:
await self.receive_stream.aclose()
@property
- def extra_attributes(self):
- return dict(**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes)
+ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
+ return {**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes}
@dataclass(eq=False)
@@ -92,12 +92,12 @@ class MultiListener(Generic[T_Stream], Listener[T_Stream]):
listeners: Sequence[Listener[T_Stream]]
- def __post_init__(self):
- listeners = []
+ def __post_init__(self) -> None:
+ listeners: List[Listener[T_Stream]] = []
for listener in self.listeners:
if isinstance(listener, MultiListener):
listeners.extend(listener.listeners)
- del listener.listeners[:]
+ del listener.listeners[:] # type: ignore[attr-defined]
else:
listeners.append(listener)
@@ -116,8 +116,8 @@ async def aclose(self) -> None:
await listener.aclose()
@property
- def extra_attributes(self):
- attributes = {}
+ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
+ attributes: dict = {}
for listener in self.listeners:
attributes.update(listener.extra_attributes)
diff --git a/src/anyio/streams/text.py b/src/anyio/streams/text.py
index 351f4a45..0cdfe227 100644
--- a/src/anyio/streams/text.py
+++ b/src/anyio/streams/text.py
@@ -1,6 +1,6 @@
import codecs
from dataclasses import InitVar, dataclass, field
-from typing import Callable, Tuple
+from typing import Any, Callable, Mapping, Tuple
from ..abc import (
AnyByteReceiveStream, AnyByteSendStream, AnyByteStream, ObjectReceiveStream, ObjectSendStream,
@@ -29,7 +29,7 @@ class TextReceiveStream(ObjectReceiveStream[str]):
errors: InitVar[str] = 'strict'
_decoder: codecs.IncrementalDecoder = field(init=False)
- def __post_init__(self, encoding, errors):
+ def __post_init__(self, encoding: str, errors: str) -> None:
decoder_class = codecs.getincrementaldecoder(encoding)
self._decoder = decoder_class(errors=errors)
@@ -45,8 +45,8 @@ async def aclose(self) -> None:
self._decoder.reset()
@property
- def extra_attributes(self):
- return self.transport_stream.extra_attributes
+ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
+ return self.transport_stream.extra_attributes # type: ignore[return-value]
@dataclass(eq=False)
@@ -68,8 +68,8 @@ class TextSendStream(ObjectSendStream[str]):
errors: str = 'strict'
_encoder: Callable[..., Tuple[bytes, int]] = field(init=False)
- def __post_init__(self, encoding):
- self._encoder = codecs.getencoder(encoding)
+ def __post_init__(self, encoding: str) -> None:
+ self._encoder = codecs.getencoder(encoding) # type: ignore[assignment]
async def send(self, item: str) -> None:
encoded = self._encoder(item, self.errors)[0]
@@ -79,8 +79,8 @@ async def aclose(self) -> None:
await self.transport_stream.aclose()
@property
- def extra_attributes(self):
- return self.transport_stream.extra_attributes
+ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
+ return self.transport_stream.extra_attributes # type: ignore[return-value]
@dataclass(eq=False)
@@ -107,7 +107,7 @@ class TextStream(ObjectStream[str]):
_receive_stream: TextReceiveStream = field(init=False)
_send_stream: TextSendStream = field(init=False)
- def __post_init__(self, encoding, errors):
+ def __post_init__(self, encoding: str, errors: str) -> None:
self._receive_stream = TextReceiveStream(self.transport_stream, encoding=encoding,
errors=errors)
self._send_stream = TextSendStream(self.transport_stream, encoding=encoding, errors=errors)
@@ -126,5 +126,5 @@ async def aclose(self) -> None:
await self._receive_stream.aclose()
@property
- def extra_attributes(self):
- return dict(**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes)
+ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
+ return {**self._send_stream.extra_attributes, **self._receive_stream.extra_attributes}
diff --git a/src/anyio/streams/tls.py b/src/anyio/streams/tls.py
index c42254a4..1d84c05c 100644
--- a/src/anyio/streams/tls.py
+++ b/src/anyio/streams/tls.py
@@ -3,7 +3,7 @@
import ssl
from dataclasses import dataclass
from functools import wraps
-from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
+from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, TypeVar, Union
from .. import BrokenResourceError, EndOfStream, aclose_forcefully, get_cancelled_exc_class
from .._core._typedattr import TypedAttributeSet, typed_attribute
@@ -94,7 +94,9 @@ async def wrap(cls, transport_stream: AnyByteStream, *, server_side: Optional[bo
await wrapper._call_sslobject_method(ssl_object.do_handshake)
return wrapper
- async def _call_sslobject_method(self, func: Callable[..., T_Retval], *args) -> T_Retval:
+ async def _call_sslobject_method(
+ self, func: Callable[..., T_Retval], *args: object
+ ) -> T_Retval:
while True:
try:
result = func(*args)
@@ -168,7 +170,7 @@ async def send_eof(self) -> None:
raise NotImplementedError('send_eof() has not yet been implemented for TLS streams')
@property
- def extra_attributes(self):
+ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return {
**self.transport_stream.extra_attributes,
TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol,
@@ -236,7 +238,7 @@ async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> N
async def serve(self, handler: Callable[[TLSStream], Any],
task_group: Optional[TaskGroup] = None) -> None:
@wraps(handler)
- async def handler_wrapper(stream: AnyByteStream):
+ async def handler_wrapper(stream: AnyByteStream) -> None:
from .. import fail_after
try:
with fail_after(self.handshake_timeout):
@@ -254,7 +256,7 @@ async def aclose(self) -> None:
await self.listener.aclose()
@property
- def extra_attributes(self):
+ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return {
TLSAttribute.standard_compatible: lambda: self.standard_compatible,
}
diff --git a/src/anyio/to_process.py b/src/anyio/to_process.py
index 8c18cd79..5675e9eb 100644
--- a/src/anyio/to_process.py
+++ b/src/anyio/to_process.py
@@ -26,7 +26,7 @@
async def run_sync(
- func: Callable[..., T_Retval], *args, cancellable: bool = False,
+ func: Callable[..., T_Retval], *args: object, cancellable: bool = False,
limiter: Optional[CapacityLimiter] = None) -> T_Retval:
"""
Call the given function with the given arguments in a worker process.
@@ -43,7 +43,7 @@ async def run_sync(
:return: an awaitable that yields the return value of the function.
"""
- async def send_raw_command(pickled_cmd: bytes):
+ async def send_raw_command(pickled_cmd: bytes) -> object:
try:
await stdin.send(pickled_cmd)
response = await buffered.receive_until(b'\n', 50)
@@ -145,7 +145,7 @@ async def send_raw_command(pickled_cmd: bytes):
with CancelScope(shield=not cancellable):
try:
- return await send_raw_command(request)
+ return cast(T_Retval, await send_raw_command(request))
finally:
if process in workers:
idle_workers.append((process, current_time()))
diff --git a/src/anyio/to_thread.py b/src/anyio/to_thread.py
index e0428777..5fc95894 100644
--- a/src/anyio/to_thread.py
+++ b/src/anyio/to_thread.py
@@ -8,7 +8,7 @@
async def run_sync(
- func: Callable[..., T_Retval], *args, cancellable: bool = False,
+ func: Callable[..., T_Retval], *args: object, cancellable: bool = False,
limiter: Optional[CapacityLimiter] = None) -> T_Retval:
"""
Call the given function with the given arguments in a worker thread.
@@ -30,7 +30,7 @@ async def run_sync(
async def run_sync_in_worker_thread(
- func: Callable[..., T_Retval], *args, cancellable: bool = False,
+ func: Callable[..., T_Retval], *args: object, cancellable: bool = False,
limiter: Optional[CapacityLimiter] = None) -> T_Retval:
warn('run_sync_in_worker_thread() has been deprecated, use anyio.to_thread.run_sync() instead',
DeprecationWarning)
diff --git a/tests/streams/test_text.py b/tests/streams/test_text.py
index 47ba8f9a..16374dce 100644
--- a/tests/streams/test_text.py
+++ b/tests/streams/test_text.py
@@ -55,3 +55,4 @@ async def test_bidirectional_stream():
await send_stream.send(b'\xc3\xa6\xc3\xb8')
assert await text_stream.receive() == 'æø'
+ assert text_stream.extra_attributes == {}
diff --git a/tests/test_compat.py b/tests/test_compat.py
index f662e826..007e541f 100644
--- a/tests/test_compat.py
+++ b/tests/test_compat.py
@@ -35,6 +35,10 @@ async def test_get_running_tasks(self):
tasks = await maybe_async(get_running_tasks())
assert type(tasks) is list
+ async def test_get_current_task(self):
+ task = await maybe_async(get_current_task())
+ assert type(task) is TaskInfo
+
async def test_maybe_async_cm():
async with maybe_async_cm(CancelScope()):
diff --git a/tests/test_debugging.py b/tests/test_debugging.py
index e7636096..7d982270 100644
--- a/tests/test_debugging.py
+++ b/tests/test_debugging.py
@@ -29,6 +29,25 @@ async def main():
loop.close()
+@pytest.mark.parametrize(
+ "name_input,expected",
+ [
+ (None, 'test_debugging.test_non_main_task_name..non_main'),
+ (b'name', "b'name'"),
+ ("name", "name"),
+ ("", ""),
+ ],
+)
+async def test_non_main_task_name(name_input, expected):
+ async def non_main(*, task_status):
+ task_status.started(anyio.get_current_task().name)
+
+ async with anyio.create_task_group() as tg:
+ name = await tg.start(non_main, name=name_input)
+
+ assert name == expected
+
+
async def test_get_running_tasks():
async def inspect():
await wait_all_tasks_blocked()