From d2f23ce4182be09e5e7c9d16938fd4e6012e778b Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 00:04:32 +0100 Subject: [PATCH 01/31] disallow untyped defs --- setup.cfg | 1 + src/anyio/_backends/_asyncio.py | 109 ++++++++++++++-------------- src/anyio/_backends/_trio.py | 99 +++++++++++++------------ src/anyio/_core/_compat.py | 45 +++++++----- src/anyio/_core/_eventloop.py | 6 +- src/anyio/_core/_exceptions.py | 4 +- src/anyio/_core/_fileio.py | 16 ++-- src/anyio/_core/_sockets.py | 37 +++++++--- src/anyio/_core/_streams.py | 4 +- src/anyio/_core/_subprocesses.py | 6 +- src/anyio/_core/_synchronization.py | 29 ++++---- src/anyio/_core/_tasks.py | 12 +-- src/anyio/_core/_testing.py | 32 +++++--- src/anyio/_core/_typedattr.py | 6 +- src/anyio/abc/_resources.py | 6 +- src/anyio/abc/_sockets.py | 18 +++-- src/anyio/abc/_streams.py | 2 +- src/anyio/abc/_tasks.py | 8 +- src/anyio/abc/_testing.py | 10 ++- src/anyio/from_thread.py | 44 +++++------ src/anyio/lowlevel.py | 35 +++++---- src/anyio/pytest_plugin.py | 27 ++++--- src/anyio/streams/buffered.py | 4 +- src/anyio/streams/file.py | 4 +- src/anyio/streams/memory.py | 4 +- src/anyio/streams/stapled.py | 20 ++--- src/anyio/streams/text.py | 20 ++--- src/anyio/streams/tls.py | 8 +- src/anyio/to_process.py | 4 +- src/anyio/to_thread.py | 4 +- 30 files changed, 343 insertions(+), 281 deletions(-) diff --git a/setup.cfg b/setup.cfg index 353182af..931b930b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -60,3 +60,4 @@ pytest11 = [mypy] ignore_missing_imports = true +disallow_untyped_defs = true diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 5bcc6514..93df4e56 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -213,11 +213,11 @@ def _maybe_set_event_loop_policy(policy: Optional[asyncio.AbstractEventLoopPolic asyncio.set_event_loop_policy(policy) -def run(func: Callable[..., T_Retval], *args, debug: bool = False, use_uvloop: bool = True, +def run(func: Callable[..., Awaitable[T_Retval]], *args: object, debug: bool = False, use_uvloop: bool = True, policy: Optional[asyncio.AbstractEventLoopPolicy] = None) -> T_Retval: @wraps(func) - async def wrapper(): - task = current_task() + async def wrapper() -> T_Retval: + task = cast(asyncio.Task, current_task()) task_state = TaskState(None, get_callable_name(func), None) _task_states[task] = task_state if _native_task_names: @@ -247,7 +247,7 @@ async def wrapper(): class CancelScope(BaseCancelScope): - def __new__(cls, *, deadline: float = math.inf, shield: bool = False): + def __new__(cls, *, deadline: float = math.inf, shield: bool = False) -> "CancelScope": return object.__new__(cls) def __init__(self, deadline: float = math.inf, shield: bool = False): @@ -262,20 +262,20 @@ def __init__(self, deadline: float = math.inf, shield: bool = False): self._host_task: Optional[asyncio.Task] = None self._timeout_expired = False - def __enter__(self): + def __enter__(self) -> "CancelScope": if self._active: raise RuntimeError( "Each CancelScope may only be used for a single 'with' block" ) - self._host_task = current_task() - self._tasks.add(self._host_task) + self._host_task = host_task = cast(asyncio.Task, current_task()) + self._tasks.add(host_task) try: - task_state = _task_states[self._host_task] + task_state = _task_states[host_task] except KeyError: - task_name = self._host_task.get_name() if _native_task_names else None + task_name = host_task.get_name() if _native_task_names else None task_state = TaskState(None, task_name, self) - _task_states[self._host_task] = task_state + _task_states[host_task] = task_state else: self._parent_scope = task_state.cancel_scope task_state.cancel_scope = self @@ -326,7 +326,7 @@ def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[Ba return None - def _timeout(self): + def _timeout(self) -> None: if self._deadline != math.inf: loop = get_running_loop() if loop.time() >= self._deadline: @@ -460,9 +460,9 @@ async def cancel_shielded_checkpoint() -> None: await sleep(0) -def current_effective_deadline(): +def current_effective_deadline() -> float: try: - cancel_scope = _task_states[current_task()].cancel_scope + cancel_scope = _task_states[current_task()].cancel_scope # type: ignore[index] except KeyError: return math.inf @@ -477,7 +477,7 @@ def current_effective_deadline(): return deadline -def current_time(): +def current_time() -> float: return get_running_loop().time() @@ -517,17 +517,17 @@ class _AsyncioTaskStatus(abc.TaskStatus): def __init__(self, future: asyncio.Future): self._future = future - def started(self, value=None) -> None: + def started(self, value: object = None) -> None: self._future.set_result(value) class TaskGroup(abc.TaskGroup): - def __init__(self): + def __init__(self) -> None: self.cancel_scope: CancelScope = CancelScope() self._active = False self._exceptions: List[BaseException] = [] - async def __aenter__(self): + async def __aenter__(self) -> "TaskGroup": self.cancel_scope.__enter__() self._active = True return self @@ -613,7 +613,7 @@ async def _run_wrapped_task( self.cancel_scope._tasks.remove(task) del _task_states[task] - def _spawn(self, func: Callable[..., Coroutine], args: tuple, name, + def _spawn(self, func: Callable[..., Coroutine], args: tuple, name: Optional[str], task_status_future: Optional[asyncio.Future] = None) -> asyncio.Task: def task_done(_task: asyncio.Task) -> None: # This is the code path for Python 3.8+ @@ -669,10 +669,10 @@ def task_done(_task: asyncio.Task) -> None: self.cancel_scope._tasks.add(task) return task - def start_soon(self, func: Callable[..., Coroutine], *args, name=None) -> None: + def start_soon(self, func: Callable[..., Coroutine], *args: object, name: str = None) -> None: self._spawn(func, args, name) - async def start(self, func: Callable[..., Coroutine], *args, name=None) -> None: + async def start(self, func: Callable[..., Coroutine], *args: object, name: str = None) -> None: future: asyncio.Future = asyncio.Future() task = self._spawn(func, args, name, future) @@ -745,7 +745,7 @@ def stop(self, f: Optional[asyncio.Task] = None) -> None: async def run_sync_in_worker_thread( - func: Callable[..., T_Retval], *args, cancellable: bool = False, + func: Callable[..., T_Retval], *args: object, cancellable: bool = False, limiter: Optional['CapacityLimiter'] = None) -> T_Retval: await checkpoint() @@ -789,10 +789,10 @@ async def run_sync_in_worker_thread( idle_workers.append(worker) -def run_sync_from_thread(func: Callable[..., T_Retval], *args, +def run_sync_from_thread(func: Callable[..., T_Retval], *args: object, loop: Optional[asyncio.AbstractEventLoop] = None) -> T_Retval: @wraps(func) - def wrapper(): + def wrapper() -> None: try: f.set_result(func(*args)) except BaseException as exc: @@ -806,22 +806,22 @@ def wrapper(): return f.result() -def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval: +def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval: f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe( func(*args), threadlocals.loop) return f.result() class BlockingPortal(abc.BlockingPortal): - def __new__(cls): + def __new__(cls) -> "BlockingPortal": return object.__new__(cls) - def __init__(self): + def __init__(self) -> None: super().__init__() self._loop = get_running_loop() def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any], - name, future: Future) -> None: + name: Optional[str], future: Future) -> None: run_sync_from_thread( partial(self._task_group.start_soon, name=name), self._call_func, func, args, kwargs, future, loop=self._loop) @@ -908,12 +908,12 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]: return self._stderr -async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int, +async def open_process(command: Union[str, Sequence[str]], *, shell: bool, stdin: int, stdout: int, stderr: int, cwd: Union[str, bytes, PathLike, None] = None, env: Optional[Mapping[str, str]] = None) -> Process: await checkpoint() if shell: - process = await asyncio.create_subprocess_shell(command, stdin=stdin, stdout=stdout, + process = await asyncio.create_subprocess_shell(command, stdin=stdin, stdout=stdout, # type: ignore[arg-type] stderr=stderr, cwd=cwd, env=env) else: process = await asyncio.create_subprocess_exec(*command, stdin=stdin, stdout=stdout, @@ -925,7 +925,7 @@ async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: return Process(process, stdin_stream, stdout_stream, stderr_stream) -def _forcibly_shutdown_process_pool_on_exit(workers: Set[Process], _task) -> None: +def _forcibly_shutdown_process_pool_on_exit(workers: Set[Process], _task: object) -> None: """ Forcibly shuts down worker processes belonging to this event loop.""" child_watcher: Optional[asyncio.AbstractChildWatcher] @@ -1142,7 +1142,7 @@ def _raw_socket(self) -> SocketType: return self.__raw_socket def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: - def callback(f): + def callback(f: object) -> None: del self._receive_future loop.remove_reader(self.__raw_socket) @@ -1152,7 +1152,7 @@ def callback(f): return f def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: - def callback(f): + def callback(f: object) -> None: del self._send_future loop.remove_writer(self.__raw_socket) @@ -1594,10 +1594,10 @@ async def wait_socket_writable(sock: socket.SocketType) -> None: # class Event(BaseEvent): - def __new__(cls): + def __new__(cls) -> "Event": return object.__new__(cls) - def __init__(self): + def __init__(self) -> None: self._event = asyncio.Event() def set(self) -> DeprecatedAwaitable: @@ -1607,18 +1607,18 @@ def set(self) -> DeprecatedAwaitable: def is_set(self) -> bool: return self._event.is_set() - async def wait(self): + async def wait(self) -> None: if await self._event.wait(): await checkpoint() def statistics(self) -> EventStatistics: - return EventStatistics(len(self._event._waiters)) + return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined] class CapacityLimiter(BaseCapacityLimiter): _total_tokens: float = 0 - def __new__(cls, total_tokens: float): + def __new__(cls, total_tokens: float) -> "CapacityLimiter": return object.__new__(cls) def __init__(self, total_tokens: float): @@ -1626,7 +1626,7 @@ def __init__(self, total_tokens: float): self._wait_queue: Dict[Any, asyncio.Event] = OrderedDict() self.total_tokens = total_tokens - async def __aenter__(self): + async def __aenter__(self) -> None: await self.acquire() async def __aexit__(self, exc_type: Optional[Type[BaseException]], @@ -1671,7 +1671,7 @@ def acquire_nowait(self) -> DeprecatedAwaitable: self.acquire_on_behalf_of_nowait(current_task()) return DeprecatedAwaitable(self.acquire_nowait) - def acquire_on_behalf_of_nowait(self, borrower) -> DeprecatedAwaitable: + def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: if borrower in self._borrowers: raise RuntimeError("this borrower is already holding one of this CapacityLimiter's " "tokens") @@ -1685,7 +1685,7 @@ def acquire_on_behalf_of_nowait(self, borrower) -> DeprecatedAwaitable: async def acquire(self) -> None: return await self.acquire_on_behalf_of(current_task()) - async def acquire_on_behalf_of(self, borrower) -> None: + async def acquire_on_behalf_of(self, borrower: object) -> None: await checkpoint_if_cancelled() try: self.acquire_on_behalf_of_nowait(borrower) @@ -1705,7 +1705,7 @@ async def acquire_on_behalf_of(self, borrower) -> None: def release(self) -> None: self.release_on_behalf_of(current_task()) - def release_on_behalf_of(self, borrower) -> None: + def release_on_behalf_of(self, borrower: object) -> None: try: self._borrowers.remove(borrower) except KeyError: @@ -1725,7 +1725,7 @@ def statistics(self) -> CapacityLimiterStatistics: _default_thread_limiter: RunVar[CapacityLimiter] = RunVar('_default_thread_limiter') -def current_default_thread_limiter(): +def current_default_thread_limiter() -> CapacityLimiter: try: return _default_thread_limiter.get() except LookupError: @@ -1751,18 +1751,19 @@ def _deliver(self, signum: int) -> None: if not self._future.done(): self._future.set_result(None) - def __enter__(self): + def __enter__(self) -> "_SignalReceiver": for sig in set(self._signals): self._loop.add_signal_handler(sig, self._deliver, sig) self._handled_signals.add(sig) return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> Optional[bool]: for sig in self._handled_signals: self._loop.remove_signal_handler(sig) + return None - def __aiter__(self): + def __aiter__(self) -> "_SignalReceiver": return self async def __anext__(self) -> int: @@ -1825,7 +1826,7 @@ def __init__(self, debug: bool = False, use_uvloop: bool = True, self._loop.set_debug(debug) asyncio.set_event_loop(self._loop) - def _cancel_all_tasks(self): + def _cancel_all_tasks(self) -> None: to_cancel = all_tasks(self._loop) if not to_cancel: return @@ -1840,7 +1841,7 @@ def _cancel_all_tasks(self): if task.cancelled(): continue if task.exception() is not None: - raise task.exception() + raise cast(BaseException, task.exception()) def close(self) -> None: try: @@ -1850,23 +1851,23 @@ def close(self) -> None: asyncio.set_event_loop(None) self._loop.close() - def call(self, func: Callable[..., Awaitable], *args, **kwargs): + def call(self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object) -> T_Retval: def exception_handler(loop: asyncio.AbstractEventLoop, context: Dict[str, Any]) -> None: exceptions.append(context['exception']) exceptions: List[Exception] = [] self._loop.set_exception_handler(exception_handler) try: - retval = self._loop.run_until_complete(func(*args, **kwargs)) + retval: T_Retval = self._loop.run_until_complete(func(*args, **kwargs)) except Exception as exc: - retval = None + retval = None # type: ignore[assignment] exceptions.append(exc) finally: self._loop.set_exception_handler(None) - if len(exceptions) == 1: - raise exceptions[0] - elif exceptions: - raise ExceptionGroup(exceptions) + if len(exceptions) == 1: + raise exceptions[0] + elif exceptions: + raise ExceptionGroup(exceptions) return retval diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 48be7a51..8cad535f 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -8,11 +8,11 @@ from os import PathLike from types import TracebackType from typing import ( - Any, Awaitable, Callable, Collection, Coroutine, Dict, Generic, List, Mapping, NoReturn, - Optional, Set, Tuple, Type, TypeVar, Union) + TYPE_CHECKING, Any, Awaitable, Callable, Collection, ContextManager, Coroutine, Dict, Generic, + List, Mapping, NoReturn, Optional, Sequence, Set, Tuple, Type, TypeVar, Union) import trio.from_thread -from outcome import Error, Value +from outcome import Error, Outcome, Value from trio.to_thread import run_sync from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc @@ -36,6 +36,11 @@ else: from trio.lowlevel import wait_readable, wait_writable +if TYPE_CHECKING: + from trio.socket import SocketType as TrioSocketType +else: + TrioSocketType = object + T_Retval = TypeVar('T_Retval') T_SockAddr = TypeVar('T_SockAddr', str, IPSockAddrType) @@ -61,17 +66,17 @@ # class CancelScope(BaseCancelScope): - def __new__(cls, original: Optional[trio.CancelScope] = None, **kwargs): + def __new__(cls, original: Optional[trio.CancelScope] = None, **kwargs: object) -> 'CancelScope': return object.__new__(cls) - def __init__(self, original: Optional[trio.CancelScope] = None, **kwargs): + def __init__(self, original: Optional[trio.CancelScope] = None, **kwargs: object) -> None: self.__original = original or trio.CancelScope(**kwargs) - def __enter__(self): + def __enter__(self) -> 'CancelScope': self.__original.__enter__() return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> Optional[bool]: return self.__original.__exit__(exc_type, exc_val, exc_tb) def cancel(self) -> DeprecatedAwaitable: @@ -116,12 +121,12 @@ class ExceptionGroup(BaseExceptionGroup, trio.MultiError): class TaskGroup(abc.TaskGroup): - def __init__(self): + def __init__(self) -> None: self._active = False self._nursery_manager = trio.open_nursery() - self.cancel_scope = None + self.cancel_scope = None # type: ignore[assignment] - async def __aenter__(self): + async def __aenter__(self) -> 'TaskGroup': self._active = True self._nursery = await self._nursery_manager.__aenter__() self.cancel_scope = CancelScope(self._nursery.cancel_scope) @@ -137,13 +142,13 @@ async def __aexit__(self, exc_type: Optional[Type[BaseException]], finally: self._active = False - def start_soon(self, func: Callable, *args, name=None) -> None: + def start_soon(self, func: Callable, *args: object, name: str = None) -> None: if not self._active: raise RuntimeError('This task group is not active; no new tasks can be started.') self._nursery.start_soon(func, *args, name=name) - async def start(self, func: Callable[..., Coroutine], *args, name=None): + async def start(self, func: Callable[..., Coroutine], *args: object, name: str = None) -> object: if not self._active: raise RuntimeError('This task group is not active; no new tasks can be started.') @@ -155,9 +160,9 @@ async def start(self, func: Callable[..., Coroutine], *args, name=None): async def run_sync_in_worker_thread( - func: Callable[..., T_Retval], *args, cancellable: bool = False, + func: Callable[..., T_Retval], *args: object, cancellable: bool = False, limiter: Optional[trio.CapacityLimiter] = None) -> T_Retval: - def wrapper(): + def wrapper() -> T_Retval: with claim_worker_thread('trio'): return func(*args) @@ -168,15 +173,15 @@ def wrapper(): class BlockingPortal(abc.BlockingPortal): - def __new__(cls): + def __new__(cls) -> 'BlockingPortal': return object.__new__(cls) - def __init__(self): + def __init__(self) -> None: super().__init__() self._token = trio.lowlevel.current_trio_token() def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any], - name, future: Future) -> None: + name: Optional[str], future: Future) -> None: return trio.from_thread.run_sync( partial(self._task_group.start_soon, name=name), self._call_func, func, args, kwargs, future, trio_token=self._token) @@ -273,7 +278,7 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]: return self._stderr -async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: int, +async def open_process(command: Union[str, Sequence[str]], *, shell: bool, stdin: int, stdout: int, stderr: int, cwd: Union[str, bytes, PathLike, None] = None, env: Optional[Mapping[str, str]] = None) -> Process: process = await trio.open_process(command, stdin=stdin, stdout=stdout, stderr=stderr, @@ -285,7 +290,7 @@ async def open_process(command, *, shell: bool, stdin: int, stdout: int, stderr: class _ProcessPoolShutdownInstrument(trio.abc.Instrument): - def after_run(self): + def after_run(self) -> None: super().after_run() @@ -316,7 +321,7 @@ def setup_process_pool_exit_at_shutdown(workers: Set[Process]) -> None: # class _TrioSocketMixin(Generic[T_SockAddr]): - def __init__(self, trio_socket): + def __init__(self, trio_socket: TrioSocketType) -> None: self._trio_socket = trio_socket self._closed = False @@ -347,7 +352,7 @@ def _convert_socket_error(self, exc: BaseException) -> 'NoReturn': class SocketStream(_TrioSocketMixin, abc.SocketStream): - def __init__(self, trio_socket): + def __init__(self, trio_socket: TrioSocketType) -> None: super().__init__(trio_socket) self._receive_guard = ResourceGuard('reading from') self._send_guard = ResourceGuard('writing to') @@ -467,7 +472,7 @@ async def accept(self) -> UNIXSocketStream: class UDPSocket(_TrioSocketMixin[IPSockAddrType], abc.UDPSocket): - def __init__(self, trio_socket): + def __init__(self, trio_socket: TrioSocketType) -> None: super().__init__(trio_socket) self._receive_guard = ResourceGuard('reading from') self._send_guard = ResourceGuard('writing to') @@ -489,7 +494,7 @@ async def send(self, item: UDPPacketType) -> None: class ConnectedUDPSocket(_TrioSocketMixin[IPSockAddrType], abc.ConnectedUDPSocket): - def __init__(self, trio_socket): + def __init__(self, trio_socket: TrioSocketType) -> None: super().__init__(trio_socket) self._receive_guard = ResourceGuard('reading from') self._send_guard = ResourceGuard('writing to') @@ -562,7 +567,7 @@ async def create_udp_socket( getnameinfo = trio.socket.getnameinfo -async def wait_socket_readable(sock): +async def wait_socket_readable(sock: socket.SocketType) -> None: try: await wait_readable(sock) except trio.ClosedResourceError as exc: @@ -571,7 +576,7 @@ async def wait_socket_readable(sock): raise BusyResourceError('reading from') from None -async def wait_socket_writable(sock): +async def wait_socket_writable(sock: socket.SocketType) -> None: try: await wait_writable(sock) except trio.ClosedResourceError as exc: @@ -585,34 +590,34 @@ async def wait_socket_writable(sock): # class Event(BaseEvent): - def __new__(cls): + def __new__(cls) -> 'Event': return object.__new__(cls) - def __init__(self): + def __init__(self) -> None: self.__original = trio.Event() def is_set(self) -> bool: return self.__original.is_set() - async def wait(self) -> bool: + async def wait(self) -> None: return await self.__original.wait() def statistics(self) -> EventStatistics: return self.__original.statistics() - def set(self): + def set(self) -> DeprecatedAwaitable: self.__original.set() return DeprecatedAwaitable(self.set) class CapacityLimiter(BaseCapacityLimiter): - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: object, **kwargs: object) -> "CapacityLimiter": return object.__new__(cls) - def __init__(self, *args, original: Optional[trio.CapacityLimiter] = None): + def __init__(self, *args: object, original: Optional[trio.CapacityLimiter] = None) -> None: self.__original = original or trio.CapacityLimiter(*args) - async def __aenter__(self): + async def __aenter__(self) -> None: return await self.__original.__aenter__() async def __aexit__(self, exc_type: Optional[Type[BaseException]], @@ -636,24 +641,24 @@ def borrowed_tokens(self) -> int: def available_tokens(self) -> float: return self.__original.available_tokens - def acquire_nowait(self): + def acquire_nowait(self) -> DeprecatedAwaitable: self.__original.acquire_nowait() return DeprecatedAwaitable(self.acquire_nowait) - def acquire_on_behalf_of_nowait(self, borrower): + def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: self.__original.acquire_on_behalf_of_nowait(borrower) return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait) async def acquire(self) -> None: await self.__original.acquire() - async def acquire_on_behalf_of(self, borrower) -> None: + async def acquire_on_behalf_of(self, borrower: object) -> None: await self.__original.acquire_on_behalf_of(borrower) def release(self) -> None: return self.__original.release() - def release_on_behalf_of(self, borrower) -> None: + def release_on_behalf_of(self, borrower: object) -> None: return self.__original.release_on_behalf_of(borrower) def statistics(self) -> CapacityLimiterStatistics: @@ -677,17 +682,19 @@ def current_default_thread_limiter() -> CapacityLimiter: # class _SignalReceiver(DeprecatedAsyncContextManager): - def __init__(self, cm): + def __init__(self, cm: ContextManager[T]): self._cm = cm def __enter__(self) -> T: return self._cm.__enter__() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: return self._cm.__exit__(exc_type, exc_val, exc_tb) -def open_signal_receiver(*signals: int): +def open_signal_receiver(*signals: int) -> _SignalReceiver: cm = trio.open_signal_receiver(*signals) return _SignalReceiver(cm) @@ -723,18 +730,18 @@ def get_running_tasks() -> List[TaskInfo]: return task_infos -def wait_all_tasks_blocked(): +def wait_all_tasks_blocked() -> Awaitable[None]: import trio.testing return trio.testing.wait_all_tasks_blocked() class TestRunner(abc.TestRunner): - def __init__(self, **options): + def __init__(self, **options: object) -> None: from collections import deque from queue import Queue - self._call_queue = Queue() - self._result_queue = deque() + self._call_queue: "Queue[Callable[..., object]]" = Queue() + self._result_queue: Outcome = deque() self._stop_event: Optional[trio.Event] = None self._nursery: Optional[trio.Nursery] = None self._options = options @@ -744,7 +751,7 @@ async def _trio_main(self) -> None: async with trio.open_nursery() as self._nursery: await self._stop_event.wait() - async def _call_func(self, func, args, kwargs): + async def _call_func(self, func: Callable[..., Awaitable[object]], args: tuple, kwargs: dict) -> None: try: retval = await func(*args, **kwargs) except BaseException as exc: @@ -752,7 +759,7 @@ async def _call_func(self, func, args, kwargs): else: self._result_queue.append(Value(retval)) - def _main_task_finished(self, outcome) -> None: + def _main_task_finished(self, outcome: object) -> None: self._nursery = None def close(self) -> None: @@ -761,7 +768,7 @@ def close(self) -> None: while self._nursery is not None: self._call_queue.get()() - def call(self, func: Callable[..., Awaitable], *args, **kwargs): + def call(self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object) -> T_Retval: if self._nursery is None: trio.lowlevel.start_guest_run( self._trio_main, run_sync_soon_threadsafe=self._call_queue.put, diff --git a/src/anyio/_core/_compat.py b/src/anyio/_core/_compat.py index 9f44ea95..9a25725c 100644 --- a/src/anyio/_core/_compat.py +++ b/src/anyio/_core/_compat.py @@ -1,13 +1,22 @@ from abc import ABCMeta, abstractmethod from contextlib import AbstractContextManager +from types import TracebackType from typing import ( - AsyncContextManager, Callable, ContextManager, Generic, List, Optional, TypeVar, Union, - overload) + TYPE_CHECKING, AsyncContextManager, Callable, ContextManager, Generic, Iterable, List, + Optional, Tuple, Type, TypeVar, Union, overload) from warnings import warn +if TYPE_CHECKING: + from ._testing import TaskInfo + T = TypeVar('T') AnyDeprecatedAwaitable = Union['DeprecatedAwaitable', 'DeprecatedAwaitableFloat', - 'DeprecatedAwaitableList'] + 'DeprecatedAwaitableList', 'TaskInfo'] + + +@overload +async def maybe_async(__obj: 'TaskInfo') -> 'TaskInfo': + ... @overload @@ -16,7 +25,7 @@ async def maybe_async(__obj: 'DeprecatedAwaitableFloat') -> float: @overload -async def maybe_async(__obj: 'DeprecatedAwaitableList') -> list: +async def maybe_async(__obj: 'DeprecatedAwaitableList[T]') -> List[T]: ... @@ -25,7 +34,7 @@ async def maybe_async(__obj: 'DeprecatedAwaitable') -> None: ... -async def maybe_async(__obj: AnyDeprecatedAwaitable) -> Union[float, list, None]: +async def maybe_async(__obj: AnyDeprecatedAwaitable) -> Union[TaskInfo, float, list, None]: """ Await on the given object if necessary. @@ -49,7 +58,7 @@ def __init__(self, cm: ContextManager[T]): async def __aenter__(self) -> T: return self._cm.__enter__() - async def __aexit__(self, exc_type, exc_val, exc_tb) -> Optional[bool]: + async def __aexit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> Optional[bool]: return self._cm.__exit__(exc_type, exc_val, exc_tb) @@ -83,33 +92,33 @@ class DeprecatedAwaitable: def __init__(self, func: Callable[..., 'DeprecatedAwaitable']): self._name = f'{func.__module__}.{func.__qualname__}' - def __await__(self): + def __await__(self) -> Iterable[None]: _warn_deprecation(self) if False: yield - def __reduce__(self): + def __reduce__(self) -> Tuple[Type[None], Tuple]: return type(None), () - def _unwrap(self): + def _unwrap(self) -> None: return None class DeprecatedAwaitableFloat(float): - def __new__(cls, x, func): + def __new__(cls, x: float, func: Callable[..., 'DeprecatedAwaitableFloat']) -> DeprecatedAwaitableFloat: return super().__new__(cls, x) def __init__(self, x: float, func: Callable[..., 'DeprecatedAwaitableFloat']): self._name = f'{func.__module__}.{func.__qualname__}' - def __await__(self): + def __await__(self) -> Iterable[float]: _warn_deprecation(self) if False: yield return float(self) - def __reduce__(self): + def __reduce__(self) -> Tuple[Type[float], Tuple[float]]: return float, (float(self),) def _unwrap(self) -> float: @@ -117,18 +126,18 @@ def _unwrap(self) -> float: class DeprecatedAwaitableList(List[T]): - def __init__(self, *args, func: Callable[..., 'DeprecatedAwaitableList']): + def __init__(self, *args: T, func: Callable[..., 'DeprecatedAwaitableList']): super().__init__(*args) self._name = f'{func.__module__}.{func.__qualname__}' - def __await__(self): + def __await__(self) -> Iterable[List[T]]: _warn_deprecation(self) if False: yield - return self + return list(self) - def __reduce__(self): + def __reduce__(self) -> Tuple[Type[list], Tuple[List[T]]]: return list, (list(self),) def _unwrap(self) -> List[T]: @@ -141,7 +150,7 @@ def __enter__(self) -> T: pass @abstractmethod - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> Optional[bool]: pass async def __aenter__(self) -> T: @@ -151,5 +160,5 @@ async def __aenter__(self) -> T: f'you are completely migrating to AnyIO 3+.', DeprecationWarning) return self.__enter__() - async def __aexit__(self, exc_type, exc_val, exc_tb) -> Optional[bool]: + async def __aexit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> Optional[bool]: return self.__exit__(exc_type, exc_val, exc_tb) diff --git a/src/anyio/_core/_eventloop.py b/src/anyio/_core/_eventloop.py index 6021ab99..f2364a3b 100644 --- a/src/anyio/_core/_eventloop.py +++ b/src/anyio/_core/_eventloop.py @@ -16,7 +16,7 @@ threadlocals = threading.local() -def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args, +def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object, backend: str = 'asyncio', backend_options: Optional[Dict[str, Any]] = None) -> T_Retval: """ Run the given coroutine function in an asynchronous event loop. @@ -120,7 +120,7 @@ def get_cancelled_exc_class() -> Type[BaseException]: # @contextmanager -def claim_worker_thread(backend) -> Generator[Any, None, None]: +def claim_worker_thread(backend: str) -> Generator[Any, None, None]: module = sys.modules['anyio._backends._' + backend] threadlocals.current_async_module = module token = sniffio.current_async_library_cvar.set(backend) @@ -131,7 +131,7 @@ def claim_worker_thread(backend) -> Generator[Any, None, None]: del threadlocals.current_async_module -def get_asynclib(asynclib_name: Optional[str] = None): +def get_asynclib(asynclib_name: Optional[str] = None) -> Any: if asynclib_name is None: asynclib_name = sniffio.current_async_library() diff --git a/src/anyio/_core/_exceptions.py b/src/anyio/_core/_exceptions.py index 8025ef6f..52e59808 100644 --- a/src/anyio/_core/_exceptions.py +++ b/src/anyio/_core/_exceptions.py @@ -49,7 +49,7 @@ class ExceptionGroup(BaseException): #: the sequence of exceptions raised together exceptions: Sequence[BaseException] - def __str__(self): + def __str__(self) -> str: tracebacks = [''.join(format_exception(type(exc), exc, exc.__traceback__)) for exc in self.exceptions] return f'{len(self.exceptions)} exceptions were raised in the task group:\n' \ @@ -67,7 +67,7 @@ class IncompleteRead(Exception): connection is closed before the requested amount of bytes has been read. """ - def __init__(self): + def __init__(self) -> None: super().__init__('The stream was closed before the read operation could be completed') diff --git a/src/anyio/_core/_fileio.py b/src/anyio/_core/_fileio.py index fe2b53d6..ebfd34e3 100644 --- a/src/anyio/_core/_fileio.py +++ b/src/anyio/_core/_fileio.py @@ -1,12 +1,14 @@ import os from os import PathLike -from typing import Callable, Optional, Union +from typing import Any, AsyncIterator, Callable, Generic, Optional, TypeVar, Union from .. import to_thread from ..abc import AsyncResource +T_Fp = TypeVar("T_Fp") -class AsyncFile(AsyncResource): + +class AsyncFile(AsyncResource, Generic[T_Fp]): """ An asynchronous file object. @@ -38,18 +40,18 @@ class AsyncFile(AsyncResource): print(line) """ - def __init__(self, fp) -> None: - self._fp = fp + def __init__(self, fp: T_Fp) -> None: + self._fp: Any = fp - def __getattr__(self, name): + def __getattr__(self, name: str) -> object: return getattr(self._fp, name) @property - def wrapped(self): + def wrapped(self) -> T_Fp: """The wrapped file object.""" return self._fp - async def __aiter__(self): + async def __aiter__(self) -> AsyncIterator[bytes]: while True: line = await self.readline() if line: diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py index e770861f..6abefad1 100644 --- a/src/anyio/_core/_sockets.py +++ b/src/anyio/_core/_sockets.py @@ -1,6 +1,7 @@ import socket import ssl import sys +from dataclasses import dataclass from ipaddress import IPv6Address, ip_address from os import PathLike, chmod from pathlib import Path @@ -83,9 +84,9 @@ async def connect_tcp( async def connect_tcp( - remote_host, remote_port, *, local_host=None, tls=False, ssl_context=None, - tls_standard_compatible=True, tls_hostname=None, happy_eyeballs_delay=0.25 -): + remote_host: IPAddressType, remote_port: int, *, local_host: Optional[IPAddressType] = None, tls: bool = False, ssl_context: Optional[ssl.SSLContext] = None, + tls_standard_compatible: bool = True, tls_hostname: Optional[str] = None, happy_eyeballs_delay: float = 0.25 +) -> Union[SocketStream, TLSStream]: """ Connect to a host using the TCP protocol. @@ -119,7 +120,7 @@ async def connect_tcp( # Placed here due to https://github.com/python/mypy/issues/7057 connected_stream: Optional[SocketStream] = None - async def try_connect(remote_host: str, event: Event): + async def try_connect(remote_host: str, event: Event) -> None: nonlocal connected_stream try: stream = await asynclib.connect_tcp(remote_host, remote_port, local_address) @@ -184,7 +185,7 @@ async def try_connect(remote_host: str, event: Event): if tls or tls_hostname or ssl_context: try: return await TLSStream.wrap(connected_stream, server_side=False, - hostname=tls_hostname or remote_host, + hostname=tls_hostname or cast(Optional[str], remote_host), ssl_context=ssl_context, standard_compatible=tls_standard_compatible) except BaseException: @@ -475,7 +476,22 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]: # Private API # -def convert_ipv6_sockaddr(sockaddr): +@dataclass +class _SockAddr: + host: str + port: int + flowinfo: Optional[str] = None + scope_id: int = 0 + + def as_two_tuple(self) -> Tuple[str, int]: + if self.scope_id: + # Add scopeid to the address + return f"{self.host}%{self.scope_id}", self.port + else: + return self.host, self.port + + +def convert_ipv6_sockaddr(sockaddr: Union[Tuple[str, int, object, object], Tuple[str, int]]) -> Tuple[str, int]: """ Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format. @@ -487,12 +503,9 @@ def convert_ipv6_sockaddr(sockaddr): :return: the converted socket address """ + # This is more complicated than it should be because of MyPy if isinstance(sockaddr, tuple) and len(sockaddr) == 4: - if sockaddr[3]: - # Add scopeid to the address - return sockaddr[0] + '%' + str(sockaddr[3]), sockaddr[1] - else: - return sockaddr[:2] + return _SockAddr(*sockaddr).as_two_tuple() else: - return sockaddr + return cast(Tuple[str, int], sockaddr) diff --git a/src/anyio/_core/_streams.py b/src/anyio/_core/_streams.py index 02eabd4d..2cf62226 100644 --- a/src/anyio/_core/_streams.py +++ b/src/anyio/_core/_streams.py @@ -1,5 +1,5 @@ import math -from typing import Any, Tuple, Type, TypeVar, overload +from typing import Any, Optional, Tuple, Type, TypeVar, overload from ..streams.memory import ( MemoryObjectReceiveStream, MemoryObjectSendStream, MemoryObjectStreamState) @@ -21,7 +21,7 @@ def create_memory_object_stream( ... -def create_memory_object_stream(max_buffer_size=0, item_type=None): +def create_memory_object_stream(max_buffer_size: float = 0, item_type: Optional[Type[T_Item]] = None) -> Tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]: """ Create a memory object stream. diff --git a/src/anyio/_core/_subprocesses.py b/src/anyio/_core/_subprocesses.py index f7e0232c..1daf6c10 100644 --- a/src/anyio/_core/_subprocesses.py +++ b/src/anyio/_core/_subprocesses.py @@ -1,6 +1,6 @@ from os import PathLike from subprocess import DEVNULL, PIPE, CalledProcessError, CompletedProcess -from typing import Mapping, Optional, Sequence, Union, cast +from typing import AsyncIterable, List, Mapping, Optional, Sequence, Union, cast from ..abc import Process from ._eventloop import get_asynclib @@ -32,13 +32,13 @@ async def run_process(command: Union[str, Sequence[str]], *, input: Optional[byt nonzero return code """ - async def drain_stream(stream, index): + async def drain_stream(stream: AsyncIterable[bytes], index: int) -> None: chunks = [chunk async for chunk in stream] stream_contents[index] = b''.join(chunks) async with await open_process(command, stdin=PIPE if input else DEVNULL, stdout=stdout, stderr=stderr, cwd=cwd, env=env) as process: - stream_contents = [None, None] + stream_contents: List[Optional[bytes]] = [None, None] try: async with create_task_group() as tg: if process.stdout: diff --git a/src/anyio/_core/_synchronization.py b/src/anyio/_core/_synchronization.py index f7070e98..6c691770 100644 --- a/src/anyio/_core/_synchronization.py +++ b/src/anyio/_core/_synchronization.py @@ -73,7 +73,7 @@ class SemaphoreStatistics: class Event: - def __new__(cls): + def __new__(cls) -> 'Event': return get_asynclib().Event() def set(self) -> DeprecatedAwaitable: @@ -84,7 +84,7 @@ def is_set(self) -> bool: """Return ``True`` if the flag is set, ``False`` if not.""" raise NotImplementedError - async def wait(self) -> bool: + async def wait(self) -> None: """ Wait until the flag has been set. @@ -101,10 +101,10 @@ def statistics(self) -> EventStatistics: class Lock: _owner_task: Optional[TaskInfo] = None - def __init__(self): + def __init__(self) -> None: self._waiters: Deque[Tuple[TaskInfo, Event]] = deque() - async def __aenter__(self): + async def __aenter__(self) -> None: await self.acquire() async def __aexit__(self, exc_type: Optional[Type[BaseException]], @@ -183,7 +183,7 @@ def __init__(self, lock: Optional[Lock] = None): self._lock = lock or Lock() self._waiters: Deque[Event] = deque() - async def __aenter__(self): + async def __aenter__(self) -> None: await self.acquire() async def __aexit__(self, exc_type: Optional[Type[BaseException]], @@ -351,10 +351,10 @@ def statistics(self) -> SemaphoreStatistics: class CapacityLimiter: - def __new__(cls, total_tokens: float): + def __new__(cls, total_tokens: float) -> 'CapacityLimiter': return get_asynclib().CapacityLimiter(total_tokens) - async def __aenter__(self): + async def __aenter__(self) -> None: raise NotImplementedError async def __aexit__(self, exc_type: Optional[Type[BaseException]], @@ -380,7 +380,7 @@ def total_tokens(self) -> float: def total_tokens(self, value: float) -> None: raise NotImplementedError - async def set_total_tokens(self, value) -> None: + async def set_total_tokens(self, value: float) -> None: warn('CapacityLimiter.set_total_tokens has been deprecated. Set the value of the' '"total_tokens" attribute directly.', DeprecationWarning) self.total_tokens = value @@ -404,7 +404,7 @@ def acquire_nowait(self) -> DeprecatedAwaitable: """ raise NotImplementedError - def acquire_on_behalf_of_nowait(self, borrower) -> DeprecatedAwaitable: + def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: """ Acquire a token without waiting for one to become available. @@ -421,7 +421,7 @@ async def acquire(self) -> None: """ raise NotImplementedError - async def acquire_on_behalf_of(self, borrower) -> None: + async def acquire_on_behalf_of(self, borrower: object) -> None: """ Acquire a token, waiting if necessary for one to become available. @@ -438,7 +438,7 @@ def release(self) -> None: """ raise NotImplementedError - def release_on_behalf_of(self, borrower) -> None: + def release_on_behalf_of(self, borrower: object) -> None: """ Release the token held by the given borrower. @@ -541,11 +541,14 @@ def __init__(self, action: str): self.action = action self._guarded = False - def __enter__(self): + def __enter__(self) -> None: if self._guarded: raise BusyResourceError(self.action) self._guarded = True - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: self._guarded = False + return None diff --git a/src/anyio/_core/_tasks.py b/src/anyio/_core/_tasks.py index ac4e41da..8bbad974 100644 --- a/src/anyio/_core/_tasks.py +++ b/src/anyio/_core/_tasks.py @@ -9,7 +9,7 @@ class _IgnoredTaskStatus(TaskStatus): - def started(self, value=None) -> None: + def started(self, value: object = None) -> None: pass @@ -24,7 +24,7 @@ class CancelScope(DeprecatedAsyncContextManager['CancelScope']): :param shield: ``True`` to shield the cancel scope from external cancellation """ - def __new__(cls, *, deadline: float = math.inf, shield: bool = False): + def __new__(cls, *, deadline: float = math.inf, shield: bool = False) -> 'CancelScope': return get_asynclib().CancelScope(shield=shield, deadline=deadline) def cancel(self) -> DeprecatedAwaitable: @@ -64,7 +64,7 @@ def shield(self) -> bool: def shield(self, value: bool) -> None: raise NotImplementedError - def __enter__(self): + def __enter__(self) -> 'CancelScope': raise NotImplementedError def __exit__(self, exc_type: Optional[Type[BaseException]], @@ -92,10 +92,12 @@ class FailAfterContextManager(DeprecatedAsyncContextManager): def __init__(self, cancel_scope: CancelScope): self._cancel_scope = cancel_scope - def __enter__(self): + def __enter__(self) -> CancelScope: return self._cancel_scope.__enter__() - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: retval = self._cancel_scope.__exit__(exc_type, exc_val, exc_tb) if self._cancel_scope.cancel_called: raise TimeoutError diff --git a/src/anyio/_core/_testing.py b/src/anyio/_core/_testing.py index 269923da..2076db52 100644 --- a/src/anyio/_core/_testing.py +++ b/src/anyio/_core/_testing.py @@ -1,10 +1,10 @@ -from typing import Coroutine, Optional +from typing import Coroutine, Iterable, Optional, Tuple, Type -from ._compat import DeprecatedAwaitable, DeprecatedAwaitableList +from ._compat import DeprecatedAwaitable, DeprecatedAwaitableList, _warn_deprecation from ._eventloop import get_asynclib -class TaskInfo(DeprecatedAwaitable): +class TaskInfo: """ Represents an asynchronous task. @@ -15,31 +15,39 @@ class TaskInfo(DeprecatedAwaitable): :ivar ~collections.abc.Coroutine coro: the coroutine object of the task """ - __slots__ = 'id', 'parent_id', 'name', 'coro' + __slots__ = '_name', 'id', 'parent_id', 'name', 'coro' def __init__(self, id: int, parent_id: Optional[int], name: Optional[str], coro: Coroutine): - super().__init__(get_current_task) + func = get_current_task + self._name = f'{func.__module__}.{func.__qualname__}' self.id = id self.parent_id = parent_id self.name = name self.coro = coro - def __await__(self): - yield from super().__await__() - return self - - def __eq__(self, other): + def __eq__(self, other: object) -> bool: if isinstance(other, TaskInfo): return self.id == other.id return NotImplemented - def __hash__(self): + def __hash__(self) -> int: return hash(self.id) - def __repr__(self): + def __repr__(self) -> str: return f'{self.__class__.__name__}(id={self.id!r}, name={self.name!r})' + def __await__(self) -> Iterable["TaskInfo"]: + _warn_deprecation(self) + if False: + yield + + def __reduce__(self) -> Tuple[Type["TaskInfo"], Tuple[int, Optional[int], Optional[str], Coroutine]]: + return TaskInfo, (self.id, self.parent_id, self.name, self.coro) + + def _unwrap(self) -> 'TaskInfo': + return self + def get_current_task() -> TaskInfo: """ diff --git a/src/anyio/_core/_typedattr.py b/src/anyio/_core/_typedattr.py index 162426a1..e8377961 100644 --- a/src/anyio/_core/_typedattr.py +++ b/src/anyio/_core/_typedattr.py @@ -1,5 +1,5 @@ import sys -from typing import Callable, Mapping, TypeVar, Union, overload +from typing import Any, Callable, Mapping, TypeVar, Union, overload from ._exceptions import TypedAttributeLookupError @@ -13,7 +13,7 @@ undefined = object() -def typed_attribute(): +def typed_attribute() -> Any: """Return a unique object, used to mark typed attributes.""" return object() @@ -58,7 +58,7 @@ def extra(self, attribute: T_Attr, default: T_Default) -> Union[T_Attr, T_Defaul ... @final - def extra(self, attribute, default=undefined): + def extra(self, attribute: Any, default: object = undefined) -> object: """ extra(attribute, default=undefined) diff --git a/src/anyio/abc/_resources.py b/src/anyio/abc/_resources.py index d6ed168f..886366b9 100644 --- a/src/anyio/abc/_resources.py +++ b/src/anyio/abc/_resources.py @@ -1,6 +1,8 @@ from abc import ABCMeta, abstractmethod from types import TracebackType -from typing import Optional, Type +from typing import Optional, Type, TypeVar + +T = TypeVar("T") class AsyncResource(metaclass=ABCMeta): @@ -11,7 +13,7 @@ class AsyncResource(metaclass=ABCMeta): :meth:`aclose` on exit. """ - async def __aenter__(self): + async def __aenter__(self: T) -> "T": return self async def __aexit__(self, exc_type: Optional[Type[BaseException]], diff --git a/src/anyio/abc/_sockets.py b/src/anyio/abc/_sockets.py index ef23cdac..65a597ec 100644 --- a/src/anyio/abc/_sockets.py +++ b/src/anyio/abc/_sockets.py @@ -2,8 +2,9 @@ from io import IOBase from ipaddress import IPv4Address, IPv6Address from socket import AddressFamily, SocketType +from types import TracebackType from typing import ( - Any, AsyncContextManager, Callable, Collection, List, Optional, Tuple, TypeVar, Union) + Any, AsyncContextManager, Callable, Collection, List, Optional, Tuple, Type, TypeVar, Union) from .._core._typedattr import TypedAttributeProvider, TypedAttributeSet, typed_attribute from ._streams import ByteStream, Listener, T_Stream, UnreliableObjectStream @@ -17,11 +18,13 @@ class _NullAsyncContextManager: - async def __aenter__(self): + async def __aenter__(self) -> None: pass - async def __aexit__(self, exc_type, exc_val, exc_tb): - pass + async def __aexit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: + return None class SocketAttribute(TypedAttributeSet): @@ -41,7 +44,7 @@ class SocketAttribute(TypedAttributeSet): class _SocketProvider(TypedAttributeProvider): @property - def extra_attributes(self): + def extra_attributes(self) -> dict: from .._core._sockets import convert_ipv6_sockaddr as convert attributes = { @@ -50,7 +53,7 @@ def extra_attributes(self): SocketAttribute.raw_socket: lambda: self._raw_socket } try: - peername = convert(self._raw_socket.getpeername()) + peername: Optional[Tuple[str, int]] = convert(self._raw_socket.getpeername()) except OSError: peername = None @@ -62,7 +65,8 @@ def extra_attributes(self): if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6): attributes[SocketAttribute.local_port] = lambda: self._raw_socket.getsockname()[1] if peername is not None: - attributes[SocketAttribute.remote_port] = lambda: peername[1] + remote_port = peername[1] + attributes[SocketAttribute.remote_port] = lambda: remote_port return attributes diff --git a/src/anyio/abc/_streams.py b/src/anyio/abc/_streams.py index 6747f543..3f5e783d 100644 --- a/src/anyio/abc/_streams.py +++ b/src/anyio/abc/_streams.py @@ -21,7 +21,7 @@ class UnreliableObjectReceiveStream(Generic[T_Item], AsyncResource, TypedAttribu parameter. """ - def __aiter__(self): + def __aiter__(self) -> "UnreliableObjectReceiveStream[T_Item]": return self async def __anext__(self) -> T_Item: diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py index 3801d566..e36a0e47 100644 --- a/src/anyio/abc/_tasks.py +++ b/src/anyio/abc/_tasks.py @@ -12,7 +12,7 @@ class TaskStatus(metaclass=ABCMeta): @abstractmethod - def started(self, value=None) -> None: + def started(self, value: object = None) -> None: """ Signal that the task has started. @@ -30,7 +30,7 @@ class TaskGroup(metaclass=ABCMeta): cancel_scope: 'CancelScope' - async def spawn(self, func: Callable[..., Coroutine], *args, name=None) -> None: + async def spawn(self, func: Callable[..., Coroutine], *args: object, name: Optional[str] = None) -> None: """ Start a new task in this task group. @@ -48,7 +48,7 @@ async def spawn(self, func: Callable[..., Coroutine], *args, name=None) -> None: self.start_soon(func, *args, name=name) @abstractmethod - def start_soon(self, func: Callable[..., Coroutine], *args, name=None) -> None: + def start_soon(self, func: Callable[..., Coroutine], *args: object, name: Optional[str] = None) -> None: """ Start a new task in this task group. @@ -60,7 +60,7 @@ def start_soon(self, func: Callable[..., Coroutine], *args, name=None) -> None: """ @abstractmethod - async def start(self, func: Callable[..., Coroutine], *args, name=None): + async def start(self, func: Callable[..., Coroutine], *args: object, name: Optional[str] = None) -> object: """ Start a new task and wait until it signals for readiness. diff --git a/src/anyio/abc/_testing.py b/src/anyio/abc/_testing.py index ace2d9b7..012271b4 100644 --- a/src/anyio/abc/_testing.py +++ b/src/anyio/abc/_testing.py @@ -1,5 +1,8 @@ +import types from abc import ABCMeta, abstractmethod -from typing import Any, Awaitable, Callable, Dict +from typing import Any, Awaitable, Callable, Dict, Optional, Type, TypeVar + +_T = TypeVar("_T") class TestRunner(metaclass=ABCMeta): @@ -11,15 +14,16 @@ class TestRunner(metaclass=ABCMeta): def __enter__(self) -> 'TestRunner': return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[types.TracebackType]) -> Optional[bool]: self.close() + return None @abstractmethod def close(self) -> None: """Close the event loop.""" @abstractmethod - def call(self, func: Callable[..., Awaitable], *args: tuple, **kwargs: Dict[str, Any]): + def call(self, func: Callable[..., Awaitable[_T]], *args: tuple, **kwargs: Dict[str, Any]) -> _T: """ Call the given function within the backend's event loop. diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index 48fc0015..83b9e426 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -5,7 +5,7 @@ from types import TracebackType from typing import ( Any, AsyncContextManager, Callable, ContextManager, Coroutine, Dict, Generator, Iterable, - Optional, Tuple, Type, TypeVar, cast, overload) + Optional, Tuple, Type, TypeVar, Union, cast, overload) from warnings import warn from ._core import _eventloop @@ -17,7 +17,7 @@ T_co = TypeVar('T_co') -def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval: +def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval: """ Call a coroutine function from a worker thread. @@ -34,13 +34,13 @@ def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval: return asynclib.run_async_from_thread(func, *args) -def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval: +def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval: warn('run_async_from_thread() has been deprecated, use anyio.from_thread.run() instead', DeprecationWarning) return run(func, *args) -def run_sync(func: Callable[..., T_Retval], *args) -> T_Retval: +def run_sync(func: Callable[..., T_Retval], *args: object) -> T_Retval: """ Call a function in the event loop thread from a worker thread. @@ -57,7 +57,7 @@ def run_sync(func: Callable[..., T_Retval], *args) -> T_Retval: return asynclib.run_sync_from_thread(func, *args) -def run_sync_from_thread(func: Callable[..., T_Retval], *args) -> T_Retval: +def run_sync_from_thread(func: Callable[..., T_Retval], *args: object) -> T_Retval: warn('run_sync_from_thread() has been deprecated, use anyio.from_thread.run_sync() instead', DeprecationWarning) return run_sync(func, *args) @@ -74,7 +74,7 @@ def __init__(self, async_cm: AsyncContextManager[T_co], portal: 'BlockingPortal' self._async_cm = async_cm self._portal = portal - async def run_async_cm(self): + async def run_async_cm(self) -> Optional[bool]: try: self._exit_event = Event() value = await self._async_cm.__aenter__() @@ -105,18 +105,18 @@ class _BlockingPortalTaskStatus(TaskStatus): def __init__(self, future: Future): self._future = future - def started(self, value=None) -> None: + def started(self, value: object = None) -> None: self._future.set_result(value) class BlockingPortal: """An object tied that lets external threads run code in an asynchronous event loop.""" - def __new__(cls): + def __new__(cls) -> 'BlockingPortal': return get_asynclib().BlockingPortal() - def __init__(self): - self._event_loop_thread_id = threading.get_ident() + def __init__(self) -> None: + self._event_loop_thread_id: Optional[int] = threading.get_ident() self._stop_event = Event() self._task_group = create_task_group() self._cancelled_exc_class = get_cancelled_exc_class() @@ -125,7 +125,9 @@ async def __aenter__(self) -> 'BlockingPortal': await self._task_group.__aenter__() return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: await self.stop() return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) @@ -157,7 +159,7 @@ async def stop(self, cancel_remaining: bool = False) -> None: async def _call_func(self, func: Callable, args: tuple, kwargs: Dict[str, Any], future: Future) -> None: - def callback(f: Future): + def callback(f: Future) -> None: if f.cancelled(): self.call(scope.cancel) @@ -184,10 +186,10 @@ def callback(f: Future): if not future.cancelled(): future.set_result(retval) finally: - scope = None + del scope def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any], - name, future: Future) -> None: + name: Optional[str], future: Future) -> None: """ Spawn a new task using the given callable. @@ -204,14 +206,14 @@ def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, raise NotImplementedError @overload - def call(self, func: Callable[..., Coroutine[Any, Any, T_Retval]], *args) -> T_Retval: + def call(self, func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval: ... @overload - def call(self, func: Callable[..., T_Retval], *args) -> T_Retval: + def call(self, func: Callable[..., T_Retval], *args: object) -> T_Retval: ... - def call(self, func, *args): + def call(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], *args: object) -> T_Retval: """ Call the given function in the event loop thread. @@ -224,7 +226,7 @@ def call(self, func, *args): """ return self.start_task_soon(func, *args).result() - def spawn_task(self, func: Callable[..., Coroutine], *args, name=None) -> Future: + def spawn_task(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], *args: object, name: str = None) -> "Future[T_Retval]": """ Start a task in the portal's task group. @@ -245,7 +247,7 @@ def spawn_task(self, func: Callable[..., Coroutine], *args, name=None) -> Future warn('spawn_task() is deprecated -- use start_task_soon() instead', DeprecationWarning) return self.start_task_soon(func, *args, name=name) - def start_task_soon(self, func: Callable[..., Coroutine], *args, name=None) -> Future: + def start_task_soon(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], *args: object, name: str = None) -> "Future[T_Retval]": """ Start a task in the portal's task group. @@ -268,7 +270,7 @@ def start_task_soon(self, func: Callable[..., Coroutine], *args, name=None) -> F self._spawn_task_from_thread(func, args, {}, name, f) return f - def start_task(self, func: Callable[..., Coroutine], *args, name=None) -> Tuple[Future, Any]: + def start_task(self, func: Callable[..., Coroutine], *args: object, name: str = None) -> Tuple[Future, Any]: """ Start a task in the portal's task group and wait until it signals for readiness. @@ -350,7 +352,7 @@ def start_blocking_portal( Usage as a context manager is now required. """ - async def run_portal(): + async def run_portal() -> None: async with BlockingPortal() as portal_: if future.set_running_or_notify_cancel(): future.set_result(portal_) diff --git a/src/anyio/lowlevel.py b/src/anyio/lowlevel.py index 8a7d53fb..8ec36953 100644 --- a/src/anyio/lowlevel.py +++ b/src/anyio/lowlevel.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, Generic, Set, TypeVar, Union, cast +import enum +from dataclasses import dataclass +from typing import Any, Dict, Generic, NewType, Set, TypeVar, Union, cast from weakref import WeakKeyDictionary from ._core._eventloop import get_asynclib @@ -22,7 +24,7 @@ async def checkpoint() -> None: await get_asynclib().checkpoint() -async def checkpoint_if_cancelled(): +async def checkpoint_if_cancelled() -> None: """ Enter a checkpoint if the enclosing cancel scope has been cancelled. @@ -58,25 +60,22 @@ def current_token() -> object: _token_wrappers: Dict[Any, '_TokenWrapper'] = {} +@dataclass(frozen=True) class _TokenWrapper: - __slots__ = '_token', '__weakref__' + __slots__ = ("_token", "__weakref__") + _token: object - def __init__(self, token): - self._token = token - def __eq__(self, other): - return self._token is other._token +class _NoValueSet(enum.Enum): + NO_VALUE_SET = enum.auto() - def __hash__(self): - return hash(self._token) - -class RunvarToken: +class RunvarToken(Generic[T]): __slots__ = '_var', '_value', '_redeemed' - def __init__(self, var: 'RunVar', value): + def __init__(self, var: 'RunVar', value: Union[T, _NoValueSet]): self._var = var - self._value = value + self._value: Union[T, _NoValueSet] = value self._redeemed = False @@ -84,7 +83,7 @@ class RunVar(Generic[T]): """Like a :class:`~contextvars.ContextVar`, expect scoped to the running event loop.""" __slots__ = '_name', '_default' - NO_VALUE_SET = object() + NO_VALUE_SET = _NoValueSet.NO_VALUE_SET _token_wrappers: Set[_TokenWrapper] = set() @@ -119,20 +118,20 @@ def get(self, default: Union[T, object] = NO_VALUE_SET) -> T: raise LookupError(f'Run variable "{self._name}" has no value and no default set') - def set(self, value: T) -> RunvarToken: + def set(self, value: T) -> RunvarToken[T]: current_vars = self._current_vars token = RunvarToken(self, current_vars.get(self._name, RunVar.NO_VALUE_SET)) current_vars[self._name] = value return token - def reset(self, token: RunvarToken) -> None: + def reset(self, token: RunvarToken[T]) -> None: if token._var is not self: raise ValueError('This token does not belong to this RunVar') if token._redeemed: raise ValueError('This token has already been used') - if token._value is RunVar.NO_VALUE_SET: + if isinstance(token._value, _NoValueSet): try: del self._current_vars[self._name] except KeyError: @@ -142,5 +141,5 @@ def reset(self, token: RunvarToken) -> None: token._redeemed = True - def __repr__(self): + def __repr__(self) -> str: return f'' diff --git a/src/anyio/pytest_plugin.py b/src/anyio/pytest_plugin.py index 1db51764..1eadafe5 100644 --- a/src/anyio/pytest_plugin.py +++ b/src/anyio/pytest_plugin.py @@ -1,7 +1,7 @@ import sys from contextlib import contextmanager from inspect import iscoroutinefunction -from typing import Any, Dict, Iterator, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Tuple, cast import pytest import sniffio @@ -14,10 +14,14 @@ else: from async_generator import isasyncgenfunction +if TYPE_CHECKING: + from _pytest.config import Config + from _pytest.fixtures import FixtureDef + _current_runner: Optional[TestRunner] = None -def extract_backend_and_options(backend) -> Tuple[str, Dict[str, Any]]: +def extract_backend_and_options(backend: object) -> Tuple[str, Dict[str, Any]]: if isinstance(backend, str): return backend, {} elif isinstance(backend, tuple) and len(backend) == 2: @@ -51,13 +55,13 @@ def get_runner(backend_name: str, backend_options: Dict[str, Any]) -> Iterator[T sniffio.current_async_library_cvar.reset(token) -def pytest_configure(config): +def pytest_configure(config: "Config") -> None: config.addinivalue_line('markers', 'anyio: mark the (coroutine function) test to be run ' 'asynchronously via anyio.') -def pytest_fixture_setup(fixturedef, request): - def wrapper(*args, anyio_backend, **kwargs): +def pytest_fixture_setup(fixturedef: Any, request: Any) -> None: + def wrapper(*args, anyio_backend, **kwargs): # type: ignore[no-untyped-def] backend_name, backend_options = extract_backend_and_options(anyio_backend) if has_backend_arg: kwargs['anyio_backend'] = anyio_backend @@ -94,7 +98,7 @@ def wrapper(*args, anyio_backend, **kwargs): @pytest.hookimpl(tryfirst=True) -def pytest_pycollect_makeitem(collector, name, obj): +def pytest_pycollect_makeitem(collector: Any, name: Any, obj: Any) -> None: if collector.istestfunction(obj, name): inner_func = obj.hypothesis.inner_test if hasattr(obj, 'hypothesis') else obj if iscoroutinefunction(inner_func): @@ -105,8 +109,8 @@ def pytest_pycollect_makeitem(collector, name, obj): @pytest.hookimpl(tryfirst=True) -def pytest_pyfunc_call(pyfuncitem): - def run_with_hypothesis(**kwargs): +def pytest_pyfunc_call(pyfuncitem: Any) -> Optional[bool]: + def run_with_hypothesis(**kwargs: Any) -> None: with get_runner(backend_name, backend_options) as runner: runner.call(original_func, **kwargs) @@ -130,15 +134,16 @@ def run_with_hypothesis(**kwargs): runner.call(pyfuncitem.obj, **testargs) return True + return None @pytest.fixture(params=get_all_backends()) -def anyio_backend(request): +def anyio_backend(request: Any) -> Any: return request.param @pytest.fixture -def anyio_backend_name(anyio_backend) -> str: +def anyio_backend_name(anyio_backend: Any) -> str: if isinstance(anyio_backend, str): return anyio_backend else: @@ -146,7 +151,7 @@ def anyio_backend_name(anyio_backend) -> str: @pytest.fixture -def anyio_backend_options(anyio_backend) -> Dict[str, Any]: +def anyio_backend_options(anyio_backend: Any) -> Dict[str, Any]: if isinstance(anyio_backend, str): return {} else: diff --git a/src/anyio/streams/buffered.py b/src/anyio/streams/buffered.py index bed83246..7bbe2085 100644 --- a/src/anyio/streams/buffered.py +++ b/src/anyio/streams/buffered.py @@ -25,8 +25,8 @@ def buffer(self) -> bytes: return bytes(self._buffer) @property - def extra_attributes(self): - return self.receive_stream.extra_attributes + def extra_attributes(self) -> dict: + return self.receive_stream.extra_attributes # type: ignore[return-value] async def receive(self, max_bytes: int = 65536) -> bytes: if self._closed: diff --git a/src/anyio/streams/file.py b/src/anyio/streams/file.py index 12442ad3..970fe44c 100644 --- a/src/anyio/streams/file.py +++ b/src/anyio/streams/file.py @@ -25,8 +25,8 @@ async def aclose(self) -> None: await to_thread.run_sync(self._file.close) @property - def extra_attributes(self): - attributes = { + def extra_attributes(self) -> dict: + attributes: dict = { FileStreamAttribute.file: lambda: self._file, } diff --git a/src/anyio/streams/memory.py b/src/anyio/streams/memory.py index 0d04ab46..47ff1f5d 100644 --- a/src/anyio/streams/memory.py +++ b/src/anyio/streams/memory.py @@ -41,7 +41,7 @@ class MemoryObjectReceiveStream(Generic[T_Item], ObjectReceiveStream[T_Item]): _state: MemoryObjectStreamState[T_Item] _closed: bool = field(init=False, default=False) - def __post_init__(self): + def __post_init__(self) -> None: self._state.open_receive_channels += 1 def receive_nowait(self) -> T_Item: @@ -135,7 +135,7 @@ class MemoryObjectSendStream(Generic[T_Item], ObjectSendStream[T_Item]): _state: MemoryObjectStreamState[T_Item] _closed: bool = field(init=False, default=False) - def __post_init__(self): + def __post_init__(self) -> None: self._state.open_send_channels += 1 def send_nowait(self, item: T_Item) -> DeprecatedAwaitable: diff --git a/src/anyio/streams/stapled.py b/src/anyio/streams/stapled.py index 372cac23..771eab6b 100644 --- a/src/anyio/streams/stapled.py +++ b/src/anyio/streams/stapled.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Callable, Generic, Optional, Sequence, TypeVar +from typing import Any, Callable, Generic, List, Optional, Sequence, TypeVar from ..abc import ( ByteReceiveStream, ByteSendStream, ByteStream, Listener, ObjectReceiveStream, ObjectSendStream, @@ -38,8 +38,8 @@ async def aclose(self) -> None: await self.receive_stream.aclose() @property - def extra_attributes(self): - return dict(**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes) + def extra_attributes(self) -> dict: + return {**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes} @dataclass(eq=False) @@ -71,8 +71,8 @@ async def aclose(self) -> None: await self.receive_stream.aclose() @property - def extra_attributes(self): - return dict(**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes) + def extra_attributes(self) -> dict: + return {**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes} @dataclass(eq=False) @@ -92,12 +92,12 @@ class MultiListener(Generic[T_Stream], Listener[T_Stream]): listeners: Sequence[Listener[T_Stream]] - def __post_init__(self): - listeners = [] + def __post_init__(self) -> None: + listeners: List[Listener[T_Stream]] = [] for listener in self.listeners: if isinstance(listener, MultiListener): listeners.extend(listener.listeners) - del listener.listeners[:] + del listener.listeners[:] # type: ignore[attr-defined] else: listeners.append(listener) @@ -116,8 +116,8 @@ async def aclose(self) -> None: await listener.aclose() @property - def extra_attributes(self): - attributes = {} + def extra_attributes(self) -> dict: + attributes: dict = {} for listener in self.listeners: attributes.update(listener.extra_attributes) diff --git a/src/anyio/streams/text.py b/src/anyio/streams/text.py index 351f4a45..cfc1cb65 100644 --- a/src/anyio/streams/text.py +++ b/src/anyio/streams/text.py @@ -29,7 +29,7 @@ class TextReceiveStream(ObjectReceiveStream[str]): errors: InitVar[str] = 'strict' _decoder: codecs.IncrementalDecoder = field(init=False) - def __post_init__(self, encoding, errors): + def __post_init__(self, encoding: str, errors: str) -> None: decoder_class = codecs.getincrementaldecoder(encoding) self._decoder = decoder_class(errors=errors) @@ -45,8 +45,8 @@ async def aclose(self) -> None: self._decoder.reset() @property - def extra_attributes(self): - return self.transport_stream.extra_attributes + def extra_attributes(self) -> dict: + return self.transport_stream.extra_attributes # type: ignore[return-value] @dataclass(eq=False) @@ -68,8 +68,8 @@ class TextSendStream(ObjectSendStream[str]): errors: str = 'strict' _encoder: Callable[..., Tuple[bytes, int]] = field(init=False) - def __post_init__(self, encoding): - self._encoder = codecs.getencoder(encoding) + def __post_init__(self, encoding: str) -> None: + self._encoder = codecs.getencoder(encoding) # type: ignore[assignment] async def send(self, item: str) -> None: encoded = self._encoder(item, self.errors)[0] @@ -79,8 +79,8 @@ async def aclose(self) -> None: await self.transport_stream.aclose() @property - def extra_attributes(self): - return self.transport_stream.extra_attributes + def extra_attributes(self) -> dict: + return self.transport_stream.extra_attributes # type: ignore[return-value] @dataclass(eq=False) @@ -107,7 +107,7 @@ class TextStream(ObjectStream[str]): _receive_stream: TextReceiveStream = field(init=False) _send_stream: TextSendStream = field(init=False) - def __post_init__(self, encoding, errors): + def __post_init__(self, encoding: str, errors: str) -> None: self._receive_stream = TextReceiveStream(self.transport_stream, encoding=encoding, errors=errors) self._send_stream = TextSendStream(self.transport_stream, encoding=encoding, errors=errors) @@ -126,5 +126,5 @@ async def aclose(self) -> None: await self._receive_stream.aclose() @property - def extra_attributes(self): - return dict(**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes) + def extra_attributes(self) -> dict: + return {**self._send_stream.extra_attributes, **self._receive_stream.extra_attributes} diff --git a/src/anyio/streams/tls.py b/src/anyio/streams/tls.py index c42254a4..b4d53585 100644 --- a/src/anyio/streams/tls.py +++ b/src/anyio/streams/tls.py @@ -94,7 +94,7 @@ async def wrap(cls, transport_stream: AnyByteStream, *, server_side: Optional[bo await wrapper._call_sslobject_method(ssl_object.do_handshake) return wrapper - async def _call_sslobject_method(self, func: Callable[..., T_Retval], *args) -> T_Retval: + async def _call_sslobject_method(self, func: Callable[..., T_Retval], *args: object) -> T_Retval: while True: try: result = func(*args) @@ -168,7 +168,7 @@ async def send_eof(self) -> None: raise NotImplementedError('send_eof() has not yet been implemented for TLS streams') @property - def extra_attributes(self): + def extra_attributes(self) -> dict: return { **self.transport_stream.extra_attributes, TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, @@ -236,7 +236,7 @@ async def handle_handshake_error(exc: BaseException, stream: AnyByteStream) -> N async def serve(self, handler: Callable[[TLSStream], Any], task_group: Optional[TaskGroup] = None) -> None: @wraps(handler) - async def handler_wrapper(stream: AnyByteStream): + async def handler_wrapper(stream: AnyByteStream) -> None: from .. import fail_after try: with fail_after(self.handshake_timeout): @@ -254,7 +254,7 @@ async def aclose(self) -> None: await self.listener.aclose() @property - def extra_attributes(self): + def extra_attributes(self) -> dict: return { TLSAttribute.standard_compatible: lambda: self.standard_compatible, } diff --git a/src/anyio/to_process.py b/src/anyio/to_process.py index 8c18cd79..a06a93f8 100644 --- a/src/anyio/to_process.py +++ b/src/anyio/to_process.py @@ -26,7 +26,7 @@ async def run_sync( - func: Callable[..., T_Retval], *args, cancellable: bool = False, + func: Callable[..., T_Retval], *args: object, cancellable: bool = False, limiter: Optional[CapacityLimiter] = None) -> T_Retval: """ Call the given function with the given arguments in a worker process. @@ -43,7 +43,7 @@ async def run_sync( :return: an awaitable that yields the return value of the function. """ - async def send_raw_command(pickled_cmd: bytes): + async def send_raw_command(pickled_cmd: bytes) -> T_Retval: try: await stdin.send(pickled_cmd) response = await buffered.receive_until(b'\n', 50) diff --git a/src/anyio/to_thread.py b/src/anyio/to_thread.py index e0428777..0f3218e8 100644 --- a/src/anyio/to_thread.py +++ b/src/anyio/to_thread.py @@ -8,7 +8,7 @@ async def run_sync( - func: Callable[..., T_Retval], *args, cancellable: bool = False, + func: Callable[..., T_Retval], *args: object, cancellable: bool = False, limiter: Optional[CapacityLimiter] = None) -> T_Retval: """ Call the given function with the given arguments in a worker thread. @@ -30,7 +30,7 @@ async def run_sync( async def run_sync_in_worker_thread( - func: Callable[..., T_Retval], *args, cancellable: bool = False, + func: Callable[..., T_Retval], *args: object, cancellable: bool = False, limiter: Optional[CapacityLimiter] = None) -> T_Retval: warn('run_sync_in_worker_thread() has been deprecated, use anyio.to_thread.run_sync() instead', DeprecationWarning) From c9088acc4470c988518905ef118b647373da9fe6 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 14:46:43 +0100 Subject: [PATCH 02/31] fix forward annotations --- src/anyio/_core/_compat.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anyio/_core/_compat.py b/src/anyio/_core/_compat.py index 9a25725c..cdf75cf0 100644 --- a/src/anyio/_core/_compat.py +++ b/src/anyio/_core/_compat.py @@ -8,6 +8,8 @@ if TYPE_CHECKING: from ._testing import TaskInfo +else: + TaskInfo = object T = TypeVar('T') AnyDeprecatedAwaitable = Union['DeprecatedAwaitable', 'DeprecatedAwaitableFloat', @@ -105,7 +107,7 @@ def _unwrap(self) -> None: class DeprecatedAwaitableFloat(float): - def __new__(cls, x: float, func: Callable[..., 'DeprecatedAwaitableFloat']) -> DeprecatedAwaitableFloat: + def __new__(cls, x: float, func: Callable[..., 'DeprecatedAwaitableFloat']) -> 'DeprecatedAwaitableFloat': return super().__new__(cls, x) def __init__(self, x: float, func: Callable[..., 'DeprecatedAwaitableFloat']): From 0a4c2a6746a5057659d7fbd4a47c8256d2f415d9 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 14:47:06 +0100 Subject: [PATCH 03/31] actually return a TaskInfo when awaited --- src/anyio/_core/_testing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/anyio/_core/_testing.py b/src/anyio/_core/_testing.py index 2076db52..c41e38c4 100644 --- a/src/anyio/_core/_testing.py +++ b/src/anyio/_core/_testing.py @@ -1,4 +1,4 @@ -from typing import Coroutine, Iterable, Optional, Tuple, Type +from typing import Coroutine, Generator, Iterable, Optional, Tuple, Type from ._compat import DeprecatedAwaitable, DeprecatedAwaitableList, _warn_deprecation from ._eventloop import get_asynclib @@ -37,10 +37,11 @@ def __hash__(self) -> int: def __repr__(self) -> str: return f'{self.__class__.__name__}(id={self.id!r}, name={self.name!r})' - def __await__(self) -> Iterable["TaskInfo"]: + def __await__(self) -> Generator[None, None, "TaskInfo"]: _warn_deprecation(self) if False: yield + return self def __reduce__(self) -> Tuple[Type["TaskInfo"], Tuple[int, Optional[int], Optional[str], Coroutine]]: return TaskInfo, (self.id, self.parent_id, self.name, self.coro) From bff90a0dc404ac095e637f92b94098c39be5cbe5 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 14:47:27 +0100 Subject: [PATCH 04/31] can't del scope --- src/anyio/from_thread.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index 83b9e426..2df6b868 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -186,7 +186,7 @@ def callback(f: Future) -> None: if not future.cancelled(): future.set_result(retval) finally: - del scope + scope = None def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any], name: Optional[str], future: Future) -> None: From ae24c0674b8a0c9ad2f60fb74ff2f1e63ee38950 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 14:48:23 +0100 Subject: [PATCH 05/31] can't delete the scope --- src/anyio/from_thread.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index 2df6b868..71ed6153 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -186,7 +186,7 @@ def callback(f: Future) -> None: if not future.cancelled(): future.set_result(retval) finally: - scope = None + scope = None # type: ignore[assignment] def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any], name: Optional[str], future: Future) -> None: From ef44368bdc064785c2c06c45ec0a46e477678cec Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 14:51:17 +0100 Subject: [PATCH 06/31] remove unused imports --- src/anyio/_core/_testing.py | 4 ++-- src/anyio/lowlevel.py | 2 +- src/anyio/pytest_plugin.py | 1 - 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/anyio/_core/_testing.py b/src/anyio/_core/_testing.py index c41e38c4..5fd08a20 100644 --- a/src/anyio/_core/_testing.py +++ b/src/anyio/_core/_testing.py @@ -1,6 +1,6 @@ -from typing import Coroutine, Generator, Iterable, Optional, Tuple, Type +from typing import Coroutine, Generator, Optional, Tuple, Type -from ._compat import DeprecatedAwaitable, DeprecatedAwaitableList, _warn_deprecation +from ._compat import DeprecatedAwaitableList, _warn_deprecation from ._eventloop import get_asynclib diff --git a/src/anyio/lowlevel.py b/src/anyio/lowlevel.py index 8ec36953..540dc11b 100644 --- a/src/anyio/lowlevel.py +++ b/src/anyio/lowlevel.py @@ -1,6 +1,6 @@ import enum from dataclasses import dataclass -from typing import Any, Dict, Generic, NewType, Set, TypeVar, Union, cast +from typing import Any, Dict, Generic, Set, TypeVar, Union, cast from weakref import WeakKeyDictionary from ._core._eventloop import get_asynclib diff --git a/src/anyio/pytest_plugin.py b/src/anyio/pytest_plugin.py index 1eadafe5..79180ec5 100644 --- a/src/anyio/pytest_plugin.py +++ b/src/anyio/pytest_plugin.py @@ -16,7 +16,6 @@ if TYPE_CHECKING: from _pytest.config import Config - from _pytest.fixtures import FixtureDef _current_runner: Optional[TestRunner] = None From df6061eec9446479e1ce8fff038dc73c7ab610ee Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 15:05:01 +0100 Subject: [PATCH 07/31] flake8 >99 length lines --- src/anyio/_backends/_asyncio.py | 23 ++++++++++++++++------- src/anyio/_backends/_trio.py | 19 +++++++++++++------ src/anyio/_core/_compat.py | 16 ++++++++++++---- src/anyio/_core/_sockets.py | 10 +++++++--- src/anyio/_core/_streams.py | 4 +++- src/anyio/_core/_testing.py | 4 +++- src/anyio/abc/_tasks.py | 9 ++++++--- src/anyio/abc/_testing.py | 7 +++++-- src/anyio/from_thread.py | 15 ++++++++++----- src/anyio/streams/tls.py | 4 +++- 10 files changed, 78 insertions(+), 33 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 93df4e56..3c43289e 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -213,7 +213,8 @@ def _maybe_set_event_loop_policy(policy: Optional[asyncio.AbstractEventLoopPolic asyncio.set_event_loop_policy(policy) -def run(func: Callable[..., Awaitable[T_Retval]], *args: object, debug: bool = False, use_uvloop: bool = True, +def run(func: Callable[..., Awaitable[T_Retval]], *args: object, + debug: bool = False, use_uvloop: bool = True, policy: Optional[asyncio.AbstractEventLoopPolicy] = None) -> T_Retval: @wraps(func) async def wrapper() -> T_Retval: @@ -806,7 +807,9 @@ def wrapper() -> None: return f.result() -def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval: +def run_async_from_thread( + func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object +) -> T_Retval: f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe( func(*args), threadlocals.loop) return f.result() @@ -908,13 +911,16 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]: return self._stderr -async def open_process(command: Union[str, Sequence[str]], *, shell: bool, stdin: int, stdout: int, stderr: int, +async def open_process(command: Union[str, Sequence[str]], *, shell: bool, + stdin: int, stdout: int, stderr: int, cwd: Union[str, bytes, PathLike, None] = None, env: Optional[Mapping[str, str]] = None) -> Process: await checkpoint() if shell: - process = await asyncio.create_subprocess_shell(command, stdin=stdin, stdout=stdout, # type: ignore[arg-type] - stderr=stderr, cwd=cwd, env=env) + process = await asyncio.create_subprocess_shell( + command, stdin=stdin, stdout=stdout, # type: ignore[arg-type] + stderr=stderr, cwd=cwd, env=env, + ) else: process = await asyncio.create_subprocess_exec(*command, stdin=stdin, stdout=stdout, stderr=stderr, cwd=cwd, env=env) @@ -1758,7 +1764,9 @@ def __enter__(self) -> "_SignalReceiver": return self - def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> Optional[bool]: + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: for sig in self._handled_signals: self._loop.remove_signal_handler(sig) return None @@ -1851,7 +1859,8 @@ def close(self) -> None: asyncio.set_event_loop(None) self._loop.close() - def call(self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object) -> T_Retval: + def call(self, func: Callable[..., Awaitable[T_Retval]], + *args: object, **kwargs: object) -> T_Retval: def exception_handler(loop: asyncio.AbstractEventLoop, context: Dict[str, Any]) -> None: exceptions.append(context['exception']) diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 8cad535f..5d65e833 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -66,7 +66,8 @@ # class CancelScope(BaseCancelScope): - def __new__(cls, original: Optional[trio.CancelScope] = None, **kwargs: object) -> 'CancelScope': + def __new__(cls, original: Optional[trio.CancelScope] = None, + **kwargs: object) -> 'CancelScope': return object.__new__(cls) def __init__(self, original: Optional[trio.CancelScope] = None, **kwargs: object) -> None: @@ -76,7 +77,9 @@ def __enter__(self) -> 'CancelScope': self.__original.__enter__() return self - def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> Optional[bool]: + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: return self.__original.__exit__(exc_type, exc_val, exc_tb) def cancel(self) -> DeprecatedAwaitable: @@ -148,7 +151,8 @@ def start_soon(self, func: Callable, *args: object, name: str = None) -> None: self._nursery.start_soon(func, *args, name=name) - async def start(self, func: Callable[..., Coroutine], *args: object, name: str = None) -> object: + async def start(self, func: Callable[..., Coroutine], + *args: object, name: str = None) -> object: if not self._active: raise RuntimeError('This task group is not active; no new tasks can be started.') @@ -278,7 +282,8 @@ def stderr(self) -> Optional[abc.ByteReceiveStream]: return self._stderr -async def open_process(command: Union[str, Sequence[str]], *, shell: bool, stdin: int, stdout: int, stderr: int, +async def open_process(command: Union[str, Sequence[str]], *, shell: bool, + stdin: int, stdout: int, stderr: int, cwd: Union[str, bytes, PathLike, None] = None, env: Optional[Mapping[str, str]] = None) -> Process: process = await trio.open_process(command, stdin=stdin, stdout=stdout, stderr=stderr, @@ -751,7 +756,8 @@ async def _trio_main(self) -> None: async with trio.open_nursery() as self._nursery: await self._stop_event.wait() - async def _call_func(self, func: Callable[..., Awaitable[object]], args: tuple, kwargs: dict) -> None: + async def _call_func(self, func: Callable[..., Awaitable[object]], + args: tuple, kwargs: dict) -> None: try: retval = await func(*args, **kwargs) except BaseException as exc: @@ -768,7 +774,8 @@ def close(self) -> None: while self._nursery is not None: self._call_queue.get()() - def call(self, func: Callable[..., Awaitable[T_Retval]], *args: object, **kwargs: object) -> T_Retval: + def call(self, func: Callable[..., Awaitable[T_Retval]], + *args: object, **kwargs: object) -> T_Retval: if self._nursery is None: trio.lowlevel.start_guest_run( self._trio_main, run_sync_soon_threadsafe=self._call_queue.put, diff --git a/src/anyio/_core/_compat.py b/src/anyio/_core/_compat.py index cdf75cf0..755cb989 100644 --- a/src/anyio/_core/_compat.py +++ b/src/anyio/_core/_compat.py @@ -60,7 +60,9 @@ def __init__(self, cm: ContextManager[T]): async def __aenter__(self) -> T: return self._cm.__enter__() - async def __aexit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> Optional[bool]: + async def __aexit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: return self._cm.__exit__(exc_type, exc_val, exc_tb) @@ -107,7 +109,9 @@ def _unwrap(self) -> None: class DeprecatedAwaitableFloat(float): - def __new__(cls, x: float, func: Callable[..., 'DeprecatedAwaitableFloat']) -> 'DeprecatedAwaitableFloat': + def __new__( + cls, x: float, func: Callable[..., 'DeprecatedAwaitableFloat'] + ) -> 'DeprecatedAwaitableFloat': return super().__new__(cls, x) def __init__(self, x: float, func: Callable[..., 'DeprecatedAwaitableFloat']): @@ -152,7 +156,9 @@ def __enter__(self) -> T: pass @abstractmethod - def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> Optional[bool]: + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: pass async def __aenter__(self) -> T: @@ -162,5 +168,7 @@ async def __aenter__(self) -> T: f'you are completely migrating to AnyIO 3+.', DeprecationWarning) return self.__enter__() - async def __aexit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]) -> Optional[bool]: + async def __aexit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType]) -> Optional[bool]: return self.__exit__(exc_type, exc_val, exc_tb) diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py index 6abefad1..b7c2255c 100644 --- a/src/anyio/_core/_sockets.py +++ b/src/anyio/_core/_sockets.py @@ -84,8 +84,10 @@ async def connect_tcp( async def connect_tcp( - remote_host: IPAddressType, remote_port: int, *, local_host: Optional[IPAddressType] = None, tls: bool = False, ssl_context: Optional[ssl.SSLContext] = None, - tls_standard_compatible: bool = True, tls_hostname: Optional[str] = None, happy_eyeballs_delay: float = 0.25 + remote_host: IPAddressType, remote_port: int, *, local_host: Optional[IPAddressType] = None, + tls: bool = False, ssl_context: Optional[ssl.SSLContext] = None, + tls_standard_compatible: bool = True, tls_hostname: Optional[str] = None, + happy_eyeballs_delay: float = 0.25 ) -> Union[SocketStream, TLSStream]: """ Connect to a host using the TCP protocol. @@ -491,7 +493,9 @@ def as_two_tuple(self) -> Tuple[str, int]: return self.host, self.port -def convert_ipv6_sockaddr(sockaddr: Union[Tuple[str, int, object, object], Tuple[str, int]]) -> Tuple[str, int]: +def convert_ipv6_sockaddr( + sockaddr: Union[Tuple[str, int, object, object], Tuple[str, int]] +) -> Tuple[str, int]: """ Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format. diff --git a/src/anyio/_core/_streams.py b/src/anyio/_core/_streams.py index 2cf62226..f43875c5 100644 --- a/src/anyio/_core/_streams.py +++ b/src/anyio/_core/_streams.py @@ -21,7 +21,9 @@ def create_memory_object_stream( ... -def create_memory_object_stream(max_buffer_size: float = 0, item_type: Optional[Type[T_Item]] = None) -> Tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]: +def create_memory_object_stream( + max_buffer_size: float = 0, item_type: Optional[Type[T_Item]] = None +) -> Tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]: """ Create a memory object stream. diff --git a/src/anyio/_core/_testing.py b/src/anyio/_core/_testing.py index 5fd08a20..941232a4 100644 --- a/src/anyio/_core/_testing.py +++ b/src/anyio/_core/_testing.py @@ -43,7 +43,9 @@ def __await__(self) -> Generator[None, None, "TaskInfo"]: yield return self - def __reduce__(self) -> Tuple[Type["TaskInfo"], Tuple[int, Optional[int], Optional[str], Coroutine]]: + def __reduce__(self) -> Tuple[ + Type["TaskInfo"], Tuple[int, Optional[int], Optional[str], Coroutine] + ]: return TaskInfo, (self.id, self.parent_id, self.name, self.coro) def _unwrap(self) -> 'TaskInfo': diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py index e36a0e47..2860d66f 100644 --- a/src/anyio/abc/_tasks.py +++ b/src/anyio/abc/_tasks.py @@ -30,7 +30,8 @@ class TaskGroup(metaclass=ABCMeta): cancel_scope: 'CancelScope' - async def spawn(self, func: Callable[..., Coroutine], *args: object, name: Optional[str] = None) -> None: + async def spawn(self, func: Callable[..., Coroutine], + *args: object, name: Optional[str] = None) -> None: """ Start a new task in this task group. @@ -48,7 +49,8 @@ async def spawn(self, func: Callable[..., Coroutine], *args: object, name: Optio self.start_soon(func, *args, name=name) @abstractmethod - def start_soon(self, func: Callable[..., Coroutine], *args: object, name: Optional[str] = None) -> None: + def start_soon(self, func: Callable[..., Coroutine], + *args: object, name: Optional[str] = None) -> None: """ Start a new task in this task group. @@ -60,7 +62,8 @@ def start_soon(self, func: Callable[..., Coroutine], *args: object, name: Option """ @abstractmethod - async def start(self, func: Callable[..., Coroutine], *args: object, name: Optional[str] = None) -> object: + async def start(self, func: Callable[..., Coroutine], + *args: object, name: Optional[str] = None) -> object: """ Start a new task and wait until it signals for readiness. diff --git a/src/anyio/abc/_testing.py b/src/anyio/abc/_testing.py index 012271b4..68aeb00e 100644 --- a/src/anyio/abc/_testing.py +++ b/src/anyio/abc/_testing.py @@ -14,7 +14,9 @@ class TestRunner(metaclass=ABCMeta): def __enter__(self) -> 'TestRunner': return self - def __exit__(self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[types.TracebackType]) -> Optional[bool]: + def __exit__(self, exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[types.TracebackType]) -> Optional[bool]: self.close() return None @@ -23,7 +25,8 @@ def close(self) -> None: """Close the event loop.""" @abstractmethod - def call(self, func: Callable[..., Awaitable[_T]], *args: tuple, **kwargs: Dict[str, Any]) -> _T: + def call(self, func: Callable[..., Awaitable[_T]], + *args: tuple, **kwargs: Dict[str, Any]) -> _T: """ Call the given function within the backend's event loop. diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index 71ed6153..572822d2 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -34,7 +34,8 @@ def run(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_ return asynclib.run_async_from_thread(func, *args) -def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object) -> T_Retval: +def run_async_from_thread(func: Callable[..., Coroutine[Any, Any, T_Retval]], + *args: object) -> T_Retval: warn('run_async_from_thread() has been deprecated, use anyio.from_thread.run() instead', DeprecationWarning) return run(func, *args) @@ -213,7 +214,8 @@ def call(self, func: Callable[..., Coroutine[Any, Any, T_Retval]], *args: object def call(self, func: Callable[..., T_Retval], *args: object) -> T_Retval: ... - def call(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], *args: object) -> T_Retval: + def call(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], + *args: object) -> T_Retval: """ Call the given function in the event loop thread. @@ -226,7 +228,8 @@ def call(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval """ return self.start_task_soon(func, *args).result() - def spawn_task(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], *args: object, name: str = None) -> "Future[T_Retval]": + def spawn_task(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], + *args: object, name: str = None) -> "Future[T_Retval]": """ Start a task in the portal's task group. @@ -247,7 +250,8 @@ def spawn_task(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_ warn('spawn_task() is deprecated -- use start_task_soon() instead', DeprecationWarning) return self.start_task_soon(func, *args, name=name) - def start_task_soon(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], *args: object, name: str = None) -> "Future[T_Retval]": + def start_task_soon(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], + *args: object, name: str = None) -> "Future[T_Retval]": """ Start a task in the portal's task group. @@ -270,7 +274,8 @@ def start_task_soon(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval self._spawn_task_from_thread(func, args, {}, name, f) return f - def start_task(self, func: Callable[..., Coroutine], *args: object, name: str = None) -> Tuple[Future, Any]: + def start_task(self, func: Callable[..., Coroutine], *args: object, + name: str = None) -> Tuple[Future, Any]: """ Start a task in the portal's task group and wait until it signals for readiness. diff --git a/src/anyio/streams/tls.py b/src/anyio/streams/tls.py index b4d53585..cfb6606e 100644 --- a/src/anyio/streams/tls.py +++ b/src/anyio/streams/tls.py @@ -94,7 +94,9 @@ async def wrap(cls, transport_stream: AnyByteStream, *, server_side: Optional[bo await wrapper._call_sslobject_method(ssl_object.do_handshake) return wrapper - async def _call_sslobject_method(self, func: Callable[..., T_Retval], *args: object) -> T_Retval: + async def _call_sslobject_method( + self, func: Callable[..., T_Retval], *args: object + ) -> T_Retval: while True: try: result = func(*args) From fcc28ea2f63608aff6c66e6bbdf6d9d869187483 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 15:07:16 +0100 Subject: [PATCH 08/31] test TextStream.extra_attributes --- tests/streams/test_text.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/streams/test_text.py b/tests/streams/test_text.py index 47ba8f9a..16374dce 100644 --- a/tests/streams/test_text.py +++ b/tests/streams/test_text.py @@ -55,3 +55,4 @@ async def test_bidirectional_stream(): await send_stream.send(b'\xc3\xa6\xc3\xb8') assert await text_stream.receive() == 'æø' + assert text_stream.extra_attributes == {} From 6fffa998bafad031310d54d8543aceb6803efa37 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 14:45:41 +0100 Subject: [PATCH 09/31] add a test for await maybe_async(current_task()) --- tests/test_compat.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/test_compat.py b/tests/test_compat.py index f662e826..007e541f 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -35,6 +35,10 @@ async def test_get_running_tasks(self): tasks = await maybe_async(get_running_tasks()) assert type(tasks) is list + async def test_get_current_task(self): + task = await maybe_async(get_current_task()) + assert type(task) is TaskInfo + async def test_maybe_async_cm(): async with maybe_async_cm(CancelScope()): From 9b14209d88393465f52a57e4e35a408b49daece5 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 15:08:24 +0100 Subject: [PATCH 10/31] fix warning message to refer to anyio.maybe_async --- src/anyio/_core/_compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anyio/_core/_compat.py b/src/anyio/_core/_compat.py index 755cb989..fec5dc78 100644 --- a/src/anyio/_core/_compat.py +++ b/src/anyio/_core/_compat.py @@ -87,7 +87,7 @@ def maybe_async_cm(cm: Union[ContextManager[T], AsyncContextManager[T]]) -> Asyn def _warn_deprecation(awaitable: AnyDeprecatedAwaitable, stacklevel: int = 1) -> None: warn(f'Awaiting on {awaitable._name}() is deprecated. Use "await ' - f'anyio.maybe_awaitable({awaitable._name}(...)) if you have to support both AnyIO 2.x ' + f'anyio.maybe_async({awaitable._name}(...)) if you have to support both AnyIO 2.x ' f'and 3.x, or just remove the "await" if you are completely migrating to AnyIO 3+.', DeprecationWarning, stacklevel=stacklevel + 1) From ebb72c3645fa7d868e641f50777583fb997002aa Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 15:23:43 +0100 Subject: [PATCH 11/31] remove pointless TaskInfo.__reduce__ you cannot pickle a coroutine --- src/anyio/_core/_testing.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/anyio/_core/_testing.py b/src/anyio/_core/_testing.py index 941232a4..06b4d9c2 100644 --- a/src/anyio/_core/_testing.py +++ b/src/anyio/_core/_testing.py @@ -1,4 +1,4 @@ -from typing import Coroutine, Generator, Optional, Tuple, Type +from typing import Coroutine, Generator, Optional from ._compat import DeprecatedAwaitableList, _warn_deprecation from ._eventloop import get_asynclib @@ -43,11 +43,6 @@ def __await__(self) -> Generator[None, None, "TaskInfo"]: yield return self - def __reduce__(self) -> Tuple[ - Type["TaskInfo"], Tuple[int, Optional[int], Optional[str], Coroutine] - ]: - return TaskInfo, (self.id, self.parent_id, self.name, self.coro) - def _unwrap(self) -> 'TaskInfo': return self From fbb3f48ab7f80158e1bb3799247213b68f47da28 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 16:36:33 +0100 Subject: [PATCH 12/31] convert remote_address to str explicitly --- src/anyio/_core/_sockets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py index b7c2255c..39902a97 100644 --- a/src/anyio/_core/_sockets.py +++ b/src/anyio/_core/_sockets.py @@ -187,7 +187,7 @@ async def try_connect(remote_host: str, event: Event) -> None: if tls or tls_hostname or ssl_context: try: return await TLSStream.wrap(connected_stream, server_side=False, - hostname=tls_hostname or cast(Optional[str], remote_host), + hostname=tls_hostname or str(remote_host), ssl_context=ssl_context, standard_compatible=tls_standard_compatible) except BaseException: From 8eae7fceeeb8f899a99b83e65ee79c2a72ae4ab2 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 16:47:00 +0100 Subject: [PATCH 13/31] remove _SockAddr class --- src/anyio/_core/_sockets.py | 25 +++++++------------------ 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py index 39902a97..715499f1 100644 --- a/src/anyio/_core/_sockets.py +++ b/src/anyio/_core/_sockets.py @@ -1,7 +1,6 @@ import socket import ssl import sys -from dataclasses import dataclass from ipaddress import IPv6Address, ip_address from os import PathLike, chmod from pathlib import Path @@ -478,23 +477,8 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]: # Private API # -@dataclass -class _SockAddr: - host: str - port: int - flowinfo: Optional[str] = None - scope_id: int = 0 - - def as_two_tuple(self) -> Tuple[str, int]: - if self.scope_id: - # Add scopeid to the address - return f"{self.host}%{self.scope_id}", self.port - else: - return self.host, self.port - - def convert_ipv6_sockaddr( - sockaddr: Union[Tuple[str, int, object, object], Tuple[str, int]] + sockaddr: Union[Tuple[str, int, int, int], Tuple[str, int]] ) -> Tuple[str, int]: """ Convert a 4-tuple IPv6 socket address to a 2-tuple (address, port) format. @@ -510,6 +494,11 @@ def convert_ipv6_sockaddr( # This is more complicated than it should be because of MyPy if isinstance(sockaddr, tuple) and len(sockaddr) == 4: - return _SockAddr(*sockaddr).as_two_tuple() + host, port, flowinfo, scope_id = cast(Tuple[str, int, int, int], sockaddr) + if scope_id: + # Add scope_id to the address + return f"{host}%{scope_id}", port + else: + return host, port else: return cast(Tuple[str, int], sockaddr) From 4744d9b1c1c6ad0afc0ea50fa747c19d4ec96b6f Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 16:49:26 +0100 Subject: [PATCH 14/31] remove pointless whitespace --- src/anyio/_core/_sockets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py index 715499f1..f58cdc3a 100644 --- a/src/anyio/_core/_sockets.py +++ b/src/anyio/_core/_sockets.py @@ -491,7 +491,6 @@ def convert_ipv6_sockaddr( :return: the converted socket address """ - # This is more complicated than it should be because of MyPy if isinstance(sockaddr, tuple) and len(sockaddr) == 4: host, port, flowinfo, scope_id = cast(Tuple[str, int, int, int], sockaddr) From b0e984cc0c4314de1eb78f5e60b3c30b7f963ad9 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 12 May 2021 16:50:46 +0100 Subject: [PATCH 15/31] move asyncio test runner exception handling back out of the finally --- src/anyio/_backends/_asyncio.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 3c43289e..a4336859 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -1874,9 +1874,9 @@ def exception_handler(loop: asyncio.AbstractEventLoop, context: Dict[str, Any]) finally: self._loop.set_exception_handler(None) - if len(exceptions) == 1: - raise exceptions[0] - elif exceptions: - raise ExceptionGroup(exceptions) + if len(exceptions) == 1: + raise exceptions[0] + elif exceptions: + raise ExceptionGroup(exceptions) return retval From 36fd77061169ead2d35783be7b2fdb06d2b8600a Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 18 May 2021 09:48:25 +0100 Subject: [PATCH 16/31] restore _TokenWrapper.__slots__ code style --- src/anyio/lowlevel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anyio/lowlevel.py b/src/anyio/lowlevel.py index 540dc11b..00d08d79 100644 --- a/src/anyio/lowlevel.py +++ b/src/anyio/lowlevel.py @@ -62,7 +62,7 @@ def current_token() -> object: @dataclass(frozen=True) class _TokenWrapper: - __slots__ = ("_token", "__weakref__") + __slots__ = '_token', '__weakref__' _token: object From bb9ec2fe36bd5214e418ad42c4b8689ff9c32fa1 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 18 May 2021 10:16:25 +0100 Subject: [PATCH 17/31] convert start_soon(..., name=object) to str and support falsy name="" --- src/anyio/_backends/_asyncio.py | 2 +- tests/test_debugging.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index a4336859..1b3388ab 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -644,7 +644,7 @@ def task_done(_task: asyncio.Task) -> None: raise RuntimeError('This task group is not active; no new tasks can be started.') options = {} - name = name or get_callable_name(func) + name = get_callable_name(func) if name is None else str(name) if _native_task_names: options['name'] = name diff --git a/tests/test_debugging.py b/tests/test_debugging.py index e7636096..7d982270 100644 --- a/tests/test_debugging.py +++ b/tests/test_debugging.py @@ -29,6 +29,25 @@ async def main(): loop.close() +@pytest.mark.parametrize( + "name_input,expected", + [ + (None, 'test_debugging.test_non_main_task_name..non_main'), + (b'name', "b'name'"), + ("name", "name"), + ("", ""), + ], +) +async def test_non_main_task_name(name_input, expected): + async def non_main(*, task_status): + task_status.started(anyio.get_current_task().name) + + async with anyio.create_task_group() as tg: + name = await tg.start(non_main, name=name_input) + + assert name == expected + + async def test_get_running_tasks(): async def inspect(): await wait_all_tasks_blocked() From 13e5ac3654422a771da3db4a59c132cf52fe1017 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 18 May 2021 10:20:58 +0100 Subject: [PATCH 18/31] type annotate tg.start_soon(..., name: object) --- src/anyio/_backends/_asyncio.py | 14 ++++++++------ src/anyio/_backends/_trio.py | 6 +++--- src/anyio/abc/_tasks.py | 6 +++--- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 1b3388ab..c8545e3f 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -100,8 +100,8 @@ def _cancel_all_tasks(loop): events.set_event_loop(None) loop.close() - def create_task(coro: Union[Generator[Any, None, _T], Awaitable[_T]], *, # type: ignore - name: Optional[str] = None) -> asyncio.Task: + def create_task(coro: Union[Generator[Any, None, _T], Awaitable[_T]], *, + name: object = None) -> asyncio.Task: return get_running_loop().create_task(coro) def get_running_loop() -> asyncio.AbstractEventLoop: @@ -614,7 +614,7 @@ async def _run_wrapped_task( self.cancel_scope._tasks.remove(task) del _task_states[task] - def _spawn(self, func: Callable[..., Coroutine], args: tuple, name: Optional[str], + def _spawn(self, func: Callable[..., Coroutine], args: tuple, name: object, task_status_future: Optional[asyncio.Future] = None) -> asyncio.Task: def task_done(_task: asyncio.Task) -> None: # This is the code path for Python 3.8+ @@ -670,10 +670,12 @@ def task_done(_task: asyncio.Task) -> None: self.cancel_scope._tasks.add(task) return task - def start_soon(self, func: Callable[..., Coroutine], *args: object, name: str = None) -> None: + def start_soon(self, func: Callable[..., Coroutine], *args: object, + name: object = None) -> None: self._spawn(func, args, name) - async def start(self, func: Callable[..., Coroutine], *args: object, name: str = None) -> None: + async def start(self, func: Callable[..., Coroutine], *args: object, + name: object = None) -> None: future: asyncio.Future = asyncio.Future() task = self._spawn(func, args, name, future) @@ -824,7 +826,7 @@ def __init__(self) -> None: self._loop = get_running_loop() def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any], - name: Optional[str], future: Future) -> None: + name: object, future: Future) -> None: run_sync_from_thread( partial(self._task_group.start_soon, name=name), self._call_func, func, args, kwargs, future, loop=self._loop) diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 5d65e833..b43827ab 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -145,14 +145,14 @@ async def __aexit__(self, exc_type: Optional[Type[BaseException]], finally: self._active = False - def start_soon(self, func: Callable, *args: object, name: str = None) -> None: + def start_soon(self, func: Callable, *args: object, name: object = None) -> None: if not self._active: raise RuntimeError('This task group is not active; no new tasks can be started.') self._nursery.start_soon(func, *args, name=name) async def start(self, func: Callable[..., Coroutine], - *args: object, name: str = None) -> object: + *args: object, name: object = None) -> object: if not self._active: raise RuntimeError('This task group is not active; no new tasks can be started.') @@ -185,7 +185,7 @@ def __init__(self) -> None: self._token = trio.lowlevel.current_trio_token() def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any], - name: Optional[str], future: Future) -> None: + name: object, future: Future) -> None: return trio.from_thread.run_sync( partial(self._task_group.start_soon, name=name), self._call_func, func, args, kwargs, future, trio_token=self._token) diff --git a/src/anyio/abc/_tasks.py b/src/anyio/abc/_tasks.py index 2860d66f..afa2d983 100644 --- a/src/anyio/abc/_tasks.py +++ b/src/anyio/abc/_tasks.py @@ -31,7 +31,7 @@ class TaskGroup(metaclass=ABCMeta): cancel_scope: 'CancelScope' async def spawn(self, func: Callable[..., Coroutine], - *args: object, name: Optional[str] = None) -> None: + *args: object, name: object = None) -> None: """ Start a new task in this task group. @@ -50,7 +50,7 @@ async def spawn(self, func: Callable[..., Coroutine], @abstractmethod def start_soon(self, func: Callable[..., Coroutine], - *args: object, name: Optional[str] = None) -> None: + *args: object, name: object = None) -> None: """ Start a new task in this task group. @@ -63,7 +63,7 @@ def start_soon(self, func: Callable[..., Coroutine], @abstractmethod async def start(self, func: Callable[..., Coroutine], - *args: object, name: Optional[str] = None) -> object: + *args: object, name: object = None) -> object: """ Start a new task and wait until it signals for readiness. From 3f726be069e621681c11ac11831b31449f8c0853 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 18 May 2021 10:42:46 +0100 Subject: [PATCH 19/31] fix type annotation of send_raw_command --- src/anyio/to_process.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anyio/to_process.py b/src/anyio/to_process.py index a06a93f8..5675e9eb 100644 --- a/src/anyio/to_process.py +++ b/src/anyio/to_process.py @@ -43,7 +43,7 @@ async def run_sync( :return: an awaitable that yields the return value of the function. """ - async def send_raw_command(pickled_cmd: bytes) -> T_Retval: + async def send_raw_command(pickled_cmd: bytes) -> object: try: await stdin.send(pickled_cmd) response = await buffered.receive_until(b'\n', 50) @@ -145,7 +145,7 @@ async def send_raw_command(pickled_cmd: bytes) -> T_Retval: with CancelScope(shield=not cancellable): try: - return await send_raw_command(request) + return cast(T_Retval, await send_raw_command(request)) finally: if process in workers: idle_workers.append((process, current_time())) From 686ded77348790a6d3cf9d56fc85c560f3892fb2 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 18 May 2021 10:44:46 +0100 Subject: [PATCH 20/31] avoid redundant forward references to TaskInfo --- src/anyio/_core/_compat.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anyio/_core/_compat.py b/src/anyio/_core/_compat.py index fec5dc78..0ebfeb2d 100644 --- a/src/anyio/_core/_compat.py +++ b/src/anyio/_core/_compat.py @@ -13,11 +13,11 @@ T = TypeVar('T') AnyDeprecatedAwaitable = Union['DeprecatedAwaitable', 'DeprecatedAwaitableFloat', - 'DeprecatedAwaitableList', 'TaskInfo'] + 'DeprecatedAwaitableList', TaskInfo] @overload -async def maybe_async(__obj: 'TaskInfo') -> 'TaskInfo': +async def maybe_async(__obj: TaskInfo) -> TaskInfo: ... From 3f1d35e4201b435e442bd1ead9d0c11087b25408 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 18 May 2021 10:46:44 +0100 Subject: [PATCH 21/31] removed redundant [Any] type parameters --- src/anyio/_core/_streams.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anyio/_core/_streams.py b/src/anyio/_core/_streams.py index f43875c5..4a003bea 100644 --- a/src/anyio/_core/_streams.py +++ b/src/anyio/_core/_streams.py @@ -1,5 +1,5 @@ import math -from typing import Any, Optional, Tuple, Type, TypeVar, overload +from typing import Optional, Tuple, Type, TypeVar, overload from ..streams.memory import ( MemoryObjectReceiveStream, MemoryObjectSendStream, MemoryObjectStreamState) @@ -17,13 +17,13 @@ def create_memory_object_stream( @overload def create_memory_object_stream( max_buffer_size: float = 0 -) -> Tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]: +) -> Tuple[MemoryObjectSendStream, MemoryObjectReceiveStream]: ... def create_memory_object_stream( max_buffer_size: float = 0, item_type: Optional[Type[T_Item]] = None -) -> Tuple[MemoryObjectSendStream[Any], MemoryObjectReceiveStream[Any]]: +) -> Tuple[MemoryObjectSendStream, MemoryObjectReceiveStream]: """ Create a memory object stream. From f5543a479f92314aa69cdd66fcc7f2b105fd446f Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Tue, 18 May 2021 22:46:22 +0100 Subject: [PATCH 22/31] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Alex Grönholm --- src/anyio/_core/_testing.py | 1 + src/anyio/abc/_resources.py | 2 +- src/anyio/pytest_plugin.py | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anyio/_core/_testing.py b/src/anyio/_core/_testing.py index 06b4d9c2..c48bd45e 100644 --- a/src/anyio/_core/_testing.py +++ b/src/anyio/_core/_testing.py @@ -41,6 +41,7 @@ def __await__(self) -> Generator[None, None, "TaskInfo"]: _warn_deprecation(self) if False: yield + return self def _unwrap(self) -> 'TaskInfo': diff --git a/src/anyio/abc/_resources.py b/src/anyio/abc/_resources.py index 886366b9..4594e6e9 100644 --- a/src/anyio/abc/_resources.py +++ b/src/anyio/abc/_resources.py @@ -13,7 +13,7 @@ class AsyncResource(metaclass=ABCMeta): :meth:`aclose` on exit. """ - async def __aenter__(self: T) -> "T": + async def __aenter__(self: T) -> T: return self async def __aexit__(self, exc_type: Optional[Type[BaseException]], diff --git a/src/anyio/pytest_plugin.py b/src/anyio/pytest_plugin.py index 79180ec5..0e99a456 100644 --- a/src/anyio/pytest_plugin.py +++ b/src/anyio/pytest_plugin.py @@ -133,6 +133,7 @@ def run_with_hypothesis(**kwargs: Any) -> None: runner.call(pyfuncitem.obj, **testargs) return True + return None From 8ff7a2a8a052623022b4ea1f76c573325ff9d736 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 19 May 2021 10:28:40 +0100 Subject: [PATCH 23/31] remove if TYPE_CHECKING for trio.socket --- src/anyio/_backends/_trio.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index b43827ab..967e74c5 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -8,11 +8,12 @@ from os import PathLike from types import TracebackType from typing import ( - TYPE_CHECKING, Any, Awaitable, Callable, Collection, ContextManager, Coroutine, Dict, Generic, - List, Mapping, NoReturn, Optional, Sequence, Set, Tuple, Type, TypeVar, Union) + Any, Awaitable, Callable, Collection, ContextManager, Coroutine, Dict, Generic, List, Mapping, + NoReturn, Optional, Sequence, Set, Tuple, Type, TypeVar, Union) import trio.from_thread from outcome import Error, Outcome, Value +from trio.socket import SocketType as TrioSocketType from trio.to_thread import run_sync from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc @@ -36,10 +37,6 @@ else: from trio.lowlevel import wait_readable, wait_writable -if TYPE_CHECKING: - from trio.socket import SocketType as TrioSocketType -else: - TrioSocketType = object T_Retval = TypeVar('T_Retval') T_SockAddr = TypeVar('T_SockAddr', str, IPSockAddrType) From 0b479b5150410e24b47e89e18393c681ceba6f16 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 19 May 2021 10:31:23 +0100 Subject: [PATCH 24/31] fix TestRunner._result_queue type --- src/anyio/_backends/_trio.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index 967e74c5..42764907 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -8,8 +8,8 @@ from os import PathLike from types import TracebackType from typing import ( - Any, Awaitable, Callable, Collection, ContextManager, Coroutine, Dict, Generic, List, Mapping, - NoReturn, Optional, Sequence, Set, Tuple, Type, TypeVar, Union) + Any, Awaitable, Callable, Collection, ContextManager, Coroutine, Deque, Dict, Generic, List, + Mapping, NoReturn, Optional, Sequence, Set, Tuple, Type, TypeVar, Union) import trio.from_thread from outcome import Error, Outcome, Value @@ -743,7 +743,7 @@ def __init__(self, **options: object) -> None: from queue import Queue self._call_queue: "Queue[Callable[..., object]]" = Queue() - self._result_queue: Outcome = deque() + self._result_queue: Deque[Outcome] = deque() self._stop_event: Optional[trio.Event] = None self._nursery: Optional[trio.Nursery] = None self._options = options From e46ed143ac2a823b5631d13f1c9166001e7d1a2c Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 19 May 2021 10:38:19 +0100 Subject: [PATCH 25/31] remove cast from RunVar.get and restore RunvarToken identity check --- src/anyio/lowlevel.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/src/anyio/lowlevel.py b/src/anyio/lowlevel.py index 00d08d79..471b7e6b 100644 --- a/src/anyio/lowlevel.py +++ b/src/anyio/lowlevel.py @@ -1,10 +1,16 @@ import enum +import sys from dataclasses import dataclass -from typing import Any, Dict, Generic, Set, TypeVar, Union, cast +from typing import Any, Dict, Generic, Set, TypeVar, Union, overload from weakref import WeakKeyDictionary from ._core._eventloop import get_asynclib +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + T = TypeVar('T') D = TypeVar('D') @@ -73,9 +79,9 @@ class _NoValueSet(enum.Enum): class RunvarToken(Generic[T]): __slots__ = '_var', '_value', '_redeemed' - def __init__(self, var: 'RunVar', value: Union[T, _NoValueSet]): + def __init__(self, var: 'RunVar', value: Union[T, Literal[_NoValueSet.NO_VALUE_SET]]): self._var = var - self._value: Union[T, _NoValueSet] = value + self._value: Union[T, Literal[_NoValueSet.NO_VALUE_SET]] = value self._redeemed = False @@ -83,11 +89,12 @@ class RunVar(Generic[T]): """Like a :class:`~contextvars.ContextVar`, expect scoped to the running event loop.""" __slots__ = '_name', '_default' - NO_VALUE_SET = _NoValueSet.NO_VALUE_SET + NO_VALUE_SET: Literal[_NoValueSet.NO_VALUE_SET] = _NoValueSet.NO_VALUE_SET _token_wrappers: Set[_TokenWrapper] = set() - def __init__(self, name: str, default: Union[T, object] = NO_VALUE_SET): + def __init__(self, name: str, + default: Union[T, Literal[_NoValueSet.NO_VALUE_SET]] = NO_VALUE_SET): self._name = name self._default = default @@ -107,14 +114,22 @@ def _current_vars(self) -> Dict[str, T]: run_vars = _run_vars[token] = {} return run_vars - def get(self, default: Union[T, object] = NO_VALUE_SET) -> T: + @overload + def get(self, default: D) -> Union[T, D]: ... + + @overload + def get(self) -> T: ... + + def get( + self, default: Union[D, Literal[_NoValueSet.NO_VALUE_SET]] = NO_VALUE_SET + ) -> Union[T, D]: try: return self._current_vars[self._name] except KeyError: if default is not RunVar.NO_VALUE_SET: - return cast(T, default) + return default elif self._default is not RunVar.NO_VALUE_SET: - return cast(T, self._default) + return self._default raise LookupError(f'Run variable "{self._name}" has no value and no default set') @@ -131,7 +146,7 @@ def reset(self, token: RunvarToken[T]) -> None: if token._redeemed: raise ValueError('This token has already been used') - if isinstance(token._value, _NoValueSet): + if token._value is _NoValueSet.NO_VALUE_SET: try: del self._current_vars[self._name] except KeyError: From f23d8573098b11694312cfcb6fecc16c509398dd Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 19 May 2021 10:50:06 +0100 Subject: [PATCH 26/31] name can still be any object, even if you're calling from a thread --- src/anyio/from_thread.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/anyio/from_thread.py b/src/anyio/from_thread.py index 7a3afc33..29f7d03d 100644 --- a/src/anyio/from_thread.py +++ b/src/anyio/from_thread.py @@ -190,7 +190,7 @@ def callback(f: Future) -> None: scope = None # type: ignore[assignment] def _spawn_task_from_thread(self, func: Callable, args: tuple, kwargs: Dict[str, Any], - name: Optional[str], future: Future) -> None: + name: object, future: Future) -> None: """ Spawn a new task using the given callable. @@ -229,7 +229,7 @@ def call(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval return self.start_task_soon(func, *args).result() def spawn_task(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], - *args: object, name: str = None) -> "Future[T_Retval]": + *args: object, name: object = None) -> "Future[T_Retval]": """ Start a task in the portal's task group. @@ -251,7 +251,7 @@ def spawn_task(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_ return self.start_task_soon(func, *args, name=name) def start_task_soon(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval], T_Retval]], - *args: object, name: str = None) -> "Future[T_Retval]": + *args: object, name: object = None) -> "Future[T_Retval]": """ Start a task in the portal's task group. @@ -275,7 +275,7 @@ def start_task_soon(self, func: Callable[..., Union[Coroutine[Any, Any, T_Retval return f def start_task(self, func: Callable[..., Coroutine], *args: object, - name: str = None) -> Tuple[Future, Any]: + name: object = None) -> Tuple[Future, Any]: """ Start a task in the portal's task group and wait until it signals for readiness. From 89bab0decc40784b137defd698c66973254d4cda Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 19 May 2021 10:50:58 +0100 Subject: [PATCH 27/31] undo dedent of run_sync_in_worker_thread --- src/anyio/to_thread.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anyio/to_thread.py b/src/anyio/to_thread.py index 0f3218e8..5fc95894 100644 --- a/src/anyio/to_thread.py +++ b/src/anyio/to_thread.py @@ -30,7 +30,7 @@ async def run_sync( async def run_sync_in_worker_thread( - func: Callable[..., T_Retval], *args: object, cancellable: bool = False, + func: Callable[..., T_Retval], *args: object, cancellable: bool = False, limiter: Optional[CapacityLimiter] = None) -> T_Retval: warn('run_sync_in_worker_thread() has been deprecated, use anyio.to_thread.run_sync() instead', DeprecationWarning) From f3d7411dbc7904aa774d81ea2beafabecb996d2f Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 19 May 2021 11:03:58 +0100 Subject: [PATCH 28/31] use Mapping[Any, Callable[[], Any] for extra_attributes --- src/anyio/abc/_sockets.py | 5 +++-- src/anyio/streams/buffered.py | 3 ++- src/anyio/streams/file.py | 6 +++--- src/anyio/streams/stapled.py | 8 ++++---- src/anyio/streams/text.py | 8 ++++---- src/anyio/streams/tls.py | 6 +++--- 6 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/anyio/abc/_sockets.py b/src/anyio/abc/_sockets.py index 65a597ec..2e460fbb 100644 --- a/src/anyio/abc/_sockets.py +++ b/src/anyio/abc/_sockets.py @@ -4,7 +4,8 @@ from socket import AddressFamily, SocketType from types import TracebackType from typing import ( - Any, AsyncContextManager, Callable, Collection, List, Optional, Tuple, Type, TypeVar, Union) + Any, AsyncContextManager, Callable, Collection, List, Mapping, Optional, Tuple, Type, TypeVar, + Union) from .._core._typedattr import TypedAttributeProvider, TypedAttributeSet, typed_attribute from ._streams import ByteStream, Listener, T_Stream, UnreliableObjectStream @@ -44,7 +45,7 @@ class SocketAttribute(TypedAttributeSet): class _SocketProvider(TypedAttributeProvider): @property - def extra_attributes(self) -> dict: + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: from .._core._sockets import convert_ipv6_sockaddr as convert attributes = { diff --git a/src/anyio/streams/buffered.py b/src/anyio/streams/buffered.py index 7bbe2085..53ef5176 100644 --- a/src/anyio/streams/buffered.py +++ b/src/anyio/streams/buffered.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, field +from typing import Any, Callable, Mapping from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead from ..abc import AnyByteReceiveStream, ByteReceiveStream @@ -25,7 +26,7 @@ def buffer(self) -> bytes: return bytes(self._buffer) @property - def extra_attributes(self) -> dict: + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.receive_stream.extra_attributes # type: ignore[return-value] async def receive(self, max_bytes: int = 65536) -> bytes: diff --git a/src/anyio/streams/file.py b/src/anyio/streams/file.py index 970fe44c..d3e22bae 100644 --- a/src/anyio/streams/file.py +++ b/src/anyio/streams/file.py @@ -1,6 +1,6 @@ from io import SEEK_SET, UnsupportedOperation from pathlib import Path -from typing import BinaryIO, Union, cast +from typing import Any, BinaryIO, Callable, Dict, Mapping, Union, cast from .. import ( BrokenResourceError, ClosedResourceError, EndOfStream, TypedAttributeSet, to_thread, @@ -25,8 +25,8 @@ async def aclose(self) -> None: await to_thread.run_sync(self._file.close) @property - def extra_attributes(self) -> dict: - attributes: dict = { + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + attributes: Dict[Any, Callable[[], Any]] = { FileStreamAttribute.file: lambda: self._file, } diff --git a/src/anyio/streams/stapled.py b/src/anyio/streams/stapled.py index 771eab6b..0d5e7fb2 100644 --- a/src/anyio/streams/stapled.py +++ b/src/anyio/streams/stapled.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Callable, Generic, List, Optional, Sequence, TypeVar +from typing import Any, Callable, Generic, List, Mapping, Optional, Sequence, TypeVar from ..abc import ( ByteReceiveStream, ByteSendStream, ByteStream, Listener, ObjectReceiveStream, ObjectSendStream, @@ -38,7 +38,7 @@ async def aclose(self) -> None: await self.receive_stream.aclose() @property - def extra_attributes(self) -> dict: + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return {**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes} @@ -71,7 +71,7 @@ async def aclose(self) -> None: await self.receive_stream.aclose() @property - def extra_attributes(self) -> dict: + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return {**self.send_stream.extra_attributes, **self.receive_stream.extra_attributes} @@ -116,7 +116,7 @@ async def aclose(self) -> None: await listener.aclose() @property - def extra_attributes(self) -> dict: + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: attributes: dict = {} for listener in self.listeners: attributes.update(listener.extra_attributes) diff --git a/src/anyio/streams/text.py b/src/anyio/streams/text.py index cfc1cb65..0cdfe227 100644 --- a/src/anyio/streams/text.py +++ b/src/anyio/streams/text.py @@ -1,6 +1,6 @@ import codecs from dataclasses import InitVar, dataclass, field -from typing import Callable, Tuple +from typing import Any, Callable, Mapping, Tuple from ..abc import ( AnyByteReceiveStream, AnyByteSendStream, AnyByteStream, ObjectReceiveStream, ObjectSendStream, @@ -45,7 +45,7 @@ async def aclose(self) -> None: self._decoder.reset() @property - def extra_attributes(self) -> dict: + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.transport_stream.extra_attributes # type: ignore[return-value] @@ -79,7 +79,7 @@ async def aclose(self) -> None: await self.transport_stream.aclose() @property - def extra_attributes(self) -> dict: + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.transport_stream.extra_attributes # type: ignore[return-value] @@ -126,5 +126,5 @@ async def aclose(self) -> None: await self._receive_stream.aclose() @property - def extra_attributes(self) -> dict: + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return {**self._send_stream.extra_attributes, **self._receive_stream.extra_attributes} diff --git a/src/anyio/streams/tls.py b/src/anyio/streams/tls.py index cfb6606e..1d84c05c 100644 --- a/src/anyio/streams/tls.py +++ b/src/anyio/streams/tls.py @@ -3,7 +3,7 @@ import ssl from dataclasses import dataclass from functools import wraps -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, TypeVar, Union from .. import BrokenResourceError, EndOfStream, aclose_forcefully, get_cancelled_exc_class from .._core._typedattr import TypedAttributeSet, typed_attribute @@ -170,7 +170,7 @@ async def send_eof(self) -> None: raise NotImplementedError('send_eof() has not yet been implemented for TLS streams') @property - def extra_attributes(self) -> dict: + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return { **self.transport_stream.extra_attributes, TLSAttribute.alpn_protocol: self._ssl_object.selected_alpn_protocol, @@ -256,7 +256,7 @@ async def aclose(self) -> None: await self.listener.aclose() @property - def extra_attributes(self) -> dict: + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return { TLSAttribute.standard_compatible: lambda: self.standard_compatible, } From 51a5f331dc3bb7c60a92618b3c7305537b1ea323 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 19 May 2021 15:11:06 +0300 Subject: [PATCH 29/31] Updated the version history --- docs/versionhistory.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index d1df910d..5571c89c 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -11,6 +11,7 @@ This library adheres to `Semantic Versioning 2.0 `_. - Changed asyncio task groups so that if the host and child tasks have only raised ``CancelledErrors``, just one ``CancelledError`` will now be raised instead of an ``ExceptionGroup``, allowing asyncio to ignore it when it propagates out of the task +- Changed task names to be converted to ``str`` early on asyncio (PR by Thomas Grainger) - Fixed ``sniffio._impl.AsyncLibraryNotFoundError: unknown async library, or not in async context`` on asyncio and Python 3.6 when ``to_thread.run_sync()`` is used from ``loop.run_until_complete()`` @@ -20,6 +21,15 @@ This library adheres to `Semantic Versioning 2.0 `_. task is cancelled (PR by Thomas Grainger) - Fixed declared return type of ``TaskGroup.start()`` (it was declared as ``None``, but anything can be returned from it) +- Fixed ``TextStream.extra_attributes`` raising ``AttributeError`` (PR by Thomas Grainger) +- Fixed ``await maybe_async(current_task())`` returning ``None`` (PR by Thomas Grainger) +- Fixed: ``pickle.dumps(current_task())`` now correctly raises ``TypeError`` instead of pickling to + ``None`` (PR by Thomas Grainger) +- Fixed return type annotation of ``Event.wait()`` (``bool`` → ``None``) (PR by Thomas Grainger) +- Fixed return type annotation of ``RunVar.get()`` to return either the type of the default value + or the type of the contained value (PR by Thomas Grainger) +- Fixed a deprecation warning message to refer to ``maybe_async()`` and not ``maybe_awaitable()`` + (PR by Thomas Grainger) **3.0.1** From eecb90b5e8847e026e864af640c571dd736b2314 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Wed, 19 May 2021 14:36:22 +0100 Subject: [PATCH 30/31] Update docs/versionhistory.rst regarding untyped defs --- docs/versionhistory.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 5571c89c..e3eee7a1 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -30,6 +30,9 @@ This library adheres to `Semantic Versioning 2.0 `_. or the type of the contained value (PR by Thomas Grainger) - Fixed a deprecation warning message to refer to ``maybe_async()`` and not ``maybe_awaitable()`` (PR by Thomas Grainger) +- Filled in argument and return types for all functions and methods + previously missing them + (PR by Thomas Grainger) **3.0.1** From 0119ad6ecb9f91d371bb41afda8ebd6ce47ef3b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alex=20Gr=C3=B6nholm?= Date: Wed, 19 May 2021 16:39:01 +0300 Subject: [PATCH 31/31] Adjusted formatting --- docs/versionhistory.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index e3eee7a1..d112e3a9 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -30,8 +30,7 @@ This library adheres to `Semantic Versioning 2.0 `_. or the type of the contained value (PR by Thomas Grainger) - Fixed a deprecation warning message to refer to ``maybe_async()`` and not ``maybe_awaitable()`` (PR by Thomas Grainger) -- Filled in argument and return types for all functions and methods - previously missing them +- Filled in argument and return types for all functions and methods previously missing them (PR by Thomas Grainger) **3.0.1**