From e23b44e171c71d44e57d0c103a7daec1aaa7ad57 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 19 May 2021 14:50:34 +0100 Subject: [PATCH] Filled in missing type annotations and disallowed untyped defs (#289) --- docs/versionhistory.rst | 12 +++ setup.cfg | 1 + src/anyio/_backends/_asyncio.py | 120 +++++++++++++++------------- src/anyio/_backends/_trio.py | 103 +++++++++++++----------- src/anyio/_core/_compat.py | 57 ++++++++----- src/anyio/_core/_eventloop.py | 6 +- src/anyio/_core/_exceptions.py | 4 +- src/anyio/_core/_fileio.py | 16 ++-- src/anyio/_core/_sockets.py | 27 ++++--- src/anyio/_core/_streams.py | 8 +- src/anyio/_core/_subprocesses.py | 6 +- src/anyio/_core/_synchronization.py | 29 ++++--- src/anyio/_core/_tasks.py | 12 +-- src/anyio/_core/_testing.py | 31 ++++--- src/anyio/_core/_typedattr.py | 6 +- src/anyio/abc/_resources.py | 6 +- src/anyio/abc/_sockets.py | 19 +++-- src/anyio/abc/_streams.py | 2 +- src/anyio/abc/_tasks.py | 11 ++- src/anyio/abc/_testing.py | 13 ++- src/anyio/from_thread.py | 49 +++++++----- src/anyio/lowlevel.py | 56 ++++++++----- src/anyio/pytest_plugin.py | 27 ++++--- src/anyio/streams/buffered.py | 5 +- src/anyio/streams/file.py | 6 +- src/anyio/streams/memory.py | 4 +- src/anyio/streams/stapled.py | 20 ++--- src/anyio/streams/text.py | 22 ++--- src/anyio/streams/tls.py | 12 +-- src/anyio/to_process.py | 6 +- src/anyio/to_thread.py | 4 +- tests/streams/test_text.py | 1 + tests/test_compat.py | 4 + tests/test_debugging.py | 19 +++++ 34 files changed, 435 insertions(+), 289 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index d1df910d..d112e3a9 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -11,6 +11,7 @@ This library adheres to `Semantic Versioning 2.0 `_. - Changed asyncio task groups so that if the host and child tasks have only raised ``CancelledErrors``, just one ``CancelledError`` will now be raised instead of an ``ExceptionGroup``, allowing asyncio to ignore it when it propagates out of the task +- Changed task names to be converted to ``str`` early on asyncio (PR by Thomas Grainger) - Fixed ``sniffio._impl.AsyncLibraryNotFoundError: unknown async library, or not in async context`` on asyncio and Python 3.6 when ``to_thread.run_sync()`` is used from ``loop.run_until_complete()`` @@ -20,6 +21,17 @@ This library adheres to `Semantic Versioning 2.0 `_. task is cancelled (PR by Thomas Grainger) - Fixed declared return type of ``TaskGroup.start()`` (it was declared as ``None``, but anything can be returned from it) +- Fixed ``TextStream.extra_attributes`` raising ``AttributeError`` (PR by Thomas Grainger) +- Fixed ``await maybe_async(current_task())`` returning ``None`` (PR by Thomas Grainger) +- Fixed: ``pickle.dumps(current_task())`` now correctly raises ``TypeError`` instead of pickling to + ``None`` (PR by Thomas Grainger) +- Fixed return type annotation of ``Event.wait()`` (``bool`` → ``None``) (PR by Thomas Grainger) +- Fixed return type annotation of ``RunVar.get()`` to return either the type of the default value + or the type of the contained value (PR by Thomas Grainger) +- Fixed a deprecation warning message to refer to ``maybe_async()`` and not ``maybe_awaitable()`` + (PR by Thomas Grainger) +- Filled in argument and return types for all functions and methods previously missing them + (PR by Thomas Grainger) **3.0.1** diff --git a/setup.cfg b/setup.cfg index 353182af..931b930b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,3 +60,4 @@ pytest11 = [mypy] ignore_missing_imports = true +disallow_untyped_defs = true diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 5bcc6514..c8545e3f 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -100,8 +100,8 @@ def _cancel_all_tasks(loop): events.set_event_loop(None) loop.close() - def create_task(coro: Union[Generator[Any, None, _T], Awaitable[_T]], *, # type: ignore - name: Optional[str] = None) -> asyncio.Task: + def create_task(coro: Union[Generator[Any, None, _T], Awaitable[_T]], *, + name: object = None) -> asyncio.Task: return get_running_loop().create_task(coro) def get_running_loop() -> asyncio.AbstractEventLoop: @@ -213,11 +213,12 @@ def _maybe_set_event_loop_policy(policy: Optional[asyncio.AbstractEventLoopPolic asyncio.set_event_loop_policy(policy) -def run(func: Callable[..., T_Retval], *args, debug: bool = False, use_uvloop: bool = True, +def run(func: Callable[..., Awaitable[T_Retval]], *args: object, + debug: bool = False, use_uvloop: bool = True, policy: Optional[asyncio.AbstractEventLoopPolicy] = None) -> T_Retval: @wraps(func) - async def wrapper(): - task = current_task() + async def wrapper() -> T_Retval: + task = cast(asyncio.Task, current_task()) task_state = TaskState(None, get_callable_name(func), None) _task_states[task] = task_state if _native_task_names: @@ -247,7 +248,7 @@ async def wrapper(): class CancelScope(BaseCancelScope): - def __new__(cls, *, deadline: float = math.inf, shield: bool = False): + def __new__(cls, *, deadline: float = math.inf, shield: bool = False) -> "CancelScope": return object.__new__(cls) def __init__(self, deadline: float = math.inf, shield: bool = False): @@ -262,20 +263,20 @@ def __init__(self, deadline: float = math.inf, shield: bool = False): self._host_task: Optional[asyncio.Task] = None self._timeout_expired = False - def __enter__(self): + def __enter__(self) -> "CancelScope": if self._active: raise RuntimeError( "Each CancelScope may only be used for a single 'with' block" ) - self._host_task = current_task() - self._tasks.add(self._host_task) + self._host_task = host_task = cast(asyncio.Task, current_task()) + self._tasks.add(host_task) try: - task_state = _task_states[self._host_task] + task_state = _task_states[host_task] except KeyError: - task_name = self._host_task.get_name() if _native_task_names else None + task_name = host_task.get_name() if _native_task_names else None task_state = TaskState(None, task_name, self) - _task_states[self._host_task] = task_state + _task_states[host_task] = task_state else: self._parent_scope = task_state.cancel_scope task_state.cancel_scope = self @@ -326,7 +327,7 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[Ba return None - def _timeout(self): + def _timeout(self) -> None: if self._deadline != math.inf: loop = get_running_loop() if loop.time() >= self._deadline: @@ -460,9 +461,9 @@ async def cancel_shielded_checkpoint() -> None: await sleep(0) -def current_effective_deadline(): +def current_effective_deadline() -> float: try: - cancel_scope = _task_states[current_task()].cancel_scope + cancel_scope = _task_states[current_task()].cancel_scope # type: ignore[index] except KeyError: return math.inf @@ -477,7 +478,7 @@ def current_effective_deadline(): return deadline -def current_time(): +def current_time() -> float: return get_running_loop().time() @@ -517,17 +518,17 @@ class _AsyncioTaskStatus(abc.TaskStatus): def __init__(self, future: asyncio.Future): self._future = future - def started(self, value=None) -> None: + def started(self, value: object = None) -> None: self._future.set_result(value) class TaskGroup(abc.TaskGroup): - def __init__(self): + def __init__(self) -> None: self.cancel_scope: CancelScope = CancelScope() self._active = False self._exceptions: List[BaseException] = [] - async def __aenter__(self): + async def __aenter__(self) -> "TaskGroup": self.cancel_scope.__enter__() self._active = True return self @@ -613,7 +614,7 @@ async def _run_wrapped_task( self.cancel_scope._tasks.remove(task) del _task_states[task] - def _spawn(self, func: Callable[..., Coroutine], args: tuple, name, + def _spawn(self, func: Callable[..., Coroutine], args: tuple, name: object, task_status_future: Optional[asyncio.Future] = None) -> asyncio.Task: def task_done(_task: asyncio.Task) -> None: # This is the code path for Python 3.8+ @@ -643,7 +644,7 @@ def task_done(_task: asyncio.Task) -> None: raise RuntimeError('This task group is not active; no new tasks can be started.') options = {} - name = name or get_callable_name(func) + name = get_callable_name(func) if name is None else str(name) if _native_task_names: options['name'] = name @@ -669,10 +670,12 @@ def task_done(_task: asyncio.Task) -> None: self.cancel_scope._tasks.add(task) return task - def start_soon(self, func: Callable[..., Coroutine], *args, name=None) -> None: + def start_soon(self, func: Callable[..., Coroutine], *args: object, + name: object = None) -> None: self._spawn(func, args, name) - async def start(self, func: Callable[..., Coroutine], *args, name=None) -> None: + async def start(self, func: Callable[..., Coroutine], *args: object, + name: object = None) -> None: future: asyncio.Future = asyncio.Future() task = self._spawn(func, args, name, future) @@ -745,7 +748,7 @@ def stop(self, f: Optional[asyncio.Task] = None) -> None: async def run_sync_in_worker_thread( - func: Callable[..., T_Retval], *args, cancellable: bool = False, + func: Callable[..., T_Retval], *args: object, cancellable: bool = False, limiter: Optional['CapacityLimiter'] = None) -> T_Retval: await checkpoint() @@ -789,10 +792,10 @@ async def run_sync_in_worker_thread( idle_workers.append(worker) -def run_sync_from_thread(func: Callable[..., T_Retval], *args, +def run_sync_from_thread(func: Callable[..., T_Retval], *args: object, loop: Optional[asyncio.AbstractEventLoop] = None) -> T_Retval: @wraps(func) - def wrapper(): + def wrapper() -> None: try: f.set_result(func(*args)) except BaseException as exc: @@ -806,22 +809,24 @@ def wrapper(): return f.result() -def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval: +def run_async_from_thread( + func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object +) -> T_Retval: f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe( func(*args), threadlocals.loop) return f.result() class BlockingPortal(abc.BlockingPortal): - def __new__(cls): + def __new__(cls) -> "BlockingPortal": return object.__new__(cls) - def __init__(self): + def __init__(self) -> None: super().__init__() self._loop = get_running_loop() def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any], - name, future: Future) -> None: + name: object, future: Future) -> None: run_sync_from_thread( partial(self._task_group.start_soon, name=name), self._call_func, func, args, kwargs, future, loop=self._loop) @@ -908,13 +913,16 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]: return self._stderr -async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int, +async def open_process(command: Union[str, Sequence[str]], *, shell: bool, + stdin: int, stdout: int, stderr: int, cwd: Union[str, bytes, PathLike, None] = None, env: Optional[Mapping[str, str]] = None) -> Process: await checkpoint() if shell: - process = await asyncio.create_subprocess_shell(command, stdin=stdin, stdout=stdout, - stderr=stderr, cwd=cwd, env=env) + process = await asyncio.create_subprocess_shell( + command, stdin=stdin, stdout=stdout, # type: ignore[arg-type] + stderr=stderr, cwd=cwd, env=env, + ) else: process = await asyncio.create_subprocess_exec(*command, stdin=stdin, stdout=stdout, stderr=stderr, cwd=cwd, env=env) @@ -925,7 +933,7 @@ async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: return Process(process, stdin_stream, stdout_stream, stderr_stream) -def _forcibly_shutdown_process_pool_on_exit(workers: Set[Process], _task) -> None: +def _forcibly_shutdown_process_pool_on_exit(workers: Set[Process], _task: object) -> None: """ Forcibly shuts down worker processes belonging to this event loop.""" child_watcher: Optional[asyncio.AbstractChildWatcher] @@ -1142,7 +1150,7 @@ def _raw_socket(self) -> SocketType: return self.__raw_socket def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: - def callback(f): + def callback(f: object) -> None: del self._receive_future loop.remove_reader(self.__raw_socket) @@ -1152,7 +1160,7 @@ def callback(f): return f def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: - def callback(f): + def callback(f: object) -> None: del self._send_future loop.remove_writer(self.__raw_socket) @@ -1594,10 +1602,10 @@ async def wait_socket_writable(sock: socket.SocketType) -> None: # class Event(BaseEvent): - def __new__(cls): + def __new__(cls) -> "Event": return object.__new__(cls) - def __init__(self): + def __init__(self) -> None: self._event = asyncio.Event() def set(self) -> DeprecatedAwaitable: @@ -1607,18 +1615,18 @@ def set(self) -> DeprecatedAwaitable: def is_set(self) -> bool: return self._event.is_set() - async def wait(self): + async def wait(self) -> None: if await self._event.wait(): await checkpoint() def statistics(self) -> EventStatistics: - return EventStatistics(len(self._event._waiters)) + return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined] class CapacityLimiter(BaseCapacityLimiter): _total_tokens: float = 0 - def __new__(cls, total_tokens: float): + def __new__(cls, total_tokens: float) -> "CapacityLimiter": return object.__new__(cls) def __init__(self, total_tokens: float): @@ -1626,7 +1634,7 @@ def __init__(self, total_tokens: float): self._wait_queue: Dict[Any, asyncio.Event] = OrderedDict() self.total_tokens = total_tokens - async def __aenter__(self): + async def __aenter__(self) -> None: await self.acquire() async def __aexit__(self, exc_type: Optional[Type[BaseException]], @@ -1671,7 +1679,7 @@ def acquire_nowait(self) -> DeprecatedAwaitable: self.acquire_on_behalf_of_nowait(current_task()) return DeprecatedAwaitable(self.acquire_nowait) - def acquire_on_behalf_of_nowait(self, borrower) -> DeprecatedAwaitable: + def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: if borrower in self._borrowers: raise RuntimeError("this borrower is already holding one of this CapacityLimiter's " "tokens") @@ -1685,7 +1693,7 @@ def acquire_on_behalf_of_nowait(self, borrower) -> DeprecatedAwaitable: async def acquire(self) -> None: return await self.acquire_on_behalf_of(current_task()) - async def acquire_on_behalf_of(self, borrower) -> None: + async def acquire_on_behalf_of(self, borrower: object) -> None: await checkpoint_if_cancelled() try: self.acquire_on_behalf_of_nowait(borrower) @@ -1705,7 +1713,7 @@ async def acquire_on_behalf_of(self, borrower) -> None: def release(self) -> None: self.release_on_behalf_of(current_task()) - def release_on_behalf_of(self, borrower) -> None: + def release_on_behalf_of(self, borrower: object) -> None: try: self._borrowers.remove(borrower) except KeyError: @@ -1725,7 +1733,7 @@ def statistics(self) -> CapacityLimiterStatistics: _default_thread_limiter: RunVar[CapacityLimiter] = RunVar('_default_thread_limiter') -def current_default_thread_limiter(): +def current_default_thread_limiter() -> CapacityLimiter: try: return _default_thread_limiter.get() except LookupError: @@ -1751,18 +1759,21 @@ def _deliver(self, signum: int) -> None: if not self._future.done(): self._future.set_result(None) - def __enter__(self): + def __enter__(self) -> "_SignalReceiver": for sig in set(self._signals): self._loop.add_signal_handler(sig, self._deliver, sig) self._handled_signals.add(sig) return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: for sig in self._handled_signals: self._loop.remove_signal_handler(sig) + return None - def __aiter__(self): + def __aiter__(self) -> "_SignalReceiver": return self async def __anext__(self) -> int: @@ -1825,7 +1836,7 @@ def __init__(self, debug: bool = False, use_uvloop: bool = True, self._loop.set_debug(debug) asyncio.set_event_loop(self._loop) - def _cancel_all_tasks(self): + def _cancel_all_tasks(self) -> None: to_cancel = all_tasks(self._loop) if not to_cancel: return @@ -1840,7 +1851,7 @@ def _cancel_all_tasks(self): if task.cancelled(): continue if task.exception() is not None: - raise task.exception() + raise cast(BaseException, task.exception()) def close(self) -> None: try: @@ -1850,16 +1861,17 @@ def close(self) -> None: asyncio.set_event_loop(None) self._loop.close() - def call(self, func: Callable[..., Awaitable], *args, **kwargs): + def call(self, func: Callable[..., Awaitable[T_Retval]], + *args: object, **kwargs: object) -> T_Retval: def exception_handler(loop: asyncio.AbstractEventLoop, context: Dict[str, Any]) -> None: exceptions.append(context['exception']) exceptions: List[Exception] = [] self._loop.set_exception_handler(exception_handler) try: - retval = self._loop.run_until_complete(func(*args, **kwargs)) + retval: T_Retval = self._loop.run_until_complete(func(*args, **kwargs)) except Exception as exc: - retval = None + retval = None # type: ignore[assignment] exceptions.append(exc) finally: self._loop.set_exception_handler(None) diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 48be7a51..42764907 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -8,11 +8,12 @@ from os import PathLike from types import TracebackType from typing import ( - Any, Awaitable, Callable, Collection, Coroutine, Dict, Generic, List, Mapping, NoReturn, - Optional, Set, Tuple, Type, TypeVar, Union) + Any, Awaitable, Callable, Collection, ContextManager, Coroutine, Deque, Dict, Generic, List, + Mapping, NoReturn, Optional, Sequence, Set, Tuple, Type, TypeVar, Union) import trio.from_thread -from outcome import Error, Value +from outcome import Error, Outcome, Value +from trio.socket import SocketType as TrioSocketType from trio.to_thread import run_sync from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc @@ -36,6 +37,7 @@ else: from trio.lowlevel import wait_readable, wait_writable + T_Retval = TypeVar('T_Retval') T_SockAddr = TypeVar('T_SockAddr', str, IPSockAddrType) @@ -61,17 +63,20 @@ # class CancelScope(BaseCancelScope): - def __new__(cls, original: Optional[trio.CancelScope] = None, **kwargs): + def __new__(cls, original: Optional[trio.CancelScope] = None, + **kwargs: object) -> 'CancelScope': return object.__new__(cls) - def __init__(self, original: Optional[trio.CancelScope] = None, **kwargs): + def __init__(self, original: Optional[trio.CancelScope] = None, **kwargs: object) -> None: self.__original = original or trio.CancelScope(**kwargs) - def __enter__(self): + def __enter__(self) -> 'CancelScope': self.__original.__enter__() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: return self.__original.__exit__(exc_type, exc_val, exc_tb) def cancel(self) -> DeprecatedAwaitable: @@ -116,12 +121,12 @@ class ExceptionGroup(BaseExceptionGroup, trio.MultiError): class TaskGroup(abc.TaskGroup): - def __init__(self): + def __init__(self) -> None: self._active = False self._nursery_manager = trio.open_nursery() - self.cancel_scope = None + self.cancel_scope = None # type: ignore[assignment] - async def __aenter__(self): + async def __aenter__(self) -> 'TaskGroup': self._active = True self._nursery = await self._nursery_manager.__aenter__() self.cancel_scope = CancelScope(self._nursery.cancel_scope) @@ -137,13 +142,14 @@ async def __aexit__(self, exc_type: Optional[Type[BaseException]], finally: self._active = False - def start_soon(self, func: Callable, *args, name=None) -> None: + def start_soon(self, func: Callable, *args: object, name: object = None) -> None: if not self._active: raise RuntimeError('This task group is not active; no new tasks can be started.') self._nursery.start_soon(func, *args, name=name) - async def start(self, func: Callable[..., Coroutine], *args, name=None): + async def start(self, func: Callable[..., Coroutine], + *args: object, name: object = None) -> object: if not self._active: raise RuntimeError('This task group is not active; no new tasks can be started.') @@ -155,9 +161,9 @@ async def start(self, func: Callable[..., Coroutine], *args, name=None): async def run_sync_in_worker_thread( - func: Callable[..., T_Retval], *args, cancellable: bool = False, + func: Callable[..., T_Retval], *args: object, cancellable: bool = False, limiter: Optional[trio.CapacityLimiter] = None) -> T_Retval: - def wrapper(): + def wrapper() -> T_Retval: with claim_worker_thread('trio'): return func(*args) @@ -168,15 +174,15 @@ def wrapper(): class BlockingPortal(abc.BlockingPortal): - def __new__(cls): + def __new__(cls) -> 'BlockingPortal': return object.__new__(cls) - def __init__(self): + def __init__(self) -> None: super().__init__() self._token = trio.lowlevel.current_trio_token() def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any], - name, future: Future) -> None: + name: object, future: Future) -> None: return trio.from_thread.run_sync( partial(self._task_group.start_soon, name=name), self._call_func, func, args, kwargs, future, trio_token=self._token) @@ -273,7 +279,8 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]: return self._stderr -async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int, +async def open_process(command: Union[str, Sequence[str]], *, shell: bool, + stdin: int, stdout: int, stderr: int, cwd: Union[str, bytes, PathLike, None] = None, env: Optional[Mapping[str, str]] = None) -> Process: process = await trio.open_process(command, stdin=stdin, stdout=stdout, stderr=stderr, @@ -285,7 +292,7 @@ async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: class _ProcessPoolShutdownInstrument(trio.abc.Instrument): - def after_run(self): + def after_run(self) -> None: super().after_run() @@ -316,7 +323,7 @@ def setup_process_pool_exit_at_shutdown(workers: Set[Process]) -> None: # class _TrioSocketMixin(Generic[T_SockAddr]): - def __init__(self, trio_socket): + def __init__(self, trio_socket: TrioSocketType) -> None: self._trio_socket = trio_socket self._closed = False @@ -347,7 +354,7 @@ def _convert_socket_error(self, exc: BaseException) -> 'NoReturn': class SocketStream(_TrioSocketMixin, abc.SocketStream): - def __init__(self, trio_socket): + def __init__(self, trio_socket: TrioSocketType) -> None: super().__init__(trio_socket) self._receive_guard = ResourceGuard('reading from') self._send_guard = ResourceGuard('writing to') @@ -467,7 +474,7 @@ async def accept(self) -> UNIXSocketStream: class UDPSocket(_TrioSocketMixin[IPSockAddrType], abc.UDPSocket): - def __init__(self, trio_socket): + def __init__(self, trio_socket: TrioSocketType) -> None: super().__init__(trio_socket) self._receive_guard = ResourceGuard('reading from') self._send_guard = ResourceGuard('writing to') @@ -489,7 +496,7 @@ async def send(self, item: UDPPacketType) -> None: class ConnectedUDPSocket(_TrioSocketMixin[IPSockAddrType], abc.ConnectedUDPSocket): - def __init__(self, trio_socket): + def __init__(self, trio_socket: TrioSocketType) -> None: super().__init__(trio_socket) self._receive_guard = ResourceGuard('reading from') self._send_guard = ResourceGuard('writing to') @@ -562,7 +569,7 @@ async def create_udp_socket( getnameinfo = trio.socket.getnameinfo -async def wait_socket_readable(sock): +async def wait_socket_readable(sock: socket.SocketType) -> None: try: await wait_readable(sock) except trio.ClosedResourceError as exc: @@ -571,7 +578,7 @@ async def wait_socket_readable(sock): raise BusyResourceError('reading from') from None -async def wait_socket_writable(sock): +async def wait_socket_writable(sock: socket.SocketType) -> None: try: await wait_writable(sock) except trio.ClosedResourceError as exc: @@ -585,34 +592,34 @@ async def wait_socket_writable(sock): # class Event(BaseEvent): - def __new__(cls): + def __new__(cls) -> 'Event': return object.__new__(cls) - def __init__(self): + def __init__(self) -> None: self.__original = trio.Event() def is_set(self) -> bool: return self.__original.is_set() - async def wait(self) -> bool: + async def wait(self) -> None: return await self.__original.wait() def statistics(self) -> EventStatistics: return self.__original.statistics() - def set(self): + def set(self) -> DeprecatedAwaitable: self.__original.set() return DeprecatedAwaitable(self.set) class CapacityLimiter(BaseCapacityLimiter): - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: object, **kwargs: object) -> "CapacityLimiter": return object.__new__(cls) - def __init__(self, *args, original: Optional[trio.CapacityLimiter] = None): + def __init__(self, *args: object, original: Optional[trio.CapacityLimiter] = None) -> None: self.__original = original or trio.CapacityLimiter(*args) - async def __aenter__(self): + async def __aenter__(self) -> None: return await self.__original.__aenter__() async def __aexit__(self, exc_type: Optional[Type[BaseException]], @@ -636,24 +643,24 @@ def borrowed_tokens(self) -> int: def available_tokens(self) -> float: return self.__original.available_tokens - def acquire_nowait(self): + def acquire_nowait(self) -> DeprecatedAwaitable: self.__original.acquire_nowait() return DeprecatedAwaitable(self.acquire_nowait) - def acquire_on_behalf_of_nowait(self, borrower): + def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: self.__original.acquire_on_behalf_of_nowait(borrower) return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait) async def acquire(self) -> None: await self.__original.acquire() - async def acquire_on_behalf_of(self, borrower) -> None: + async def acquire_on_behalf_of(self, borrower: object) -> None: await self.__original.acquire_on_behalf_of(borrower) def release(self) -> None: return self.__original.release() - def release_on_behalf_of(self, borrower) -> None: + def release_on_behalf_of(self, borrower: object) -> None: return self.__original.release_on_behalf_of(borrower) def statistics(self) -> CapacityLimiterStatistics: @@ -677,17 +684,19 @@ def current_default_thread_limiter() -> CapacityLimiter: # class _SignalReceiver(DeprecatedAsyncContextManager): - def __init__(self, cm): + def __init__(self, cm: ContextManager[T]): self._cm = cm def __enter__(self) -> T: return self._cm.__enter__() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: return self._cm.__exit__(exc_type, exc_val, exc_tb) -def open_signal_receiver(*signals: int): +def open_signal_receiver(*signals: int) -> _SignalReceiver: cm = trio.open_signal_receiver(*signals) return _SignalReceiver(cm) @@ -723,18 +732,18 @@ def get_running_tasks() -> List[TaskInfo]: return task_infos -def wait_all_tasks_blocked(): +def wait_all_tasks_blocked() -> Awaitable[None]: import trio.testing return trio.testing.wait_all_tasks_blocked() class TestRunner(abc.TestRunner): - def __init__(self, **options): + def __init__(self, **options: object) -> None: from collections import deque from queue import Queue - self._call_queue = Queue() - self._result_queue = deque() + self._call_queue: "Queue[Callable[..., object]]" = Queue() + self._result_queue: Deque[Outcome] = deque() self._stop_event: Optional[trio.Event] = None self._nursery: Optional[trio.Nursery] = None self._options = options @@ -744,7 +753,8 @@ async def _trio_main(self) -> None: async with trio.open_nursery() as self._nursery: await self._stop_event.wait() - async def _call_func(self, func, args, kwargs): + async def _call_func(self, func: Callable[..., Awaitable[object]], + args: tuple, kwargs: dict) -> None: try: retval = await func(*args, **kwargs) except BaseException as exc: @@ -752,7 +762,7 @@ async def _call_func(self, func, args, kwargs): else: self._result_queue.append(Value(retval)) - def _main_task_finished(self, outcome) -> None: + def _main_task_finished(self, outcome: object) -> None: self._nursery = None def close(self) -> None: @@ -761,7 +771,8 @@ def close(self) -> None: while self._nursery is not None: self._call_queue.get()() - def call(self, func: Callable[..., Awaitable], *args, **kwargs): + def call(self, func: Callable[..., Awaitable[T_Retval]], + *args: object, **kwargs: object) -> T_Retval: if self._nursery is None: trio.lowlevel.start_guest_run( self._trio_main, run_sync_soon_threadsafe=self._call_queue.put, diff --git a/src/anyio/_core/_compat.py b/src/anyio/_core/_compat.py index 9f44ea95..0ebfeb2d 100644 --- a/src/anyio/_core/_compat.py +++ b/src/anyio/_core/_compat.py @@ -1,13 +1,24 @@ from abc import ABCMeta, abstractmethod from contextlib import AbstractContextManager +from types import TracebackType from typing import ( - AsyncContextManager, Callable, ContextManager, Generic, List, Optional, TypeVar, Union, - overload) + TYPE_CHECKING, AsyncContextManager, Callable, ContextManager, Generic, Iterable, List, + Optional, Tuple, Type, TypeVar, Union, overload) from warnings import warn +if TYPE_CHECKING: + from ._testing import TaskInfo +else: + TaskInfo = object + T = TypeVar('T') AnyDeprecatedAwaitable = Union['DeprecatedAwaitable', 'DeprecatedAwaitableFloat', - 'DeprecatedAwaitableList'] + 'DeprecatedAwaitableList', TaskInfo] + + +@overload +async def maybe_async(__obj: TaskInfo) -> TaskInfo: + ... @overload @@ -16,7 +27,7 @@ async def maybe_async(__obj: 'DeprecatedAwaitableFloat') -> float: @overload -async def maybe_async(__obj: 'DeprecatedAwaitableList') -> list: +async def maybe_async(__obj: 'DeprecatedAwaitableList[T]') -> List[T]: ... @@ -25,7 +36,7 @@ async def maybe_async(__obj: 'DeprecatedAwaitable') -> None: ... -async def maybe_async(__obj: AnyDeprecatedAwaitable) -> Union[float, list, None]: +async def maybe_async(__obj: AnyDeprecatedAwaitable) -> Union[TaskInfo, float, list, None]: """ Await on the given object if necessary. @@ -49,7 +60,9 @@ def __init__(self, cm: ContextManager[T]): async def __aenter__(self) -> T: return self._cm.__enter__() - async def __aexit__(self, exc_type, exc_val, exc_tb) -> Optional[bool]: + async def __aexit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: return self._cm.__exit__(exc_type, exc_val, exc_tb) @@ -74,7 +87,7 @@ def maybe_async_cm(cm: Union[ContextManager[T], AsyncContextManager[T]]) -> Asyn def _warn_deprecation(awaitable: AnyDeprecatedAwaitable, stacklevel: int = 1) -> None: warn(f'Awaiting on {awaitable._name}() is deprecated. Use "await ' - f'anyio.maybe_awaitable({awaitable._name}(...)) if you have to support both AnyIO 2.x ' + f'anyio.maybe_async({awaitable._name}(...)) if you have to support both AnyIO 2.x ' f'and 3.x, or just remove the "await" if you are completely migrating to AnyIO 3+.', DeprecationWarning, stacklevel=stacklevel + 1) @@ -83,33 +96,35 @@ class DeprecatedAwaitable: def __init__(self, func: Callable[..., 'DeprecatedAwaitable']): self._name = f'{func.__module__}.{func.__qualname__}' - def __await__(self): + def __await__(self) -> Iterable[None]: _warn_deprecation(self) if False: yield - def __reduce__(self): + def __reduce__(self) -> Tuple[Type[None], Tuple]: return type(None), () - def _unwrap(self): + def _unwrap(self) -> None: return None class DeprecatedAwaitableFloat(float): - def __new__(cls, x, func): + def __new__( + cls, x: float, func: Callable[..., 'DeprecatedAwaitableFloat'] + ) -> 'DeprecatedAwaitableFloat': return super().__new__(cls, x) def __init__(self, x: float, func: Callable[..., 'DeprecatedAwaitableFloat']): self._name = f'{func.__module__}.{func.__qualname__}' - def __await__(self): + def __await__(self) -> Iterable[float]: _warn_deprecation(self) if False: yield return float(self) - def __reduce__(self): + def __reduce__(self) -> Tuple[Type[float], Tuple[float]]: return float, (float(self),) def _unwrap(self) -> float: @@ -117,18 +132,18 @@ def _unwrap(self) -> float: class DeprecatedAwaitableList(List[T]): - def __init__(self, *args, func: Callable[..., 'DeprecatedAwaitableList']): + def __init__(self, *args: T, func: Callable[..., 'DeprecatedAwaitableList']): super().__init__(*args) self._name = f'{func.__module__}.{func.__qualname__}' - def __await__(self): + def __await__(self) -> Iterable[List[T]]: _warn_deprecation(self) if False: yield - return self + return list(self) - def __reduce__(self): + def __reduce__(self) -> Tuple[Type[list], Tuple[List[T]]]: return list, (list(self),) def _unwrap(self) -> List[T]: @@ -141,7 +156,9 @@ def __enter__(self) -> T: pass @abstractmethod - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: pass async def __aenter__(self) -> T: @@ -151,5 +168,7 @@ async def __aenter__(self) -> T: f'you are completely migrating to AnyIO 3+.', DeprecationWarning) return self.__enter__() - async def __aexit__(self, exc_type, exc_val, exc_tb) -> Optional[bool]: + async def __aexit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: return self.__exit__(exc_type, exc_val, exc_tb) diff --git a/src/anyio/_core/_eventloop.py b/src/anyio/_core/_eventloop.py index 6021ab99..f2364a3b 100644 --- a/src/anyio/_core/_eventloop.py +++ b/src/anyio/_core/_eventloop.py @@ -16,7 +16,7 @@ threadlocals = threading.local() -def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args, +def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object, backend: str = 'asyncio', backend_options: Optional[Dict[str, Any]] = None) -> T_Retval: """ Run the given coroutine function in an asynchronous event loop. @@ -120,7 +120,7 @@ def get_cancelled_exc_class() -> Type[BaseException]: # @contextmanager -def claim_worker_thread(backend) -> Generator[Any, None, None]: +def claim_worker_thread(backend: str) -> Generator[Any, None, None]: module = sys.modules['anyio._backends._' + backend] threadlocals.current_async_module = module token = sniffio.current_async_library_cvar.set(backend) @@ -131,7 +131,7 @@ def claim_worker_thread(backend) -> Generator[Any, None, None]: del threadlocals.current_async_module -def get_asynclib(asynclib_name: Optional[str] = None): +def get_asynclib(asynclib_name: Optional[str] = None) -> Any: if asynclib_name is None: asynclib_name = sniffio.current_async_library() diff --git a/src/anyio/_core/_exceptions.py b/src/anyio/_core/_exceptions.py index 8025ef6f..52e59808 100644 --- a/src/anyio/_core/_exceptions.py +++ b/src/anyio/_core/_exceptions.py @@ -49,7 +49,7 @@ class ExceptionGroup(BaseException): #: the sequence of exceptions raised together exceptions: Sequence[BaseException] - def __str__(self): + def __str__(self) -> str: tracebacks = [''.join(format_exception(type(exc), exc, exc.__traceback__)) for exc in self.exceptions] return f'{len(self.exceptions)} exceptions were raised in the task group:\n' \ @@ -67,7 +67,7 @@ class IncompleteRead(Exception): connection is closed before the requested amount of bytes has been read. """ - def __init__(self): + def __init__(self) -> None: super().__init__('The stream was closed before the read operation could be completed') diff --git a/src/anyio/_core/_fileio.py b/src/anyio/_core/_fileio.py index fe2b53d6..ebfd34e3 100644 --- a/src/anyio/_core/_fileio.py +++ b/src/anyio/_core/_fileio.py @@ -1,12 +1,14 @@ import os from os import PathLike -from typing import Callable, Optional, Union +from typing import Any, AsyncIterator, Callable, Generic, Optional, TypeVar, Union from .. import to_thread from ..abc import AsyncResource +T_Fp = TypeVar("T_Fp") -class AsyncFile(AsyncResource): + +class AsyncFile(AsyncResource, Generic[T_Fp]): """ An asynchronous file object. @@ -38,18 +40,18 @@ class AsyncFile(AsyncResource): print(line) """ - def __init__(self, fp) -> None: - self._fp = fp + def __init__(self, fp: T_Fp) -> None: + self._fp: Any = fp - def __getattr__(self, name): + def __getattr__(self, name: str) -> object: return getattr(self._fp, name) @property - def wrapped(self): + def wrapped(self) -> T_Fp: """The wrapped file object.""" return self._fp - async def __aiter__(self): + async def __aiter__(self) -> AsyncIterator[bytes]: while True: line = await self.readline() if line: diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py index e770861f..f58cdc3a 100644 --- a/src/anyio/_core/_sockets.py +++ b/src/anyio/_core/_sockets.py @@ -83,9 +83,11 @@ async def connect_tcp( async def connect_tcp( - remote_host, remote_port, *, local_host=None, tls=False, ssl_context=None, - tls_standard_compatible=True, tls_hostname=None, happy_eyeballs_delay=0.25 -): + remote_host: IPAddressType, remote_port: int, *, local_host: Optional[IPAddressType] = None, + tls: bool = False, ssl_context: Optional[ssl.SSLContext] = None, + tls_standard_compatible: bool = True, tls_hostname: Optional[str] = None, + happy_eyeballs_delay: float = 0.25 +) -> Union[SocketStream, TLSStream]: """ Connect to a host using the TCP protocol. @@ -119,7 +121,7 @@ async def connect_tcp( # Placed here due to https://github.com/python/mypy/issues/7057 connected_stream: Optional[SocketStream] = None - async def try_connect(remote_host: str, event: Event): + async def try_connect(remote_host: str, event: Event) -> None: nonlocal connected_stream try: stream = await asynclib.connect_tcp(remote_host, remote_port, local_address) @@ -184,7 +186,7 @@ async def try_connect(remote_host: str, event: Event): if tls or tls_hostname or ssl_context: try: return await TLSStream.wrap(connected_stream, server_side=False, - hostname=tls_hostname or remote_host, + hostname=tls_hostname or str(remote_host), ssl_context=ssl_context, standard_compatible=tls_standard_compatible) except BaseException: @@ -475,7 +477,9 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]: # Private API # -def convert_ipv6_sockaddr(sockaddr): +def convert_ipv6_sockaddr( + sockaddr: Union[Tuple[str, int, int, int], Tuple[str, int]] +) -> Tuple[str, int]: """ Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format. @@ -489,10 +493,11 @@ def convert_ipv6_sockaddr(sockaddr): """ # This is more complicated than it should be because of MyPy if isinstance(sockaddr, tuple) and len(sockaddr) == 4: - if sockaddr[3]: - # Add scopeid to the address - return sockaddr[0] + '%' + str(sockaddr[3]), sockaddr[1] + host, port, flowinfo, scope_id = cast(Tuple[str, int, int, int], sockaddr) + if scope_id: + # Add scope_id to the address + return f"{host}%{scope_id}", port else: - return sockaddr[:2] + return host, port else: - return sockaddr + return cast(Tuple[str, int], sockaddr) diff --git a/src/anyio/_core/_streams.py b/src/anyio/_core/_streams.py index 02eabd4d..4a003bea 100644 --- a/src/anyio/_core/_streams.py +++ b/src/anyio/_core/_streams.py @@ -1,5 +1,5 @@ import math -from typing import Any, Tuple, Type, TypeVar, overload +from typing import Optional, Tuple, Type, TypeVar, overload from ..streams.memory import ( MemoryObjectReceiveStream, MemoryObjectSendStream, MemoryObjectStreamState) @@ -17,11 +17,13 @@ def create_memory_object_stream( @overload def create_memory_object_stream( max_buffer_size: float = 0 -) -> Tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]: +) -> Tuple[MemoryObjectSendStream, MemoryObjectReceiveStream]: ... -def create_memory_object_stream(max_buffer_size=0, item_type=None): +def create_memory_object_stream( + max_buffer_size: float = 0, item_type: Optional[Type[T_Item]] = None +) -> Tuple[MemoryObjectSendStream, MemoryObjectReceiveStream]: """ Create a memory object stream. diff --git a/src/anyio/_core/_subprocesses.py b/src/anyio/_core/_subprocesses.py index f7e0232c..1daf6c10 100644 --- a/src/anyio/_core/_subprocesses.py +++ b/src/anyio/_core/_subprocesses.py @@ -1,6 +1,6 @@ from os import PathLike from subprocess import DEVNULL, PIPE, CalledProcessError, CompletedProcess -from typing import Mapping, Optional, Sequence, Union, cast +from typing import AsyncIterable, List, Mapping, Optional, Sequence, Union, cast from ..abc import Process from ._eventloop import get_asynclib @@ -32,13 +32,13 @@ async def run_process(command: Union[str, Sequence[str]], *, input: Optional[byt nonzero return code """ - async def drain_stream(stream, index): + async def drain_stream(stream: AsyncIterable[bytes], index: int) -> None: chunks = [chunk async for chunk in stream] stream_contents[index] = b''.join(chunks) async with await open_process(command, stdin=PIPE if input else DEVNULL, stdout=stdout, stderr=stderr, cwd=cwd, env=env) as process: - stream_contents = [None, None] + stream_contents: List[Optional[bytes]] = [None, None] try: async with create_task_group() as tg: if process.stdout: diff --git a/src/anyio/_core/_synchronization.py b/src/anyio/_core/_synchronization.py index f7070e98..6c691770 100644 --- a/src/anyio/_core/_synchronization.py +++ b/src/anyio/_core/_synchronization.py @@ -73,7 +73,7 @@ class SemaphoreStatistics: class Event: - def __new__(cls): + def __new__(cls) -> 'Event': return get_asynclib().Event() def set(self) -> DeprecatedAwaitable: @@ -84,7 +84,7 @@ def is_set(self) -> bool: """Return ``True`` if the flag is set, ``False`` if not.""" raise NotImplementedError - async def wait(self) -> bool: + async def wait(self) -> None: """ Wait until the flag has been set. @@ -101,10 +101,10 @@ def statistics(self) -> EventStatistics: class Lock: _owner_task: Optional[TaskInfo] = None - def __init__(self): + def __init__(self) -> None: self._waiters: Deque[Tuple[TaskInfo, Event]] = deque() - async def __aenter__(self): + async def __aenter__(self) -> None: await self.acquire() async def __aexit__(self, exc_type: Optional[Type[BaseException]], @@ -183,7 +183,7 @@ def __init__(self, lock: Optional[Lock] = None): self._lock = lock or Lock() self._waiters: Deque[Event] = deque() - async def __aenter__(self): + async def __aenter__(self) -> None: await self.acquire() async def __aexit__(self, exc_type: Optional[Type[BaseException]], @@ -351,10 +351,10 @@ def statistics(self) -> SemaphoreStatistics: class CapacityLimiter: - def __new__(cls, total_tokens: float): + def __new__(cls, total_tokens: float) -> 'CapacityLimiter': return get_asynclib().CapacityLimiter(total_tokens) - async def __aenter__(self): + async def __aenter__(self) -> None: raise NotImplementedError async def __aexit__(self, exc_type: Optional[Type[BaseException]], @@ -380,7 +380,7 @@ def total_tokens(self) -> float: def total_tokens(self, value: float) -> None: raise NotImplementedError - async def set_total_tokens(self, value) -> None: + async def set_total_tokens(self, value: float) -> None: warn('CapacityLimiter.set_total_tokens has been deprecated. Set the value of the' '"total_tokens" attribute directly.', DeprecationWarning) self.total_tokens = value @@ -404,7 +404,7 @@ def acquire_nowait(self) -> DeprecatedAwaitable: """ raise NotImplementedError - def acquire_on_behalf_of_nowait(self, borrower) -> DeprecatedAwaitable: + def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: """ Acquire a token without waiting for one to become available. @@ -421,7 +421,7 @@ async def acquire(self) -> None: """ raise NotImplementedError - async def acquire_on_behalf_of(self, borrower) -> None: + async def acquire_on_behalf_of(self, borrower: object) -> None: """ Acquire a token, waiting if necessary for one to become available. @@ -438,7 +438,7 @@ def release(self) -> None: """ raise NotImplementedError - def release_on_behalf_of(self, borrower) -> None: + def release_on_behalf_of(self, borrower: object) -> None: """ Release the token held by the given borrower. @@ -541,11 +541,14 @@ def __init__(self, action: str): self.action = action self._guarded = False - def __enter__(self): + def __enter__(self) -> None: if self._guarded: raise BusyResourceError(self.action) self._guarded = True - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: self._guarded = False + return None diff --git a/src/anyio/_core/_tasks.py b/src/anyio/_core/_tasks.py index ac4e41da..8bbad974 100644 --- a/src/anyio/_core/_tasks.py +++ b/src/anyio/_core/_tasks.py @@ -9,7 +9,7 @@ class _IgnoredTaskStatus(TaskStatus): - def started(self, value=None) -> None: + def started(self, value: object = None) -> None: pass @@ -24,7 +24,7 @@ class CancelScope(DeprecatedAsyncContextManager['CancelScope']): :param shield: ``True`` to shield the cancel scope from external cancellation """ - def __new__(cls, *, deadline: float = math.inf, shield: bool = False): + def __new__(cls, *, deadline: float = math.inf, shield: bool = False) -> 'CancelScope': return get_asynclib().CancelScope(shield=shield, deadline=deadline) def cancel(self) -> DeprecatedAwaitable: @@ -64,7 +64,7 @@ def shield(self) -> bool: def shield(self, value: bool) -> None: raise NotImplementedError - def __enter__(self): + def __enter__(self) -> 'CancelScope': raise NotImplementedError def __exit__(self, exc_type: Optional[Type[BaseException]], @@ -92,10 +92,12 @@ class FailAfterContextManager(DeprecatedAsyncContextManager): def __init__(self, cancel_scope: CancelScope): self._cancel_scope = cancel_scope - def __enter__(self): + def __enter__(self) -> CancelScope: return self._cancel_scope.__enter__() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: retval = self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) if self._cancel_scope.cancel_called: raise TimeoutError diff --git a/src/anyio/_core/_testing.py b/src/anyio/_core/_testing.py index 269923da..c48bd45e 100644 --- a/src/anyio/_core/_testing.py +++ b/src/anyio/_core/_testing.py @@ -1,10 +1,10 @@ -from typing import Coroutine, Optional +from typing import Coroutine, Generator, Optional -from ._compat import DeprecatedAwaitable, DeprecatedAwaitableList +from ._compat import DeprecatedAwaitableList, _warn_deprecation from ._eventloop import get_asynclib -class TaskInfo(DeprecatedAwaitable): +class TaskInfo: """ Represents an asynchronous task. @@ -15,31 +15,38 @@ class TaskInfo(DeprecatedAwaitable): :ivar ~collections.abc.Coroutine coro: the coroutine object of the task """ - __slots__ = 'id', 'parent_id', 'name', 'coro' + __slots__ = '_name', 'id', 'parent_id', 'name', 'coro' def __init__(self, id: int, parent_id: Optional[int], name: Optional[str], coro: Coroutine): - super().__init__(get_current_task) + func = get_current_task + self._name = f'{func.__module__}.{func.__qualname__}' self.id = id self.parent_id = parent_id self.name = name self.coro = coro - def __await__(self): - yield from super().__await__() - return self - - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, TaskInfo): return self.id == other.id return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return hash(self.id) - def __repr__(self): + def __repr__(self) -> str: return f'{self.__class__.__name__}(id={self.id!r}, name={self.name!r})' + def __await__(self) -> Generator[None, None, "TaskInfo"]: + _warn_deprecation(self) + if False: + yield + + return self + + def _unwrap(self) -> 'TaskInfo': + return self + def get_current_task() -> TaskInfo: """ diff --git a/src/anyio/_core/_typedattr.py b/src/anyio/_core/_typedattr.py index 162426a1..e8377961 100644 --- a/src/anyio/_core/_typedattr.py +++ b/src/anyio/_core/_typedattr.py @@ -1,5 +1,5 @@ import sys -from typing import Callable, Mapping, TypeVar, Union, overload +from typing import Any, Callable, Mapping, TypeVar, Union, overload from ._exceptions import TypedAttributeLookupError @@ -13,7 +13,7 @@ undefined = object() -def typed_attribute(): +def typed_attribute() -> Any: """Return a unique object, used to mark typed attributes.""" return object() @@ -58,7 +58,7 @@ def extra(self, attribute: T_Attr, default: T_Default) -> Union[T_Attr, T_Defaul ... @final - def extra(self, attribute, default=undefined): + def extra(self, attribute: Any, default: object = undefined) -> object: """ extra(attribute, default=undefined) diff --git a/src/anyio/abc/_resources.py b/src/anyio/abc/_resources.py index d6ed168f..4594e6e9 100644 --- a/src/anyio/abc/_resources.py +++ b/src/anyio/abc/_resources.py @@ -1,6 +1,8 @@ from abc import ABCMeta, abstractmethod from types import TracebackType -from typing import Optional, Type +from typing import Optional, Type, TypeVar + +T = TypeVar("T") class AsyncResource(metaclass=ABCMeta): @@ -11,7 +13,7 @@ class AsyncResource(metaclass=ABCMeta): :meth:`aclose` on exit. """ - async def __aenter__(self): + async def __aenter__(self: T) -> T: return self async def __aexit__(self, exc_type: Optional[Type[BaseException]], diff --git a/src/anyio/abc/_sockets.py b/src/anyio/abc/_sockets.py index ef23cdac..2e460fbb 100644 --- a/src/anyio/abc/_sockets.py +++ b/src/anyio/abc/_sockets.py @@ -2,8 +2,10 @@ from io import IOBase from ipaddress import IPv4Address, IPv6Address from socket import AddressFamily, SocketType +from types import TracebackType from typing import ( - Any, AsyncContextManager, Callable, Collection, List, Optional, Tuple, TypeVar, Union) + Any, AsyncContextManager, Callable, Collection, List, Mapping, Optional, Tuple, Type, TypeVar, + Union) from .._core._typedattr import TypedAttributeProvider, TypedAttributeSet, typed_attribute from ._streams import ByteStream, Listener, T_Stream, UnreliableObjectStream @@ -17,11 +19,13 @@ class _NullAsyncContextManager: - async def __aenter__(self): + async def __aenter__(self) -> None: pass - async def __aexit__(self, exc_type, exc_val, exc_tb): - pass + async def __aexit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: + return None class SocketAttribute(TypedAttributeSet): @@ -41,7 +45,7 @@ class SocketAttribute(TypedAttributeSet): class _SocketProvider(TypedAttributeProvider): @property - def extra_attributes(self): + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: from .._core._sockets import convert_ipv6_sockaddr as convert attributes = { @@ -50,7 +54,7 @@ def extra_attributes(self): SocketAttribute.raw_socket: lambda: self._raw_socket } try: - peername = convert(self._raw_socket.getpeername()) + peername: Optional[Tuple[str, int]] = convert(self._raw_socket.getpeername()) except OSError: peername = None @@ -62,7 +66,8 @@ def extra_attributes(self): if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6): attributes[SocketAttribute.local_port] = lambda: self._raw_socket.getsockname()[1] if peername is not None: - attributes[SocketAttribute.remote_port] = lambda: peername[1] + remote_port = peername[1] + attributes[SocketAttribute.remote_port] = lambda: remote_port return attributes diff --git a/src/anyio/abc/_streams.py b/src/anyio/abc/_streams.py index 6747f543..3f5e783d 100644 --- a/src/anyio/abc/_streams.py +++ b/src/anyio/abc/_streams.py @@ -21,7 +21,7 @@ class UnreliableObjectReceiveStream(Generic[T_Item], AsyncResource, TypedAttribu parameter. """ - def __aiter__(self): + def __aiter__(self) -> "UnreliableObjectReceiveStream[T_Item]": return self async def __anext__(self) -> T_Item: diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py index 3801d566..afa2d983 100644 --- a/src/anyio/abc/_tasks.py +++ b/src/anyio/abc/_tasks.py @@ -12,7 +12,7 @@ class TaskStatus(metaclass=ABCMeta): @abstractmethod - def started(self, value=None) -> None: + def started(self, value: object = None) -> None: """ Signal that the task has started. @@ -30,7 +30,8 @@ class TaskGroup(metaclass=ABCMeta): cancel_scope: 'CancelScope' - async def spawn(self, func: Callable[..., Coroutine], *args, name=None) -> None: + async def spawn(self, func: Callable[..., Coroutine], + *args: object, name: object = None) -> None: """ Start a new task in this task group. @@ -48,7 +49,8 @@ async def spawn(self, func: Callable[..., Coroutine], *args, name=None) -> None: self.start_soon(func, *args, name=name) @abstractmethod - def start_soon(self, func: Callable[..., Coroutine], *args, name=None) -> None: + def start_soon(self, func: Callable[..., Coroutine], + *args: object, name: object = None) -> None: """ Start a new task in this task group. @@ -60,7 +62,8 @@ def start_soon(self, func: Callable[..., Coroutine], *args, name=None) -> None: """ @abstractmethod - async def start(self, func: Callable[..., Coroutine], *args, name=None): + async def start(self, func: Callable[..., Coroutine], + *args: object, name: object = None) -> object: """ Start a new task and wait until it signals for readiness. diff --git a/src/anyio/abc/_testing.py b/src/anyio/abc/_testing.py index ace2d9b7..68aeb00e 100644 --- a/src/anyio/abc/_testing.py +++ b/src/anyio/abc/_testing.py @@ -1,5 +1,8 @@ +import types from abc import ABCMeta, abstractmethod -from typing import Any, Awaitable, Callable, Dict +from typing import Any, Awaitable, Callable, Dict, Optional, Type, TypeVar + +_T = TypeVar("_T") class TestRunner(metaclass=ABCMeta): @@ -11,15 +14,19 @@ class TestRunner(metaclass=ABCMeta): def __enter__(self) -> 'TestRunner': return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[types.TracebackType]) -> Optional[bool]: self.close() + return None @abstractmethod def close(self) -> None: """Close the event loop.""" @abstractmethod - def call(self, func: Callable[..., Awaitable], *args: tuple, **kwargs: Dict[str, Any]): + def call(self, func: Callable[..., Awaitable[_T]], + *args: tuple, **kwargs: Dict[str, Any]) -> _T: """ Call the given function within the backend's event loop. diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index 90d06cd6..29f7d03d 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -5,7 +5,7 @@ from types import TracebackType from typing import ( Any, AsyncContextManager, Callable, ContextManager, Coroutine, Dict, Generator, Iterable, - Optional, Tuple, Type, TypeVar, cast, overload) + Optional, Tuple, Type, TypeVar, Union, cast, overload) from warnings import warn from ._core import _eventloop @@ -17,7 +17,7 @@ T_co = TypeVar('T_co') -def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval: +def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval: """ Call a coroutine function from a worker thread. @@ -34,13 +34,14 @@ def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval: return asynclib.run_async_from_thread(func, *args) -def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval: +def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], + *args: object) -> T_Retval: warn('run_async_from_thread() has been deprecated, use anyio.from_thread.run() instead', DeprecationWarning) return run(func, *args) -def run_sync(func: Callable[..., T_Retval], *args) -> T_Retval: +def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval: """ Call a function in the event loop thread from a worker thread. @@ -57,7 +58,7 @@ def run_sync(func: Callable[..., T_Retval], *args) -> T_Retval: return asynclib.run_sync_from_thread(func, *args) -def run_sync_from_thread(func: Callable[..., T_Retval], *args) -> T_Retval: +def run_sync_from_thread(func: Callable[..., T_Retval], *args: object) -> T_Retval: warn('run_sync_from_thread() has been deprecated, use anyio.from_thread.run_sync() instead', DeprecationWarning) return run_sync(func, *args) @@ -74,7 +75,7 @@ def __init__(self, async_cm: AsyncContextManager[T_co], portal: 'BlockingPortal' self._async_cm = async_cm self._portal = portal - async def run_async_cm(self): + async def run_async_cm(self) -> Optional[bool]: try: self._exit_event = Event() value = await self._async_cm.__aenter__() @@ -105,18 +106,18 @@ class _BlockingPortalTaskStatus(TaskStatus): def __init__(self, future: Future): self._future = future - def started(self, value=None) -> None: + def started(self, value: object = None) -> None: self._future.set_result(value) class BlockingPortal: """An object that lets external threads run code in an asynchronous event loop.""" - def __new__(cls): + def __new__(cls) -> 'BlockingPortal': return get_asynclib().BlockingPortal() - def __init__(self): - self._event_loop_thread_id = threading.get_ident() + def __init__(self) -> None: + self._event_loop_thread_id: Optional[int] = threading.get_ident() self._stop_event = Event() self._task_group = create_task_group() self._cancelled_exc_class = get_cancelled_exc_class() @@ -125,7 +126,9 @@ async def __aenter__(self) -> 'BlockingPortal': await self._task_group.__aenter__() return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: await self.stop() return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) @@ -157,7 +160,7 @@ async def stop(self, cancel_remaining: bool = False) -> None: async def _call_func(self, func: Callable, args: tuple, kwargs: Dict[str, Any], future: Future) -> None: - def callback(f: Future): + def callback(f: Future) -> None: if f.cancelled(): self.call(scope.cancel) @@ -184,10 +187,10 @@ def callback(f: Future): if not future.cancelled(): future.set_result(retval) finally: - scope = None + scope = None # type: ignore[assignment] def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any], - name, future: Future) -> None: + name: object, future: Future) -> None: """ Spawn a new task using the given callable. @@ -204,14 +207,15 @@ def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, raise NotImplementedError @overload - def call(self, func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval: + def call(self, func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval: ... @overload - def call(self, func: Callable[..., T_Retval], *args) -> T_Retval: + def call(self, func: Callable[..., T_Retval], *args: object) -> T_Retval: ... - def call(self, func, *args): + def call(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], + *args: object) -> T_Retval: """ Call the given function in the event loop thread. @@ -224,7 +228,8 @@ def call(self, func, *args): """ return self.start_task_soon(func, *args).result() - def spawn_task(self, func: Callable[..., Coroutine], *args, name=None) -> Future: + def spawn_task(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], + *args: object, name: object = None) -> "Future[T_Retval]": """ Start a task in the portal's task group. @@ -245,7 +250,8 @@ def spawn_task(self, func: Callable[..., Coroutine], *args, name=None) -> Future warn('spawn_task() is deprecated -- use start_task_soon() instead', DeprecationWarning) return self.start_task_soon(func, *args, name=name) - def start_task_soon(self, func: Callable[..., Coroutine], *args, name=None) -> Future: + def start_task_soon(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], + *args: object, name: object = None) -> "Future[T_Retval]": """ Start a task in the portal's task group. @@ -268,7 +274,8 @@ def start_task_soon(self, func: Callable[..., Coroutine], *args, name=None) -> F self._spawn_task_from_thread(func, args, {}, name, f) return f - def start_task(self, func: Callable[..., Coroutine], *args, name=None) -> Tuple[Future, Any]: + def start_task(self, func: Callable[..., Coroutine], *args: object, + name: object = None) -> Tuple[Future, Any]: """ Start a task in the portal's task group and wait until it signals for readiness. @@ -350,7 +357,7 @@ def start_blocking_portal( Usage as a context manager is now required. """ - async def run_portal(): + async def run_portal() -> None: async with BlockingPortal() as portal_: if future.set_running_or_notify_cancel(): future.set_result(portal_) diff --git a/src/anyio/lowlevel.py b/src/anyio/lowlevel.py index 8a7d53fb..471b7e6b 100644 --- a/src/anyio/lowlevel.py +++ b/src/anyio/lowlevel.py @@ -1,8 +1,16 @@ -from typing import Any, Dict, Generic, Set, TypeVar, Union, cast +import enum +import sys +from dataclasses import dataclass +from typing import Any, Dict, Generic, Set, TypeVar, Union, overload from weakref import WeakKeyDictionary from ._core._eventloop import get_asynclib +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + T = TypeVar('T') D = TypeVar('D') @@ -22,7 +30,7 @@ async def checkpoint() -> None: await get_asynclib().checkpoint() -async def checkpoint_if_cancelled(): +async def checkpoint_if_cancelled() -> None: """ Enter a checkpoint if the enclosing cancel scope has been cancelled. @@ -58,25 +66,22 @@ def current_token() -> object: _token_wrappers: Dict[Any, '_TokenWrapper'] = {} +@dataclass(frozen=True) class _TokenWrapper: __slots__ = '_token', '__weakref__' + _token: object - def __init__(self, token): - self._token = token - - def __eq__(self, other): - return self._token is other._token - def __hash__(self): - return hash(self._token) +class _NoValueSet(enum.Enum): + NO_VALUE_SET = enum.auto() -class RunvarToken: +class RunvarToken(Generic[T]): __slots__ = '_var', '_value', '_redeemed' - def __init__(self, var: 'RunVar', value): + def __init__(self, var: 'RunVar', value: Union[T, Literal[_NoValueSet.NO_VALUE_SET]]): self._var = var - self._value = value + self._value: Union[T, Literal[_NoValueSet.NO_VALUE_SET]] = value self._redeemed = False @@ -84,11 +89,12 @@ class RunVar(Generic[T]): """Like a :class:`~contextvars.ContextVar`, expect scoped to the running event loop.""" __slots__ = '_name', '_default' - NO_VALUE_SET = object() + NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET _token_wrappers: Set[_TokenWrapper] = set() - def __init__(self, name: str, default: Union[T, object] = NO_VALUE_SET): + def __init__(self, name: str, + default: Union[T, Literal[_NoValueSet.NO_VALUE_SET]] = NO_VALUE_SET): self._name = name self._default = default @@ -108,31 +114,39 @@ def _current_vars(self) -> Dict[str, T]: run_vars = _run_vars[token] = {} return run_vars - def get(self, default: Union[T, object] = NO_VALUE_SET) -> T: + @overload + def get(self, default: D) -> Union[T, D]: ... + + @overload + def get(self) -> T: ... + + def get( + self, default: Union[D, Literal[_NoValueSet.NO_VALUE_SET]] = NO_VALUE_SET + ) -> Union[T, D]: try: return self._current_vars[self._name] except KeyError: if default is not RunVar.NO_VALUE_SET: - return cast(T, default) + return default elif self._default is not RunVar.NO_VALUE_SET: - return cast(T, self._default) + return self._default raise LookupError(f'Run variable "{self._name}" has no value and no default set') - def set(self, value: T) -> RunvarToken: + def set(self, value: T) -> RunvarToken[T]: current_vars = self._current_vars token = RunvarToken(self, current_vars.get(self._name, RunVar.NO_VALUE_SET)) current_vars[self._name] = value return token - def reset(self, token: RunvarToken) -> None: + def reset(self, token: RunvarToken[T]) -> None: if token._var is not self: raise ValueError('This token does not belong to this RunVar') if token._redeemed: raise ValueError('This token has already been used') - if token._value is RunVar.NO_VALUE_SET: + if token._value is _NoValueSet.NO_VALUE_SET: try: del self._current_vars[self._name] except KeyError: @@ -142,5 +156,5 @@ def reset(self, token: RunvarToken) -> None: token._redeemed = True - def __repr__(self): + def __repr__(self) -> str: return f'' diff --git a/src/anyio/pytest_plugin.py b/src/anyio/pytest_plugin.py index 1db51764..0e99a456 100644 --- a/src/anyio/pytest_plugin.py +++ b/src/anyio/pytest_plugin.py @@ -1,7 +1,7 @@ import sys from contextlib import contextmanager from inspect import iscoroutinefunction -from typing import Any, Dict, Iterator, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Tuple, cast import pytest import sniffio @@ -14,10 +14,13 @@ else: from async_generator import isasyncgenfunction +if TYPE_CHECKING: + from _pytest.config import Config + _current_runner: Optional[TestRunner] = None -def extract_backend_and_options(backend) -> Tuple[str, Dict[str, Any]]: +def extract_backend_and_options(backend: object) -> Tuple[str, Dict[str, Any]]: if isinstance(backend, str): return backend, {} elif isinstance(backend, tuple) and len(backend) == 2: @@ -51,13 +54,13 @@ def get_runner(backend_name: str, backend_options: Dict[str, Any]) -> Iterator[T sniffio.current_async_library_cvar.reset(token) -def pytest_configure(config): +def pytest_configure(config: "Config") -> None: config.addinivalue_line('markers', 'anyio: mark the (coroutine function) test to be run ' 'asynchronously via anyio.') -def pytest_fixture_setup(fixturedef, request): - def wrapper(*args, anyio_backend, **kwargs): +def pytest_fixture_setup(fixturedef: Any, request: Any) -> None: + def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def] backend_name, backend_options = extract_backend_and_options(anyio_backend) if has_backend_arg: kwargs['anyio_backend'] = anyio_backend @@ -94,7 +97,7 @@ def wrapper(*args, anyio_backend, **kwargs): @pytest.hookimpl(tryfirst=True) -def pytest_pycollect_makeitem(collector, name, obj): +def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None: if collector.istestfunction(obj, name): inner_func = obj.hypothesis.inner_test if hasattr(obj, 'hypothesis') else obj if iscoroutinefunction(inner_func): @@ -105,8 +108,8 @@ def pytest_pycollect_makeitem(collector, name, obj): @pytest.hookimpl(tryfirst=True) -def pytest_pyfunc_call(pyfuncitem): - def run_with_hypothesis(**kwargs): +def pytest_pyfunc_call(pyfuncitem: Any) -> Optional[bool]: + def run_with_hypothesis(**kwargs: Any) -> None: with get_runner(backend_name, backend_options) as runner: runner.call(original_func, **kwargs) @@ -131,14 +134,16 @@ def run_with_hypothesis(**kwargs): return True + return None + @pytest.fixture(params=get_all_backends()) -def anyio_backend(request): +def anyio_backend(request: Any) -> Any: return request.param @pytest.fixture -def anyio_backend_name(anyio_backend) -> str: +def anyio_backend_name(anyio_backend: Any) -> str: if isinstance(anyio_backend, str): return anyio_backend else: @@ -146,7 +151,7 @@ def anyio_backend_name(anyio_backend) -> str: @pytest.fixture -def anyio_backend_options(anyio_backend) -> Dict[str, Any]: +def anyio_backend_options(anyio_backend: Any) -> Dict[str, Any]: if isinstance(anyio_backend, str): return {} else: diff --git a/src/anyio/streams/buffered.py b/src/anyio/streams/buffered.py index bed83246..53ef5176 100644 --- a/src/anyio/streams/buffered.py +++ b/src/anyio/streams/buffered.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Any, Callable, Mapping from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead from ..abc import AnyByteReceiveStream, ByteReceiveStream @@ -25,8 +26,8 @@ def buffer(self) -> bytes: return bytes(self._buffer) @property - def extra_attributes(self): - return self.receive_stream.extra_attributes + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return self.receive_stream.extra_attributes # type: ignore[return-value] async def receive(self, max_bytes: int = 65536) -> bytes: if self._closed: diff --git a/src/anyio/streams/file.py b/src/anyio/streams/file.py index 12442ad3..d3e22bae 100644 --- a/src/anyio/streams/file.py +++ b/src/anyio/streams/file.py @@ -1,6 +1,6 @@ from io import SEEK_SET, UnsupportedOperation from pathlib import Path -from typing import BinaryIO, Union, cast +from typing import Any, BinaryIO, Callable, Dict, Mapping, Union, cast from .. import ( BrokenResourceError, ClosedResourceError, EndOfStream, TypedAttributeSet, to_thread, @@ -25,8 +25,8 @@ async def aclose(self) -> None: await to_thread.run_sync(self._file.close) @property - def extra_attributes(self): - attributes = { + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + attributes: Dict[Any, Callable[[], Any]] = { FileStreamAttribute.file: lambda: self._file, } diff --git a/src/anyio/streams/memory.py b/src/anyio/streams/memory.py index 0d04ab46..47ff1f5d 100644 --- a/src/anyio/streams/memory.py +++ b/src/anyio/streams/memory.py @@ -41,7 +41,7 @@ class MemoryObjectReceiveStream(Generic[T_Item], ObjectReceiveStream[T_Item]): _state: MemoryObjectStreamState[T_Item] _closed: bool = field(init=False, default=False) - def __post_init__(self): + def __post_init__(self) -> None: self._state.open_receive_channels += 1 def receive_nowait(self) -> T_Item: @@ -135,7 +135,7 @@ class MemoryObjectSendStream(Generic[T_Item], ObjectSendStream[T_Item]): _state: MemoryObjectStreamState[T_Item] _closed: bool = field(init=False, default=False) - def __post_init__(self): + def __post_init__(self) -> None: self._state.open_send_channels += 1 def send_nowait(self, item: T_Item) -> DeprecatedAwaitable: diff --git a/src/anyio/streams/stapled.py b/src/anyio/streams/stapled.py index 372cac23..0d5e7fb2 100644 --- a/src/anyio/streams/stapled.py +++ b/src/anyio/streams/stapled.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Callable, Generic, Optional, Sequence, TypeVar +from typing import Any, Callable, Generic, List, Mapping, Optional, Sequence, TypeVar from ..abc import ( ByteReceiveStream, ByteSendStream, ByteStream, Listener, ObjectReceiveStream, ObjectSendStream, @@ -38,8 +38,8 @@ async def aclose(self) -> None: await self.receive_stream.aclose() @property - def extra_attributes(self): - return dict(**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes) + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return {**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes} @dataclass(eq=False) @@ -71,8 +71,8 @@ async def aclose(self) -> None: await self.receive_stream.aclose() @property - def extra_attributes(self): - return dict(**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes) + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return {**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes} @dataclass(eq=False) @@ -92,12 +92,12 @@ class MultiListener(Generic[T_Stream], Listener[T_Stream]): listeners: Sequence[Listener[T_Stream]] - def __post_init__(self): - listeners = [] + def __post_init__(self) -> None: + listeners: List[Listener[T_Stream]] = [] for listener in self.listeners: if isinstance(listener, MultiListener): listeners.extend(listener.listeners) - del listener.listeners[:] + del listener.listeners[:] # type: ignore[attr-defined] else: listeners.append(listener) @@ -116,8 +116,8 @@ async def aclose(self) -> None: await listener.aclose() @property - def extra_attributes(self): - attributes = {} + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + attributes: dict = {} for listener in self.listeners: attributes.update(listener.extra_attributes) diff --git a/src/anyio/streams/text.py b/src/anyio/streams/text.py index 351f4a45..0cdfe227 100644 --- a/src/anyio/streams/text.py +++ b/src/anyio/streams/text.py @@ -1,6 +1,6 @@ import codecs from dataclasses import InitVar, dataclass, field -from typing import Callable, Tuple +from typing import Any, Callable, Mapping, Tuple from ..abc import ( AnyByteReceiveStream, AnyByteSendStream, AnyByteStream, ObjectReceiveStream, ObjectSendStream, @@ -29,7 +29,7 @@ class TextReceiveStream(ObjectReceiveStream[str]): errors: InitVar[str] = 'strict' _decoder: codecs.IncrementalDecoder = field(init=False) - def __post_init__(self, encoding, errors): + def __post_init__(self, encoding: str, errors: str) -> None: decoder_class = codecs.getincrementaldecoder(encoding) self._decoder = decoder_class(errors=errors) @@ -45,8 +45,8 @@ async def aclose(self) -> None: self._decoder.reset() @property - def extra_attributes(self): - return self.transport_stream.extra_attributes + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return self.transport_stream.extra_attributes # type: ignore[return-value] @dataclass(eq=False) @@ -68,8 +68,8 @@ class TextSendStream(ObjectSendStream[str]): errors: str = 'strict' _encoder: Callable[..., Tuple[bytes, int]] = field(init=False) - def __post_init__(self, encoding): - self._encoder = codecs.getencoder(encoding) + def __post_init__(self, encoding: str) -> None: + self._encoder = codecs.getencoder(encoding) # type: ignore[assignment] async def send(self, item: str) -> None: encoded = self._encoder(item, self.errors)[0] @@ -79,8 +79,8 @@ async def aclose(self) -> None: await self.transport_stream.aclose() @property - def extra_attributes(self): - return self.transport_stream.extra_attributes + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return self.transport_stream.extra_attributes # type: ignore[return-value] @dataclass(eq=False) @@ -107,7 +107,7 @@ class TextStream(ObjectStream[str]): _receive_stream: TextReceiveStream = field(init=False) _send_stream: TextSendStream = field(init=False) - def __post_init__(self, encoding, errors): + def __post_init__(self, encoding: str, errors: str) -> None: self._receive_stream = TextReceiveStream(self.transport_stream, encoding=encoding, errors=errors) self._send_stream = TextSendStream(self.transport_stream, encoding=encoding, errors=errors) @@ -126,5 +126,5 @@ async def aclose(self) -> None: await self._receive_stream.aclose() @property - def extra_attributes(self): - return dict(**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes) + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return {**self._send_stream.extra_attributes, **self._receive_stream.extra_attributes} diff --git a/src/anyio/streams/tls.py b/src/anyio/streams/tls.py index c42254a4..1d84c05c 100644 --- a/src/anyio/streams/tls.py +++ b/src/anyio/streams/tls.py @@ -3,7 +3,7 @@ import ssl from dataclasses import dataclass from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, TypeVar, Union from .. import BrokenResourceError, EndOfStream, aclose_forcefully, get_cancelled_exc_class from .._core._typedattr import TypedAttributeSet, typed_attribute @@ -94,7 +94,9 @@ async def wrap(cls, transport_stream: AnyByteStream, *, server_side: Optional[bo await wrapper._call_sslobject_method(ssl_object.do_handshake) return wrapper - async def _call_sslobject_method(self, func: Callable[..., T_Retval], *args) -> T_Retval: + async def _call_sslobject_method( + self, func: Callable[..., T_Retval], *args: object + ) -> T_Retval: while True: try: result = func(*args) @@ -168,7 +170,7 @@ async def send_eof(self) -> None: raise NotImplementedError('send_eof() has not yet been implemented for TLS streams') @property - def extra_attributes(self): + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return { **self.transport_stream.extra_attributes, TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, @@ -236,7 +238,7 @@ async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> N async def serve(self, handler: Callable[[TLSStream], Any], task_group: Optional[TaskGroup] = None) -> None: @wraps(handler) - async def handler_wrapper(stream: AnyByteStream): + async def handler_wrapper(stream: AnyByteStream) -> None: from .. import fail_after try: with fail_after(self.handshake_timeout): @@ -254,7 +256,7 @@ async def aclose(self) -> None: await self.listener.aclose() @property - def extra_attributes(self): + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return { TLSAttribute.standard_compatible: lambda: self.standard_compatible, } diff --git a/src/anyio/to_process.py b/src/anyio/to_process.py index 8c18cd79..5675e9eb 100644 --- a/src/anyio/to_process.py +++ b/src/anyio/to_process.py @@ -26,7 +26,7 @@ async def run_sync( - func: Callable[..., T_Retval], *args, cancellable: bool = False, + func: Callable[..., T_Retval], *args: object, cancellable: bool = False, limiter: Optional[CapacityLimiter] = None) -> T_Retval: """ Call the given function with the given arguments in a worker process. @@ -43,7 +43,7 @@ async def run_sync( :return: an awaitable that yields the return value of the function. """ - async def send_raw_command(pickled_cmd: bytes): + async def send_raw_command(pickled_cmd: bytes) -> object: try: await stdin.send(pickled_cmd) response = await buffered.receive_until(b'\n', 50) @@ -145,7 +145,7 @@ async def send_raw_command(pickled_cmd: bytes): with CancelScope(shield=not cancellable): try: - return await send_raw_command(request) + return cast(T_Retval, await send_raw_command(request)) finally: if process in workers: idle_workers.append((process, current_time())) diff --git a/src/anyio/to_thread.py b/src/anyio/to_thread.py index e0428777..5fc95894 100644 --- a/src/anyio/to_thread.py +++ b/src/anyio/to_thread.py @@ -8,7 +8,7 @@ async def run_sync( - func: Callable[..., T_Retval], *args, cancellable: bool = False, + func: Callable[..., T_Retval], *args: object, cancellable: bool = False, limiter: Optional[CapacityLimiter] = None) -> T_Retval: """ Call the given function with the given arguments in a worker thread. @@ -30,7 +30,7 @@ async def run_sync( async def run_sync_in_worker_thread( - func: Callable[..., T_Retval], *args, cancellable: bool = False, + func: Callable[..., T_Retval], *args: object, cancellable: bool = False, limiter: Optional[CapacityLimiter] = None) -> T_Retval: warn('run_sync_in_worker_thread() has been deprecated, use anyio.to_thread.run_sync() instead', DeprecationWarning) diff --git a/tests/streams/test_text.py b/tests/streams/test_text.py index 47ba8f9a..16374dce 100644 --- a/tests/streams/test_text.py +++ b/tests/streams/test_text.py @@ -55,3 +55,4 @@ async def test_bidirectional_stream(): await send_stream.send(b'\xc3\xa6\xc3\xb8') assert await text_stream.receive() == 'æø' + assert text_stream.extra_attributes == {} diff --git a/tests/test_compat.py b/tests/test_compat.py index f662e826..007e541f 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -35,6 +35,10 @@ async def test_get_running_tasks(self): tasks = await maybe_async(get_running_tasks()) assert type(tasks) is list + async def test_get_current_task(self): + task = await maybe_async(get_current_task()) + assert type(task) is TaskInfo + async def test_maybe_async_cm(): async with maybe_async_cm(CancelScope()): diff --git a/tests/test_debugging.py b/tests/test_debugging.py index e7636096..7d982270 100644 --- a/tests/test_debugging.py +++ b/tests/test_debugging.py @@ -29,6 +29,25 @@ async def main(): loop.close() +@pytest.mark.parametrize( + "name_input,expected", + [ + (None, 'test_debugging.test_non_main_task_name..non_main'), + (b'name', "b'name'"), + ("name", "name"), + ("", ""), + ], +) +async def test_non_main_task_name(name_input, expected): + async def non_main(*, task_status): + task_status.started(anyio.get_current_task().name) + + async with anyio.create_task_group() as tg: + name = await tg.start(non_main, name=name_input) + + assert name == expected + + async def test_get_running_tasks(): async def inspect(): await wait_all_tasks_blocked()