diff --git a/src/prefect/_internal/concurrency/api.py b/src/prefect/_internal/concurrency/api.py index bcfa6ae189db..f263e61b6def 100644 --- a/src/prefect/_internal/concurrency/api.py +++ b/src/prefect/_internal/concurrency/api.py @@ -6,50 +6,46 @@ import asyncio import concurrent.futures import contextlib -from typing import ( - Any, - Awaitable, - Callable, - ContextManager, - Iterable, - Optional, - TypeVar, - Union, -) +from collections.abc import Awaitable, Iterable +from contextlib import AbstractContextManager +from typing import Any, Callable, Optional, Union, cast -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, TypeAlias, TypeVar from prefect._internal.concurrency.threads import ( WorkerThread, get_global_loop, in_global_loop, ) -from prefect._internal.concurrency.waiters import ( - AsyncWaiter, - Call, - SyncWaiter, -) +from prefect._internal.concurrency.waiters import AsyncWaiter, Call, SyncWaiter P = ParamSpec("P") -T = TypeVar("T") +T = TypeVar("T", infer_variance=True) Future = Union[concurrent.futures.Future[T], asyncio.Future[T]] +_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[T, Awaitable[T]]] -def create_call(__fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Call[T]: + +def create_call( + __fn: _SyncOrAsyncCallable[P, T], *args: P.args, **kwargs: P.kwargs +) -> Call[T]: return Call[T].new(__fn, *args, **kwargs) -def _cast_to_call(call_like: Union[Callable[[], T], Call[T]]) -> Call[T]: +def cast_to_call( + call_like: Union["_SyncOrAsyncCallable[[], T]", Call[T]], +) -> Call[T]: if isinstance(call_like, Call): - return call_like + return cast(Call[T], call_like) else: return create_call(call_like) class _base(abc.ABC): - @abc.abstractstaticmethod + @staticmethod + @abc.abstractmethod def wait_for_call_in_loop_thread( - __call: Union[Callable[[], T], Call[T]], # type: ignore[reportGeneralTypeIssues] + __call: Union["_SyncOrAsyncCallable[[], Any]", Call[T]], timeout: Optional[float] = None, done_callbacks: Optional[Iterable[Call[Any]]] = None, ) -> T: @@ -60,9 +56,10 @@ def wait_for_call_in_loop_thread( """ raise NotImplementedError() - @abc.abstractstaticmethod + @staticmethod + @abc.abstractmethod def wait_for_call_in_new_thread( - __call: Union[Callable[[], T], Call[T]], # type: ignore[reportGeneralTypeIssues] + __call: Union["_SyncOrAsyncCallable[[], T]", Call[T]], timeout: Optional[float] = None, done_callbacks: Optional[Iterable[Call[Any]]] = None, ) -> T: @@ -75,14 +72,15 @@ def wait_for_call_in_new_thread( @staticmethod def call_soon_in_new_thread( - __call: Union[Callable[[], T], Call[T]], timeout: Optional[float] = None + __call: Union["_SyncOrAsyncCallable[[], T]", Call[T]], + timeout: Optional[float] = None, ) -> Call[T]: """ Schedule a call for execution in a new worker thread. Returns the submitted call. """ - call = _cast_to_call(__call) + call = cast_to_call(__call) runner = WorkerThread(run_once=True) call.set_timeout(timeout) runner.submit(call) @@ -90,7 +88,7 @@ def call_soon_in_new_thread( @staticmethod def call_soon_in_loop_thread( - __call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]], + __call: Union["_SyncOrAsyncCallable[[], T]", Call[T]], timeout: Optional[float] = None, ) -> Call[T]: """ @@ -98,7 +96,7 @@ def call_soon_in_loop_thread( Returns the submitted call. """ - call = _cast_to_call(__call) + call = cast_to_call(__call) runner = get_global_loop() call.set_timeout(timeout) runner.submit(call) @@ -117,7 +115,7 @@ def call_in_new_thread( @staticmethod def call_in_loop_thread( - __call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]], + __call: Union[Callable[[], Awaitable[T]], Call[T]], timeout: Optional[float] = None, ) -> T: """ @@ -131,12 +129,12 @@ def call_in_loop_thread( class from_async(_base): @staticmethod async def wait_for_call_in_loop_thread( - __call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]], + __call: Union[Callable[[], Awaitable[T]], Call[T]], timeout: Optional[float] = None, done_callbacks: Optional[Iterable[Call[Any]]] = None, - contexts: Optional[Iterable[ContextManager[Any]]] = None, - ) -> Awaitable[T]: - call = _cast_to_call(__call) + contexts: Optional[Iterable[AbstractContextManager[Any]]] = None, + ) -> T: + call = cast_to_call(__call) waiter = AsyncWaiter(call) for callback in done_callbacks or []: waiter.add_done_callback(callback) @@ -153,7 +151,7 @@ async def wait_for_call_in_new_thread( timeout: Optional[float] = None, done_callbacks: Optional[Iterable[Call[Any]]] = None, ) -> T: - call = _cast_to_call(__call) + call = cast_to_call(__call) waiter = AsyncWaiter(call=call) for callback in done_callbacks or []: waiter.add_done_callback(callback) @@ -170,7 +168,7 @@ def call_in_new_thread( @staticmethod def call_in_loop_thread( - __call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]], + __call: Union[Callable[[], Awaitable[T]], Call[T]], timeout: Optional[float] = None, ) -> Awaitable[T]: call = _base.call_soon_in_loop_thread(__call, timeout=timeout) @@ -182,13 +180,13 @@ class from_sync(_base): def wait_for_call_in_loop_thread( __call: Union[ Callable[[], Awaitable[T]], - Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]], + Call[T], ], timeout: Optional[float] = None, - done_callbacks: Optional[Iterable[Call]] = None, - contexts: Optional[Iterable[ContextManager]] = None, - ) -> Awaitable[T]: - call = _cast_to_call(__call) + done_callbacks: Optional[Iterable[Call[T]]] = None, + contexts: Optional[Iterable[AbstractContextManager[Any]]] = None, + ) -> T: + call = cast_to_call(__call) waiter = SyncWaiter(call) _base.call_soon_in_loop_thread(call, timeout=timeout) for callback in done_callbacks or []: @@ -203,9 +201,9 @@ def wait_for_call_in_loop_thread( def wait_for_call_in_new_thread( __call: Union[Callable[[], T], Call[T]], timeout: Optional[float] = None, - done_callbacks: Optional[Iterable[Call]] = None, - ) -> Call[T]: - call = _cast_to_call(__call) + done_callbacks: Optional[Iterable[Call[T]]] = None, + ) -> T: + call = cast_to_call(__call) waiter = SyncWaiter(call=call) for callback in done_callbacks or []: waiter.add_done_callback(callback) @@ -215,20 +213,21 @@ def wait_for_call_in_new_thread( @staticmethod def call_in_new_thread( - __call: Union[Callable[[], T], Call[T]], timeout: Optional[float] = None + __call: Union["_SyncOrAsyncCallable[[], T]", Call[T]], + timeout: Optional[float] = None, ) -> T: call = _base.call_soon_in_new_thread(__call, timeout=timeout) return call.result() @staticmethod def call_in_loop_thread( - __call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]], + __call: Union["_SyncOrAsyncCallable[[], T]", Call[T]], timeout: Optional[float] = None, - ) -> T: + ) -> Union[Awaitable[T], T]: if in_global_loop(): # Avoid deadlock where the call is submitted to the loop then the loop is # blocked waiting for the call - call = _cast_to_call(__call) + call = cast_to_call(__call) return call() call = _base.call_soon_in_loop_thread(__call, timeout=timeout) diff --git a/src/prefect/_internal/concurrency/calls.py b/src/prefect/_internal/concurrency/calls.py index 5e8b675bd23e..4a715ac90491 100644 --- a/src/prefect/_internal/concurrency/calls.py +++ b/src/prefect/_internal/concurrency/calls.py @@ -12,18 +12,20 @@ import inspect import threading import weakref +from collections.abc import Awaitable, Generator from concurrent.futures._base import ( CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, RUNNING, ) -from typing import Any, Awaitable, Callable, Dict, Generic, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, Union -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar, TypeVarTuple from prefect._internal.concurrency import logger from prefect._internal.concurrency.cancellation import ( + AsyncCancelScope, CancelledError, cancel_async_at, cancel_sync_at, @@ -31,9 +33,13 @@ ) from prefect._internal.concurrency.event_loop import get_running_loop -T = TypeVar("T") +T = TypeVar("T", infer_variance=True) +Ts = TypeVarTuple("Ts") P = ParamSpec("P") +_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[T, Awaitable[T]]] + + # Tracks the current call being executed. Note that storing the `Call` # object for an async call directly in the contextvar appears to create a # memory leak, despite the fact that we `reset` when leaving the context @@ -41,16 +47,16 @@ # we already have strong references to the `Call` objects in other places # and b) this is used for performance optimizations where we have fallback # behavior if this weakref is garbage collected. A fix for issue #10952. -current_call: contextvars.ContextVar["weakref.ref[Call]"] = ( # novm +current_call: contextvars.ContextVar["weakref.ref[Call[Any]]"] = ( # novm contextvars.ContextVar("current_call") ) # Create a strong reference to tasks to prevent destruction during execution errors -_ASYNC_TASK_REFS = set() +_ASYNC_TASK_REFS: set[asyncio.Task[None]] = set() @contextlib.contextmanager -def set_current_call(call: "Call"): +def set_current_call(call: "Call[Any]") -> Generator[None, Any, None]: token = current_call.set(weakref.ref(call)) try: yield @@ -58,7 +64,7 @@ def set_current_call(call: "Call"): current_call.reset(token) -class Future(concurrent.futures.Future): +class Future(concurrent.futures.Future[T]): """ Extension of `concurrent.futures.Future` with support for cancellation of running futures. @@ -70,7 +76,7 @@ def __init__(self, name: Optional[str] = None) -> None: super().__init__() self._cancel_scope = None self._deadline = None - self._cancel_callbacks = [] + self._cancel_callbacks: list[Callable[[], None]] = [] self._name = name self._timed_out = False @@ -79,7 +85,7 @@ def set_running_or_notify_cancel(self, timeout: Optional[float] = None): return super().set_running_or_notify_cancel() @contextlib.contextmanager - def enforce_async_deadline(self): + def enforce_async_deadline(self) -> Generator[AsyncCancelScope]: with cancel_async_at(self._deadline, name=self._name) as self._cancel_scope: for callback in self._cancel_callbacks: self._cancel_scope.add_cancel_callback(callback) @@ -92,7 +98,7 @@ def enforce_sync_deadline(self): self._cancel_scope.add_cancel_callback(callback) yield self._cancel_scope - def add_cancel_callback(self, callback: Callable[[], None]): + def add_cancel_callback(self, callback: Callable[[], Any]) -> None: """ Add a callback to be enforced on cancellation. @@ -113,7 +119,7 @@ def timedout(self) -> bool: with self._condition: return self._timed_out - def cancel(self): + def cancel(self) -> bool: """Cancel the future if possible. Returns True if the future was cancelled, False otherwise. A future cannot be @@ -147,7 +153,12 @@ def cancel(self): self._invoke_callbacks() return True - def result(self, timeout=None): + if TYPE_CHECKING: + + def __get_result(self) -> T: + ... + + def result(self, timeout: Optional[float] = None) -> T: """Return the result of the call that the future represents. Args: @@ -186,7 +197,9 @@ def result(self, timeout=None): # Break a reference cycle with the exception in self._exception self = None - def _invoke_callbacks(self): + _done_callbacks: list[Callable[[Self], object]] + + def _invoke_callbacks(self) -> None: """ Invoke our done callbacks and clean up cancel scopes and cancel callbacks. Fixes a memory leak that hung on to Call objects, @@ -206,7 +219,7 @@ def _invoke_callbacks(self): self._cancel_callbacks = [] if self._cancel_scope: - self._cancel_scope._callbacks = [] + setattr(self._cancel_scope, "_callbacks", []) self._cancel_scope = None @@ -216,16 +229,21 @@ class Call(Generic[T]): A deferred function call. """ - future: Future - fn: Callable[..., T] - args: Tuple - kwargs: Dict[str, Any] + future: Future[T] + fn: "_SyncOrAsyncCallable[..., T]" + args: tuple[Any, ...] + kwargs: dict[str, Any] context: contextvars.Context - timeout: float + timeout: Optional[float] runner: Optional["Portal"] = None @classmethod - def new(cls, __fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "Call[T]": + def new( + cls, + __fn: _SyncOrAsyncCallable[P, T], + *args: P.args, + **kwargs: P.kwargs, + ) -> Self: return cls( future=Future(name=getattr(__fn, "__name__", str(__fn))), fn=__fn, @@ -255,7 +273,7 @@ def set_runner(self, portal: "Portal") -> None: self.runner = portal - def run(self) -> Optional[Awaitable[T]]: + def run(self) -> Optional[Awaitable[None]]: """ Execute the call and place the result on the future. @@ -337,7 +355,7 @@ def timedout(self) -> bool: def cancel(self) -> bool: return self.future.cancel() - def _run_sync(self): + def _run_sync(self) -> Optional[Awaitable[T]]: cancel_scope = None try: with set_current_call(self): @@ -348,8 +366,8 @@ def _run_sync(self): # Forget this call's arguments in order to free up any memory # that may be referenced by them; after a call has happened, # there's no need to keep a reference to them - self.args = None - self.kwargs = None + with contextlib.suppress(AttributeError): + del self.args, self.kwargs # Return the coroutine for async execution if inspect.isawaitable(result): @@ -357,8 +375,10 @@ def _run_sync(self): except CancelledError: # Report cancellation + if TYPE_CHECKING: + assert cancel_scope is not None if cancel_scope.timedout(): - self.future._timed_out = True + setattr(self.future, "_timed_out", True) self.future.cancel() elif cancel_scope.cancelled(): self.future.cancel() @@ -374,8 +394,8 @@ def _run_sync(self): self.future.set_result(result) # noqa: F821 logger.debug("Finished call %r", self) # noqa: F821 - async def _run_async(self, coro): - cancel_scope = None + async def _run_async(self, coro: Awaitable[T]) -> None: + cancel_scope = result = None try: with set_current_call(self): with self.future.enforce_async_deadline() as cancel_scope: @@ -385,12 +405,14 @@ async def _run_async(self, coro): # Forget this call's arguments in order to free up any memory # that may be referenced by them; after a call has happened, # there's no need to keep a reference to them - self.args = None - self.kwargs = None + with contextlib.suppress(AttributeError): + del self.args, self.kwargs except CancelledError: # Report cancellation + if TYPE_CHECKING: + assert cancel_scope is not None if cancel_scope.timedout(): - self.future._timed_out = True + setattr(self.future, "_timed_out", True) self.future.cancel() elif cancel_scope.cancelled(): self.future.cancel() @@ -403,10 +425,11 @@ async def _run_async(self, coro): # Prevent reference cycle in `exc` del self else: + # F821 ignored because Ruff gets confused about the del self above. self.future.set_result(result) # noqa: F821 logger.debug("Finished async call %r", self) # noqa: F821 - def __call__(self) -> T: + def __call__(self) -> Union[T, Awaitable[T]]: """ Execute the call and return its result. @@ -417,7 +440,7 @@ def __call__(self) -> T: # Return an awaitable if in an async context if coro is not None: - async def run_and_return_result(): + async def run_and_return_result() -> T: await coro return self.result() @@ -428,8 +451,9 @@ async def run_and_return_result(): def __repr__(self) -> str: name = getattr(self.fn, "__name__", str(self.fn)) - args, kwargs = self.args, self.kwargs - if args is None or kwargs is None: + try: + args, kwargs = self.args, self.kwargs + except AttributeError: call_args = "" else: call_args = ", ".join( @@ -450,7 +474,7 @@ class Portal(abc.ABC): """ @abc.abstractmethod - def submit(self, call: "Call") -> "Call": + def submit(self, call: "Call[T]") -> "Call[T]": """ Submit a call to execute elsewhere. diff --git a/src/prefect/_internal/concurrency/event_loop.py b/src/prefect/_internal/concurrency/event_loop.py index b3a8c0026056..5b690d53f01b 100644 --- a/src/prefect/_internal/concurrency/event_loop.py +++ b/src/prefect/_internal/concurrency/event_loop.py @@ -5,7 +5,8 @@ import asyncio import concurrent.futures import functools -from typing import Awaitable, Callable, Coroutine, Optional, TypeVar +from collections.abc import Coroutine +from typing import Any, Callable, Optional, TypeVar from typing_extensions import ParamSpec @@ -13,7 +14,7 @@ T = TypeVar("T") -def get_running_loop() -> Optional[asyncio.BaseEventLoop]: +def get_running_loop() -> Optional[asyncio.AbstractEventLoop]: """ Get the current running loop. @@ -30,7 +31,7 @@ def call_soon_in_loop( __fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs, -) -> concurrent.futures.Future: +) -> concurrent.futures.Future[T]: """ Run a synchronous call in an event loop's thread from another thread. @@ -38,7 +39,7 @@ def call_soon_in_loop( Returns a future that can be used to retrieve the result of the call. """ - future = concurrent.futures.Future() + future: concurrent.futures.Future[T] = concurrent.futures.Future() @functools.wraps(__fn) def wrapper() -> None: @@ -62,8 +63,8 @@ def wrapper() -> None: async def run_coroutine_in_loop_from_async( - __loop: asyncio.AbstractEventLoop, __coro: Coroutine -) -> Awaitable: + __loop: asyncio.AbstractEventLoop, __coro: Coroutine[Any, Any, T] +) -> T: """ Run an asynchronous call in an event loop from an asynchronous context. diff --git a/src/prefect/_internal/concurrency/threads.py b/src/prefect/_internal/concurrency/threads.py index 301cc386d513..c3f5e4589428 100644 --- a/src/prefect/_internal/concurrency/threads.py +++ b/src/prefect/_internal/concurrency/threads.py @@ -8,7 +8,9 @@ import itertools import queue import threading -from typing import List, Optional +from typing import Any, Optional + +from typing_extensions import TypeVar from prefect._internal.concurrency import logger from prefect._internal.concurrency.calls import Call, Portal @@ -16,6 +18,8 @@ from prefect._internal.concurrency.event_loop import get_running_loop from prefect._internal.concurrency.primitives import Event +T = TypeVar("T", infer_variance=True) + class WorkerThread(Portal): """ @@ -33,7 +37,7 @@ def __init__( self.thread = threading.Thread( name=name, daemon=daemon, target=self._entrypoint ) - self._queue = queue.Queue() + self._queue: queue.Queue[Optional[Call[Any]]] = queue.Queue() self._run_once: bool = run_once self._started: bool = False self._submitted_count: int = 0 @@ -42,7 +46,7 @@ def __init__( if not daemon: atexit.register(self.shutdown) - def start(self): + def start(self) -> None: """ Start the worker thread. """ @@ -51,7 +55,7 @@ def start(self): self._started = True self.thread.start() - def submit(self, call: Call) -> Call: + def submit(self, call: Call[T]) -> Call[T]: if self._submitted_count > 0 and self._run_once: raise RuntimeError( "Worker configured to only run once. A call has already been submitted." @@ -83,7 +87,7 @@ def shutdown(self) -> None: def name(self) -> str: return self.thread.name - def _entrypoint(self): + def _entrypoint(self) -> None: """ Entrypoint for the thread. """ @@ -129,12 +133,14 @@ def __init__( self.thread = threading.Thread( name=name, daemon=daemon, target=self._entrypoint ) - self._ready_future = concurrent.futures.Future() + self._ready_future: concurrent.futures.Future[ + bool + ] = concurrent.futures.Future() self._loop: Optional[asyncio.AbstractEventLoop] = None self._shutdown_event: Event = Event() self._run_once: bool = run_once self._submitted_count: int = 0 - self._on_shutdown: List[Call] = [] + self._on_shutdown: list[Call[Any]] = [] self._lock = threading.Lock() if not daemon: @@ -149,7 +155,7 @@ def start(self): self.thread.start() self._ready_future.result() - def submit(self, call: Call) -> Call: + def submit(self, call: Call[T]) -> Call[T]: if self._loop is None: self.start() @@ -167,6 +173,7 @@ def submit(self, call: Call) -> Call: call.set_runner(self) # Submit the call to the event loop + assert self._loop is not None asyncio.run_coroutine_threadsafe(self._run_call(call), self._loop) self._submitted_count += 1 @@ -180,15 +187,16 @@ def shutdown(self) -> None: Shutdown the worker thread. Does not wait for the thread to stop. """ with self._lock: - if self._shutdown_event is None: - return - self._shutdown_event.set() @property def name(self) -> str: return self.thread.name + @property + def running(self) -> bool: + return not self._shutdown_event.is_set() + def _entrypoint(self): """ Entrypoint for the thread. @@ -218,12 +226,12 @@ async def _run_until_shutdown(self): # Empty the list to allow calls to be garbage collected. Issue #10338. self._on_shutdown = [] - async def _run_call(self, call: Call) -> None: + async def _run_call(self, call: Call[Any]) -> None: task = call.run() if task is not None: await task - def add_shutdown_call(self, call: Call) -> None: + def add_shutdown_call(self, call: Call[Any]) -> None: self._on_shutdown.append(call) def __enter__(self): @@ -235,9 +243,9 @@ def __exit__(self, *_): # the GLOBAL LOOP is used for background services, like logs -GLOBAL_LOOP: Optional[EventLoopThread] = None +_global_loop: Optional[EventLoopThread] = None # the RUN SYNC LOOP is used exclusively for running async functions in a sync context via asyncutils.run_sync -RUN_SYNC_LOOP: Optional[EventLoopThread] = None +_run_sync_loop: Optional[EventLoopThread] = None def get_global_loop() -> EventLoopThread: @@ -246,29 +254,29 @@ def get_global_loop() -> EventLoopThread: Creates a new one if there is not one available. """ - global GLOBAL_LOOP + global _global_loop # Create a new worker on first call or if the existing worker is dead if ( - GLOBAL_LOOP is None - or not GLOBAL_LOOP.thread.is_alive() - or GLOBAL_LOOP._shutdown_event.is_set() + _global_loop is None + or not _global_loop.thread.is_alive() + or not _global_loop.running ): - GLOBAL_LOOP = EventLoopThread(daemon=True, name="GlobalEventLoopThread") - GLOBAL_LOOP.start() + _global_loop = EventLoopThread(daemon=True, name="GlobalEventLoopThread") + _global_loop.start() - return GLOBAL_LOOP + return _global_loop def in_global_loop() -> bool: """ Check if called from the global loop. """ - if GLOBAL_LOOP is None: + if _global_loop is None: # Avoid creating a global loop if there isn't one return False - return get_global_loop()._loop == get_running_loop() + return getattr(get_global_loop(), "_loop") == get_running_loop() def get_run_sync_loop() -> EventLoopThread: @@ -277,29 +285,29 @@ def get_run_sync_loop() -> EventLoopThread: Creates a new one if there is not one available. """ - global RUN_SYNC_LOOP + global _run_sync_loop # Create a new worker on first call or if the existing worker is dead if ( - RUN_SYNC_LOOP is None - or not RUN_SYNC_LOOP.thread.is_alive() - or RUN_SYNC_LOOP._shutdown_event.is_set() + _run_sync_loop is None + or not _run_sync_loop.thread.is_alive() + or not _run_sync_loop.running ): - RUN_SYNC_LOOP = EventLoopThread(daemon=True, name="RunSyncEventLoopThread") - RUN_SYNC_LOOP.start() + _run_sync_loop = EventLoopThread(daemon=True, name="RunSyncEventLoopThread") + _run_sync_loop.start() - return RUN_SYNC_LOOP + return _run_sync_loop def in_run_sync_loop() -> bool: """ Check if called from the global loop. """ - if RUN_SYNC_LOOP is None: + if _run_sync_loop is None: # Avoid creating a global loop if there isn't one return False - return get_run_sync_loop()._loop == get_running_loop() + return getattr(get_run_sync_loop(), "_loop") == get_running_loop() def wait_for_global_loop_exit(timeout: Optional[float] = None) -> None: diff --git a/src/prefect/_internal/concurrency/waiters.py b/src/prefect/_internal/concurrency/waiters.py index 3925e3b25691..07522992100d 100644 --- a/src/prefect/_internal/concurrency/waiters.py +++ b/src/prefect/_internal/concurrency/waiters.py @@ -10,7 +10,8 @@ import queue import threading from collections import deque -from typing import Awaitable, Generic, List, Optional, TypeVar, Union +from collections.abc import Awaitable +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from weakref import WeakKeyDictionary import anyio @@ -24,12 +25,12 @@ # Waiters are stored in a stack for each thread -_WAITERS_BY_THREAD: "WeakKeyDictionary[threading.Thread, deque[Waiter]]" = ( +_WAITERS_BY_THREAD: "WeakKeyDictionary[threading.Thread, deque[Waiter[Any]]]" = ( WeakKeyDictionary() ) -def add_waiter_for_thread(waiter: "Waiter", thread: threading.Thread): +def add_waiter_for_thread(waiter: "Waiter[Any]", thread: threading.Thread) -> None: """ Add a waiter for a thread. """ @@ -62,7 +63,7 @@ def call_is_done(self) -> bool: return self._call.future.done() @abc.abstractmethod - def wait(self) -> Union[Awaitable[None], None]: + def wait(self) -> T: """ Wait for the call to finish. @@ -71,7 +72,7 @@ def wait(self) -> Union[Awaitable[None], None]: raise NotImplementedError() @abc.abstractmethod - def add_done_callback(self, callback: Call) -> Call: + def add_done_callback(self, callback: Call[Any]) -> None: """ Schedule a call to run when the waiter is done waiting. @@ -91,11 +92,11 @@ class SyncWaiter(Waiter[T]): def __init__(self, call: Call[T]) -> None: super().__init__(call=call) - self._queue: queue.Queue = queue.Queue() - self._done_callbacks = [] + self._queue: queue.Queue[Optional[Call[T]]] = queue.Queue() + self._done_callbacks: list[Call[Any]] = [] self._done_event = threading.Event() - def submit(self, call: Call): + def submit(self, call: Call[T]) -> Call[T]: """ Submit a callback to execute while waiting. """ @@ -109,7 +110,7 @@ def submit(self, call: Call): def _handle_waiting_callbacks(self): logger.debug("Waiter %r watching for callbacks", self) while True: - callback: Call = self._queue.get() + callback = self._queue.get() if callback is None: break @@ -130,13 +131,13 @@ def _handle_done_callbacks(self): if callback: callback.run() - def add_done_callback(self, callback: Call): + def add_done_callback(self, callback: Call[Any]) -> None: if self._done_event.is_set(): raise RuntimeError("Cannot add done callbacks to done waiters.") else: self._done_callbacks.append(callback) - def wait(self) -> T: + def wait(self) -> Call[T]: # Stop watching for work once the future is done self._call.future.add_done_callback(lambda _: self._queue.put_nowait(None)) self._call.future.add_done_callback(lambda _: self._done_event.set()) @@ -159,13 +160,13 @@ def __init__(self, call: Call[T]) -> None: # Delay instantiating loop and queue as there may not be a loop present yet self._loop: Optional[asyncio.AbstractEventLoop] = None - self._queue: Optional[asyncio.Queue] = None - self._early_submissions: List[Call] = [] - self._done_callbacks = [] + self._queue: Optional[asyncio.Queue[Optional[Call[T]]]] = None + self._early_submissions: list[Call[T]] = [] + self._done_callbacks: list[Call[Any]] = [] self._done_event = Event() self._done_waiting = False - def submit(self, call: Call): + def submit(self, call: Call[T]) -> Call[T]: """ Submit a callback to execute while waiting. """ @@ -180,11 +181,15 @@ def submit(self, call: Call): return call # We must put items in the queue from the event loop that owns it + if TYPE_CHECKING: + assert self._loop is not None call_soon_in_loop(self._loop, self._queue.put_nowait, call) return call - def _resubmit_early_submissions(self): - assert self._queue + def _resubmit_early_submissions(self) -> None: + if TYPE_CHECKING: + assert self._queue is not None + assert self._loop is not None for call in self._early_submissions: # We must put items in the queue from the event loop that owns it call_soon_in_loop(self._loop, self._queue.put_nowait, call) @@ -192,11 +197,11 @@ def _resubmit_early_submissions(self): async def _handle_waiting_callbacks(self): logger.debug("Waiter %r watching for callbacks", self) - tasks = [] + tasks: list[Awaitable[None]] = [] try: while True: - callback: Call = await self._queue.get() + callback = await self._queue.get() if callback is None: break @@ -228,12 +233,12 @@ async def _handle_done_callbacks(self): with anyio.CancelScope(shield=True): await self._run_done_callback(callback) - async def _run_done_callback(self, callback: Call): + async def _run_done_callback(self, callback: Call[Any]) -> None: coro = callback.run() if coro: await coro - def add_done_callback(self, callback: Call): + def add_done_callback(self, callback: Call[Any]) -> None: if self._done_event.is_set(): raise RuntimeError("Cannot add done callbacks to done waiters.") else: @@ -243,6 +248,8 @@ def _signal_stop_waiting(self): # Only send a `None` to the queue if the waiter is still blocked reading from # the queue. Otherwise, it's possible that the event loop is stopped. if not self._done_waiting: + assert self._loop is not None + assert self._queue is not None call_soon_in_loop(self._loop, self._queue.put_nowait, None) async def wait(self) -> Call[T]: diff --git a/src/prefect/_internal/pydantic/v1_schema.py b/src/prefect/_internal/pydantic/v1_schema.py index b94b0fc5973a..2fc409ae2625 100644 --- a/src/prefect/_internal/pydantic/v1_schema.py +++ b/src/prefect/_internal/pydantic/v1_schema.py @@ -6,7 +6,7 @@ from pydantic.v1 import BaseModel as V1BaseModel -def is_v1_model(v) -> bool: +def is_v1_model(v: typing.Any) -> bool: with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=pydantic.warnings.PydanticDeprecatedSince20 @@ -23,7 +23,7 @@ def is_v1_model(v) -> bool: return False -def is_v1_type(v) -> bool: +def is_v1_type(v: typing.Any) -> bool: if is_v1_model(v): return True diff --git a/src/prefect/_internal/pydantic/v2_schema.py b/src/prefect/_internal/pydantic/v2_schema.py index c9d7fa08b5e1..da22799c742b 100644 --- a/src/prefect/_internal/pydantic/v2_schema.py +++ b/src/prefect/_internal/pydantic/v2_schema.py @@ -16,7 +16,7 @@ from prefect._internal.pydantic.schemas import GenerateEmptySchemaForUserClasses -def is_v2_model(v) -> bool: +def is_v2_model(v: t.Any) -> bool: if isinstance(v, V2BaseModel): return True try: @@ -28,7 +28,7 @@ def is_v2_model(v) -> bool: return False -def is_v2_type(v) -> bool: +def is_v2_type(v: t.Any) -> bool: if is_v2_model(v): return True @@ -56,9 +56,9 @@ def process_v2_params( param: inspect.Parameter, *, position: int, - docstrings: t.Dict[str, str], - aliases: t.Dict, -) -> t.Tuple[str, t.Any, "pydantic.Field"]: + docstrings: dict[str, str], + aliases: dict[str, str], +) -> tuple[str, t.Any, t.Any]: """ Generate a sanitized name, type, and pydantic.Field for a given parameter. @@ -72,7 +72,7 @@ def process_v2_params( else: name = param.name - type_ = t.Any if param.annotation is inspect._empty else param.annotation + type_ = t.Any if param.annotation is inspect.Parameter.empty else param.annotation # Replace pendulum type annotations with our own so that they are pydantic compatible if type_ == pendulum.DateTime: @@ -95,12 +95,13 @@ def process_v2_params( def create_v2_schema( name_: str, model_cfg: t.Optional[ConfigDict] = None, - model_base: t.Optional[t.Type[V2BaseModel]] = None, - **model_fields, -): + model_base: t.Optional[type[V2BaseModel]] = None, + model_fields: t.Optional[dict[str, t.Any]] = None, +) -> dict[str, t.Any]: """ Create a pydantic v2 model and craft a v1 compatible schema from it. """ + model_fields = model_fields or {} model = create_model( name_, __config__=model_cfg, __base__=model_base, **model_fields ) diff --git a/src/prefect/cli/cloud/__init__.py b/src/prefect/cli/cloud/__init__.py index 5bc9e9ec02c6..f698fbd58160 100644 --- a/src/prefect/cli/cloud/__init__.py +++ b/src/prefect/cli/cloud/__init__.py @@ -41,7 +41,6 @@ ) from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect.utilities.collections import listrepr -from prefect.utilities.compat import raise_signal from pydantic import BaseModel diff --git a/src/prefect/exceptions.py b/src/prefect/exceptions.py index d30a52756cae..fe8b91645b32 100644 --- a/src/prefect/exceptions.py +++ b/src/prefect/exceptions.py @@ -38,7 +38,7 @@ def _trim_traceback( module_path in str(tb.tb_frame.f_globals.get("__file__", "")) for module_path in strip_paths ): - tb = tb.tb_next + curr_tb = curr_tb.tb_next return tb diff --git a/src/prefect/filesystems.py b/src/prefect/filesystems.py index 333665fd5679..97b7ee1e2e26 100644 --- a/src/prefect/filesystems.py +++ b/src/prefect/filesystems.py @@ -1,6 +1,7 @@ import abc import urllib.parse from pathlib import Path +from shutil import copytree from typing import Any, Dict, Optional import anyio @@ -13,7 +14,6 @@ ) from prefect.blocks.core import Block from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible -from prefect.utilities.compat import copytree from prefect.utilities.filesystem import filter_files from ._internal.compatibility.migration import getattr_migration @@ -158,7 +158,7 @@ async def get_directory( copytree(from_path, local_path, dirs_exist_ok=True, ignore=ignore_func) async def _get_ignore_func(self, local_path: str, ignore_file: str): - with open(ignore_file, "r") as f: + with open(ignore_file) as f: ignore_patterns = f.readlines() included_files = filter_files(root=local_path, ignore_patterns=ignore_patterns) @@ -348,7 +348,7 @@ async def put_directory( included_files = None if ignore_file: - with open(ignore_file, "r") as f: + with open(ignore_file) as f: ignore_patterns = f.readlines() included_files = filter_files( diff --git a/src/prefect/utilities/annotations.py b/src/prefect/utilities/annotations.py index 6d9bd73ed475..2e264f334d74 100644 --- a/src/prefect/utilities/annotations.py +++ b/src/prefect/utilities/annotations.py @@ -1,33 +1,40 @@ import warnings -from abc import ABC -from collections import namedtuple -from typing import Generic, TypeVar +from operator import itemgetter +from typing import Any, cast -T = TypeVar("T") +from typing_extensions import Self, TypeVar +T = TypeVar("T", infer_variance=True) -class BaseAnnotation( - namedtuple("BaseAnnotation", field_names="value"), ABC, Generic[T] -): + +class BaseAnnotation(tuple[T]): """ Base class for Prefect annotation types. - Inherits from `namedtuple` for unpacking support in another tools. + Inherits from `tuple` for unpacking support in other tools. """ + __slots__ = () + + def __new__(cls, value: T) -> Self: + return super().__new__(cls, (value,)) + + # use itemgetter to minimise overhead, just like namedtuple generated code would + value: T = cast(T, property(itemgetter(0))) + def unwrap(self) -> T: - return self.value + return self[0] - def rewrap(self, value: T) -> "BaseAnnotation[T]": + def rewrap(self, value: T) -> Self: return type(self)(value) - def __eq__(self, other: "BaseAnnotation[T]") -> bool: + def __eq__(self, other: Any) -> bool: if type(self) is not type(other): return False - return self.unwrap() == other.unwrap() + return super().__eq__(other) def __repr__(self) -> str: - return f"{type(self).__name__}({self.value!r})" + return f"{type(self).__name__}({self[0]!r})" class unmapped(BaseAnnotation[T]): @@ -38,9 +45,9 @@ class unmapped(BaseAnnotation[T]): operation instead of being split. """ - def __getitem__(self, _) -> T: + def __getitem__(self, _: object) -> T: # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] # Internally, this acts as an infinite array where all items are the same value - return self.unwrap() + return super().__getitem__(0) class allow_failure(BaseAnnotation[T]): @@ -87,14 +94,14 @@ def unquote(self) -> T: # Backwards compatibility stub for `Quote` class -class Quote(quote): - def __init__(self, expr): +class Quote(quote[T]): + def __new__(cls, expr: T) -> Self: warnings.warn( "Use of `Quote` is deprecated. Use `quote` instead.", DeprecationWarning, stacklevel=2, ) - super().__init__(expr) + return super().__new__(cls, expr) class NotSet: diff --git a/src/prefect/utilities/asyncutils.py b/src/prefect/utilities/asyncutils.py index ce5a0229b049..e2c1a2472eec 100644 --- a/src/prefect/utilities/asyncutils.py +++ b/src/prefect/utilities/asyncutils.py @@ -6,24 +6,12 @@ import inspect import threading import warnings -from concurrent.futures import ThreadPoolExecutor -from contextlib import asynccontextmanager -from contextvars import ContextVar, copy_context +from collections.abc import AsyncGenerator, Awaitable, Coroutine +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from contextvars import ContextVar from functools import partial, wraps -from typing import ( - Any, - AsyncGenerator, - Awaitable, - Callable, - Coroutine, - Dict, - List, - Optional, - TypeVar, - Union, - cast, - overload, -) +from logging import Logger +from typing import TYPE_CHECKING, Any, Callable, NoReturn, Optional, Union, overload from uuid import UUID, uuid4 import anyio @@ -31,9 +19,18 @@ import anyio.from_thread import anyio.to_thread import sniffio -from typing_extensions import Literal, ParamSpec, TypeGuard +from typing_extensions import ( + Literal, + ParamSpec, + Self, + TypeAlias, + TypeGuard, + TypeVar, + TypeVarTuple, + Unpack, +) -from prefect._internal.concurrency.api import _cast_to_call, from_sync +from prefect._internal.concurrency.api import cast_to_call, from_sync from prefect._internal.concurrency.threads import ( get_run_sync_loop, in_run_sync_loop, @@ -42,62 +39,65 @@ T = TypeVar("T") P = ParamSpec("P") -R = TypeVar("R") +R = TypeVar("R", infer_variance=True) F = TypeVar("F", bound=Callable[..., Any]) Async = Literal[True] Sync = Literal[False] A = TypeVar("A", Async, Sync, covariant=True) +PosArgsT = TypeVarTuple("PosArgsT") + +_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[R, Awaitable[R]]] # Global references to prevent garbage collection for `add_event_loop_shutdown_callback` -EVENT_LOOP_GC_REFS = {} +EVENT_LOOP_GC_REFS: dict[int, AsyncGenerator[None, Any]] = {} -PREFECT_THREAD_LIMITER: Optional[anyio.CapacityLimiter] = None RUNNING_IN_RUN_SYNC_LOOP_FLAG = ContextVar("running_in_run_sync_loop", default=False) RUNNING_ASYNC_FLAG = ContextVar("run_async", default=False) -BACKGROUND_TASKS: set[asyncio.Task] = set() -background_task_lock = threading.Lock() +BACKGROUND_TASKS: set[asyncio.Task[Any]] = set() +background_task_lock: threading.Lock = threading.Lock() # Thread-local storage to keep track of worker thread state _thread_local = threading.local() -logger = get_logger() +logger: Logger = get_logger() + + +_prefect_thread_limiter: Optional[anyio.CapacityLimiter] = None -def get_thread_limiter(): - global PREFECT_THREAD_LIMITER +def get_thread_limiter() -> anyio.CapacityLimiter: + global _prefect_thread_limiter - if PREFECT_THREAD_LIMITER is None: - PREFECT_THREAD_LIMITER = anyio.CapacityLimiter(250) + if _prefect_thread_limiter is None: + _prefect_thread_limiter = anyio.CapacityLimiter(250) - return PREFECT_THREAD_LIMITER + return _prefect_thread_limiter def is_async_fn( - func: Union[Callable[P, R], Callable[P, Awaitable[R]]], + func: _SyncOrAsyncCallable[P, R], ) -> TypeGuard[Callable[P, Awaitable[R]]]: """ Returns `True` if a function returns a coroutine. See https://github.com/microsoft/pyright/issues/2142 for an example use """ - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - + func = inspect.unwrap(func) return asyncio.iscoroutinefunction(func) -def is_async_gen_fn(func): +def is_async_gen_fn( + func: Callable[P, Any], +) -> TypeGuard[Callable[P, AsyncGenerator[Any, Any]]]: """ Returns `True` if a function is an async generator. """ - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - + func = inspect.unwrap(func) return inspect.isasyncgenfunction(func) -def create_task(coroutine: Coroutine) -> asyncio.Task: +def create_task(coroutine: Coroutine[Any, Any, R]) -> asyncio.Task[R]: """ Replacement for asyncio.create_task that will ensure that tasks aren't garbage collected before they complete. Allows for "fire and forget" @@ -123,68 +123,32 @@ def create_task(coroutine: Coroutine) -> asyncio.Task: return task -def _run_sync_in_new_thread(coroutine: Coroutine[Any, Any, T]) -> T: - """ - Note: this is an OLD implementation of `run_coro_as_sync` which liberally created - new threads and new loops. This works, but prevents sharing any objects - across coroutines, in particular httpx clients, which are very expensive to - instantiate. - - This is here for historical purposes and can be removed if/when it is no - longer needed for reference. - - --- - - Runs a coroutine from a synchronous context. A thread will be spawned to run - the event loop if necessary, which allows coroutines to run in environments - like Jupyter notebooks where the event loop runs on the main thread. - - Args: - coroutine: The coroutine to run. - - Returns: - The return value of the coroutine. - - Example: - Basic usage: ```python async def my_async_function(x: int) -> int: - return x + 1 - - run_sync(my_async_function(1)) ``` - """ +@overload +def run_coro_as_sync( + coroutine: Coroutine[Any, Any, R], + *, + force_new_thread: bool = ..., + wait_for_result: Literal[True] = ..., +) -> R: + ... - # ensure context variables are properly copied to the async frame - async def context_local_wrapper(): - """ - Wrapper that is submitted using copy_context().run to ensure - the RUNNING_ASYNC_FLAG mutations are tightly scoped to this coroutine's frame. - """ - token = RUNNING_ASYNC_FLAG.set(True) - try: - result = await coroutine - finally: - RUNNING_ASYNC_FLAG.reset(token) - return result - context = copy_context() - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - if loop and loop.is_running(): - with ThreadPoolExecutor() as executor: - future = executor.submit(context.run, asyncio.run, context_local_wrapper()) - result = cast(T, future.result()) - else: - result = context.run(asyncio.run, context_local_wrapper()) - return result +@overload +def run_coro_as_sync( + coroutine: Coroutine[Any, Any, R], + *, + force_new_thread: bool = ..., + wait_for_result: Literal[False] = False, +) -> R: + ... def run_coro_as_sync( - coroutine: Awaitable[R], + coroutine: Coroutine[Any, Any, R], + *, force_new_thread: bool = False, wait_for_result: bool = True, -) -> Union[R, None]: +) -> Optional[R]: """ Runs a coroutine from a synchronous context, as if it were a synchronous function. @@ -211,7 +175,7 @@ def run_coro_as_sync( The result of the coroutine if wait_for_result is True, otherwise None. """ - async def coroutine_wrapper() -> Union[R, None]: + async def coroutine_wrapper() -> Optional[R]: """ Set flags so that children (and grandchildren...) of this task know they are running in a new thread and do not try to run on the run_sync thread, which would cause a @@ -232,12 +196,13 @@ async def coroutine_wrapper() -> Union[R, None]: # that is running in the run_sync loop, we need to run this coroutine in a # new thread if in_run_sync_loop() or RUNNING_IN_RUN_SYNC_LOOP_FLAG.get() or force_new_thread: - return from_sync.call_in_new_thread(coroutine_wrapper) + result = from_sync.call_in_new_thread(coroutine_wrapper) + return result # otherwise, we can run the coroutine in the run_sync loop # and wait for the result else: - call = _cast_to_call(coroutine_wrapper) + call = cast_to_call(coroutine_wrapper) runner = get_run_sync_loop() runner.submit(call) try: @@ -250,8 +215,8 @@ async def coroutine_wrapper() -> Union[R, None]: async def run_sync_in_worker_thread( - __fn: Callable[..., T], *args: Any, **kwargs: Any -) -> T: + __fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs +) -> R: """ Runs a sync function in a new worker thread so that the main thread's event loop is not blocked. @@ -275,14 +240,14 @@ async def run_sync_in_worker_thread( RUNNING_ASYNC_FLAG.reset(token) -def call_with_mark(call): +def call_with_mark(call: Callable[..., R]) -> R: mark_as_worker_thread() return call() def run_async_from_worker_thread( - __fn: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any -) -> T: + __fn: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs +) -> R: """ Runs an async function in the main thread's event loop, blocking the worker thread until completion @@ -291,11 +256,13 @@ def run_async_from_worker_thread( return anyio.from_thread.run(call) -def run_async_in_new_loop(__fn: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any): +def run_async_in_new_loop( + __fn: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs +) -> R: return anyio.run(partial(__fn, *args, **kwargs)) -def mark_as_worker_thread(): +def mark_as_worker_thread() -> None: _thread_local.is_worker_thread = True @@ -313,23 +280,9 @@ def in_async_main_thread() -> bool: return not in_async_worker_thread() -@overload -def sync_compatible( - async_fn: Callable[..., Coroutine[Any, Any, R]], -) -> Callable[..., R]: - ... - - -@overload def sync_compatible( - async_fn: Callable[..., Coroutine[Any, Any, R]], -) -> Callable[..., Coroutine[Any, Any, R]]: - ... - - -def sync_compatible( - async_fn: Callable[..., Coroutine[Any, Any, R]], -) -> Callable[..., Union[R, Coroutine[Any, Any, R]]]: + async_fn: Callable[P, Coroutine[Any, Any, R]], +) -> Callable[P, Union[R, Coroutine[Any, Any, R]]]: """ Converts an async function into a dual async and sync function. @@ -394,7 +347,7 @@ async def ctx_call(): if _sync is True: return run_coro_as_sync(ctx_call()) - elif _sync is False or RUNNING_ASYNC_FLAG.get() or is_async: + elif RUNNING_ASYNC_FLAG.get() or is_async: return ctx_call() else: return run_coro_as_sync(ctx_call()) @@ -410,10 +363,24 @@ async def ctx_call(): return wrapper +@overload +def asyncnullcontext( + value: None = None, *args: Any, **kwargs: Any +) -> AbstractAsyncContextManager[None, None]: + ... + + +@overload +def asyncnullcontext( + value: R, *args: Any, **kwargs: Any +) -> AbstractAsyncContextManager[R, None]: + ... + + @asynccontextmanager async def asyncnullcontext( - value: Optional[Any] = None, *args: Any, **kwargs: Any -) -> AsyncGenerator[Any, None]: + value: Optional[R] = None, *args: Any, **kwargs: Any +) -> AsyncGenerator[Any, Optional[R]]: yield value @@ -429,7 +396,7 @@ def sync(__async_fn: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwarg "`sync` called from an asynchronous context; " "you should `await` the async function directly instead." ) - with anyio.start_blocking_portal() as portal: + with anyio.from_thread.start_blocking_portal() as portal: return portal.call(partial(__async_fn, *args, **kwargs)) elif in_async_worker_thread(): # In a sync context but we can access the event loop thread; send the async @@ -441,7 +408,9 @@ def sync(__async_fn: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwarg return run_async_in_new_loop(__async_fn, *args, **kwargs) -async def add_event_loop_shutdown_callback(coroutine_fn: Callable[[], Awaitable]): +async def add_event_loop_shutdown_callback( + coroutine_fn: Callable[[], Awaitable[Any]], +) -> None: """ Adds a callback to the given callable on event loop closure. The callable must be a coroutine function. It will be awaited when the current event loop is shutting @@ -457,7 +426,7 @@ async def add_event_loop_shutdown_callback(coroutine_fn: Callable[[], Awaitable] loop is about to close. """ - async def on_shutdown(key): + async def on_shutdown(key: int) -> AsyncGenerator[None, Any]: # It appears that EVENT_LOOP_GC_REFS is somehow being garbage collected early. # We hold a reference to it so as to preserve it, at least for the lifetime of # this coroutine. See the issue below for the initial report/discussion: @@ -496,7 +465,7 @@ class GatherTaskGroup(anyio.abc.TaskGroup): """ A task group that gathers results. - AnyIO does not include support `gather`. This class extends the `TaskGroup` + AnyIO does not include `gather` support. This class extends the `TaskGroup` interface to allow simple gathering. See https://github.com/agronholm/anyio/issues/100 @@ -505,21 +474,31 @@ class GatherTaskGroup(anyio.abc.TaskGroup): """ def __init__(self, task_group: anyio.abc.TaskGroup): - self._results: Dict[UUID, Any] = {} + self._results: dict[UUID, Any] = {} # The concrete task group implementation to use self._task_group: anyio.abc.TaskGroup = task_group - async def _run_and_store(self, key, fn, args): + async def _run_and_store( + self, + key: UUID, + fn: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], + ) -> None: self._results[key] = await fn(*args) - def start_soon(self, fn, *args) -> UUID: + def start_soon( # pyright: ignore[reportIncompatibleMethodOverride] + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], + name: object = None, + ) -> UUID: key = uuid4() # Put a placeholder in-case the result is retrieved earlier self._results[key] = GatherIncomplete - self._task_group.start_soon(self._run_and_store, key, fn, args) + self._task_group.start_soon(self._run_and_store, key, func, *args, name=name) return key - async def start(self, fn, *args): + async def start(self, func: object, *args: object, name: object = None) -> NoReturn: """ Since `start` returns the result of `task_status.started()` but here we must return the key instead, we just won't support this method for now. @@ -535,11 +514,11 @@ def get_result(self, key: UUID) -> Any: ) return result - async def __aenter__(self): + async def __aenter__(self) -> Self: await self._task_group.__aenter__() return self - async def __aexit__(self, *tb): + async def __aexit__(self, *tb: Any) -> Optional[bool]: try: retval = await self._task_group.__aexit__(*tb) return retval @@ -555,14 +534,14 @@ def create_gather_task_group() -> GatherTaskGroup: return GatherTaskGroup(anyio.create_task_group()) -async def gather(*calls: Callable[[], Coroutine[Any, Any, T]]) -> List[T]: +async def gather(*calls: Callable[[], Coroutine[Any, Any, T]]) -> list[T]: """ Run calls concurrently and gather their results. Unlike `asyncio.gather` this expects to receive _callables_ not _coroutines_. This matches `anyio` semantics. """ - keys = [] + keys: list[UUID] = [] async with create_gather_task_group() as tg: for call in calls: keys.append(tg.start_soon(call)) @@ -570,19 +549,23 @@ async def gather(*calls: Callable[[], Coroutine[Any, Any, T]]) -> List[T]: class LazySemaphore: - def __init__(self, initial_value_func): - self._semaphore = None + def __init__(self, initial_value_func: Callable[[], int]) -> None: + self._semaphore: Optional[asyncio.Semaphore] = None self._initial_value_func = initial_value_func - async def __aenter__(self): + async def __aenter__(self) -> asyncio.Semaphore: self._initialize_semaphore() + if TYPE_CHECKING: + assert self._semaphore is not None await self._semaphore.__aenter__() return self._semaphore - async def __aexit__(self, exc_type, exc, tb): - await self._semaphore.__aexit__(exc_type, exc, tb) + async def __aexit__(self, *args: Any) -> None: + if TYPE_CHECKING: + assert self._semaphore is not None + await self._semaphore.__aexit__(*args) - def _initialize_semaphore(self): + def _initialize_semaphore(self) -> None: if self._semaphore is None: initial_value = self._initial_value_func() self._semaphore = asyncio.Semaphore(initial_value) diff --git a/src/prefect/utilities/callables.py b/src/prefect/utilities/callables.py index 382f5cd6a224..decc4e59f362 100644 --- a/src/prefect/utilities/callables.py +++ b/src/prefect/utilities/callables.py @@ -6,14 +6,17 @@ import importlib.util import inspect import warnings +from collections import OrderedDict +from collections.abc import Iterable from functools import partial +from logging import Logger from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Optional, Union, cast import cloudpickle import pydantic from griffe import Docstring, DocstringSectionKind, Parser, parse -from typing_extensions import Literal +from typing_extensions import Literal, TypeVar from prefect._internal.pydantic.v1_schema import has_v1_type_as_param from prefect._internal.pydantic.v2_schema import ( @@ -32,15 +35,17 @@ from prefect.utilities.collections import isiterable from prefect.utilities.importtools import safe_load_namespace -logger = get_logger(__name__) +logger: Logger = get_logger(__name__) + +R = TypeVar("R", infer_variance=True) def get_call_parameters( - fn: Callable, - call_args: Tuple[Any, ...], - call_kwargs: Dict[str, Any], + fn: Callable[..., Any], + call_args: tuple[Any, ...], + call_kwargs: dict[str, Any], apply_defaults: bool = True, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Bind a call to a function to get parameter/value mapping. Default values on the signature will be included if not overridden. @@ -57,7 +62,7 @@ def get_call_parameters( function """ if hasattr(fn, "__prefect_self__"): - call_args = (fn.__prefect_self__,) + call_args + call_args = (getattr(fn, "__prefect_self__"), *call_args) try: bound_signature = inspect.signature(fn).bind(*call_args, **call_kwargs) @@ -74,14 +79,14 @@ def get_call_parameters( def get_parameter_defaults( - fn: Callable, -) -> Dict[str, Any]: + fn: Callable[..., Any], +) -> dict[str, Any]: """ Get default parameter values for a callable. """ signature = inspect.signature(fn) - parameter_defaults = {} + parameter_defaults: dict[str, Any] = {} for name, param in signature.parameters.items(): if param.default is not signature.empty: @@ -91,8 +96,8 @@ def get_parameter_defaults( def explode_variadic_parameter( - fn: Callable, parameters: Dict[str, Any] -) -> Dict[str, Any]: + fn: Callable[..., Any], parameters: dict[str, Any] +) -> dict[str, Any]: """ Given a parameter dictionary, move any parameters stored in a variadic keyword argument parameter (i.e. **kwargs) into the top level. @@ -125,8 +130,8 @@ def foo(a, b, **kwargs): def collapse_variadic_parameters( - fn: Callable, parameters: Dict[str, Any] -) -> Dict[str, Any]: + fn: Callable[..., Any], parameters: dict[str, Any] +) -> dict[str, Any]: """ Given a parameter dictionary, move any parameters stored not present in the signature into the variadic keyword argument. @@ -151,50 +156,47 @@ def foo(a, b, **kwargs): missing_parameters = set(parameters.keys()) - set(signature_parameters.keys()) - if not variadic_key and missing_parameters: + if not missing_parameters: + # no missing parameters, return parameters unchanged + return parameters + + if not variadic_key: raise ValueError( f"Signature for {fn} does not include any variadic keyword argument " "but parameters were given that are not present in the signature." ) - if variadic_key and not missing_parameters: - # variadic key is present but no missing parameters, return parameters unchanged - return parameters - new_parameters = parameters.copy() - if variadic_key: - new_parameters[variadic_key] = {} - - for key in missing_parameters: - new_parameters[variadic_key][key] = new_parameters.pop(key) - + new_parameters[variadic_key] = { + key: new_parameters.pop(key) for key in missing_parameters + } return new_parameters def parameters_to_args_kwargs( - fn: Callable, - parameters: Dict[str, Any], -) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + fn: Callable[..., Any], + parameters: dict[str, Any], +) -> tuple[tuple[Any, ...], dict[str, Any]]: """ Convert a `parameters` dictionary to positional and keyword arguments The function _must_ have an identical signature to the original function or this will return an empty tuple and dict. """ - function_params = dict(inspect.signature(fn).parameters).keys() + function_params = inspect.signature(fn).parameters.keys() # Check for parameters that are not present in the function signature unknown_params = parameters.keys() - function_params if unknown_params: raise SignatureMismatchError.from_bad_params( - list(function_params), list(parameters.keys()) + list(function_params), list(parameters) ) bound_signature = inspect.signature(fn).bind_partial() - bound_signature.arguments = parameters + bound_signature.arguments = OrderedDict(parameters) return bound_signature.args, bound_signature.kwargs -def call_with_parameters(fn: Callable, parameters: Dict[str, Any]): +def call_with_parameters(fn: Callable[..., R], parameters: dict[str, Any]) -> R: """ Call a function with parameters extracted with `get_call_parameters` @@ -207,7 +209,7 @@ def call_with_parameters(fn: Callable, parameters: Dict[str, Any]): def cloudpickle_wrapped_call( - __fn: Callable, *args: Any, **kwargs: Any + __fn: Callable[..., Any], *args: Any, **kwargs: Any ) -> Callable[[], bytes]: """ Serializes a function call using cloudpickle then returns a callable which will @@ -221,7 +223,7 @@ def cloudpickle_wrapped_call( return partial(_run_serialized_call, payload) -def _run_serialized_call(payload) -> bytes: +def _run_serialized_call(payload: bytes) -> bytes: """ Defined at the top-level so it can be pickled by the Python pickler. Used by `cloudpickle_wrapped_call`. @@ -236,18 +238,18 @@ class ParameterSchema(pydantic.BaseModel): title: Literal["Parameters"] = "Parameters" type: Literal["object"] = "object" - properties: Dict[str, Any] = pydantic.Field(default_factory=dict) - required: List[str] = pydantic.Field(default_factory=list) - definitions: Dict[str, Any] = pydantic.Field(default_factory=dict) + properties: dict[str, Any] = pydantic.Field(default_factory=dict) + required: list[str] = pydantic.Field(default_factory=list) + definitions: dict[str, Any] = pydantic.Field(default_factory=dict) - def model_dump_for_openapi(self) -> Dict[str, Any]: + def model_dump_for_openapi(self) -> dict[str, Any]: result = self.model_dump(mode="python", exclude_none=True) if "required" in result and not result["required"]: del result["required"] return result -def parameter_docstrings(docstring: Optional[str]) -> Dict[str, str]: +def parameter_docstrings(docstring: Optional[str]) -> dict[str, str]: """ Given a docstring in Google docstring format, parse the parameter section and return a dictionary that maps parameter names to docstring. @@ -258,7 +260,7 @@ def parameter_docstrings(docstring: Optional[str]) -> Dict[str, str]: Returns: Mapping from parameter names to docstrings. """ - param_docstrings = {} + param_docstrings: dict[str, str] = {} if not docstring: return param_docstrings @@ -279,9 +281,9 @@ def process_v1_params( param: inspect.Parameter, *, position: int, - docstrings: Dict[str, str], - aliases: Dict, -) -> Tuple[str, Any, "pydantic.Field"]: + docstrings: dict[str, str], + aliases: dict[str, str], +) -> tuple[str, Any, Any]: # Pydantic model creation will fail if names collide with the BaseModel type if hasattr(pydantic.BaseModel, param.name): name = param.name + "__" @@ -289,13 +291,13 @@ def process_v1_params( else: name = param.name - type_ = Any if param.annotation is inspect._empty else param.annotation + type_ = Any if param.annotation is inspect.Parameter.empty else param.annotation with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=pydantic.warnings.PydanticDeprecatedSince20 ) - field = pydantic.Field( + field: Any = pydantic.Field( default=... if param.default is param.empty else param.default, title=param.name, description=docstrings.get(param.name, None), @@ -305,19 +307,22 @@ def process_v1_params( return name, type_, field -def create_v1_schema(name_: str, model_cfg, **model_fields): +def create_v1_schema( + name_: str, model_cfg: type[Any], model_fields: Optional[dict[str, Any]] = None +) -> dict[str, Any]: with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=pydantic.warnings.PydanticDeprecatedSince20 ) - model: "pydantic.BaseModel" = pydantic.create_model( + model_fields = model_fields or {} + model: type[pydantic.BaseModel] = pydantic.create_model( name_, __config__=model_cfg, **model_fields ) return model.schema(by_alias=True) -def parameter_schema(fn: Callable) -> ParameterSchema: +def parameter_schema(fn: Callable[..., Any]) -> ParameterSchema: """Given a function, generates an OpenAPI-compatible description of the function's arguments, including: - name @@ -378,7 +383,7 @@ def parameter_schema_from_entrypoint(entrypoint: str) -> ParameterSchema: def generate_parameter_schema( - signature: inspect.Signature, docstrings: Dict[str, str] + signature: inspect.Signature, docstrings: dict[str, str] ) -> ParameterSchema: """ Generate a parameter schema from a function signature and docstrings. @@ -394,22 +399,22 @@ def generate_parameter_schema( ParameterSchema: The parameter schema. """ - model_fields = {} - aliases = {} + model_fields: dict[str, Any] = {} + aliases: dict[str, str] = {} if not has_v1_type_as_param(signature): - create_schema = create_v2_schema + config = pydantic.ConfigDict(arbitrary_types_allowed=True) + + create_schema = partial(create_v2_schema, model_cfg=config) process_params = process_v2_params - config = pydantic.ConfigDict(arbitrary_types_allowed=True) else: - create_schema = create_v1_schema - process_params = process_v1_params class ModelConfig: arbitrary_types_allowed = True - config = ModelConfig + create_schema = partial(create_v1_schema, model_cfg=ModelConfig) + process_params = process_v1_params for position, param in enumerate(signature.parameters.values()): name, type_, field = process_params( @@ -418,24 +423,26 @@ class ModelConfig: # Generate a Pydantic model at each step so we can check if this parameter # type supports schema generation try: - create_schema("CheckParameter", model_cfg=config, **{name: (type_, field)}) + create_schema("CheckParameter", model_fields={name: (type_, field)}) except (ValueError, TypeError): # This field's type is not valid for schema creation, update it to `Any` type_ = Any model_fields[name] = (type_, field) # Generate the final model and schema - schema = create_schema("Parameters", model_cfg=config, **model_fields) + schema = create_schema("Parameters", model_fields=model_fields) return ParameterSchema(**schema) -def raise_for_reserved_arguments(fn: Callable, reserved_arguments: Iterable[str]): +def raise_for_reserved_arguments( + fn: Callable[..., Any], reserved_arguments: Iterable[str] +) -> None: """Raise a ReservedArgumentError if `fn` has any parameters that conflict with the names contained in `reserved_arguments`.""" - function_paremeters = inspect.signature(fn).parameters + function_parameters = inspect.signature(fn).parameters for argument in reserved_arguments: - if argument in function_paremeters: + if argument in function_parameters: raise ReservedArgumentError( f"{argument!r} is a reserved argument name and cannot be used." ) @@ -479,7 +486,7 @@ def _generate_signature_from_source( ) if func_def is None: raise ValueError(f"Function {func_name} not found in source code") - parameters = [] + parameters: list[inspect.Parameter] = [] # Handle annotations for positional only args e.g. def func(a, /, b, c) for arg in func_def.args.posonlyargs: @@ -642,8 +649,8 @@ def _get_docstring_from_source(source_code: str, func_name: str) -> Optional[str def expand_mapping_parameters( - func: Callable, parameters: Dict[str, Any] -) -> List[Dict[str, Any]]: + func: Callable[..., Any], parameters: dict[str, Any] +) -> list[dict[str, Any]]: """ Generates a list of call parameters to be used for individual calls in a mapping operation. @@ -653,29 +660,29 @@ def expand_mapping_parameters( parameters: A dictionary of parameters with iterables to be mapped over Returns: - List: A list of dictionaries to be used as parameters for each + list: A list of dictionaries to be used as parameters for each call in the mapping operation """ # Ensure that any parameters in kwargs are expanded before this check parameters = explode_variadic_parameter(func, parameters) - iterable_parameters = {} - static_parameters = {} - annotated_parameters = {} + iterable_parameters: dict[str, list[Any]] = {} + static_parameters: dict[str, Any] = {} + annotated_parameters: dict[str, Union[allow_failure[Any], quote[Any]]] = {} for key, val in parameters.items(): if isinstance(val, (allow_failure, quote)): # Unwrap annotated parameters to determine if they are iterable annotated_parameters[key] = val - val = val.unwrap() + val: Any = val.unwrap() if isinstance(val, unmapped): - static_parameters[key] = val.value + static_parameters[key] = cast(unmapped[Any], val).value elif isiterable(val): iterable_parameters[key] = list(val) else: static_parameters[key] = val - if not len(iterable_parameters): + if not iterable_parameters: raise MappingMissingIterable( "No iterable parameters were received. Parameters for map must " f"include at least one iterable. Parameters: {parameters}" @@ -693,7 +700,7 @@ def expand_mapping_parameters( map_length = list(lengths)[0] - call_parameters_list = [] + call_parameters_list: list[dict[str, Any]] = [] for i in range(map_length): call_parameters = {key: value[i] for key, value in iterable_parameters.items()} call_parameters.update({key: value for key, value in static_parameters.items()}) diff --git a/src/prefect/utilities/collections.py b/src/prefect/utilities/collections.py index 5b2ae453c67e..3588b6f48a73 100644 --- a/src/prefect/utilities/collections.py +++ b/src/prefect/utilities/collections.py @@ -6,33 +6,40 @@ import itertools import types import warnings -from collections import OrderedDict, defaultdict -from collections.abc import Iterator as IteratorABC -from collections.abc import Sequence -from dataclasses import fields, is_dataclass -from enum import Enum, auto -from typing import ( - Any, +from collections import OrderedDict +from collections.abc import ( Callable, - Dict, + Collection, Generator, Hashable, Iterable, - List, - Optional, + Iterator, + Sequence, Set, - Tuple, - Type, - TypeVar, +) +from dataclasses import fields, is_dataclass, replace +from enum import Enum, auto +from typing import ( + TYPE_CHECKING, + Any, + Literal, + Optional, Union, cast, + overload, ) from unittest.mock import Mock import pydantic +from typing_extensions import TypeAlias, TypeVar # Quote moved to `prefect.utilities.annotations` but preserved here for compatibility -from prefect.utilities.annotations import BaseAnnotation, Quote, quote # noqa +from prefect.utilities.annotations import BaseAnnotation as BaseAnnotation +from prefect.utilities.annotations import Quote as Quote +from prefect.utilities.annotations import quote as quote + +if TYPE_CHECKING: + pass class AutoEnum(str, Enum): @@ -55,11 +62,12 @@ class MyEnum(AutoEnum): ``` """ - def _generate_next_value_(name, start, count, last_values): + @staticmethod + def _generate_next_value_(name: str, *_: object, **__: object) -> str: return name @staticmethod - def auto(): + def auto() -> str: """ Exposes `enum.auto()` to avoid requiring a second import to use `AutoEnum` """ @@ -70,12 +78,15 @@ def __repr__(self) -> str: KT = TypeVar("KT") -VT = TypeVar("VT") +VT = TypeVar("VT", infer_variance=True) +VT1 = TypeVar("VT1", infer_variance=True) +VT2 = TypeVar("VT2", infer_variance=True) +R = TypeVar("R", infer_variance=True) +NestedDict: TypeAlias = dict[KT, Union[VT, "NestedDict[KT, VT]"]] +HashableT = TypeVar("HashableT", bound=Hashable) -def dict_to_flatdict( - dct: Dict[KT, Union[Any, Dict[KT, Any]]], _parent: Tuple[KT, ...] = None -) -> Dict[Tuple[KT, ...], Any]: +def dict_to_flatdict(dct: NestedDict[KT, VT]) -> dict[tuple[KT, ...], VT]: """Converts a (nested) dictionary to a flattened representation. Each key of the flat dict will be a CompoundKey tuple containing the "chain of keys" @@ -83,28 +94,28 @@ def dict_to_flatdict( Args: dct (dict): The dictionary to flatten - _parent (Tuple, optional): The current parent for recursion Returns: A flattened dict of the same type as dct """ - typ = cast(Type[Dict[Tuple[KT, ...], Any]], type(dct)) - items: List[Tuple[Tuple[KT, ...], Any]] = [] - parent = _parent or tuple() - - for k, v in dct.items(): - k_parent = tuple(parent + (k,)) - # if v is a non-empty dict, recurse - if isinstance(v, dict) and v: - items.extend(dict_to_flatdict(v, _parent=k_parent).items()) - else: - items.append((k_parent, v)) - return typ(items) + def flatten( + dct: NestedDict[KT, VT], _parent: tuple[KT, ...] = () + ) -> Iterator[tuple[tuple[KT, ...], VT]]: + parent = _parent or () + for k, v in dct.items(): + k_parent = (*parent, k) + # if v is a non-empty dict, recurse + if isinstance(v, dict) and v: + yield from flatten(cast(NestedDict[KT, VT], v), _parent=k_parent) + else: + yield (k_parent, cast(VT, v)) -def flatdict_to_dict( - dct: Dict[Tuple[KT, ...], VT], -) -> Dict[KT, Union[VT, Dict[KT, VT]]]: + type_ = cast(type[dict[tuple[KT, ...], VT]], type(dct)) + return type_(flatten(dct)) + + +def flatdict_to_dict(dct: dict[tuple[KT, ...], VT]) -> NestedDict[KT, VT]: """Converts a flattened dictionary back to a nested dictionary. Args: @@ -114,16 +125,26 @@ def flatdict_to_dict( Returns A nested dict of the same type as dct """ - typ = type(dct) - result = cast(Dict[KT, Union[VT, Dict[KT, VT]]], typ()) + + type_ = cast(type[NestedDict[KT, VT]], type(dct)) + + def new(type_: type[NestedDict[KT, VT]] = type_) -> NestedDict[KT, VT]: + return type_() + + result = new() for key_tuple, value in dct.items(): - current_dict = result - for prefix_key in key_tuple[:-1]: + current = result + *prefix_keys, last_key = key_tuple + for prefix_key in prefix_keys: # Build nested dictionaries up for the current key tuple - # Use `setdefault` in case the nested dict has already been created - current_dict = current_dict.setdefault(prefix_key, typ()) # type: ignore + try: + current = cast(NestedDict[KT, VT], current[prefix_key]) + except KeyError: + new_dict = current[prefix_key] = new() + current = new_dict + # Set the value - current_dict[key_tuple[-1]] = value + current[last_key] = value return result @@ -148,9 +169,9 @@ def isiterable(obj: Any) -> bool: return not isinstance(obj, (str, bytes, io.IOBase)) -def ensure_iterable(obj: Union[T, Iterable[T]]) -> Iterable[T]: +def ensure_iterable(obj: Union[T, Iterable[T]]) -> Collection[T]: if isinstance(obj, Sequence) or isinstance(obj, Set): - return obj + return cast(Collection[T], obj) obj = cast(T, obj) # No longer in the iterable case return [obj] @@ -160,9 +181,9 @@ def listrepr(objs: Iterable[Any], sep: str = " ") -> str: def extract_instances( - objects: Iterable, - types: Union[Type[T], Tuple[Type[T], ...]] = object, -) -> Union[List[T], Dict[Type[T], T]]: + objects: Iterable[Any], + types: Union[type[T], tuple[type[T], ...]] = object, +) -> Union[list[T], dict[type[T], list[T]]]: """ Extract objects from a file and returns a dict of type -> instances @@ -174,26 +195,27 @@ def extract_instances( If a single type is given: a list of instances of that type If a tuple of types is given: a mapping of type to a list of instances """ - types = ensure_iterable(types) + types_collection = ensure_iterable(types) # Create a mapping of type -> instance from the exec values - ret = defaultdict(list) + ret: dict[type[T], list[Any]] = {} for o in objects: # We iterate here so that the key is the passed type rather than type(o) - for type_ in types: + for type_ in types_collection: if isinstance(o, type_): - ret[type_].append(o) + ret.setdefault(type_, []).append(o) - if len(types) == 1: - return ret[types[0]] + if len(types_collection) == 1: + [type_] = types_collection + return ret[type_] return ret def batched_iterable( iterable: Iterable[T], size: int -) -> Generator[Tuple[T, ...], None, None]: +) -> Generator[tuple[T, ...], None, None]: """ Yield batches of a certain size from an iterable @@ -221,15 +243,86 @@ class StopVisiting(BaseException): """ +@overload +def visit_collection( + expr: Any, + visit_fn: Callable[[Any, dict[str, VT]], Any], + *, + return_data: Literal[True] = ..., + max_depth: int = ..., + context: dict[str, VT] = ..., + remove_annotations: bool = ..., + _seen: Optional[set[int]] = ..., +) -> Any: + ... + + +@overload +def visit_collection( + expr: Any, + visit_fn: Callable[[Any], Any], + *, + return_data: Literal[True] = ..., + max_depth: int = ..., + context: None = None, + remove_annotations: bool = ..., + _seen: Optional[set[int]] = ..., +) -> Any: + ... + + +@overload +def visit_collection( + expr: Any, + visit_fn: Callable[[Any, dict[str, VT]], Any], + *, + return_data: bool = ..., + max_depth: int = ..., + context: dict[str, VT] = ..., + remove_annotations: bool = ..., + _seen: Optional[set[int]] = ..., +) -> Optional[Any]: + ... + + +@overload +def visit_collection( + expr: Any, + visit_fn: Callable[[Any], Any], + *, + return_data: bool = ..., + max_depth: int = ..., + context: None = None, + remove_annotations: bool = ..., + _seen: Optional[set[int]] = ..., +) -> Optional[Any]: + ... + + +@overload +def visit_collection( + expr: Any, + visit_fn: Callable[[Any, dict[str, VT]], Any], + *, + return_data: Literal[False] = False, + max_depth: int = ..., + context: dict[str, VT] = ..., + remove_annotations: bool = ..., + _seen: Optional[set[int]] = ..., +) -> None: + ... + + def visit_collection( expr: Any, - visit_fn: Union[Callable[[Any, Optional[dict]], Any], Callable[[Any], Any]], + visit_fn: Union[Callable[[Any, dict[str, VT]], Any], Callable[[Any], Any]], + *, return_data: bool = False, max_depth: int = -1, - context: Optional[dict] = None, + context: Optional[dict[str, VT]] = None, remove_annotations: bool = False, - _seen: Optional[Set[int]] = None, -) -> Any: + _seen: Optional[set[int]] = None, +) -> Optional[Any]: """ Visits and potentially transforms every element of an arbitrary Python collection. @@ -289,24 +382,39 @@ def visit_collection( if _seen is None: _seen = set() - def visit_nested(expr): - # Utility for a recursive call, preserving options and updating the depth. - return visit_collection( - expr, - visit_fn=visit_fn, - return_data=return_data, - remove_annotations=remove_annotations, - max_depth=max_depth - 1, - # Copy the context on nested calls so it does not "propagate up" - context=context.copy() if context is not None else None, - _seen=_seen, - ) - - def visit_expression(expr): - if context is not None: - return visit_fn(expr, context) - else: - return visit_fn(expr) + if context is not None: + _callback = cast(Callable[[Any, dict[str, VT]], Any], visit_fn) + + def visit_nested(expr: Any) -> Optional[Any]: + return visit_collection( + expr, + _callback, + return_data=return_data, + remove_annotations=remove_annotations, + max_depth=max_depth - 1, + # Copy the context on nested calls so it does not "propagate up" + context=context.copy(), + _seen=_seen, + ) + + def visit_expression(expr: Any) -> Any: + return _callback(expr, context) + else: + _callback = cast(Callable[[Any], Any], visit_fn) + + def visit_nested(expr: Any) -> Optional[Any]: + # Utility for a recursive call, preserving options and updating the depth. + return visit_collection( + expr, + _callback, + return_data=return_data, + remove_annotations=remove_annotations, + max_depth=max_depth - 1, + _seen=_seen, + ) + + def visit_expression(expr: Any) -> Any: + return _callback(expr) # --- 1. Visit every expression try: @@ -329,10 +437,6 @@ def visit_expression(expr): else: _seen.add(id(expr)) - # Get the expression type; treat iterators like lists - typ = list if isinstance(expr, IteratorABC) and isiterable(expr) else type(expr) - typ = cast(type, typ) # mypy treats this as 'object' otherwise and complains - # Then visit every item in the expression if it is a collection # presume that the result is the original expression. @@ -354,9 +458,10 @@ def visit_expression(expr): # --- Annotations (unmapped, quote, etc.) elif isinstance(expr, BaseAnnotation): + annotated = cast(BaseAnnotation[Any], expr) if context is not None: - context["annotation"] = expr - unwrapped = expr.unwrap() + context["annotation"] = cast(VT, annotated) + unwrapped = annotated.unwrap() value = visit_nested(unwrapped) if return_data: @@ -365,47 +470,49 @@ def visit_expression(expr): result = value # if the value was modified, rewrap it elif value is not unwrapped: - result = expr.rewrap(value) + result = annotated.rewrap(value) # otherwise return the expr # --- Sequences elif isinstance(expr, (list, tuple, set)): - items = [visit_nested(o) for o in expr] + seq = cast(Union[list[Any], tuple[Any], set[Any]], expr) + items = [visit_nested(o) for o in seq] if return_data: - modified = any(item is not orig for item, orig in zip(items, expr)) + modified = any(item is not orig for item, orig in zip(items, seq)) if modified: - result = typ(items) + result = type(seq)(items) # --- Dictionaries - elif typ in (dict, OrderedDict): - assert isinstance(expr, (dict, OrderedDict)) # typecheck assertion - items = [(visit_nested(k), visit_nested(v)) for k, v in expr.items()] + elif isinstance(expr, (dict, OrderedDict)): + mapping = cast(dict[Any, Any], expr) + items = [(visit_nested(k), visit_nested(v)) for k, v in mapping.items()] if return_data: modified = any( k1 is not k2 or v1 is not v2 - for (k1, v1), (k2, v2) in zip(items, expr.items()) + for (k1, v1), (k2, v2) in zip(items, mapping.items()) ) if modified: - result = typ(items) + result = type(mapping)(items) # --- Dataclasses elif is_dataclass(expr) and not isinstance(expr, type): - values = [visit_nested(getattr(expr, f.name)) for f in fields(expr)] + expr_fields = fields(expr) + values = [visit_nested(getattr(expr, f.name)) for f in expr_fields] if return_data: modified = any( - getattr(expr, f.name) is not v for f, v in zip(fields(expr), values) + getattr(expr, f.name) is not v for f, v in zip(expr_fields, values) ) if modified: - result = typ(**{f.name: v for f, v in zip(fields(expr), values)}) + result = replace( + expr, **{f.name: v for f, v in zip(expr_fields, values)} + ) # --- Pydantic models elif isinstance(expr, pydantic.BaseModel): - typ = cast(Type[pydantic.BaseModel], typ) - # when extra=allow, fields not in model_fields may be in model_fields_set model_fields = expr.model_fields_set.union(expr.model_fields.keys()) @@ -424,7 +531,7 @@ def visit_expression(expr): ) if modified: # Use construct to avoid validation and handle immutability - model_instance = typ.model_construct( + model_instance = expr.model_construct( _fields_set=expr.model_fields_set, **updated_data ) for private_attr in expr.__private_attributes__: @@ -435,7 +542,21 @@ def visit_expression(expr): return result -def remove_nested_keys(keys_to_remove: List[Hashable], obj): +@overload +def remove_nested_keys( + keys_to_remove: list[HashableT], obj: NestedDict[HashableT, VT] +) -> NestedDict[HashableT, VT]: + ... + + +@overload +def remove_nested_keys(keys_to_remove: list[HashableT], obj: Any) -> Any: + ... + + +def remove_nested_keys( + keys_to_remove: list[HashableT], obj: Union[NestedDict[HashableT, VT], Any] +) -> Union[NestedDict[HashableT, VT], Any]: """ Recurses a dictionary returns a copy without all keys that match an entry in `key_to_remove`. Return `obj` unchanged if not a dictionary. @@ -452,24 +573,56 @@ def remove_nested_keys(keys_to_remove: List[Hashable], obj): return obj return { key: remove_nested_keys(keys_to_remove, value) - for key, value in obj.items() + for key, value in cast(NestedDict[HashableT, VT], obj).items() if key not in keys_to_remove } +@overload +def distinct(iterable: Iterable[HashableT], key: None = None) -> Iterator[HashableT]: + ... + + +@overload +def distinct(iterable: Iterable[T], key: Callable[[T], Hashable]) -> Iterator[T]: + ... + + def distinct( - iterable: Iterable[T], - key: Callable[[T], Any] = (lambda i: i), -) -> Generator[T, None, None]: - seen: Set = set() + iterable: Iterable[Union[T, HashableT]], + key: Optional[Callable[[T], Hashable]] = None, +) -> Iterator[Union[T, HashableT]]: + def _key(__i: Any) -> Hashable: + return __i + + if key is not None: + _key = cast(Callable[[Any], Hashable], key) + + seen: set[Hashable] = set() for item in iterable: - if key(item) in seen: + if _key(item) in seen: continue - seen.add(key(item)) + seen.add(_key(item)) yield item -def get_from_dict(dct: Dict, keys: Union[str, List[str]], default: Any = None) -> Any: +@overload +def get_from_dict( + dct: NestedDict[str, VT], keys: Union[str, list[str]], default: None = None +) -> Optional[VT]: + ... + + +@overload +def get_from_dict( + dct: NestedDict[str, VT], keys: Union[str, list[str]], default: R +) -> Union[VT, R]: + ... + + +def get_from_dict( + dct: NestedDict[str, VT], keys: Union[str, list[str]], default: Optional[R] = None +) -> Union[VT, R, None]: """ Fetch a value from a nested dictionary or list using a sequence of keys. @@ -500,6 +653,7 @@ def get_from_dict(dct: Dict, keys: Union[str, List[str]], default: Any = None) - """ if isinstance(keys, str): keys = keys.replace("[", ".").replace("]", "").split(".") + value = dct try: for key in keys: try: @@ -509,13 +663,15 @@ def get_from_dict(dct: Dict, keys: Union[str, List[str]], default: Any = None) - # If it's not an int, use the key as-is # for dict lookup pass - dct = dct[key] - return dct + value = value[key] # type: ignore + return cast(VT, value) except (TypeError, KeyError, IndexError): return default -def set_in_dict(dct: Dict, keys: Union[str, List[str]], value: Any): +def set_in_dict( + dct: NestedDict[str, VT], keys: Union[str, list[str]], value: VT +) -> None: """ Sets a value in a nested dictionary using a sequence of keys. @@ -543,11 +699,13 @@ def set_in_dict(dct: Dict, keys: Union[str, List[str]], value: Any): raise TypeError(f"Key path exists and contains a non-dict value: {keys}") if k not in dct: dct[k] = {} - dct = dct[k] + dct = cast(NestedDict[str, VT], dct[k]) dct[keys[-1]] = value -def deep_merge(dct: Dict, merge: Dict): +def deep_merge( + dct: NestedDict[str, VT1], merge: NestedDict[str, VT2] +) -> NestedDict[str, Union[VT1, VT2]]: """ Recursively merges `merge` into `dct`. @@ -558,18 +716,21 @@ def deep_merge(dct: Dict, merge: Dict): Returns: A new dictionary with the merged contents. """ - result = dct.copy() # Start with keys and values from `dct` + result: dict[str, Any] = dct.copy() # Start with keys and values from `dct` for key, value in merge.items(): if key in result and isinstance(result[key], dict) and isinstance(value, dict): # If both values are dictionaries, merge them recursively - result[key] = deep_merge(result[key], value) + result[key] = deep_merge( + cast(NestedDict[str, VT1], result[key]), + cast(NestedDict[str, VT2], value), + ) else: # Otherwise, overwrite with the new value - result[key] = value + result[key] = cast(Union[VT2, NestedDict[str, VT2]], value) return result -def deep_merge_dicts(*dicts): +def deep_merge_dicts(*dicts: NestedDict[str, Any]) -> NestedDict[str, Any]: """ Recursively merges multiple dictionaries. @@ -579,7 +740,7 @@ def deep_merge_dicts(*dicts): Returns: A new dictionary with the merged contents. """ - result = {} + result: NestedDict[str, Any] = {} for dictionary in dicts: result = deep_merge(result, dictionary) return result diff --git a/src/prefect/utilities/compat.py b/src/prefect/utilities/compat.py index 3eadafb3edf3..6bf8f34c46ca 100644 --- a/src/prefect/utilities/compat.py +++ b/src/prefect/utilities/compat.py @@ -3,29 +3,21 @@ """ # Please organize additions to this file by version -import asyncio import sys -from shutil import copytree -from signal import raise_signal if sys.version_info < (3, 10): - import importlib_metadata - from importlib_metadata import EntryPoint, EntryPoints, entry_points + import importlib_metadata as importlib_metadata + from importlib_metadata import ( + EntryPoint as EntryPoint, + EntryPoints as EntryPoints, + entry_points as entry_points, + ) else: - import importlib.metadata as importlib_metadata - from importlib.metadata import EntryPoint, EntryPoints, entry_points + import importlib.metadata + from importlib.metadata import ( + EntryPoint as EntryPoint, + EntryPoints as EntryPoints, + entry_points as entry_points, + ) -if sys.version_info < (3, 9): - # https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread - - import functools - - async def asyncio_to_thread(fn, *args, **kwargs): - loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, functools.partial(fn, *args, **kwargs)) - -else: - from asyncio import to_thread as asyncio_to_thread - -if sys.platform != "win32": - from asyncio import ThreadedChildWatcher + importlib_metadata = importlib.metadata diff --git a/src/prefect/utilities/context.py b/src/prefect/utilities/context.py index 3bd87f975f40..c475f3eea453 100644 --- a/src/prefect/utilities/context.py +++ b/src/prefect/utilities/context.py @@ -1,6 +1,7 @@ +from collections.abc import Generator from contextlib import contextmanager from contextvars import Context, ContextVar, Token -from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Optional, cast from uuid import UUID if TYPE_CHECKING: @@ -8,8 +9,8 @@ @contextmanager -def temporary_context(context: Context): - tokens: Dict[ContextVar, Token] = {} +def temporary_context(context: Context) -> Generator[None, Any, None]: + tokens: dict[ContextVar[Any], Token[Any]] = {} for key, value in context.items(): token = key.set(value) tokens[key] = token @@ -38,11 +39,11 @@ def get_flow_run_id() -> Optional[UUID]: return None -def get_task_and_flow_run_ids() -> Tuple[Optional[UUID], Optional[UUID]]: +def get_task_and_flow_run_ids() -> tuple[Optional[UUID], Optional[UUID]]: """ Get the task run and flow run ids from the context, if available. Returns: - Tuple[Optional[UUID], Optional[UUID]]: a tuple of the task run id and flow run id + tuple[Optional[UUID], Optional[UUID]]: a tuple of the task run id and flow run id """ return get_task_run_id(), get_flow_run_id() diff --git a/src/prefect/utilities/dispatch.py b/src/prefect/utilities/dispatch.py index 599a9afafda4..603c24aecbd1 100644 --- a/src/prefect/utilities/dispatch.py +++ b/src/prefect/utilities/dispatch.py @@ -23,28 +23,39 @@ class Foo(Base): import abc import inspect import warnings -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Literal, Optional, TypeVar, overload -T = TypeVar("T", bound=Type) +T = TypeVar("T", bound=type[Any]) -_TYPE_REGISTRIES: Dict[Type, Dict[str, Type]] = {} +_TYPE_REGISTRIES: dict[Any, dict[str, Any]] = {} -def get_registry_for_type(cls: T) -> Optional[Dict[str, T]]: +def get_registry_for_type(cls: T) -> Optional[dict[str, T]]: """ Get the first matching registry for a class or any of its base classes. If not found, `None` is returned. """ return next( - filter( - lambda registry: registry is not None, - (_TYPE_REGISTRIES.get(cls) for cls in cls.mro()), - ), + (reg for cls in cls.mro() if (reg := _TYPE_REGISTRIES.get(cls)) is not None), None, ) +@overload +def get_dispatch_key( + cls_or_instance: Any, allow_missing: Literal[False] = False +) -> str: + ... + + +@overload +def get_dispatch_key( + cls_or_instance: Any, allow_missing: Literal[True] = ... +) -> Optional[str]: + ... + + def get_dispatch_key( cls_or_instance: Any, allow_missing: bool = False ) -> Optional[str]: @@ -89,14 +100,14 @@ def get_dispatch_key( @classmethod -def _register_subclass_of_base_type(cls, **kwargs): +def _register_subclass_of_base_type(cls: type[Any], **kwargs: Any) -> None: if hasattr(cls, "__init_subclass_original__"): cls.__init_subclass_original__(**kwargs) elif hasattr(cls, "__pydantic_init_subclass_original__"): cls.__pydantic_init_subclass_original__(**kwargs) # Do not register abstract base classes - if abc.ABC in getattr(cls, "__bases__", []): + if abc.ABC in cls.__bases__: return register_type(cls) @@ -123,7 +134,7 @@ def register_base_type(cls: T) -> T: cls.__pydantic_init_subclass__ = _register_subclass_of_base_type else: cls.__init_subclass_original__ = getattr(cls, "__init_subclass__") - cls.__init_subclass__ = _register_subclass_of_base_type + setattr(cls, "__init_subclass__", _register_subclass_of_base_type) return cls @@ -190,7 +201,7 @@ def lookup_type(cls: T, dispatch_key: str) -> T: Look up a dispatch key in the type registry for the given class. """ # Get the first matching registry for the class or one of its bases - registry = get_registry_for_type(cls) + registry = get_registry_for_type(cls) or {} # Look up this type in the registry subcls = registry.get(dispatch_key) diff --git a/src/prefect/utilities/dockerutils.py b/src/prefect/utilities/dockerutils.py index 72575d81d3a0..8f38cf7fa786 100644 --- a/src/prefect/utilities/dockerutils.py +++ b/src/prefect/utilities/dockerutils.py @@ -3,22 +3,12 @@ import shutil import sys import warnings +from collections.abc import Generator, Iterable, Iterator from contextlib import contextmanager from pathlib import Path, PurePosixPath from tempfile import TemporaryDirectory from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Generator, - Iterable, - List, - Optional, - TextIO, - Tuple, - Type, - Union, -) +from typing import TYPE_CHECKING, Any, Optional, TextIO, Union, cast from urllib.parse import urlsplit import pendulum @@ -29,6 +19,12 @@ from prefect.utilities.importtools import lazy_import from prefect.utilities.slugify import slugify +if TYPE_CHECKING: + import docker + import docker.errors + from docker import DockerClient + from docker.models.images import Image + CONTAINER_LABELS = { "io.prefect.version": prefect.__version__, } @@ -102,10 +98,7 @@ def silence_docker_warnings() -> Generator[None, None, None]: # want to have those popping up in various modules and test suites. Instead, # consolidate the imports we need here, and expose them via this module. with silence_docker_warnings(): - if TYPE_CHECKING: - import docker - from docker import DockerClient - else: + if not TYPE_CHECKING: docker = lazy_import("docker") @@ -123,7 +116,8 @@ def docker_client() -> Generator["DockerClient", None, None]: "This error is often thrown because Docker is not running. Please ensure Docker is running." ) from exc finally: - client is not None and client.close() + if client is not None: + client.close() class BuildError(Exception): @@ -207,14 +201,14 @@ class ImageBuilder: base_directory: Path context: Optional[Path] platform: Optional[str] - dockerfile_lines: List[str] + dockerfile_lines: list[str] def __init__( self, base_image: str, - base_directory: Path = None, + base_directory: Optional[Path] = None, platform: Optional[str] = None, - context: Path = None, + context: Optional[Path] = None, ): """Create an ImageBuilder @@ -250,7 +244,7 @@ def __enter__(self) -> Self: return self def __exit__( - self, exc: Type[BaseException], value: BaseException, traceback: TracebackType + self, exc: type[BaseException], value: BaseException, traceback: TracebackType ) -> None: if not self.temporary_directory: return @@ -315,6 +309,7 @@ def build( Returns: The image ID """ + assert self.context is not None dockerfile_path: Path = self.context / "Dockerfile" with dockerfile_path.open("w") as dockerfile: @@ -436,9 +431,12 @@ def push_image( repository = f"{registry}/{name}" with docker_client() as client: - image: "docker.Image" = client.images.get(image_id) + image: "Image" = client.images.get(image_id) image.tag(repository, tag=tag) - events = client.api.push(repository, tag=tag, stream=True, decode=True) + events = cast( + Iterator[dict[str, Any]], + client.api.push(repository, tag=tag, stream=True, decode=True), + ) try: for event in events: if "status" in event: @@ -457,7 +455,7 @@ def push_image( return f"{repository}:{tag}" -def to_run_command(command: List[str]) -> str: +def to_run_command(command: list[str]) -> str: """ Convert a process-style list of command arguments to a single Dockerfile RUN instruction. @@ -481,7 +479,7 @@ def to_run_command(command: List[str]) -> str: return run_command -def parse_image_tag(name: str) -> Tuple[str, Optional[str]]: +def parse_image_tag(name: str) -> tuple[str, Optional[str]]: """ Parse Docker Image String @@ -519,7 +517,7 @@ def parse_image_tag(name: str) -> Tuple[str, Optional[str]]: return image_name, tag -def split_repository_path(repository_path: str) -> Tuple[Optional[str], str]: +def split_repository_path(repository_path: str) -> tuple[Optional[str], str]: """ Splits a Docker repository path into its namespace and repository components. @@ -580,7 +578,7 @@ def generate_default_dockerfile(context: Optional[Path] = None): """ if not context: context = Path.cwd() - lines = [] + lines: list[str] = [] base_image = get_prefect_image_name() lines.append(f"FROM {base_image}") dir_name = context.name diff --git a/src/prefect/utilities/engine.py b/src/prefect/utilities/engine.py index fd94104d3773..dc2712b37d82 100644 --- a/src/prefect/utilities/engine.py +++ b/src/prefect/utilities/engine.py @@ -4,18 +4,9 @@ import os import signal import time +from collections.abc import Callable, Iterable from functools import partial -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Iterable, - Optional, - Set, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from uuid import UUID, uuid4 import anyio @@ -69,12 +60,12 @@ from prefect.client.orchestration import PrefectClient, SyncPrefectClient API_HEALTHCHECKS = {} -UNTRACKABLE_TYPES = {bool, type(None), type(...), type(NotImplemented)} +UNTRACKABLE_TYPES: set[type[Any]] = {bool, type(None), type(...), type(NotImplemented)} engine_logger = get_logger("engine") T = TypeVar("T") -async def collect_task_run_inputs(expr: Any, max_depth: int = -1) -> Set[TaskRunInput]: +async def collect_task_run_inputs(expr: Any, max_depth: int = -1) -> set[TaskRunInput]: """ This function recurses through an expression to generate a set of any discernible task run inputs it finds in the data structure. It produces a set of all inputs @@ -87,14 +78,11 @@ async def collect_task_run_inputs(expr: Any, max_depth: int = -1) -> Set[TaskRun """ # TODO: This function needs to be updated to detect parameters and constants - inputs = set() - futures = set() + inputs: set[TaskRunInput] = set() - def add_futures_and_states_to_inputs(obj): + def add_futures_and_states_to_inputs(obj: Any) -> None: if isinstance(obj, PrefectFuture): - # We need to wait for futures to be submitted before we can get the task - # run id but we want to do so asynchronously - futures.add(obj) + inputs.add(TaskRunResult(id=obj.task_run_id)) elif isinstance(obj, State): if obj.state_details.task_run_id: inputs.add(TaskRunResult(id=obj.state_details.task_run_id)) @@ -113,16 +101,12 @@ def add_futures_and_states_to_inputs(obj): max_depth=max_depth, ) - await asyncio.gather(*[future._wait_for_submission() for future in futures]) - for future in futures: - inputs.add(TaskRunResult(id=future.task_run.id)) - return inputs def collect_task_run_inputs_sync( expr: Any, future_cls: Any = PrefectFuture, max_depth: int = -1 -) -> Set[TaskRunInput]: +) -> set[TaskRunInput]: """ This function recurses through an expression to generate a set of any discernible task run inputs it finds in the data structure. It produces a set of all inputs @@ -135,9 +119,9 @@ def collect_task_run_inputs_sync( """ # TODO: This function needs to be updated to detect parameters and constants - inputs = set() + inputs: set[TaskRunInput] = set() - def add_futures_and_states_to_inputs(obj): + def add_futures_and_states_to_inputs(obj: Any) -> None: if isinstance(obj, future_cls) and hasattr(obj, "task_run_id"): inputs.add(TaskRunResult(id=obj.task_run_id)) elif isinstance(obj, State): @@ -162,12 +146,14 @@ def add_futures_and_states_to_inputs(obj): async def wait_for_task_runs_and_report_crashes( - task_run_futures: Iterable[PrefectFuture], client: "PrefectClient" + task_run_futures: Iterable[PrefectFuture[Any]], client: "PrefectClient" ) -> Literal[True]: crash_exceptions = [] # Gather states concurrently first - states = await gather(*(future._wait for future in task_run_futures)) + states: list[State[Any]] = await gather( + *(future._wait for future in task_run_futures) + ) for future, state in zip(task_run_futures, states): logger = task_run_logger(future.task_run) @@ -241,8 +227,8 @@ def cancel_flow_run(*args): async def resolve_inputs( - parameters: Dict[str, Any], return_data: bool = True, max_depth: int = -1 -) -> Dict[str, Any]: + parameters: dict[str, Any], return_data: bool = True, max_depth: int = -1 +) -> dict[str, Any]: """ Resolve any `Quote`, `PrefectFuture`, or `State` types nested in parameters into data. @@ -664,7 +650,7 @@ def should_log_prints(flow_or_task: Union[Flow, Task]) -> bool: return flow_or_task.log_prints -def _resolve_custom_flow_run_name(flow: Flow, parameters: Dict[str, Any]) -> str: +def _resolve_custom_flow_run_name(flow: Flow, parameters: dict[str, Any]) -> str: if callable(flow.flow_run_name): flow_run_name = flow.flow_run_name() if not isinstance(flow_run_name, str): @@ -683,7 +669,7 @@ def _resolve_custom_flow_run_name(flow: Flow, parameters: Dict[str, Any]) -> str return flow_run_name -def _resolve_custom_task_run_name(task: Task, parameters: Dict[str, Any]) -> str: +def _resolve_custom_task_run_name(task: Task, parameters: dict[str, Any]) -> str: if callable(task.task_run_name): sig = inspect.signature(task.task_run_name) @@ -884,8 +870,8 @@ def resolve_to_final_result(expr, context): def resolve_inputs_sync( - parameters: Dict[str, Any], return_data: bool = True, max_depth: int = -1 -) -> Dict[str, Any]: + parameters: dict[str, Any], return_data: bool = True, max_depth: int = -1 +) -> dict[str, Any]: """ Resolve any `Quote`, `PrefectFuture`, or `State` types nested in parameters into data. diff --git a/src/prefect/utilities/names.py b/src/prefect/utilities/names.py index 6be2b93ab707..baeb2b1b6475 100644 --- a/src/prefect/utilities/names.py +++ b/src/prefect/utilities/names.py @@ -1,6 +1,6 @@ from typing import Any -import coolname +import coolname # type: ignore # the version after coolname 2.2.0 should have stubs. OBFUSCATED_PREFIX = "****" @@ -42,7 +42,7 @@ def generate_slug(n_words: int) -> str: return "-".join(words) -def obfuscate(s: Any, show_tail=False) -> str: +def obfuscate(s: Any, show_tail: bool = False) -> str: """ Obfuscates any data type's string representation. See `obfuscate_string`. """ @@ -52,7 +52,7 @@ def obfuscate(s: Any, show_tail=False) -> str: return obfuscate_string(str(s), show_tail=show_tail) -def obfuscate_string(s: str, show_tail=False) -> str: +def obfuscate_string(s: str, show_tail: bool = False) -> str: """ Obfuscates a string by returning a new string of 8 characters. If the input string is longer than 10 characters and show_tail is True, then up to 4 of diff --git a/src/prefect/utilities/schema_tools/__init__.py b/src/prefect/utilities/schema_tools/__init__.py index 1e6e73fc372a..eca02737cca9 100644 --- a/src/prefect/utilities/schema_tools/__init__.py +++ b/src/prefect/utilities/schema_tools/__init__.py @@ -12,5 +12,6 @@ "HydrationError", "ValidationError", "hydrate", + "is_valid_schema", "validate", ] diff --git a/src/prefect/utilities/schema_tools/hydration.py b/src/prefect/utilities/schema_tools/hydration.py index 49bd1bc8b33d..91269b3b8b57 100644 --- a/src/prefect/utilities/schema_tools/hydration.py +++ b/src/prefect/utilities/schema_tools/hydration.py @@ -1,10 +1,12 @@ import json -from typing import Any, Callable, Dict, Optional +from abc import ABC, abstractmethod +from collections.abc import Callable +from typing import Any, Optional, cast import jinja2 from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession -from typing_extensions import TypeAlias +from typing_extensions import Self, TypeAlias, TypeIs from prefect.server.utilities.user_templates import ( TemplateSecurityError, @@ -15,14 +17,14 @@ class HydrationContext(BaseModel): - workspace_variables: Dict[ + workspace_variables: dict[ str, StrictVariableValue, ] = Field(default_factory=dict) render_workspace_variables: bool = Field(default=False) raise_on_error: bool = Field(default=False) render_jinja: bool = Field(default=False) - jinja_context: Dict[str, Any] = Field(default_factory=dict) + jinja_context: dict[str, Any] = Field(default_factory=dict) @classmethod async def build( @@ -31,7 +33,7 @@ async def build( raise_on_error: bool = False, render_jinja: bool = False, render_workspace_variables: bool = False, - ) -> "HydrationContext": + ) -> Self: from prefect.server.models.variables import read_variables if render_workspace_variables: @@ -51,14 +53,14 @@ async def build( ) -Handler: TypeAlias = Callable[[dict, HydrationContext], Any] +Handler: TypeAlias = Callable[[dict[str, Any], HydrationContext], Any] PrefectKind: TypeAlias = Optional[str] -_handlers: Dict[PrefectKind, Handler] = {} +_handlers: dict[PrefectKind, Handler] = {} class Placeholder: - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, type(self)) @property @@ -70,11 +72,11 @@ class RemoveValue(Placeholder): pass -def _remove_value(value) -> bool: +def _remove_value(value: Any) -> TypeIs[RemoveValue]: return isinstance(value, RemoveValue) -class HydrationError(Placeholder, Exception): +class HydrationError(Placeholder, Exception, ABC): def __init__(self, detail: Optional[str] = None): self.detail = detail @@ -83,47 +85,49 @@ def is_error(self) -> bool: return True @property - def message(self): + @abstractmethod + def message(self) -> str: raise NotImplementedError("Must be implemented by subclass") - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, type(self)) and self.message == other.message - def __str__(self): + def __str__(self) -> str: return self.message class KeyNotFound(HydrationError): @property - def message(self): + def message(self) -> str: return f"Missing '{self.key}' key in __prefect object" @property + @abstractmethod def key(self) -> str: raise NotImplementedError("Must be implemented by subclass") class ValueNotFound(KeyNotFound): @property - def key(self): + def key(self) -> str: return "value" class TemplateNotFound(KeyNotFound): @property - def key(self): + def key(self) -> str: return "template" class VariableNameNotFound(KeyNotFound): @property - def key(self): + def key(self) -> str: return "variable_name" class InvalidJSON(HydrationError): @property - def message(self): + def message(self) -> str: message = "Invalid JSON" if self.detail: message += f": {self.detail}" @@ -132,7 +136,7 @@ def message(self): class InvalidJinja(HydrationError): @property - def message(self): + def message(self) -> str: message = "Invalid jinja" if self.detail: message += f": {self.detail}" @@ -146,29 +150,29 @@ def variable_name(self) -> str: return self.detail @property - def message(self): + def message(self) -> str: return f"Variable '{self.detail}' not found in workspace." class WorkspaceVariable(Placeholder): - def __init__(self, variable_name: str): + def __init__(self, variable_name: str) -> None: self.variable_name = variable_name - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return ( isinstance(other, type(self)) and self.variable_name == other.variable_name ) class ValidJinja(Placeholder): - def __init__(self, template: str): + def __init__(self, template: str) -> None: self.template = template - def __eq__(self, other): + def __eq__(self, other: Any): return isinstance(other, type(self)) and self.template == other.template -def handler(kind: PrefectKind) -> Callable: +def handler(kind: PrefectKind) -> Callable[[Handler], Handler]: def decorator(func: Handler) -> Handler: _handlers[kind] = func return func @@ -176,9 +180,9 @@ def decorator(func: Handler) -> Handler: return decorator -def call_handler(kind: PrefectKind, obj: dict, ctx: HydrationContext) -> Any: +def call_handler(kind: PrefectKind, obj: dict[str, Any], ctx: HydrationContext) -> Any: if kind not in _handlers: - return (obj or {}).get("value", None) + return obj.get("value", None) res = _handlers[kind](obj, ctx) if ctx.raise_on_error and isinstance(res, HydrationError): @@ -187,7 +191,7 @@ def call_handler(kind: PrefectKind, obj: dict, ctx: HydrationContext) -> Any: @handler("none") -def null_handler(obj: dict, ctx: HydrationContext): +def null_handler(obj: dict[str, Any], ctx: HydrationContext): if "value" in obj: # null handler is a pass through, so we want to continue to hydrate return _hydrate(obj["value"], ctx) @@ -196,7 +200,7 @@ def null_handler(obj: dict, ctx: HydrationContext): @handler("json") -def json_handler(obj: dict, ctx: HydrationContext): +def json_handler(obj: dict[str, Any], ctx: HydrationContext): if "value" in obj: if isinstance(obj["value"], dict): dehydrated_json = _hydrate(obj["value"], ctx) @@ -222,7 +226,7 @@ def json_handler(obj: dict, ctx: HydrationContext): @handler("jinja") -def jinja_handler(obj: dict, ctx: HydrationContext): +def jinja_handler(obj: dict[str, Any], ctx: HydrationContext) -> Any: if "template" in obj: if isinstance(obj["template"], dict): dehydrated_jinja = _hydrate(obj["template"], ctx) @@ -247,7 +251,7 @@ def jinja_handler(obj: dict, ctx: HydrationContext): @handler("workspace_variable") -def workspace_variable_handler(obj: dict, ctx: HydrationContext): +def workspace_variable_handler(obj: dict[str, Any], ctx: HydrationContext) -> Any: if "variable_name" in obj: if isinstance(obj["variable_name"], dict): dehydrated_variable = _hydrate(obj["variable_name"], ctx) @@ -277,35 +281,36 @@ def workspace_variable_handler(obj: dict, ctx: HydrationContext): return RemoveValue() -def hydrate(obj: dict, ctx: Optional[HydrationContext] = None): - res = _hydrate(obj, ctx) +def hydrate( + obj: dict[str, Any], ctx: Optional[HydrationContext] = None +) -> dict[str, Any]: + res: dict[str, Any] = _hydrate(obj, ctx) if _remove_value(res): - return {} + res = {} return res -def _hydrate(obj, ctx: Optional[HydrationContext] = None) -> Any: +def _hydrate(obj: Any, ctx: Optional[HydrationContext] = None) -> Any: if ctx is None: ctx = HydrationContext() - prefect_object = isinstance(obj, dict) and "__prefect_kind" in obj - - if prefect_object: - prefect_kind = obj.get("__prefect_kind") - return call_handler(prefect_kind, obj, ctx) + if isinstance(obj, dict) and "__prefect_kind" in obj: + obj_dict: dict[str, Any] = obj + prefect_kind = obj_dict["__prefect_kind"] + return call_handler(prefect_kind, obj_dict, ctx) else: if isinstance(obj, dict): return { key: hydrated_value - for key, value in obj.items() + for key, value in cast(dict[str, Any], obj).items() if not _remove_value(hydrated_value := _hydrate(value, ctx)) } elif isinstance(obj, list): return [ hydrated_element - for element in obj + for element in cast(list[Any], obj) if not _remove_value(hydrated_element := _hydrate(element, ctx)) ] else: diff --git a/src/prefect/utilities/schema_tools/validation.py b/src/prefect/utilities/schema_tools/validation.py index e94204331125..58c8298e7c2e 100644 --- a/src/prefect/utilities/schema_tools/validation.py +++ b/src/prefect/utilities/schema_tools/validation.py @@ -1,14 +1,19 @@ from collections import defaultdict, deque +from collections.abc import Callable, Iterable, Iterator from copy import deepcopy -from typing import Any, Dict, List +from typing import TYPE_CHECKING, Any, cast import jsonschema from jsonschema.exceptions import ValidationError as JSONSchemaValidationError from jsonschema.validators import Draft202012Validator, create +from referencing.jsonschema import ObjectSchema, Schema from prefect.utilities.collections import remove_nested_keys from prefect.utilities.schema_tools.hydration import HydrationError, Placeholder +if TYPE_CHECKING: + from jsonschema.validators import _Validator # type: ignore + class CircularSchemaRefError(Exception): pass @@ -21,12 +26,16 @@ class ValidationError(Exception): PLACEHOLDERS_VALIDATOR_NAME = "_placeholders" -def _build_validator(): - def _applicable_validators(schema): +def _build_validator() -> type["_Validator"]: + def _applicable_validators(schema: Schema) -> Iterable[tuple[str, Any]]: # the default implementation returns `schema.items()` - return {**schema, PLACEHOLDERS_VALIDATOR_NAME: None}.items() + assert not isinstance(schema, bool) + schema = {**schema, PLACEHOLDERS_VALIDATOR_NAME: None} + return schema.items() - def _placeholders(validator, _, instance, schema): + def _placeholders( + _validator: "_Validator", _property: object, instance: Any, _schema: Schema + ) -> Iterator[JSONSchemaValidationError]: if isinstance(instance, HydrationError): yield JSONSchemaValidationError(instance.message) @@ -43,7 +52,9 @@ def _placeholders(validator, _, instance, schema): version="prefect", type_checker=Draft202012Validator.TYPE_CHECKER, format_checker=Draft202012Validator.FORMAT_CHECKER, - id_of=Draft202012Validator.ID_OF, + id_of=cast( # the stub for create() is wrong here; id_of accepts (Schema) -> str | None + Callable[[Schema], str], Draft202012Validator.ID_OF + ), applicable_validators=_applicable_validators, ) @@ -51,24 +62,23 @@ def _placeholders(validator, _, instance, schema): _VALIDATOR = _build_validator() -def is_valid_schema(schema: Dict, preprocess: bool = True): +def is_valid_schema(schema: ObjectSchema, preprocess: bool = True) -> None: if preprocess: schema = preprocess_schema(schema) try: - if schema is not None: - _VALIDATOR.check_schema(schema, format_checker=_VALIDATOR.FORMAT_CHECKER) + _VALIDATOR.check_schema(schema, format_checker=_VALIDATOR.FORMAT_CHECKER) except jsonschema.SchemaError as exc: raise ValueError(f"Invalid schema: {exc.message}") from exc def validate( - obj: Dict, - schema: Dict, + obj: dict[str, Any], + schema: ObjectSchema, raise_on_error: bool = False, preprocess: bool = True, ignore_required: bool = False, allow_none_with_default: bool = False, -) -> List[JSONSchemaValidationError]: +) -> list[JSONSchemaValidationError]: if preprocess: schema = preprocess_schema(schema, allow_none_with_default) @@ -93,28 +103,27 @@ def validate( else: try: validator = _VALIDATOR(schema, format_checker=_VALIDATOR.FORMAT_CHECKER) - errors = list(validator.iter_errors(obj)) + errors = list(validator.iter_errors(obj)) # type: ignore except RecursionError: raise CircularSchemaRefError return errors -def is_valid( - obj: Dict, - schema: Dict, -) -> bool: +def is_valid(obj: dict[str, Any], schema: ObjectSchema) -> bool: errors = validate(obj, schema) - return len(errors) == 0 + return not errors -def prioritize_placeholder_errors(errors): - errors_by_path = defaultdict(list) +def prioritize_placeholder_errors( + errors: list[JSONSchemaValidationError], +) -> list[JSONSchemaValidationError]: + errors_by_path: dict[str, list[JSONSchemaValidationError]] = defaultdict(list) for error in errors: path_str = "->".join(str(p) for p in error.relative_path) errors_by_path[path_str].append(error) - filtered_errors = [] - for path, grouped_errors in errors_by_path.items(): + filtered_errors: list[JSONSchemaValidationError] = [] + for grouped_errors in errors_by_path.values(): placeholders_errors = [ error for error in grouped_errors @@ -129,8 +138,8 @@ def prioritize_placeholder_errors(errors): return filtered_errors -def build_error_obj(errors: List[JSONSchemaValidationError]) -> Dict: - error_response: Dict[str, Any] = {"errors": []} +def build_error_obj(errors: list[JSONSchemaValidationError]) -> dict[str, Any]: + error_response: dict[str, Any] = {"errors": []} # If multiple errors are present for the same path and one of them # is a placeholder error, we want only want to use the placeholder error. @@ -145,11 +154,11 @@ def build_error_obj(errors: List[JSONSchemaValidationError]) -> Dict: # Required errors should be moved one level down to the property # they're associated with, so we add an extra level to the path. - if error.validator == "required": - required_field = error.message.split(" ")[0].strip("'") + if error.validator == "required": # type: ignore + required_field = error.message.partition(" ")[0].strip("'") path.append(required_field) - current = error_response["errors"] + current: list[Any] = error_response["errors"] # error at the root, just append the error message if not path: @@ -163,10 +172,10 @@ def build_error_obj(errors: List[JSONSchemaValidationError]) -> Dict: else: for entry in current: if entry.get("index") == part: - current = entry["errors"] + current = cast(list[Any], entry["errors"]) break else: - new_entry = {"index": part, "errors": []} + new_entry: dict[str, Any] = {"index": part, "errors": []} current.append(new_entry) current = new_entry["errors"] else: @@ -182,7 +191,7 @@ def build_error_obj(errors: List[JSONSchemaValidationError]) -> Dict: current.append(new_entry) current = new_entry["errors"] - valid = len(error_response["errors"]) == 0 + valid = not bool(error_response["errors"]) error_response["valid"] = valid return error_response @@ -190,10 +199,10 @@ def build_error_obj(errors: List[JSONSchemaValidationError]) -> Dict: def _fix_null_typing( key: str, - schema: Dict, - required_fields: List[str], + schema: dict[str, Any], + required_fields: list[str], allow_none_with_default: bool = False, -): +) -> None: """ Pydantic V1 does not generate a valid Draft2020-12 schema for null types. """ @@ -207,7 +216,7 @@ def _fix_null_typing( del schema["type"] -def _fix_tuple_items(schema: Dict): +def _fix_tuple_items(schema: dict[str, Any]) -> None: """ Pydantic V1 does not generate a valid Draft2020-12 schema for tuples. """ @@ -216,13 +225,13 @@ def _fix_tuple_items(schema: Dict): and isinstance(schema["items"], list) and not schema.get("prefixItems") ): - schema["prefixItems"] = deepcopy(schema["items"]) + schema["prefixItems"] = deepcopy(cast(list[Any], schema["items"])) del schema["items"] def process_properties( - properties: Dict, - required_fields: List[str], + properties: dict[str, dict[str, Any]], + required_fields: list[str], allow_none_with_default: bool = False, ): for key, schema in properties.items(): @@ -235,9 +244,9 @@ def process_properties( def preprocess_schema( - schema: Dict, + schema: ObjectSchema, allow_none_with_default: bool = False, -): +) -> ObjectSchema: schema = deepcopy(schema) if "properties" in schema: @@ -247,7 +256,8 @@ def preprocess_schema( ) if "definitions" in schema: # Also process definitions for reused models - for definition in (schema["definitions"] or {}).values(): + definitions = cast(dict[str, Any], schema["definitions"]) + for definition in definitions.values(): if "properties" in definition: required_fields = definition.get("required", []) process_properties(