diff --git a/src/prefect/_internal/concurrency/api.py b/src/prefect/_internal/concurrency/api.py index bcfa6ae189db..f263e61b6def 100644 --- a/src/prefect/_internal/concurrency/api.py +++ b/src/prefect/_internal/concurrency/api.py @@ -6,50 +6,46 @@ import asyncio import concurrent.futures import contextlib -from typing import ( - Any, - Awaitable, - Callable, - ContextManager, - Iterable, - Optional, - TypeVar, - Union, -) +from collections.abc import Awaitable, Iterable +from contextlib import AbstractContextManager +from typing import Any, Callable, Optional, Union, cast -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, TypeAlias, TypeVar from prefect._internal.concurrency.threads import ( WorkerThread, get_global_loop, in_global_loop, ) -from prefect._internal.concurrency.waiters import ( - AsyncWaiter, - Call, - SyncWaiter, -) +from prefect._internal.concurrency.waiters import AsyncWaiter, Call, SyncWaiter P = ParamSpec("P") -T = TypeVar("T") +T = TypeVar("T", infer_variance=True) Future = Union[concurrent.futures.Future[T], asyncio.Future[T]] +_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[T, Awaitable[T]]] -def create_call(__fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Call[T]: + +def create_call( + __fn: _SyncOrAsyncCallable[P, T], *args: P.args, **kwargs: P.kwargs +) -> Call[T]: return Call[T].new(__fn, *args, **kwargs) -def _cast_to_call(call_like: Union[Callable[[], T], Call[T]]) -> Call[T]: +def cast_to_call( + call_like: Union["_SyncOrAsyncCallable[[], T]", Call[T]], +) -> Call[T]: if isinstance(call_like, Call): - return call_like + return cast(Call[T], call_like) else: return create_call(call_like) class _base(abc.ABC): - @abc.abstractstaticmethod + @staticmethod + @abc.abstractmethod def wait_for_call_in_loop_thread( - __call: Union[Callable[[], T], Call[T]], # type: ignore[reportGeneralTypeIssues] + __call: Union["_SyncOrAsyncCallable[[], Any]", Call[T]], timeout: Optional[float] = None, done_callbacks: Optional[Iterable[Call[Any]]] = None, ) -> T: @@ -60,9 +56,10 @@ def wait_for_call_in_loop_thread( """ raise NotImplementedError() - @abc.abstractstaticmethod + @staticmethod + @abc.abstractmethod def wait_for_call_in_new_thread( - __call: Union[Callable[[], T], Call[T]], # type: ignore[reportGeneralTypeIssues] + __call: Union["_SyncOrAsyncCallable[[], T]", Call[T]], timeout: Optional[float] = None, done_callbacks: Optional[Iterable[Call[Any]]] = None, ) -> T: @@ -75,14 +72,15 @@ def wait_for_call_in_new_thread( @staticmethod def call_soon_in_new_thread( - __call: Union[Callable[[], T], Call[T]], timeout: Optional[float] = None + __call: Union["_SyncOrAsyncCallable[[], T]", Call[T]], + timeout: Optional[float] = None, ) -> Call[T]: """ Schedule a call for execution in a new worker thread. Returns the submitted call. """ - call = _cast_to_call(__call) + call = cast_to_call(__call) runner = WorkerThread(run_once=True) call.set_timeout(timeout) runner.submit(call) @@ -90,7 +88,7 @@ def call_soon_in_new_thread( @staticmethod def call_soon_in_loop_thread( - __call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]], + __call: Union["_SyncOrAsyncCallable[[], T]", Call[T]], timeout: Optional[float] = None, ) -> Call[T]: """ @@ -98,7 +96,7 @@ def call_soon_in_loop_thread( Returns the submitted call. """ - call = _cast_to_call(__call) + call = cast_to_call(__call) runner = get_global_loop() call.set_timeout(timeout) runner.submit(call) @@ -117,7 +115,7 @@ def call_in_new_thread( @staticmethod def call_in_loop_thread( - __call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]], + __call: Union[Callable[[], Awaitable[T]], Call[T]], timeout: Optional[float] = None, ) -> T: """ @@ -131,12 +129,12 @@ def call_in_loop_thread( class from_async(_base): @staticmethod async def wait_for_call_in_loop_thread( - __call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]], + __call: Union[Callable[[], Awaitable[T]], Call[T]], timeout: Optional[float] = None, done_callbacks: Optional[Iterable[Call[Any]]] = None, - contexts: Optional[Iterable[ContextManager[Any]]] = None, - ) -> Awaitable[T]: - call = _cast_to_call(__call) + contexts: Optional[Iterable[AbstractContextManager[Any]]] = None, + ) -> T: + call = cast_to_call(__call) waiter = AsyncWaiter(call) for callback in done_callbacks or []: waiter.add_done_callback(callback) @@ -153,7 +151,7 @@ async def wait_for_call_in_new_thread( timeout: Optional[float] = None, done_callbacks: Optional[Iterable[Call[Any]]] = None, ) -> T: - call = _cast_to_call(__call) + call = cast_to_call(__call) waiter = AsyncWaiter(call=call) for callback in done_callbacks or []: waiter.add_done_callback(callback) @@ -170,7 +168,7 @@ def call_in_new_thread( @staticmethod def call_in_loop_thread( - __call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]], + __call: Union[Callable[[], Awaitable[T]], Call[T]], timeout: Optional[float] = None, ) -> Awaitable[T]: call = _base.call_soon_in_loop_thread(__call, timeout=timeout) @@ -182,13 +180,13 @@ class from_sync(_base): def wait_for_call_in_loop_thread( __call: Union[ Callable[[], Awaitable[T]], - Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]], + Call[T], ], timeout: Optional[float] = None, - done_callbacks: Optional[Iterable[Call]] = None, - contexts: Optional[Iterable[ContextManager]] = None, - ) -> Awaitable[T]: - call = _cast_to_call(__call) + done_callbacks: Optional[Iterable[Call[T]]] = None, + contexts: Optional[Iterable[AbstractContextManager[Any]]] = None, + ) -> T: + call = cast_to_call(__call) waiter = SyncWaiter(call) _base.call_soon_in_loop_thread(call, timeout=timeout) for callback in done_callbacks or []: @@ -203,9 +201,9 @@ def wait_for_call_in_loop_thread( def wait_for_call_in_new_thread( __call: Union[Callable[[], T], Call[T]], timeout: Optional[float] = None, - done_callbacks: Optional[Iterable[Call]] = None, - ) -> Call[T]: - call = _cast_to_call(__call) + done_callbacks: Optional[Iterable[Call[T]]] = None, + ) -> T: + call = cast_to_call(__call) waiter = SyncWaiter(call=call) for callback in done_callbacks or []: waiter.add_done_callback(callback) @@ -215,20 +213,21 @@ def wait_for_call_in_new_thread( @staticmethod def call_in_new_thread( - __call: Union[Callable[[], T], Call[T]], timeout: Optional[float] = None + __call: Union["_SyncOrAsyncCallable[[], T]", Call[T]], + timeout: Optional[float] = None, ) -> T: call = _base.call_soon_in_new_thread(__call, timeout=timeout) return call.result() @staticmethod def call_in_loop_thread( - __call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]], + __call: Union["_SyncOrAsyncCallable[[], T]", Call[T]], timeout: Optional[float] = None, - ) -> T: + ) -> Union[Awaitable[T], T]: if in_global_loop(): # Avoid deadlock where the call is submitted to the loop then the loop is # blocked waiting for the call - call = _cast_to_call(__call) + call = cast_to_call(__call) return call() call = _base.call_soon_in_loop_thread(__call, timeout=timeout) diff --git a/src/prefect/_internal/concurrency/calls.py b/src/prefect/_internal/concurrency/calls.py index 5e8b675bd23e..4a715ac90491 100644 --- a/src/prefect/_internal/concurrency/calls.py +++ b/src/prefect/_internal/concurrency/calls.py @@ -12,18 +12,20 @@ import inspect import threading import weakref +from collections.abc import Awaitable, Generator from concurrent.futures._base import ( CANCELLED, CANCELLED_AND_NOTIFIED, FINISHED, RUNNING, ) -from typing import Any, Awaitable, Callable, Dict, Generic, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, Union -from typing_extensions import ParamSpec +from typing_extensions import ParamSpec, Self, TypeAlias, TypeVar, TypeVarTuple from prefect._internal.concurrency import logger from prefect._internal.concurrency.cancellation import ( + AsyncCancelScope, CancelledError, cancel_async_at, cancel_sync_at, @@ -31,9 +33,13 @@ ) from prefect._internal.concurrency.event_loop import get_running_loop -T = TypeVar("T") +T = TypeVar("T", infer_variance=True) +Ts = TypeVarTuple("Ts") P = ParamSpec("P") +_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[T, Awaitable[T]]] + + # Tracks the current call being executed. Note that storing the `Call` # object for an async call directly in the contextvar appears to create a # memory leak, despite the fact that we `reset` when leaving the context @@ -41,16 +47,16 @@ # we already have strong references to the `Call` objects in other places # and b) this is used for performance optimizations where we have fallback # behavior if this weakref is garbage collected. A fix for issue #10952. -current_call: contextvars.ContextVar["weakref.ref[Call]"] = ( # novm +current_call: contextvars.ContextVar["weakref.ref[Call[Any]]"] = ( # novm contextvars.ContextVar("current_call") ) # Create a strong reference to tasks to prevent destruction during execution errors -_ASYNC_TASK_REFS = set() +_ASYNC_TASK_REFS: set[asyncio.Task[None]] = set() @contextlib.contextmanager -def set_current_call(call: "Call"): +def set_current_call(call: "Call[Any]") -> Generator[None, Any, None]: token = current_call.set(weakref.ref(call)) try: yield @@ -58,7 +64,7 @@ def set_current_call(call: "Call"): current_call.reset(token) -class Future(concurrent.futures.Future): +class Future(concurrent.futures.Future[T]): """ Extension of `concurrent.futures.Future` with support for cancellation of running futures. @@ -70,7 +76,7 @@ def __init__(self, name: Optional[str] = None) -> None: super().__init__() self._cancel_scope = None self._deadline = None - self._cancel_callbacks = [] + self._cancel_callbacks: list[Callable[[], None]] = [] self._name = name self._timed_out = False @@ -79,7 +85,7 @@ def set_running_or_notify_cancel(self, timeout: Optional[float] = None): return super().set_running_or_notify_cancel() @contextlib.contextmanager - def enforce_async_deadline(self): + def enforce_async_deadline(self) -> Generator[AsyncCancelScope]: with cancel_async_at(self._deadline, name=self._name) as self._cancel_scope: for callback in self._cancel_callbacks: self._cancel_scope.add_cancel_callback(callback) @@ -92,7 +98,7 @@ def enforce_sync_deadline(self): self._cancel_scope.add_cancel_callback(callback) yield self._cancel_scope - def add_cancel_callback(self, callback: Callable[[], None]): + def add_cancel_callback(self, callback: Callable[[], Any]) -> None: """ Add a callback to be enforced on cancellation. @@ -113,7 +119,7 @@ def timedout(self) -> bool: with self._condition: return self._timed_out - def cancel(self): + def cancel(self) -> bool: """Cancel the future if possible. Returns True if the future was cancelled, False otherwise. A future cannot be @@ -147,7 +153,12 @@ def cancel(self): self._invoke_callbacks() return True - def result(self, timeout=None): + if TYPE_CHECKING: + + def __get_result(self) -> T: + ... + + def result(self, timeout: Optional[float] = None) -> T: """Return the result of the call that the future represents. Args: @@ -186,7 +197,9 @@ def result(self, timeout=None): # Break a reference cycle with the exception in self._exception self = None - def _invoke_callbacks(self): + _done_callbacks: list[Callable[[Self], object]] + + def _invoke_callbacks(self) -> None: """ Invoke our done callbacks and clean up cancel scopes and cancel callbacks. Fixes a memory leak that hung on to Call objects, @@ -206,7 +219,7 @@ def _invoke_callbacks(self): self._cancel_callbacks = [] if self._cancel_scope: - self._cancel_scope._callbacks = [] + setattr(self._cancel_scope, "_callbacks", []) self._cancel_scope = None @@ -216,16 +229,21 @@ class Call(Generic[T]): A deferred function call. """ - future: Future - fn: Callable[..., T] - args: Tuple - kwargs: Dict[str, Any] + future: Future[T] + fn: "_SyncOrAsyncCallable[..., T]" + args: tuple[Any, ...] + kwargs: dict[str, Any] context: contextvars.Context - timeout: float + timeout: Optional[float] runner: Optional["Portal"] = None @classmethod - def new(cls, __fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> "Call[T]": + def new( + cls, + __fn: _SyncOrAsyncCallable[P, T], + *args: P.args, + **kwargs: P.kwargs, + ) -> Self: return cls( future=Future(name=getattr(__fn, "__name__", str(__fn))), fn=__fn, @@ -255,7 +273,7 @@ def set_runner(self, portal: "Portal") -> None: self.runner = portal - def run(self) -> Optional[Awaitable[T]]: + def run(self) -> Optional[Awaitable[None]]: """ Execute the call and place the result on the future. @@ -337,7 +355,7 @@ def timedout(self) -> bool: def cancel(self) -> bool: return self.future.cancel() - def _run_sync(self): + def _run_sync(self) -> Optional[Awaitable[T]]: cancel_scope = None try: with set_current_call(self): @@ -348,8 +366,8 @@ def _run_sync(self): # Forget this call's arguments in order to free up any memory # that may be referenced by them; after a call has happened, # there's no need to keep a reference to them - self.args = None - self.kwargs = None + with contextlib.suppress(AttributeError): + del self.args, self.kwargs # Return the coroutine for async execution if inspect.isawaitable(result): @@ -357,8 +375,10 @@ def _run_sync(self): except CancelledError: # Report cancellation + if TYPE_CHECKING: + assert cancel_scope is not None if cancel_scope.timedout(): - self.future._timed_out = True + setattr(self.future, "_timed_out", True) self.future.cancel() elif cancel_scope.cancelled(): self.future.cancel() @@ -374,8 +394,8 @@ def _run_sync(self): self.future.set_result(result) # noqa: F821 logger.debug("Finished call %r", self) # noqa: F821 - async def _run_async(self, coro): - cancel_scope = None + async def _run_async(self, coro: Awaitable[T]) -> None: + cancel_scope = result = None try: with set_current_call(self): with self.future.enforce_async_deadline() as cancel_scope: @@ -385,12 +405,14 @@ async def _run_async(self, coro): # Forget this call's arguments in order to free up any memory # that may be referenced by them; after a call has happened, # there's no need to keep a reference to them - self.args = None - self.kwargs = None + with contextlib.suppress(AttributeError): + del self.args, self.kwargs except CancelledError: # Report cancellation + if TYPE_CHECKING: + assert cancel_scope is not None if cancel_scope.timedout(): - self.future._timed_out = True + setattr(self.future, "_timed_out", True) self.future.cancel() elif cancel_scope.cancelled(): self.future.cancel() @@ -403,10 +425,11 @@ async def _run_async(self, coro): # Prevent reference cycle in `exc` del self else: + # F821 ignored because Ruff gets confused about the del self above. self.future.set_result(result) # noqa: F821 logger.debug("Finished async call %r", self) # noqa: F821 - def __call__(self) -> T: + def __call__(self) -> Union[T, Awaitable[T]]: """ Execute the call and return its result. @@ -417,7 +440,7 @@ def __call__(self) -> T: # Return an awaitable if in an async context if coro is not None: - async def run_and_return_result(): + async def run_and_return_result() -> T: await coro return self.result() @@ -428,8 +451,9 @@ async def run_and_return_result(): def __repr__(self) -> str: name = getattr(self.fn, "__name__", str(self.fn)) - args, kwargs = self.args, self.kwargs - if args is None or kwargs is None: + try: + args, kwargs = self.args, self.kwargs + except AttributeError: call_args = "" else: call_args = ", ".join( @@ -450,7 +474,7 @@ class Portal(abc.ABC): """ @abc.abstractmethod - def submit(self, call: "Call") -> "Call": + def submit(self, call: "Call[T]") -> "Call[T]": """ Submit a call to execute elsewhere. diff --git a/src/prefect/_internal/concurrency/cancellation.py b/src/prefect/_internal/concurrency/cancellation.py index 372f016d1216..57564ba4c9f5 100644 --- a/src/prefect/_internal/concurrency/cancellation.py +++ b/src/prefect/_internal/concurrency/cancellation.py @@ -12,14 +12,15 @@ import sys import threading import time -from typing import Callable, Dict, Optional, Type +from types import TracebackType +from typing import TYPE_CHECKING, Any, Callable, Optional, overload import anyio from prefect._internal.concurrency import logger from prefect._internal.concurrency.event_loop import get_running_loop -_THREAD_SHIELDS: Dict[threading.Thread, "ThreadShield"] = {} +_THREAD_SHIELDS: dict[threading.Thread, "ThreadShield"] = {} _THREAD_SHIELDS_LOCK = threading.Lock() @@ -42,14 +43,14 @@ def __init__(self, owner: threading.Thread): # Uses the Python implementation of the RLock instead of the C implementation # because we need to inspect `_count` directly to check if the lock is active # which is needed for delayed exception raising during alarms - self._lock = threading._RLock() + self._lock = threading._RLock() # type: ignore # yes, we want the private version self._exception = None self._owner = owner def __enter__(self) -> None: self._lock.__enter__() - def __exit__(self, *exc_info): + def __exit__(self, *exc_info: Any): retval = self._lock.__exit__(*exc_info) # Raise the exception if this is the last shield to exit in the owner thread @@ -65,14 +66,14 @@ def __exit__(self, *exc_info): return retval - def set_exception(self, exc: Exception): + def set_exception(self, exc: BaseException): self._exception = exc def active(self) -> bool: """ Returns true if the shield is active. """ - return self._lock._count > 0 + return getattr(self._lock, "_count") > 0 class CancelledError(asyncio.CancelledError): @@ -82,7 +83,7 @@ class CancelledError(asyncio.CancelledError): pass -def _get_thread_shield(thread) -> ThreadShield: +def _get_thread_shield(thread: threading.Thread) -> ThreadShield: with _THREAD_SHIELDS_LOCK: if thread not in _THREAD_SHIELDS: _THREAD_SHIELDS[thread] = ThreadShield(thread) @@ -139,7 +140,7 @@ def __init__( self._end_time = None self._timeout = timeout self._lock = threading.Lock() - self._callbacks = [] + self._callbacks: list[Callable[[], None]] = [] super().__init__() def __enter__(self): @@ -151,7 +152,9 @@ def __enter__(self): logger.debug("%r entered", self) return self - def __exit__(self, *_): + def __exit__( + self, exc_type: type[BaseException], exc_val: Exception, exc_tb: TracebackType + ) -> Optional[bool]: with self._lock: if not self._cancelled: self._completed = True @@ -195,7 +198,7 @@ def cancel(self, throw: bool = True) -> bool: throw the cancelled error. """ with self._lock: - if not self.started: + if not self._started: raise RuntimeError("Scope has not been entered.") if self._completed: @@ -247,7 +250,6 @@ def __init__( self, name: Optional[str] = None, timeout: Optional[float] = None ) -> None: super().__init__(name=name, timeout=timeout) - self.loop = None def __enter__(self): self.loop = asyncio.get_running_loop() @@ -262,7 +264,9 @@ def __enter__(self): return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, exc_type: type[BaseException], exc_val: Exception, exc_tb: TracebackType + ) -> bool: if self._anyio_scope.cancel_called: # Mark as cancelled self.cancel(throw=False) @@ -310,7 +314,7 @@ def __init__( super().__init__(name, timeout) self.reason = reason or "null cancel scope" - def cancel(self): + def cancel(self, throw: bool = True) -> bool: logger.warning("%r cannot cancel %s.", self, self.reason) return False @@ -355,7 +359,7 @@ def __enter__(self): return self - def _sigalarm_to_error(self, *args): + def _sigalarm_to_error(self, *args: object) -> None: logger.debug("%r captured alarm raising as cancelled error", self) if self.cancel(throw=False): shield = _get_thread_shield(threading.main_thread()) @@ -365,11 +369,13 @@ def _sigalarm_to_error(self, *args): else: raise CancelledError() - def __exit__(self, *_): + def __exit__(self, *_: Any) -> Optional[bool]: retval = super().__exit__(*_) if self.timeout is not None: # Restore the previous timer + if TYPE_CHECKING: + assert self._previous_timer is not None signal.setitimer(signal.ITIMER_REAL, *self._previous_timer) # Restore the previous signal handler @@ -417,7 +423,7 @@ def __enter__(self): return self - def __exit__(self, *_): + def __exit__(self, *_: Any) -> Optional[bool]: retval = super().__exit__(*_) self._event.set() if self._enforcer_thread: @@ -466,7 +472,17 @@ def cancel(self, throw: bool = True): return True -def get_deadline(timeout: Optional[float]): +@overload +def get_deadline(timeout: float) -> float: + ... + + +@overload +def get_deadline(timeout: None) -> None: + ... + + +def get_deadline(timeout: Optional[float]) -> Optional[float]: """ Compute an deadline given a timeout. @@ -572,7 +588,7 @@ def cancel_sync_after(timeout: Optional[float], name: Optional[str] = None): yield scope -def _send_exception_to_thread(thread: threading.Thread, exc_type: Type[BaseException]): +def _send_exception_to_thread(thread: threading.Thread, exc_type: type[BaseException]): """ Raise an exception in a thread. diff --git a/src/prefect/_internal/concurrency/event_loop.py b/src/prefect/_internal/concurrency/event_loop.py index b3a8c0026056..5b690d53f01b 100644 --- a/src/prefect/_internal/concurrency/event_loop.py +++ b/src/prefect/_internal/concurrency/event_loop.py @@ -5,7 +5,8 @@ import asyncio import concurrent.futures import functools -from typing import Awaitable, Callable, Coroutine, Optional, TypeVar +from collections.abc import Coroutine +from typing import Any, Callable, Optional, TypeVar from typing_extensions import ParamSpec @@ -13,7 +14,7 @@ T = TypeVar("T") -def get_running_loop() -> Optional[asyncio.BaseEventLoop]: +def get_running_loop() -> Optional[asyncio.AbstractEventLoop]: """ Get the current running loop. @@ -30,7 +31,7 @@ def call_soon_in_loop( __fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs, -) -> concurrent.futures.Future: +) -> concurrent.futures.Future[T]: """ Run a synchronous call in an event loop's thread from another thread. @@ -38,7 +39,7 @@ def call_soon_in_loop( Returns a future that can be used to retrieve the result of the call. """ - future = concurrent.futures.Future() + future: concurrent.futures.Future[T] = concurrent.futures.Future() @functools.wraps(__fn) def wrapper() -> None: @@ -62,8 +63,8 @@ def wrapper() -> None: async def run_coroutine_in_loop_from_async( - __loop: asyncio.AbstractEventLoop, __coro: Coroutine -) -> Awaitable: + __loop: asyncio.AbstractEventLoop, __coro: Coroutine[Any, Any, T] +) -> T: """ Run an asynchronous call in an event loop from an asynchronous context. diff --git a/src/prefect/_internal/concurrency/threads.py b/src/prefect/_internal/concurrency/threads.py index 301cc386d513..c3f5e4589428 100644 --- a/src/prefect/_internal/concurrency/threads.py +++ b/src/prefect/_internal/concurrency/threads.py @@ -8,7 +8,9 @@ import itertools import queue import threading -from typing import List, Optional +from typing import Any, Optional + +from typing_extensions import TypeVar from prefect._internal.concurrency import logger from prefect._internal.concurrency.calls import Call, Portal @@ -16,6 +18,8 @@ from prefect._internal.concurrency.event_loop import get_running_loop from prefect._internal.concurrency.primitives import Event +T = TypeVar("T", infer_variance=True) + class WorkerThread(Portal): """ @@ -33,7 +37,7 @@ def __init__( self.thread = threading.Thread( name=name, daemon=daemon, target=self._entrypoint ) - self._queue = queue.Queue() + self._queue: queue.Queue[Optional[Call[Any]]] = queue.Queue() self._run_once: bool = run_once self._started: bool = False self._submitted_count: int = 0 @@ -42,7 +46,7 @@ def __init__( if not daemon: atexit.register(self.shutdown) - def start(self): + def start(self) -> None: """ Start the worker thread. """ @@ -51,7 +55,7 @@ def start(self): self._started = True self.thread.start() - def submit(self, call: Call) -> Call: + def submit(self, call: Call[T]) -> Call[T]: if self._submitted_count > 0 and self._run_once: raise RuntimeError( "Worker configured to only run once. A call has already been submitted." @@ -83,7 +87,7 @@ def shutdown(self) -> None: def name(self) -> str: return self.thread.name - def _entrypoint(self): + def _entrypoint(self) -> None: """ Entrypoint for the thread. """ @@ -129,12 +133,14 @@ def __init__( self.thread = threading.Thread( name=name, daemon=daemon, target=self._entrypoint ) - self._ready_future = concurrent.futures.Future() + self._ready_future: concurrent.futures.Future[ + bool + ] = concurrent.futures.Future() self._loop: Optional[asyncio.AbstractEventLoop] = None self._shutdown_event: Event = Event() self._run_once: bool = run_once self._submitted_count: int = 0 - self._on_shutdown: List[Call] = [] + self._on_shutdown: list[Call[Any]] = [] self._lock = threading.Lock() if not daemon: @@ -149,7 +155,7 @@ def start(self): self.thread.start() self._ready_future.result() - def submit(self, call: Call) -> Call: + def submit(self, call: Call[T]) -> Call[T]: if self._loop is None: self.start() @@ -167,6 +173,7 @@ def submit(self, call: Call) -> Call: call.set_runner(self) # Submit the call to the event loop + assert self._loop is not None asyncio.run_coroutine_threadsafe(self._run_call(call), self._loop) self._submitted_count += 1 @@ -180,15 +187,16 @@ def shutdown(self) -> None: Shutdown the worker thread. Does not wait for the thread to stop. """ with self._lock: - if self._shutdown_event is None: - return - self._shutdown_event.set() @property def name(self) -> str: return self.thread.name + @property + def running(self) -> bool: + return not self._shutdown_event.is_set() + def _entrypoint(self): """ Entrypoint for the thread. @@ -218,12 +226,12 @@ async def _run_until_shutdown(self): # Empty the list to allow calls to be garbage collected. Issue #10338. self._on_shutdown = [] - async def _run_call(self, call: Call) -> None: + async def _run_call(self, call: Call[Any]) -> None: task = call.run() if task is not None: await task - def add_shutdown_call(self, call: Call) -> None: + def add_shutdown_call(self, call: Call[Any]) -> None: self._on_shutdown.append(call) def __enter__(self): @@ -235,9 +243,9 @@ def __exit__(self, *_): # the GLOBAL LOOP is used for background services, like logs -GLOBAL_LOOP: Optional[EventLoopThread] = None +_global_loop: Optional[EventLoopThread] = None # the RUN SYNC LOOP is used exclusively for running async functions in a sync context via asyncutils.run_sync -RUN_SYNC_LOOP: Optional[EventLoopThread] = None +_run_sync_loop: Optional[EventLoopThread] = None def get_global_loop() -> EventLoopThread: @@ -246,29 +254,29 @@ def get_global_loop() -> EventLoopThread: Creates a new one if there is not one available. """ - global GLOBAL_LOOP + global _global_loop # Create a new worker on first call or if the existing worker is dead if ( - GLOBAL_LOOP is None - or not GLOBAL_LOOP.thread.is_alive() - or GLOBAL_LOOP._shutdown_event.is_set() + _global_loop is None + or not _global_loop.thread.is_alive() + or not _global_loop.running ): - GLOBAL_LOOP = EventLoopThread(daemon=True, name="GlobalEventLoopThread") - GLOBAL_LOOP.start() + _global_loop = EventLoopThread(daemon=True, name="GlobalEventLoopThread") + _global_loop.start() - return GLOBAL_LOOP + return _global_loop def in_global_loop() -> bool: """ Check if called from the global loop. """ - if GLOBAL_LOOP is None: + if _global_loop is None: # Avoid creating a global loop if there isn't one return False - return get_global_loop()._loop == get_running_loop() + return getattr(get_global_loop(), "_loop") == get_running_loop() def get_run_sync_loop() -> EventLoopThread: @@ -277,29 +285,29 @@ def get_run_sync_loop() -> EventLoopThread: Creates a new one if there is not one available. """ - global RUN_SYNC_LOOP + global _run_sync_loop # Create a new worker on first call or if the existing worker is dead if ( - RUN_SYNC_LOOP is None - or not RUN_SYNC_LOOP.thread.is_alive() - or RUN_SYNC_LOOP._shutdown_event.is_set() + _run_sync_loop is None + or not _run_sync_loop.thread.is_alive() + or not _run_sync_loop.running ): - RUN_SYNC_LOOP = EventLoopThread(daemon=True, name="RunSyncEventLoopThread") - RUN_SYNC_LOOP.start() + _run_sync_loop = EventLoopThread(daemon=True, name="RunSyncEventLoopThread") + _run_sync_loop.start() - return RUN_SYNC_LOOP + return _run_sync_loop def in_run_sync_loop() -> bool: """ Check if called from the global loop. """ - if RUN_SYNC_LOOP is None: + if _run_sync_loop is None: # Avoid creating a global loop if there isn't one return False - return get_run_sync_loop()._loop == get_running_loop() + return getattr(get_run_sync_loop(), "_loop") == get_running_loop() def wait_for_global_loop_exit(timeout: Optional[float] = None) -> None: diff --git a/src/prefect/_internal/concurrency/waiters.py b/src/prefect/_internal/concurrency/waiters.py index 3925e3b25691..07522992100d 100644 --- a/src/prefect/_internal/concurrency/waiters.py +++ b/src/prefect/_internal/concurrency/waiters.py @@ -10,7 +10,8 @@ import queue import threading from collections import deque -from typing import Awaitable, Generic, List, Optional, TypeVar, Union +from collections.abc import Awaitable +from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from weakref import WeakKeyDictionary import anyio @@ -24,12 +25,12 @@ # Waiters are stored in a stack for each thread -_WAITERS_BY_THREAD: "WeakKeyDictionary[threading.Thread, deque[Waiter]]" = ( +_WAITERS_BY_THREAD: "WeakKeyDictionary[threading.Thread, deque[Waiter[Any]]]" = ( WeakKeyDictionary() ) -def add_waiter_for_thread(waiter: "Waiter", thread: threading.Thread): +def add_waiter_for_thread(waiter: "Waiter[Any]", thread: threading.Thread) -> None: """ Add a waiter for a thread. """ @@ -62,7 +63,7 @@ def call_is_done(self) -> bool: return self._call.future.done() @abc.abstractmethod - def wait(self) -> Union[Awaitable[None], None]: + def wait(self) -> T: """ Wait for the call to finish. @@ -71,7 +72,7 @@ def wait(self) -> Union[Awaitable[None], None]: raise NotImplementedError() @abc.abstractmethod - def add_done_callback(self, callback: Call) -> Call: + def add_done_callback(self, callback: Call[Any]) -> None: """ Schedule a call to run when the waiter is done waiting. @@ -91,11 +92,11 @@ class SyncWaiter(Waiter[T]): def __init__(self, call: Call[T]) -> None: super().__init__(call=call) - self._queue: queue.Queue = queue.Queue() - self._done_callbacks = [] + self._queue: queue.Queue[Optional[Call[T]]] = queue.Queue() + self._done_callbacks: list[Call[Any]] = [] self._done_event = threading.Event() - def submit(self, call: Call): + def submit(self, call: Call[T]) -> Call[T]: """ Submit a callback to execute while waiting. """ @@ -109,7 +110,7 @@ def submit(self, call: Call): def _handle_waiting_callbacks(self): logger.debug("Waiter %r watching for callbacks", self) while True: - callback: Call = self._queue.get() + callback = self._queue.get() if callback is None: break @@ -130,13 +131,13 @@ def _handle_done_callbacks(self): if callback: callback.run() - def add_done_callback(self, callback: Call): + def add_done_callback(self, callback: Call[Any]) -> None: if self._done_event.is_set(): raise RuntimeError("Cannot add done callbacks to done waiters.") else: self._done_callbacks.append(callback) - def wait(self) -> T: + def wait(self) -> Call[T]: # Stop watching for work once the future is done self._call.future.add_done_callback(lambda _: self._queue.put_nowait(None)) self._call.future.add_done_callback(lambda _: self._done_event.set()) @@ -159,13 +160,13 @@ def __init__(self, call: Call[T]) -> None: # Delay instantiating loop and queue as there may not be a loop present yet self._loop: Optional[asyncio.AbstractEventLoop] = None - self._queue: Optional[asyncio.Queue] = None - self._early_submissions: List[Call] = [] - self._done_callbacks = [] + self._queue: Optional[asyncio.Queue[Optional[Call[T]]]] = None + self._early_submissions: list[Call[T]] = [] + self._done_callbacks: list[Call[Any]] = [] self._done_event = Event() self._done_waiting = False - def submit(self, call: Call): + def submit(self, call: Call[T]) -> Call[T]: """ Submit a callback to execute while waiting. """ @@ -180,11 +181,15 @@ def submit(self, call: Call): return call # We must put items in the queue from the event loop that owns it + if TYPE_CHECKING: + assert self._loop is not None call_soon_in_loop(self._loop, self._queue.put_nowait, call) return call - def _resubmit_early_submissions(self): - assert self._queue + def _resubmit_early_submissions(self) -> None: + if TYPE_CHECKING: + assert self._queue is not None + assert self._loop is not None for call in self._early_submissions: # We must put items in the queue from the event loop that owns it call_soon_in_loop(self._loop, self._queue.put_nowait, call) @@ -192,11 +197,11 @@ def _resubmit_early_submissions(self): async def _handle_waiting_callbacks(self): logger.debug("Waiter %r watching for callbacks", self) - tasks = [] + tasks: list[Awaitable[None]] = [] try: while True: - callback: Call = await self._queue.get() + callback = await self._queue.get() if callback is None: break @@ -228,12 +233,12 @@ async def _handle_done_callbacks(self): with anyio.CancelScope(shield=True): await self._run_done_callback(callback) - async def _run_done_callback(self, callback: Call): + async def _run_done_callback(self, callback: Call[Any]) -> None: coro = callback.run() if coro: await coro - def add_done_callback(self, callback: Call): + def add_done_callback(self, callback: Call[Any]) -> None: if self._done_event.is_set(): raise RuntimeError("Cannot add done callbacks to done waiters.") else: @@ -243,6 +248,8 @@ def _signal_stop_waiting(self): # Only send a `None` to the queue if the waiter is still blocked reading from # the queue. Otherwise, it's possible that the event loop is stopped. if not self._done_waiting: + assert self._loop is not None + assert self._queue is not None call_soon_in_loop(self._loop, self._queue.put_nowait, None) async def wait(self) -> Call[T]: diff --git a/src/prefect/_internal/pydantic/v1_schema.py b/src/prefect/_internal/pydantic/v1_schema.py index b94b0fc5973a..2fc409ae2625 100644 --- a/src/prefect/_internal/pydantic/v1_schema.py +++ b/src/prefect/_internal/pydantic/v1_schema.py @@ -6,7 +6,7 @@ from pydantic.v1 import BaseModel as V1BaseModel -def is_v1_model(v) -> bool: +def is_v1_model(v: typing.Any) -> bool: with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=pydantic.warnings.PydanticDeprecatedSince20 @@ -23,7 +23,7 @@ def is_v1_model(v) -> bool: return False -def is_v1_type(v) -> bool: +def is_v1_type(v: typing.Any) -> bool: if is_v1_model(v): return True diff --git a/src/prefect/_internal/pydantic/v2_schema.py b/src/prefect/_internal/pydantic/v2_schema.py index c9d7fa08b5e1..da22799c742b 100644 --- a/src/prefect/_internal/pydantic/v2_schema.py +++ b/src/prefect/_internal/pydantic/v2_schema.py @@ -16,7 +16,7 @@ from prefect._internal.pydantic.schemas import GenerateEmptySchemaForUserClasses -def is_v2_model(v) -> bool: +def is_v2_model(v: t.Any) -> bool: if isinstance(v, V2BaseModel): return True try: @@ -28,7 +28,7 @@ def is_v2_model(v) -> bool: return False -def is_v2_type(v) -> bool: +def is_v2_type(v: t.Any) -> bool: if is_v2_model(v): return True @@ -56,9 +56,9 @@ def process_v2_params( param: inspect.Parameter, *, position: int, - docstrings: t.Dict[str, str], - aliases: t.Dict, -) -> t.Tuple[str, t.Any, "pydantic.Field"]: + docstrings: dict[str, str], + aliases: dict[str, str], +) -> tuple[str, t.Any, t.Any]: """ Generate a sanitized name, type, and pydantic.Field for a given parameter. @@ -72,7 +72,7 @@ def process_v2_params( else: name = param.name - type_ = t.Any if param.annotation is inspect._empty else param.annotation + type_ = t.Any if param.annotation is inspect.Parameter.empty else param.annotation # Replace pendulum type annotations with our own so that they are pydantic compatible if type_ == pendulum.DateTime: @@ -95,12 +95,13 @@ def process_v2_params( def create_v2_schema( name_: str, model_cfg: t.Optional[ConfigDict] = None, - model_base: t.Optional[t.Type[V2BaseModel]] = None, - **model_fields, -): + model_base: t.Optional[type[V2BaseModel]] = None, + model_fields: t.Optional[dict[str, t.Any]] = None, +) -> dict[str, t.Any]: """ Create a pydantic v2 model and craft a v1 compatible schema from it. """ + model_fields = model_fields or {} model = create_model( name_, __config__=model_cfg, __base__=model_base, **model_fields ) diff --git a/src/prefect/cli/cloud/__init__.py b/src/prefect/cli/cloud/__init__.py index 5bc9e9ec02c6..f698fbd58160 100644 --- a/src/prefect/cli/cloud/__init__.py +++ b/src/prefect/cli/cloud/__init__.py @@ -41,7 +41,6 @@ ) from prefect.utilities.asyncutils import run_sync_in_worker_thread from prefect.utilities.collections import listrepr -from prefect.utilities.compat import raise_signal from pydantic import BaseModel diff --git a/src/prefect/context.py b/src/prefect/context.py index 69c14ce4fdb0..b82d214d4aff 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -355,14 +355,14 @@ class EngineContext(RunContext): persist_result: bool = Field(default_factory=get_default_persist_setting) # Counter for task calls allowing unique - task_run_dynamic_keys: Dict[str, int] = Field(default_factory=dict) + task_run_dynamic_keys: Dict[str, Union[str, int]] = Field(default_factory=dict) # Counter for flow pauses observed_flow_pauses: Dict[str, int] = Field(default_factory=dict) # Tracking for result from task runs in this flow run for dependency tracking # Holds the ID of the object returned by the task run and task run state - task_run_results: Mapping[int, State] = Field(default_factory=dict) + task_run_results: dict[int, State] = Field(default_factory=dict) # Events worker to emit events events: Optional[EventsWorker] = None diff --git a/src/prefect/deployments/flow_runs.py b/src/prefect/deployments/flow_runs.py index 8c66b5d87bf9..07971410410e 100644 --- a/src/prefect/deployments/flow_runs.py +++ b/src/prefect/deployments/flow_runs.py @@ -113,10 +113,8 @@ async def run_deployment( task_run_ctx = TaskRunContext.get() if as_subflow and (flow_run_ctx or task_run_ctx): # TODO: this logic can likely be simplified by using `Task.create_run` - from prefect.utilities.engine import ( - _dynamic_key_for_task_run, - collect_task_run_inputs, - ) + from prefect.utilities._engine import dynamic_key_for_task_run + from prefect.utilities.engine import collect_task_run_inputs # This was called from a flow. Link the flow run as a subflow. task_inputs = { @@ -143,7 +141,7 @@ async def run_deployment( else task_run_ctx.task_run.flow_run_id ) dynamic_key = ( - _dynamic_key_for_task_run(flow_run_ctx, dummy_task) + dynamic_key_for_task_run(flow_run_ctx, dummy_task) if flow_run_ctx else task_run_ctx.task_run.dynamic_key ) diff --git a/src/prefect/events/utilities.py b/src/prefect/events/utilities.py index 106f479090e6..6995e96dced8 100644 --- a/src/prefect/events/utilities.py +++ b/src/prefect/events/utilities.py @@ -3,7 +3,6 @@ from uuid import UUID import pendulum -from pydantic_extra_types.pendulum_dt import DateTime from .clients import ( AssertingEventsClient, @@ -20,7 +19,7 @@ def emit_event( event: str, resource: Dict[str, str], - occurred: Optional[DateTime] = None, + occurred: Optional[pendulum.DateTime] = None, related: Optional[Union[List[Dict[str, str]], List[RelatedResource]]] = None, payload: Optional[Dict[str, Any]] = None, id: Optional[UUID] = None, diff --git a/src/prefect/filesystems.py b/src/prefect/filesystems.py index 333665fd5679..97b7ee1e2e26 100644 --- a/src/prefect/filesystems.py +++ b/src/prefect/filesystems.py @@ -1,6 +1,7 @@ import abc import urllib.parse from pathlib import Path +from shutil import copytree from typing import Any, Dict, Optional import anyio @@ -13,7 +14,6 @@ ) from prefect.blocks.core import Block from prefect.utilities.asyncutils import run_sync_in_worker_thread, sync_compatible -from prefect.utilities.compat import copytree from prefect.utilities.filesystem import filter_files from ._internal.compatibility.migration import getattr_migration @@ -158,7 +158,7 @@ async def get_directory( copytree(from_path, local_path, dirs_exist_ok=True, ignore=ignore_func) async def _get_ignore_func(self, local_path: str, ignore_file: str): - with open(ignore_file, "r") as f: + with open(ignore_file) as f: ignore_patterns = f.readlines() included_files = filter_files(root=local_path, ignore_patterns=ignore_patterns) @@ -348,7 +348,7 @@ async def put_directory( included_files = None if ignore_file: - with open(ignore_file, "r") as f: + with open(ignore_file) as f: ignore_patterns = f.readlines() included_files = filter_files( diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index 43bf55e21c41..c37154a09cdf 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -74,6 +74,7 @@ ) from prefect.telemetry.run_telemetry import OTELSetter from prefect.types import KeyValueLabels +from prefect.utilities._engine import get_hook_name, resolve_custom_flow_run_name from prefect.utilities.annotations import NotSet from prefect.utilities.asyncutils import run_coro_as_sync from prefect.utilities.callables import ( @@ -83,8 +84,6 @@ ) from prefect.utilities.collections import visit_collection from prefect.utilities.engine import ( - _get_hook_name, - _resolve_custom_flow_run_name, capture_sigterm, link_state_to_result, propose_state, @@ -572,7 +571,7 @@ def call_hooks(self, state: Optional[State] = None): hooks = None for hook in hooks or []: - hook_name = _get_hook_name(hook) + hook_name = get_hook_name(hook) try: self.logger.info( @@ -635,7 +634,7 @@ def setup_run_context(self, client: Optional[SyncPrefectClient] = None): # update the flow run name if necessary if not self._flow_run_name_set and self.flow.flow_run_name: - flow_run_name = _resolve_custom_flow_run_name( + flow_run_name = resolve_custom_flow_run_name( flow=self.flow, parameters=self.parameters ) self.client.set_flow_run_name( @@ -1146,7 +1145,7 @@ async def call_hooks(self, state: Optional[State] = None): hooks = None for hook in hooks or []: - hook_name = _get_hook_name(hook) + hook_name = get_hook_name(hook) try: self.logger.info( @@ -1209,7 +1208,7 @@ async def setup_run_context(self, client: Optional[PrefectClient] = None): # update the flow run name if necessary if not self._flow_run_name_set and self.flow.flow_run_name: - flow_run_name = _resolve_custom_flow_run_name( + flow_run_name = resolve_custom_flow_run_name( flow=self.flow, parameters=self.parameters ) await self.client.set_flow_run_name( diff --git a/src/prefect/runner/submit.py b/src/prefect/runner/submit.py index f57d9ccc10cf..ec42a4029a79 100644 --- a/src/prefect/runner/submit.py +++ b/src/prefect/runner/submit.py @@ -42,11 +42,8 @@ async def _submit_flow_to_runner( Returns: A `FlowRun` object representing the flow run that was submitted. """ - from prefect.utilities.engine import ( - _dynamic_key_for_task_run, - collect_task_run_inputs, - resolve_inputs, - ) + from prefect.utilities._engine import dynamic_key_for_task_run + from prefect.utilities.engine import collect_task_run_inputs, resolve_inputs async with get_client() as client: if not retry_failed_submissions: @@ -67,7 +64,7 @@ async def _submit_flow_to_runner( parent_flow_run_context.flow_run.id if parent_flow_run_context else None ), dynamic_key=( - _dynamic_key_for_task_run(parent_flow_run_context, dummy_task) + dynamic_key_for_task_run(parent_flow_run_context, dummy_task) if parent_flow_run_context else str(uuid.uuid4()) ), diff --git a/src/prefect/task_engine.py b/src/prefect/task_engine.py index fa879d4265d9..a07f2108c95a 100644 --- a/src/prefect/task_engine.py +++ b/src/prefect/task_engine.py @@ -82,12 +82,12 @@ ) from prefect.telemetry.run_telemetry import RunTelemetry from prefect.transactions import IsolationLevel, Transaction, transaction +from prefect.utilities._engine import get_hook_name from prefect.utilities.annotations import NotSet from prefect.utilities.asyncutils import run_coro_as_sync from prefect.utilities.callables import call_with_parameters, parameters_to_args_kwargs from prefect.utilities.collections import visit_collection from prefect.utilities.engine import ( - _get_hook_name, emit_task_run_state_change_event, link_state_to_result, resolve_to_final_result, @@ -196,11 +196,11 @@ def _resolve_parameters(self): self.parameters = resolved_parameters def _set_custom_task_run_name(self): - from prefect.utilities.engine import _resolve_custom_task_run_name + from prefect.utilities._engine import resolve_custom_task_run_name # update the task run name if necessary if not self._task_name_set and self.task.task_run_name: - task_run_name = _resolve_custom_task_run_name( + task_run_name = resolve_custom_task_run_name( task=self.task, parameters=self.parameters or {} ) @@ -354,7 +354,7 @@ def call_hooks(self, state: Optional[State] = None): hooks = None for hook in hooks or []: - hook_name = _get_hook_name(hook) + hook_name = get_hook_name(hook) try: self.logger.info( @@ -891,7 +891,7 @@ async def call_hooks(self, state: Optional[State] = None): hooks = None for hook in hooks or []: - hook_name = _get_hook_name(hook) + hook_name = get_hook_name(hook) try: self.logger.info( diff --git a/src/prefect/tasks.py b/src/prefect/tasks.py index 784deec2813d..fec3d53383f0 100644 --- a/src/prefect/tasks.py +++ b/src/prefect/tasks.py @@ -21,6 +21,7 @@ List, NoReturn, Optional, + Protocol, Set, Tuple, Type, @@ -31,7 +32,7 @@ ) from uuid import UUID, uuid4 -from typing_extensions import Literal, ParamSpec +from typing_extensions import Literal, ParamSpec, Self, TypeIs import prefect.states from prefect.cache_policies import DEFAULT, NONE, CachePolicy @@ -223,6 +224,16 @@ def _generate_task_key(fn: Callable[..., Any]) -> str: return f"{qualname}-{code_hash}" +class TaskRunNameCallbackWithParameters(Protocol): + @classmethod + def is_callback_with_parameters(cls, callable: Callable[..., str]) -> TypeIs[Self]: + sig = inspect.signature(callable) + return "parameters" in sig.parameters + + def __call__(self, parameters: dict[str, Any]) -> str: + ... + + class Task(Generic[P, R]): """ A Prefect task definition. @@ -311,7 +322,7 @@ def __init__( ] = None, cache_expiration: Optional[datetime.timedelta] = None, task_run_name: Optional[ - Union[Callable[[], str], Callable[[Dict[str, Any]], str], str] + Union[Callable[[], str], TaskRunNameCallbackWithParameters, str] ] = None, retries: Optional[int] = None, retry_delay_seconds: Optional[ @@ -534,7 +545,9 @@ def with_options( Callable[["TaskRunContext", Dict[str, Any]], Optional[str]] ] = None, task_run_name: Optional[ - Union[Callable[[], str], Callable[[Dict[str, Any]], str], str, Type[NotSet]] + Union[ + Callable[[], str], TaskRunNameCallbackWithParameters, str, Type[NotSet] + ] ] = NotSet, cache_expiration: Optional[datetime.timedelta] = None, retries: Union[int, Type[NotSet]] = NotSet, @@ -732,10 +745,8 @@ async def create_run( extra_task_inputs: Optional[Dict[str, Set[TaskRunInput]]] = None, deferred: bool = False, ) -> TaskRun: - from prefect.utilities.engine import ( - _dynamic_key_for_task_run, - collect_task_run_inputs_sync, - ) + from prefect.utilities._engine import dynamic_key_for_task_run + from prefect.utilities.engine import collect_task_run_inputs_sync if flow_run_context is None: flow_run_context = FlowRunContext.get() @@ -751,7 +762,7 @@ async def create_run( dynamic_key = f"{self.task_key}-{str(uuid4().hex)}" task_run_name = self.name else: - dynamic_key = _dynamic_key_for_task_run( + dynamic_key = dynamic_key_for_task_run( context=flow_run_context, task=self ) task_run_name = f"{self.name}-{dynamic_key}" @@ -835,10 +846,8 @@ async def create_local_run( extra_task_inputs: Optional[Dict[str, Set[TaskRunInput]]] = None, deferred: bool = False, ) -> TaskRun: - from prefect.utilities.engine import ( - _dynamic_key_for_task_run, - collect_task_run_inputs_sync, - ) + from prefect.utilities._engine import dynamic_key_for_task_run + from prefect.utilities.engine import collect_task_run_inputs_sync if flow_run_context is None: flow_run_context = FlowRunContext.get() @@ -854,7 +863,7 @@ async def create_local_run( dynamic_key = f"{self.task_key}-{str(uuid4().hex)}" task_run_name = self.name else: - dynamic_key = _dynamic_key_for_task_run( + dynamic_key = dynamic_key_for_task_run( context=flow_run_context, task=self, stable=False ) task_run_name = f"{self.name}-{dynamic_key[:3]}" @@ -1588,7 +1597,7 @@ def task( ] = None, cache_expiration: Optional[datetime.timedelta] = None, task_run_name: Optional[ - Union[Callable[[], str], Callable[[Dict[str, Any]], str], str] + Union[Callable[[], str], TaskRunNameCallbackWithParameters, str] ] = None, retries: int = 0, retry_delay_seconds: Union[ @@ -1629,7 +1638,7 @@ def task( ] = None, cache_expiration: Optional[datetime.timedelta] = None, task_run_name: Optional[ - Union[Callable[[], str], Callable[[Dict[str, Any]], str], str] + Union[Callable[[], str], TaskRunNameCallbackWithParameters, str] ] = None, retries: Optional[int] = None, retry_delay_seconds: Union[ diff --git a/src/prefect/transactions.py b/src/prefect/transactions.py index f274547820f7..a426999fc2c0 100644 --- a/src/prefect/transactions.py +++ b/src/prefect/transactions.py @@ -32,11 +32,9 @@ ResultStore, get_result_store, ) +from prefect.utilities._engine import get_hook_name from prefect.utilities.annotations import NotSet from prefect.utilities.collections import AutoEnum -from prefect.utilities.engine import ( - _get_hook_name, # pyright: ignore[reportPrivateUsage] -) class IsolationLevel(AutoEnum): @@ -357,7 +355,7 @@ def commit(self) -> bool: return False def run_hook(self, hook: Callable[..., Any], hook_type: str) -> None: - hook_name = _get_hook_name(hook) + hook_name = get_hook_name(hook) # Undocumented way to disable logging for a hook. Subject to change. should_log = getattr(hook, "log_on_run", True) diff --git a/src/prefect/utilities/_engine.py b/src/prefect/utilities/_engine.py new file mode 100644 index 000000000000..c3a99676dcc4 --- /dev/null +++ b/src/prefect/utilities/_engine.py @@ -0,0 +1,96 @@ +"""Internal engine utilities""" + + +from collections.abc import Callable +from functools import partial +from typing import TYPE_CHECKING, Any, Union +from uuid import uuid4 + +from prefect.context import FlowRunContext +from prefect.flows import Flow +from prefect.tasks import Task, TaskRunNameCallbackWithParameters + + +def dynamic_key_for_task_run( + context: FlowRunContext, task: "Task[..., Any]", stable: bool = True +) -> Union[int, str]: + if ( + stable is False or context.detached + ): # this task is running on remote infrastructure + return str(uuid4()) + elif context.flow_run is None: # this is an autonomous task run + context.task_run_dynamic_keys[task.task_key] = getattr( + task, "dynamic_key", str(uuid4()) + ) + + elif task.task_key not in context.task_run_dynamic_keys: + context.task_run_dynamic_keys[task.task_key] = 0 + else: + dynamic_key = context.task_run_dynamic_keys[task.task_key] + if TYPE_CHECKING: + assert isinstance(dynamic_key, int) + context.task_run_dynamic_keys[task.task_key] = dynamic_key + 1 + + return context.task_run_dynamic_keys[task.task_key] + + +def resolve_custom_flow_run_name( + flow: "Flow[..., Any]", parameters: dict[str, Any] +) -> str: + if callable(flow.flow_run_name): + flow_run_name = flow.flow_run_name() + if not TYPE_CHECKING: + if not isinstance(flow_run_name, str): + raise TypeError( + f"Callable {flow.flow_run_name} for 'flow_run_name' returned type" + f" {type(flow_run_name).__name__} but a string is required." + ) + elif isinstance(flow.flow_run_name, str): + flow_run_name = flow.flow_run_name.format(**parameters) + else: + raise TypeError( + "Expected string or callable for 'flow_run_name'; got" + f" {type(flow.flow_run_name).__name__} instead." + ) + + return flow_run_name + + +def resolve_custom_task_run_name( + task: "Task[..., Any]", parameters: dict[str, Any] +) -> str: + if callable(task.task_run_name): + # If the callable accepts a 'parameters' kwarg, pass the entire parameters dict + if TaskRunNameCallbackWithParameters.is_callback_with_parameters( + task.task_run_name + ): + task_run_name = task.task_run_name(parameters=parameters) + else: + # If it doesn't expect parameters, call it without arguments + task_run_name = task.task_run_name() + + if not TYPE_CHECKING: + if not isinstance(task_run_name, str): + raise TypeError( + f"Callable {task.task_run_name} for 'task_run_name' returned type" + f" {type(task_run_name).__name__} but a string is required." + ) + elif isinstance(task.task_run_name, str): + task_run_name = task.task_run_name.format(**parameters) + else: + raise TypeError( + "Expected string or callable for 'task_run_name'; got" + f" {type(task.task_run_name).__name__} instead." + ) + + return task_run_name + + +def get_hook_name(hook: Callable[..., Any]) -> str: + return ( + hook.__name__ + if hasattr(hook, "__name__") + else ( + hook.func.__name__ if isinstance(hook, partial) else hook.__class__.__name__ + ) + ) diff --git a/src/prefect/utilities/annotations.py b/src/prefect/utilities/annotations.py index 6d9bd73ed475..2e264f334d74 100644 --- a/src/prefect/utilities/annotations.py +++ b/src/prefect/utilities/annotations.py @@ -1,33 +1,40 @@ import warnings -from abc import ABC -from collections import namedtuple -from typing import Generic, TypeVar +from operator import itemgetter +from typing import Any, cast -T = TypeVar("T") +from typing_extensions import Self, TypeVar +T = TypeVar("T", infer_variance=True) -class BaseAnnotation( - namedtuple("BaseAnnotation", field_names="value"), ABC, Generic[T] -): + +class BaseAnnotation(tuple[T]): """ Base class for Prefect annotation types. - Inherits from `namedtuple` for unpacking support in another tools. + Inherits from `tuple` for unpacking support in other tools. """ + __slots__ = () + + def __new__(cls, value: T) -> Self: + return super().__new__(cls, (value,)) + + # use itemgetter to minimise overhead, just like namedtuple generated code would + value: T = cast(T, property(itemgetter(0))) + def unwrap(self) -> T: - return self.value + return self[0] - def rewrap(self, value: T) -> "BaseAnnotation[T]": + def rewrap(self, value: T) -> Self: return type(self)(value) - def __eq__(self, other: "BaseAnnotation[T]") -> bool: + def __eq__(self, other: Any) -> bool: if type(self) is not type(other): return False - return self.unwrap() == other.unwrap() + return super().__eq__(other) def __repr__(self) -> str: - return f"{type(self).__name__}({self.value!r})" + return f"{type(self).__name__}({self[0]!r})" class unmapped(BaseAnnotation[T]): @@ -38,9 +45,9 @@ class unmapped(BaseAnnotation[T]): operation instead of being split. """ - def __getitem__(self, _) -> T: + def __getitem__(self, _: object) -> T: # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] # Internally, this acts as an infinite array where all items are the same value - return self.unwrap() + return super().__getitem__(0) class allow_failure(BaseAnnotation[T]): @@ -87,14 +94,14 @@ def unquote(self) -> T: # Backwards compatibility stub for `Quote` class -class Quote(quote): - def __init__(self, expr): +class Quote(quote[T]): + def __new__(cls, expr: T) -> Self: warnings.warn( "Use of `Quote` is deprecated. Use `quote` instead.", DeprecationWarning, stacklevel=2, ) - super().__init__(expr) + return super().__new__(cls, expr) class NotSet: diff --git a/src/prefect/utilities/asyncutils.py b/src/prefect/utilities/asyncutils.py index ce5a0229b049..e2c1a2472eec 100644 --- a/src/prefect/utilities/asyncutils.py +++ b/src/prefect/utilities/asyncutils.py @@ -6,24 +6,12 @@ import inspect import threading import warnings -from concurrent.futures import ThreadPoolExecutor -from contextlib import asynccontextmanager -from contextvars import ContextVar, copy_context +from collections.abc import AsyncGenerator, Awaitable, Coroutine +from contextlib import AbstractAsyncContextManager, asynccontextmanager +from contextvars import ContextVar from functools import partial, wraps -from typing import ( - Any, - AsyncGenerator, - Awaitable, - Callable, - Coroutine, - Dict, - List, - Optional, - TypeVar, - Union, - cast, - overload, -) +from logging import Logger +from typing import TYPE_CHECKING, Any, Callable, NoReturn, Optional, Union, overload from uuid import UUID, uuid4 import anyio @@ -31,9 +19,18 @@ import anyio.from_thread import anyio.to_thread import sniffio -from typing_extensions import Literal, ParamSpec, TypeGuard +from typing_extensions import ( + Literal, + ParamSpec, + Self, + TypeAlias, + TypeGuard, + TypeVar, + TypeVarTuple, + Unpack, +) -from prefect._internal.concurrency.api import _cast_to_call, from_sync +from prefect._internal.concurrency.api import cast_to_call, from_sync from prefect._internal.concurrency.threads import ( get_run_sync_loop, in_run_sync_loop, @@ -42,62 +39,65 @@ T = TypeVar("T") P = ParamSpec("P") -R = TypeVar("R") +R = TypeVar("R", infer_variance=True) F = TypeVar("F", bound=Callable[..., Any]) Async = Literal[True] Sync = Literal[False] A = TypeVar("A", Async, Sync, covariant=True) +PosArgsT = TypeVarTuple("PosArgsT") + +_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[R, Awaitable[R]]] # Global references to prevent garbage collection for `add_event_loop_shutdown_callback` -EVENT_LOOP_GC_REFS = {} +EVENT_LOOP_GC_REFS: dict[int, AsyncGenerator[None, Any]] = {} -PREFECT_THREAD_LIMITER: Optional[anyio.CapacityLimiter] = None RUNNING_IN_RUN_SYNC_LOOP_FLAG = ContextVar("running_in_run_sync_loop", default=False) RUNNING_ASYNC_FLAG = ContextVar("run_async", default=False) -BACKGROUND_TASKS: set[asyncio.Task] = set() -background_task_lock = threading.Lock() +BACKGROUND_TASKS: set[asyncio.Task[Any]] = set() +background_task_lock: threading.Lock = threading.Lock() # Thread-local storage to keep track of worker thread state _thread_local = threading.local() -logger = get_logger() +logger: Logger = get_logger() + + +_prefect_thread_limiter: Optional[anyio.CapacityLimiter] = None -def get_thread_limiter(): - global PREFECT_THREAD_LIMITER +def get_thread_limiter() -> anyio.CapacityLimiter: + global _prefect_thread_limiter - if PREFECT_THREAD_LIMITER is None: - PREFECT_THREAD_LIMITER = anyio.CapacityLimiter(250) + if _prefect_thread_limiter is None: + _prefect_thread_limiter = anyio.CapacityLimiter(250) - return PREFECT_THREAD_LIMITER + return _prefect_thread_limiter def is_async_fn( - func: Union[Callable[P, R], Callable[P, Awaitable[R]]], + func: _SyncOrAsyncCallable[P, R], ) -> TypeGuard[Callable[P, Awaitable[R]]]: """ Returns `True` if a function returns a coroutine. See https://github.com/microsoft/pyright/issues/2142 for an example use """ - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - + func = inspect.unwrap(func) return asyncio.iscoroutinefunction(func) -def is_async_gen_fn(func): +def is_async_gen_fn( + func: Callable[P, Any], +) -> TypeGuard[Callable[P, AsyncGenerator[Any, Any]]]: """ Returns `True` if a function is an async generator. """ - while hasattr(func, "__wrapped__"): - func = func.__wrapped__ - + func = inspect.unwrap(func) return inspect.isasyncgenfunction(func) -def create_task(coroutine: Coroutine) -> asyncio.Task: +def create_task(coroutine: Coroutine[Any, Any, R]) -> asyncio.Task[R]: """ Replacement for asyncio.create_task that will ensure that tasks aren't garbage collected before they complete. Allows for "fire and forget" @@ -123,68 +123,32 @@ def create_task(coroutine: Coroutine) -> asyncio.Task: return task -def _run_sync_in_new_thread(coroutine: Coroutine[Any, Any, T]) -> T: - """ - Note: this is an OLD implementation of `run_coro_as_sync` which liberally created - new threads and new loops. This works, but prevents sharing any objects - across coroutines, in particular httpx clients, which are very expensive to - instantiate. - - This is here for historical purposes and can be removed if/when it is no - longer needed for reference. - - --- - - Runs a coroutine from a synchronous context. A thread will be spawned to run - the event loop if necessary, which allows coroutines to run in environments - like Jupyter notebooks where the event loop runs on the main thread. - - Args: - coroutine: The coroutine to run. - - Returns: - The return value of the coroutine. - - Example: - Basic usage: ```python async def my_async_function(x: int) -> int: - return x + 1 - - run_sync(my_async_function(1)) ``` - """ +@overload +def run_coro_as_sync( + coroutine: Coroutine[Any, Any, R], + *, + force_new_thread: bool = ..., + wait_for_result: Literal[True] = ..., +) -> R: + ... - # ensure context variables are properly copied to the async frame - async def context_local_wrapper(): - """ - Wrapper that is submitted using copy_context().run to ensure - the RUNNING_ASYNC_FLAG mutations are tightly scoped to this coroutine's frame. - """ - token = RUNNING_ASYNC_FLAG.set(True) - try: - result = await coroutine - finally: - RUNNING_ASYNC_FLAG.reset(token) - return result - context = copy_context() - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - if loop and loop.is_running(): - with ThreadPoolExecutor() as executor: - future = executor.submit(context.run, asyncio.run, context_local_wrapper()) - result = cast(T, future.result()) - else: - result = context.run(asyncio.run, context_local_wrapper()) - return result +@overload +def run_coro_as_sync( + coroutine: Coroutine[Any, Any, R], + *, + force_new_thread: bool = ..., + wait_for_result: Literal[False] = False, +) -> R: + ... def run_coro_as_sync( - coroutine: Awaitable[R], + coroutine: Coroutine[Any, Any, R], + *, force_new_thread: bool = False, wait_for_result: bool = True, -) -> Union[R, None]: +) -> Optional[R]: """ Runs a coroutine from a synchronous context, as if it were a synchronous function. @@ -211,7 +175,7 @@ def run_coro_as_sync( The result of the coroutine if wait_for_result is True, otherwise None. """ - async def coroutine_wrapper() -> Union[R, None]: + async def coroutine_wrapper() -> Optional[R]: """ Set flags so that children (and grandchildren...) of this task know they are running in a new thread and do not try to run on the run_sync thread, which would cause a @@ -232,12 +196,13 @@ async def coroutine_wrapper() -> Union[R, None]: # that is running in the run_sync loop, we need to run this coroutine in a # new thread if in_run_sync_loop() or RUNNING_IN_RUN_SYNC_LOOP_FLAG.get() or force_new_thread: - return from_sync.call_in_new_thread(coroutine_wrapper) + result = from_sync.call_in_new_thread(coroutine_wrapper) + return result # otherwise, we can run the coroutine in the run_sync loop # and wait for the result else: - call = _cast_to_call(coroutine_wrapper) + call = cast_to_call(coroutine_wrapper) runner = get_run_sync_loop() runner.submit(call) try: @@ -250,8 +215,8 @@ async def coroutine_wrapper() -> Union[R, None]: async def run_sync_in_worker_thread( - __fn: Callable[..., T], *args: Any, **kwargs: Any -) -> T: + __fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs +) -> R: """ Runs a sync function in a new worker thread so that the main thread's event loop is not blocked. @@ -275,14 +240,14 @@ async def run_sync_in_worker_thread( RUNNING_ASYNC_FLAG.reset(token) -def call_with_mark(call): +def call_with_mark(call: Callable[..., R]) -> R: mark_as_worker_thread() return call() def run_async_from_worker_thread( - __fn: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any -) -> T: + __fn: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs +) -> R: """ Runs an async function in the main thread's event loop, blocking the worker thread until completion @@ -291,11 +256,13 @@ def run_async_from_worker_thread( return anyio.from_thread.run(call) -def run_async_in_new_loop(__fn: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any): +def run_async_in_new_loop( + __fn: Callable[P, Awaitable[R]], *args: P.args, **kwargs: P.kwargs +) -> R: return anyio.run(partial(__fn, *args, **kwargs)) -def mark_as_worker_thread(): +def mark_as_worker_thread() -> None: _thread_local.is_worker_thread = True @@ -313,23 +280,9 @@ def in_async_main_thread() -> bool: return not in_async_worker_thread() -@overload -def sync_compatible( - async_fn: Callable[..., Coroutine[Any, Any, R]], -) -> Callable[..., R]: - ... - - -@overload def sync_compatible( - async_fn: Callable[..., Coroutine[Any, Any, R]], -) -> Callable[..., Coroutine[Any, Any, R]]: - ... - - -def sync_compatible( - async_fn: Callable[..., Coroutine[Any, Any, R]], -) -> Callable[..., Union[R, Coroutine[Any, Any, R]]]: + async_fn: Callable[P, Coroutine[Any, Any, R]], +) -> Callable[P, Union[R, Coroutine[Any, Any, R]]]: """ Converts an async function into a dual async and sync function. @@ -394,7 +347,7 @@ async def ctx_call(): if _sync is True: return run_coro_as_sync(ctx_call()) - elif _sync is False or RUNNING_ASYNC_FLAG.get() or is_async: + elif RUNNING_ASYNC_FLAG.get() or is_async: return ctx_call() else: return run_coro_as_sync(ctx_call()) @@ -410,10 +363,24 @@ async def ctx_call(): return wrapper +@overload +def asyncnullcontext( + value: None = None, *args: Any, **kwargs: Any +) -> AbstractAsyncContextManager[None, None]: + ... + + +@overload +def asyncnullcontext( + value: R, *args: Any, **kwargs: Any +) -> AbstractAsyncContextManager[R, None]: + ... + + @asynccontextmanager async def asyncnullcontext( - value: Optional[Any] = None, *args: Any, **kwargs: Any -) -> AsyncGenerator[Any, None]: + value: Optional[R] = None, *args: Any, **kwargs: Any +) -> AsyncGenerator[Any, Optional[R]]: yield value @@ -429,7 +396,7 @@ def sync(__async_fn: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwarg "`sync` called from an asynchronous context; " "you should `await` the async function directly instead." ) - with anyio.start_blocking_portal() as portal: + with anyio.from_thread.start_blocking_portal() as portal: return portal.call(partial(__async_fn, *args, **kwargs)) elif in_async_worker_thread(): # In a sync context but we can access the event loop thread; send the async @@ -441,7 +408,9 @@ def sync(__async_fn: Callable[P, Awaitable[T]], *args: P.args, **kwargs: P.kwarg return run_async_in_new_loop(__async_fn, *args, **kwargs) -async def add_event_loop_shutdown_callback(coroutine_fn: Callable[[], Awaitable]): +async def add_event_loop_shutdown_callback( + coroutine_fn: Callable[[], Awaitable[Any]], +) -> None: """ Adds a callback to the given callable on event loop closure. The callable must be a coroutine function. It will be awaited when the current event loop is shutting @@ -457,7 +426,7 @@ async def add_event_loop_shutdown_callback(coroutine_fn: Callable[[], Awaitable] loop is about to close. """ - async def on_shutdown(key): + async def on_shutdown(key: int) -> AsyncGenerator[None, Any]: # It appears that EVENT_LOOP_GC_REFS is somehow being garbage collected early. # We hold a reference to it so as to preserve it, at least for the lifetime of # this coroutine. See the issue below for the initial report/discussion: @@ -496,7 +465,7 @@ class GatherTaskGroup(anyio.abc.TaskGroup): """ A task group that gathers results. - AnyIO does not include support `gather`. This class extends the `TaskGroup` + AnyIO does not include `gather` support. This class extends the `TaskGroup` interface to allow simple gathering. See https://github.com/agronholm/anyio/issues/100 @@ -505,21 +474,31 @@ class GatherTaskGroup(anyio.abc.TaskGroup): """ def __init__(self, task_group: anyio.abc.TaskGroup): - self._results: Dict[UUID, Any] = {} + self._results: dict[UUID, Any] = {} # The concrete task group implementation to use self._task_group: anyio.abc.TaskGroup = task_group - async def _run_and_store(self, key, fn, args): + async def _run_and_store( + self, + key: UUID, + fn: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], + ) -> None: self._results[key] = await fn(*args) - def start_soon(self, fn, *args) -> UUID: + def start_soon( # pyright: ignore[reportIncompatibleMethodOverride] + self, + func: Callable[[Unpack[PosArgsT]], Awaitable[Any]], + *args: Unpack[PosArgsT], + name: object = None, + ) -> UUID: key = uuid4() # Put a placeholder in-case the result is retrieved earlier self._results[key] = GatherIncomplete - self._task_group.start_soon(self._run_and_store, key, fn, args) + self._task_group.start_soon(self._run_and_store, key, func, *args, name=name) return key - async def start(self, fn, *args): + async def start(self, func: object, *args: object, name: object = None) -> NoReturn: """ Since `start` returns the result of `task_status.started()` but here we must return the key instead, we just won't support this method for now. @@ -535,11 +514,11 @@ def get_result(self, key: UUID) -> Any: ) return result - async def __aenter__(self): + async def __aenter__(self) -> Self: await self._task_group.__aenter__() return self - async def __aexit__(self, *tb): + async def __aexit__(self, *tb: Any) -> Optional[bool]: try: retval = await self._task_group.__aexit__(*tb) return retval @@ -555,14 +534,14 @@ def create_gather_task_group() -> GatherTaskGroup: return GatherTaskGroup(anyio.create_task_group()) -async def gather(*calls: Callable[[], Coroutine[Any, Any, T]]) -> List[T]: +async def gather(*calls: Callable[[], Coroutine[Any, Any, T]]) -> list[T]: """ Run calls concurrently and gather their results. Unlike `asyncio.gather` this expects to receive _callables_ not _coroutines_. This matches `anyio` semantics. """ - keys = [] + keys: list[UUID] = [] async with create_gather_task_group() as tg: for call in calls: keys.append(tg.start_soon(call)) @@ -570,19 +549,23 @@ async def gather(*calls: Callable[[], Coroutine[Any, Any, T]]) -> List[T]: class LazySemaphore: - def __init__(self, initial_value_func): - self._semaphore = None + def __init__(self, initial_value_func: Callable[[], int]) -> None: + self._semaphore: Optional[asyncio.Semaphore] = None self._initial_value_func = initial_value_func - async def __aenter__(self): + async def __aenter__(self) -> asyncio.Semaphore: self._initialize_semaphore() + if TYPE_CHECKING: + assert self._semaphore is not None await self._semaphore.__aenter__() return self._semaphore - async def __aexit__(self, exc_type, exc, tb): - await self._semaphore.__aexit__(exc_type, exc, tb) + async def __aexit__(self, *args: Any) -> None: + if TYPE_CHECKING: + assert self._semaphore is not None + await self._semaphore.__aexit__(*args) - def _initialize_semaphore(self): + def _initialize_semaphore(self) -> None: if self._semaphore is None: initial_value = self._initial_value_func() self._semaphore = asyncio.Semaphore(initial_value) diff --git a/src/prefect/utilities/callables.py b/src/prefect/utilities/callables.py index 382f5cd6a224..9489ec25f061 100644 --- a/src/prefect/utilities/callables.py +++ b/src/prefect/utilities/callables.py @@ -6,14 +6,17 @@ import importlib.util import inspect import warnings +from collections import OrderedDict +from collections.abc import Iterable from functools import partial +from logging import Logger from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Callable, Optional, Union, cast -import cloudpickle +import cloudpickle # type: ignore # no stubs available import pydantic from griffe import Docstring, DocstringSectionKind, Parser, parse -from typing_extensions import Literal +from typing_extensions import Literal, TypeVar from prefect._internal.pydantic.v1_schema import has_v1_type_as_param from prefect._internal.pydantic.v2_schema import ( @@ -32,15 +35,17 @@ from prefect.utilities.collections import isiterable from prefect.utilities.importtools import safe_load_namespace -logger = get_logger(__name__) +logger: Logger = get_logger(__name__) + +R = TypeVar("R", infer_variance=True) def get_call_parameters( - fn: Callable, - call_args: Tuple[Any, ...], - call_kwargs: Dict[str, Any], + fn: Callable[..., Any], + call_args: tuple[Any, ...], + call_kwargs: dict[str, Any], apply_defaults: bool = True, -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Bind a call to a function to get parameter/value mapping. Default values on the signature will be included if not overridden. @@ -57,7 +62,7 @@ def get_call_parameters( function """ if hasattr(fn, "__prefect_self__"): - call_args = (fn.__prefect_self__,) + call_args + call_args = (getattr(fn, "__prefect_self__"), *call_args) try: bound_signature = inspect.signature(fn).bind(*call_args, **call_kwargs) @@ -74,14 +79,14 @@ def get_call_parameters( def get_parameter_defaults( - fn: Callable, -) -> Dict[str, Any]: + fn: Callable[..., Any], +) -> dict[str, Any]: """ Get default parameter values for a callable. """ signature = inspect.signature(fn) - parameter_defaults = {} + parameter_defaults: dict[str, Any] = {} for name, param in signature.parameters.items(): if param.default is not signature.empty: @@ -91,8 +96,8 @@ def get_parameter_defaults( def explode_variadic_parameter( - fn: Callable, parameters: Dict[str, Any] -) -> Dict[str, Any]: + fn: Callable[..., Any], parameters: dict[str, Any] +) -> dict[str, Any]: """ Given a parameter dictionary, move any parameters stored in a variadic keyword argument parameter (i.e. **kwargs) into the top level. @@ -125,8 +130,8 @@ def foo(a, b, **kwargs): def collapse_variadic_parameters( - fn: Callable, parameters: Dict[str, Any] -) -> Dict[str, Any]: + fn: Callable[..., Any], parameters: dict[str, Any] +) -> dict[str, Any]: """ Given a parameter dictionary, move any parameters stored not present in the signature into the variadic keyword argument. @@ -151,50 +156,47 @@ def foo(a, b, **kwargs): missing_parameters = set(parameters.keys()) - set(signature_parameters.keys()) - if not variadic_key and missing_parameters: + if not missing_parameters: + # no missing parameters, return parameters unchanged + return parameters + + if not variadic_key: raise ValueError( f"Signature for {fn} does not include any variadic keyword argument " "but parameters were given that are not present in the signature." ) - if variadic_key and not missing_parameters: - # variadic key is present but no missing parameters, return parameters unchanged - return parameters - new_parameters = parameters.copy() - if variadic_key: - new_parameters[variadic_key] = {} - - for key in missing_parameters: - new_parameters[variadic_key][key] = new_parameters.pop(key) - + new_parameters[variadic_key] = { + key: new_parameters.pop(key) for key in missing_parameters + } return new_parameters def parameters_to_args_kwargs( - fn: Callable, - parameters: Dict[str, Any], -) -> Tuple[Tuple[Any, ...], Dict[str, Any]]: + fn: Callable[..., Any], + parameters: dict[str, Any], +) -> tuple[tuple[Any, ...], dict[str, Any]]: """ Convert a `parameters` dictionary to positional and keyword arguments The function _must_ have an identical signature to the original function or this will return an empty tuple and dict. """ - function_params = dict(inspect.signature(fn).parameters).keys() + function_params = inspect.signature(fn).parameters.keys() # Check for parameters that are not present in the function signature unknown_params = parameters.keys() - function_params if unknown_params: raise SignatureMismatchError.from_bad_params( - list(function_params), list(parameters.keys()) + list(function_params), list(parameters) ) bound_signature = inspect.signature(fn).bind_partial() - bound_signature.arguments = parameters + bound_signature.arguments = OrderedDict(parameters) return bound_signature.args, bound_signature.kwargs -def call_with_parameters(fn: Callable, parameters: Dict[str, Any]): +def call_with_parameters(fn: Callable[..., R], parameters: dict[str, Any]) -> R: """ Call a function with parameters extracted with `get_call_parameters` @@ -207,7 +209,7 @@ def call_with_parameters(fn: Callable, parameters: Dict[str, Any]): def cloudpickle_wrapped_call( - __fn: Callable, *args: Any, **kwargs: Any + __fn: Callable[..., Any], *args: Any, **kwargs: Any ) -> Callable[[], bytes]: """ Serializes a function call using cloudpickle then returns a callable which will @@ -217,18 +219,18 @@ def cloudpickle_wrapped_call( built-in pickler (e.g. `anyio.to_process` and `multiprocessing`) but may require a wider range of pickling support. """ - payload = cloudpickle.dumps((__fn, args, kwargs)) + payload = cloudpickle.dumps((__fn, args, kwargs)) # type: ignore # no stubs available return partial(_run_serialized_call, payload) -def _run_serialized_call(payload) -> bytes: +def _run_serialized_call(payload: bytes) -> bytes: """ Defined at the top-level so it can be pickled by the Python pickler. Used by `cloudpickle_wrapped_call`. """ fn, args, kwargs = cloudpickle.loads(payload) retval = fn(*args, **kwargs) - return cloudpickle.dumps(retval) + return cloudpickle.dumps(retval) # type: ignore # no stubs available class ParameterSchema(pydantic.BaseModel): @@ -236,18 +238,18 @@ class ParameterSchema(pydantic.BaseModel): title: Literal["Parameters"] = "Parameters" type: Literal["object"] = "object" - properties: Dict[str, Any] = pydantic.Field(default_factory=dict) - required: List[str] = pydantic.Field(default_factory=list) - definitions: Dict[str, Any] = pydantic.Field(default_factory=dict) + properties: dict[str, Any] = pydantic.Field(default_factory=dict) + required: list[str] = pydantic.Field(default_factory=list) + definitions: dict[str, Any] = pydantic.Field(default_factory=dict) - def model_dump_for_openapi(self) -> Dict[str, Any]: + def model_dump_for_openapi(self) -> dict[str, Any]: result = self.model_dump(mode="python", exclude_none=True) if "required" in result and not result["required"]: del result["required"] return result -def parameter_docstrings(docstring: Optional[str]) -> Dict[str, str]: +def parameter_docstrings(docstring: Optional[str]) -> dict[str, str]: """ Given a docstring in Google docstring format, parse the parameter section and return a dictionary that maps parameter names to docstring. @@ -258,7 +260,7 @@ def parameter_docstrings(docstring: Optional[str]) -> Dict[str, str]: Returns: Mapping from parameter names to docstrings. """ - param_docstrings = {} + param_docstrings: dict[str, str] = {} if not docstring: return param_docstrings @@ -279,9 +281,9 @@ def process_v1_params( param: inspect.Parameter, *, position: int, - docstrings: Dict[str, str], - aliases: Dict, -) -> Tuple[str, Any, "pydantic.Field"]: + docstrings: dict[str, str], + aliases: dict[str, str], +) -> tuple[str, Any, Any]: # Pydantic model creation will fail if names collide with the BaseModel type if hasattr(pydantic.BaseModel, param.name): name = param.name + "__" @@ -289,13 +291,13 @@ def process_v1_params( else: name = param.name - type_ = Any if param.annotation is inspect._empty else param.annotation + type_ = Any if param.annotation is inspect.Parameter.empty else param.annotation with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=pydantic.warnings.PydanticDeprecatedSince20 ) - field = pydantic.Field( + field: Any = pydantic.Field( # type: ignore # this uses the v1 signature, not v2 default=... if param.default is param.empty else param.default, title=param.name, description=docstrings.get(param.name, None), @@ -305,19 +307,24 @@ def process_v1_params( return name, type_, field -def create_v1_schema(name_: str, model_cfg, **model_fields): +def create_v1_schema( + name_: str, model_cfg: type[Any], model_fields: Optional[dict[str, Any]] = None +) -> dict[str, Any]: with warnings.catch_warnings(): warnings.filterwarnings( "ignore", category=pydantic.warnings.PydanticDeprecatedSince20 ) - model: "pydantic.BaseModel" = pydantic.create_model( - name_, __config__=model_cfg, **model_fields + model_fields = model_fields or {} + model: type[pydantic.BaseModel] = pydantic.create_model( # type: ignore # this uses the v1 signature, not v2 + name_, + __config__=model_cfg, # type: ignore # this uses the v1 signature, not v2 + **model_fields, ) - return model.schema(by_alias=True) + return model.schema(by_alias=True) # type: ignore # this uses the v1 signature, not v2 -def parameter_schema(fn: Callable) -> ParameterSchema: +def parameter_schema(fn: Callable[..., Any]) -> ParameterSchema: """Given a function, generates an OpenAPI-compatible description of the function's arguments, including: - name @@ -378,7 +385,7 @@ def parameter_schema_from_entrypoint(entrypoint: str) -> ParameterSchema: def generate_parameter_schema( - signature: inspect.Signature, docstrings: Dict[str, str] + signature: inspect.Signature, docstrings: dict[str, str] ) -> ParameterSchema: """ Generate a parameter schema from a function signature and docstrings. @@ -394,22 +401,22 @@ def generate_parameter_schema( ParameterSchema: The parameter schema. """ - model_fields = {} - aliases = {} + model_fields: dict[str, Any] = {} + aliases: dict[str, str] = {} if not has_v1_type_as_param(signature): - create_schema = create_v2_schema + config = pydantic.ConfigDict(arbitrary_types_allowed=True) + + create_schema = partial(create_v2_schema, model_cfg=config) process_params = process_v2_params - config = pydantic.ConfigDict(arbitrary_types_allowed=True) else: - create_schema = create_v1_schema - process_params = process_v1_params class ModelConfig: arbitrary_types_allowed = True - config = ModelConfig + create_schema = partial(create_v1_schema, model_cfg=ModelConfig) + process_params = process_v1_params for position, param in enumerate(signature.parameters.values()): name, type_, field = process_params( @@ -418,24 +425,26 @@ class ModelConfig: # Generate a Pydantic model at each step so we can check if this parameter # type supports schema generation try: - create_schema("CheckParameter", model_cfg=config, **{name: (type_, field)}) + create_schema("CheckParameter", model_fields={name: (type_, field)}) except (ValueError, TypeError): # This field's type is not valid for schema creation, update it to `Any` type_ = Any model_fields[name] = (type_, field) # Generate the final model and schema - schema = create_schema("Parameters", model_cfg=config, **model_fields) + schema = create_schema("Parameters", model_fields=model_fields) return ParameterSchema(**schema) -def raise_for_reserved_arguments(fn: Callable, reserved_arguments: Iterable[str]): +def raise_for_reserved_arguments( + fn: Callable[..., Any], reserved_arguments: Iterable[str] +) -> None: """Raise a ReservedArgumentError if `fn` has any parameters that conflict with the names contained in `reserved_arguments`.""" - function_paremeters = inspect.signature(fn).parameters + function_parameters = inspect.signature(fn).parameters for argument in reserved_arguments: - if argument in function_paremeters: + if argument in function_parameters: raise ReservedArgumentError( f"{argument!r} is a reserved argument name and cannot be used." ) @@ -479,7 +488,7 @@ def _generate_signature_from_source( ) if func_def is None: raise ValueError(f"Function {func_name} not found in source code") - parameters = [] + parameters: list[inspect.Parameter] = [] # Handle annotations for positional only args e.g. def func(a, /, b, c) for arg in func_def.args.posonlyargs: @@ -642,8 +651,8 @@ def _get_docstring_from_source(source_code: str, func_name: str) -> Optional[str def expand_mapping_parameters( - func: Callable, parameters: Dict[str, Any] -) -> List[Dict[str, Any]]: + func: Callable[..., Any], parameters: dict[str, Any] +) -> list[dict[str, Any]]: """ Generates a list of call parameters to be used for individual calls in a mapping operation. @@ -653,29 +662,29 @@ def expand_mapping_parameters( parameters: A dictionary of parameters with iterables to be mapped over Returns: - List: A list of dictionaries to be used as parameters for each + list: A list of dictionaries to be used as parameters for each call in the mapping operation """ # Ensure that any parameters in kwargs are expanded before this check parameters = explode_variadic_parameter(func, parameters) - iterable_parameters = {} - static_parameters = {} - annotated_parameters = {} + iterable_parameters: dict[str, list[Any]] = {} + static_parameters: dict[str, Any] = {} + annotated_parameters: dict[str, Union[allow_failure[Any], quote[Any]]] = {} for key, val in parameters.items(): if isinstance(val, (allow_failure, quote)): # Unwrap annotated parameters to determine if they are iterable annotated_parameters[key] = val - val = val.unwrap() + val: Any = val.unwrap() if isinstance(val, unmapped): - static_parameters[key] = val.value + static_parameters[key] = cast(unmapped[Any], val).value elif isiterable(val): iterable_parameters[key] = list(val) else: static_parameters[key] = val - if not len(iterable_parameters): + if not iterable_parameters: raise MappingMissingIterable( "No iterable parameters were received. Parameters for map must " f"include at least one iterable. Parameters: {parameters}" @@ -693,7 +702,7 @@ def expand_mapping_parameters( map_length = list(lengths)[0] - call_parameters_list = [] + call_parameters_list: list[dict[str, Any]] = [] for i in range(map_length): call_parameters = {key: value[i] for key, value in iterable_parameters.items()} call_parameters.update({key: value for key, value in static_parameters.items()}) diff --git a/src/prefect/utilities/collections.py b/src/prefect/utilities/collections.py index 5b2ae453c67e..3588b6f48a73 100644 --- a/src/prefect/utilities/collections.py +++ b/src/prefect/utilities/collections.py @@ -6,33 +6,40 @@ import itertools import types import warnings -from collections import OrderedDict, defaultdict -from collections.abc import Iterator as IteratorABC -from collections.abc import Sequence -from dataclasses import fields, is_dataclass -from enum import Enum, auto -from typing import ( - Any, +from collections import OrderedDict +from collections.abc import ( Callable, - Dict, + Collection, Generator, Hashable, Iterable, - List, - Optional, + Iterator, + Sequence, Set, - Tuple, - Type, - TypeVar, +) +from dataclasses import fields, is_dataclass, replace +from enum import Enum, auto +from typing import ( + TYPE_CHECKING, + Any, + Literal, + Optional, Union, cast, + overload, ) from unittest.mock import Mock import pydantic +from typing_extensions import TypeAlias, TypeVar # Quote moved to `prefect.utilities.annotations` but preserved here for compatibility -from prefect.utilities.annotations import BaseAnnotation, Quote, quote # noqa +from prefect.utilities.annotations import BaseAnnotation as BaseAnnotation +from prefect.utilities.annotations import Quote as Quote +from prefect.utilities.annotations import quote as quote + +if TYPE_CHECKING: + pass class AutoEnum(str, Enum): @@ -55,11 +62,12 @@ class MyEnum(AutoEnum): ``` """ - def _generate_next_value_(name, start, count, last_values): + @staticmethod + def _generate_next_value_(name: str, *_: object, **__: object) -> str: return name @staticmethod - def auto(): + def auto() -> str: """ Exposes `enum.auto()` to avoid requiring a second import to use `AutoEnum` """ @@ -70,12 +78,15 @@ def __repr__(self) -> str: KT = TypeVar("KT") -VT = TypeVar("VT") +VT = TypeVar("VT", infer_variance=True) +VT1 = TypeVar("VT1", infer_variance=True) +VT2 = TypeVar("VT2", infer_variance=True) +R = TypeVar("R", infer_variance=True) +NestedDict: TypeAlias = dict[KT, Union[VT, "NestedDict[KT, VT]"]] +HashableT = TypeVar("HashableT", bound=Hashable) -def dict_to_flatdict( - dct: Dict[KT, Union[Any, Dict[KT, Any]]], _parent: Tuple[KT, ...] = None -) -> Dict[Tuple[KT, ...], Any]: +def dict_to_flatdict(dct: NestedDict[KT, VT]) -> dict[tuple[KT, ...], VT]: """Converts a (nested) dictionary to a flattened representation. Each key of the flat dict will be a CompoundKey tuple containing the "chain of keys" @@ -83,28 +94,28 @@ def dict_to_flatdict( Args: dct (dict): The dictionary to flatten - _parent (Tuple, optional): The current parent for recursion Returns: A flattened dict of the same type as dct """ - typ = cast(Type[Dict[Tuple[KT, ...], Any]], type(dct)) - items: List[Tuple[Tuple[KT, ...], Any]] = [] - parent = _parent or tuple() - - for k, v in dct.items(): - k_parent = tuple(parent + (k,)) - # if v is a non-empty dict, recurse - if isinstance(v, dict) and v: - items.extend(dict_to_flatdict(v, _parent=k_parent).items()) - else: - items.append((k_parent, v)) - return typ(items) + def flatten( + dct: NestedDict[KT, VT], _parent: tuple[KT, ...] = () + ) -> Iterator[tuple[tuple[KT, ...], VT]]: + parent = _parent or () + for k, v in dct.items(): + k_parent = (*parent, k) + # if v is a non-empty dict, recurse + if isinstance(v, dict) and v: + yield from flatten(cast(NestedDict[KT, VT], v), _parent=k_parent) + else: + yield (k_parent, cast(VT, v)) -def flatdict_to_dict( - dct: Dict[Tuple[KT, ...], VT], -) -> Dict[KT, Union[VT, Dict[KT, VT]]]: + type_ = cast(type[dict[tuple[KT, ...], VT]], type(dct)) + return type_(flatten(dct)) + + +def flatdict_to_dict(dct: dict[tuple[KT, ...], VT]) -> NestedDict[KT, VT]: """Converts a flattened dictionary back to a nested dictionary. Args: @@ -114,16 +125,26 @@ def flatdict_to_dict( Returns A nested dict of the same type as dct """ - typ = type(dct) - result = cast(Dict[KT, Union[VT, Dict[KT, VT]]], typ()) + + type_ = cast(type[NestedDict[KT, VT]], type(dct)) + + def new(type_: type[NestedDict[KT, VT]] = type_) -> NestedDict[KT, VT]: + return type_() + + result = new() for key_tuple, value in dct.items(): - current_dict = result - for prefix_key in key_tuple[:-1]: + current = result + *prefix_keys, last_key = key_tuple + for prefix_key in prefix_keys: # Build nested dictionaries up for the current key tuple - # Use `setdefault` in case the nested dict has already been created - current_dict = current_dict.setdefault(prefix_key, typ()) # type: ignore + try: + current = cast(NestedDict[KT, VT], current[prefix_key]) + except KeyError: + new_dict = current[prefix_key] = new() + current = new_dict + # Set the value - current_dict[key_tuple[-1]] = value + current[last_key] = value return result @@ -148,9 +169,9 @@ def isiterable(obj: Any) -> bool: return not isinstance(obj, (str, bytes, io.IOBase)) -def ensure_iterable(obj: Union[T, Iterable[T]]) -> Iterable[T]: +def ensure_iterable(obj: Union[T, Iterable[T]]) -> Collection[T]: if isinstance(obj, Sequence) or isinstance(obj, Set): - return obj + return cast(Collection[T], obj) obj = cast(T, obj) # No longer in the iterable case return [obj] @@ -160,9 +181,9 @@ def listrepr(objs: Iterable[Any], sep: str = " ") -> str: def extract_instances( - objects: Iterable, - types: Union[Type[T], Tuple[Type[T], ...]] = object, -) -> Union[List[T], Dict[Type[T], T]]: + objects: Iterable[Any], + types: Union[type[T], tuple[type[T], ...]] = object, +) -> Union[list[T], dict[type[T], list[T]]]: """ Extract objects from a file and returns a dict of type -> instances @@ -174,26 +195,27 @@ def extract_instances( If a single type is given: a list of instances of that type If a tuple of types is given: a mapping of type to a list of instances """ - types = ensure_iterable(types) + types_collection = ensure_iterable(types) # Create a mapping of type -> instance from the exec values - ret = defaultdict(list) + ret: dict[type[T], list[Any]] = {} for o in objects: # We iterate here so that the key is the passed type rather than type(o) - for type_ in types: + for type_ in types_collection: if isinstance(o, type_): - ret[type_].append(o) + ret.setdefault(type_, []).append(o) - if len(types) == 1: - return ret[types[0]] + if len(types_collection) == 1: + [type_] = types_collection + return ret[type_] return ret def batched_iterable( iterable: Iterable[T], size: int -) -> Generator[Tuple[T, ...], None, None]: +) -> Generator[tuple[T, ...], None, None]: """ Yield batches of a certain size from an iterable @@ -221,15 +243,86 @@ class StopVisiting(BaseException): """ +@overload +def visit_collection( + expr: Any, + visit_fn: Callable[[Any, dict[str, VT]], Any], + *, + return_data: Literal[True] = ..., + max_depth: int = ..., + context: dict[str, VT] = ..., + remove_annotations: bool = ..., + _seen: Optional[set[int]] = ..., +) -> Any: + ... + + +@overload +def visit_collection( + expr: Any, + visit_fn: Callable[[Any], Any], + *, + return_data: Literal[True] = ..., + max_depth: int = ..., + context: None = None, + remove_annotations: bool = ..., + _seen: Optional[set[int]] = ..., +) -> Any: + ... + + +@overload +def visit_collection( + expr: Any, + visit_fn: Callable[[Any, dict[str, VT]], Any], + *, + return_data: bool = ..., + max_depth: int = ..., + context: dict[str, VT] = ..., + remove_annotations: bool = ..., + _seen: Optional[set[int]] = ..., +) -> Optional[Any]: + ... + + +@overload +def visit_collection( + expr: Any, + visit_fn: Callable[[Any], Any], + *, + return_data: bool = ..., + max_depth: int = ..., + context: None = None, + remove_annotations: bool = ..., + _seen: Optional[set[int]] = ..., +) -> Optional[Any]: + ... + + +@overload +def visit_collection( + expr: Any, + visit_fn: Callable[[Any, dict[str, VT]], Any], + *, + return_data: Literal[False] = False, + max_depth: int = ..., + context: dict[str, VT] = ..., + remove_annotations: bool = ..., + _seen: Optional[set[int]] = ..., +) -> None: + ... + + def visit_collection( expr: Any, - visit_fn: Union[Callable[[Any, Optional[dict]], Any], Callable[[Any], Any]], + visit_fn: Union[Callable[[Any, dict[str, VT]], Any], Callable[[Any], Any]], + *, return_data: bool = False, max_depth: int = -1, - context: Optional[dict] = None, + context: Optional[dict[str, VT]] = None, remove_annotations: bool = False, - _seen: Optional[Set[int]] = None, -) -> Any: + _seen: Optional[set[int]] = None, +) -> Optional[Any]: """ Visits and potentially transforms every element of an arbitrary Python collection. @@ -289,24 +382,39 @@ def visit_collection( if _seen is None: _seen = set() - def visit_nested(expr): - # Utility for a recursive call, preserving options and updating the depth. - return visit_collection( - expr, - visit_fn=visit_fn, - return_data=return_data, - remove_annotations=remove_annotations, - max_depth=max_depth - 1, - # Copy the context on nested calls so it does not "propagate up" - context=context.copy() if context is not None else None, - _seen=_seen, - ) - - def visit_expression(expr): - if context is not None: - return visit_fn(expr, context) - else: - return visit_fn(expr) + if context is not None: + _callback = cast(Callable[[Any, dict[str, VT]], Any], visit_fn) + + def visit_nested(expr: Any) -> Optional[Any]: + return visit_collection( + expr, + _callback, + return_data=return_data, + remove_annotations=remove_annotations, + max_depth=max_depth - 1, + # Copy the context on nested calls so it does not "propagate up" + context=context.copy(), + _seen=_seen, + ) + + def visit_expression(expr: Any) -> Any: + return _callback(expr, context) + else: + _callback = cast(Callable[[Any], Any], visit_fn) + + def visit_nested(expr: Any) -> Optional[Any]: + # Utility for a recursive call, preserving options and updating the depth. + return visit_collection( + expr, + _callback, + return_data=return_data, + remove_annotations=remove_annotations, + max_depth=max_depth - 1, + _seen=_seen, + ) + + def visit_expression(expr: Any) -> Any: + return _callback(expr) # --- 1. Visit every expression try: @@ -329,10 +437,6 @@ def visit_expression(expr): else: _seen.add(id(expr)) - # Get the expression type; treat iterators like lists - typ = list if isinstance(expr, IteratorABC) and isiterable(expr) else type(expr) - typ = cast(type, typ) # mypy treats this as 'object' otherwise and complains - # Then visit every item in the expression if it is a collection # presume that the result is the original expression. @@ -354,9 +458,10 @@ def visit_expression(expr): # --- Annotations (unmapped, quote, etc.) elif isinstance(expr, BaseAnnotation): + annotated = cast(BaseAnnotation[Any], expr) if context is not None: - context["annotation"] = expr - unwrapped = expr.unwrap() + context["annotation"] = cast(VT, annotated) + unwrapped = annotated.unwrap() value = visit_nested(unwrapped) if return_data: @@ -365,47 +470,49 @@ def visit_expression(expr): result = value # if the value was modified, rewrap it elif value is not unwrapped: - result = expr.rewrap(value) + result = annotated.rewrap(value) # otherwise return the expr # --- Sequences elif isinstance(expr, (list, tuple, set)): - items = [visit_nested(o) for o in expr] + seq = cast(Union[list[Any], tuple[Any], set[Any]], expr) + items = [visit_nested(o) for o in seq] if return_data: - modified = any(item is not orig for item, orig in zip(items, expr)) + modified = any(item is not orig for item, orig in zip(items, seq)) if modified: - result = typ(items) + result = type(seq)(items) # --- Dictionaries - elif typ in (dict, OrderedDict): - assert isinstance(expr, (dict, OrderedDict)) # typecheck assertion - items = [(visit_nested(k), visit_nested(v)) for k, v in expr.items()] + elif isinstance(expr, (dict, OrderedDict)): + mapping = cast(dict[Any, Any], expr) + items = [(visit_nested(k), visit_nested(v)) for k, v in mapping.items()] if return_data: modified = any( k1 is not k2 or v1 is not v2 - for (k1, v1), (k2, v2) in zip(items, expr.items()) + for (k1, v1), (k2, v2) in zip(items, mapping.items()) ) if modified: - result = typ(items) + result = type(mapping)(items) # --- Dataclasses elif is_dataclass(expr) and not isinstance(expr, type): - values = [visit_nested(getattr(expr, f.name)) for f in fields(expr)] + expr_fields = fields(expr) + values = [visit_nested(getattr(expr, f.name)) for f in expr_fields] if return_data: modified = any( - getattr(expr, f.name) is not v for f, v in zip(fields(expr), values) + getattr(expr, f.name) is not v for f, v in zip(expr_fields, values) ) if modified: - result = typ(**{f.name: v for f, v in zip(fields(expr), values)}) + result = replace( + expr, **{f.name: v for f, v in zip(expr_fields, values)} + ) # --- Pydantic models elif isinstance(expr, pydantic.BaseModel): - typ = cast(Type[pydantic.BaseModel], typ) - # when extra=allow, fields not in model_fields may be in model_fields_set model_fields = expr.model_fields_set.union(expr.model_fields.keys()) @@ -424,7 +531,7 @@ def visit_expression(expr): ) if modified: # Use construct to avoid validation and handle immutability - model_instance = typ.model_construct( + model_instance = expr.model_construct( _fields_set=expr.model_fields_set, **updated_data ) for private_attr in expr.__private_attributes__: @@ -435,7 +542,21 @@ def visit_expression(expr): return result -def remove_nested_keys(keys_to_remove: List[Hashable], obj): +@overload +def remove_nested_keys( + keys_to_remove: list[HashableT], obj: NestedDict[HashableT, VT] +) -> NestedDict[HashableT, VT]: + ... + + +@overload +def remove_nested_keys(keys_to_remove: list[HashableT], obj: Any) -> Any: + ... + + +def remove_nested_keys( + keys_to_remove: list[HashableT], obj: Union[NestedDict[HashableT, VT], Any] +) -> Union[NestedDict[HashableT, VT], Any]: """ Recurses a dictionary returns a copy without all keys that match an entry in `key_to_remove`. Return `obj` unchanged if not a dictionary. @@ -452,24 +573,56 @@ def remove_nested_keys(keys_to_remove: List[Hashable], obj): return obj return { key: remove_nested_keys(keys_to_remove, value) - for key, value in obj.items() + for key, value in cast(NestedDict[HashableT, VT], obj).items() if key not in keys_to_remove } +@overload +def distinct(iterable: Iterable[HashableT], key: None = None) -> Iterator[HashableT]: + ... + + +@overload +def distinct(iterable: Iterable[T], key: Callable[[T], Hashable]) -> Iterator[T]: + ... + + def distinct( - iterable: Iterable[T], - key: Callable[[T], Any] = (lambda i: i), -) -> Generator[T, None, None]: - seen: Set = set() + iterable: Iterable[Union[T, HashableT]], + key: Optional[Callable[[T], Hashable]] = None, +) -> Iterator[Union[T, HashableT]]: + def _key(__i: Any) -> Hashable: + return __i + + if key is not None: + _key = cast(Callable[[Any], Hashable], key) + + seen: set[Hashable] = set() for item in iterable: - if key(item) in seen: + if _key(item) in seen: continue - seen.add(key(item)) + seen.add(_key(item)) yield item -def get_from_dict(dct: Dict, keys: Union[str, List[str]], default: Any = None) -> Any: +@overload +def get_from_dict( + dct: NestedDict[str, VT], keys: Union[str, list[str]], default: None = None +) -> Optional[VT]: + ... + + +@overload +def get_from_dict( + dct: NestedDict[str, VT], keys: Union[str, list[str]], default: R +) -> Union[VT, R]: + ... + + +def get_from_dict( + dct: NestedDict[str, VT], keys: Union[str, list[str]], default: Optional[R] = None +) -> Union[VT, R, None]: """ Fetch a value from a nested dictionary or list using a sequence of keys. @@ -500,6 +653,7 @@ def get_from_dict(dct: Dict, keys: Union[str, List[str]], default: Any = None) - """ if isinstance(keys, str): keys = keys.replace("[", ".").replace("]", "").split(".") + value = dct try: for key in keys: try: @@ -509,13 +663,15 @@ def get_from_dict(dct: Dict, keys: Union[str, List[str]], default: Any = None) - # If it's not an int, use the key as-is # for dict lookup pass - dct = dct[key] - return dct + value = value[key] # type: ignore + return cast(VT, value) except (TypeError, KeyError, IndexError): return default -def set_in_dict(dct: Dict, keys: Union[str, List[str]], value: Any): +def set_in_dict( + dct: NestedDict[str, VT], keys: Union[str, list[str]], value: VT +) -> None: """ Sets a value in a nested dictionary using a sequence of keys. @@ -543,11 +699,13 @@ def set_in_dict(dct: Dict, keys: Union[str, List[str]], value: Any): raise TypeError(f"Key path exists and contains a non-dict value: {keys}") if k not in dct: dct[k] = {} - dct = dct[k] + dct = cast(NestedDict[str, VT], dct[k]) dct[keys[-1]] = value -def deep_merge(dct: Dict, merge: Dict): +def deep_merge( + dct: NestedDict[str, VT1], merge: NestedDict[str, VT2] +) -> NestedDict[str, Union[VT1, VT2]]: """ Recursively merges `merge` into `dct`. @@ -558,18 +716,21 @@ def deep_merge(dct: Dict, merge: Dict): Returns: A new dictionary with the merged contents. """ - result = dct.copy() # Start with keys and values from `dct` + result: dict[str, Any] = dct.copy() # Start with keys and values from `dct` for key, value in merge.items(): if key in result and isinstance(result[key], dict) and isinstance(value, dict): # If both values are dictionaries, merge them recursively - result[key] = deep_merge(result[key], value) + result[key] = deep_merge( + cast(NestedDict[str, VT1], result[key]), + cast(NestedDict[str, VT2], value), + ) else: # Otherwise, overwrite with the new value - result[key] = value + result[key] = cast(Union[VT2, NestedDict[str, VT2]], value) return result -def deep_merge_dicts(*dicts): +def deep_merge_dicts(*dicts: NestedDict[str, Any]) -> NestedDict[str, Any]: """ Recursively merges multiple dictionaries. @@ -579,7 +740,7 @@ def deep_merge_dicts(*dicts): Returns: A new dictionary with the merged contents. """ - result = {} + result: NestedDict[str, Any] = {} for dictionary in dicts: result = deep_merge(result, dictionary) return result diff --git a/src/prefect/utilities/compat.py b/src/prefect/utilities/compat.py index 3eadafb3edf3..6bf8f34c46ca 100644 --- a/src/prefect/utilities/compat.py +++ b/src/prefect/utilities/compat.py @@ -3,29 +3,21 @@ """ # Please organize additions to this file by version -import asyncio import sys -from shutil import copytree -from signal import raise_signal if sys.version_info < (3, 10): - import importlib_metadata - from importlib_metadata import EntryPoint, EntryPoints, entry_points + import importlib_metadata as importlib_metadata + from importlib_metadata import ( + EntryPoint as EntryPoint, + EntryPoints as EntryPoints, + entry_points as entry_points, + ) else: - import importlib.metadata as importlib_metadata - from importlib.metadata import EntryPoint, EntryPoints, entry_points + import importlib.metadata + from importlib.metadata import ( + EntryPoint as EntryPoint, + EntryPoints as EntryPoints, + entry_points as entry_points, + ) -if sys.version_info < (3, 9): - # https://docs.python.org/3/library/asyncio-task.html#asyncio.to_thread - - import functools - - async def asyncio_to_thread(fn, *args, **kwargs): - loop = asyncio.get_running_loop() - return await loop.run_in_executor(None, functools.partial(fn, *args, **kwargs)) - -else: - from asyncio import to_thread as asyncio_to_thread - -if sys.platform != "win32": - from asyncio import ThreadedChildWatcher + importlib_metadata = importlib.metadata diff --git a/src/prefect/utilities/context.py b/src/prefect/utilities/context.py index 3bd87f975f40..c475f3eea453 100644 --- a/src/prefect/utilities/context.py +++ b/src/prefect/utilities/context.py @@ -1,6 +1,7 @@ +from collections.abc import Generator from contextlib import contextmanager from contextvars import Context, ContextVar, Token -from typing import TYPE_CHECKING, Dict, Optional, Tuple, cast +from typing import TYPE_CHECKING, Any, Optional, cast from uuid import UUID if TYPE_CHECKING: @@ -8,8 +9,8 @@ @contextmanager -def temporary_context(context: Context): - tokens: Dict[ContextVar, Token] = {} +def temporary_context(context: Context) -> Generator[None, Any, None]: + tokens: dict[ContextVar[Any], Token[Any]] = {} for key, value in context.items(): token = key.set(value) tokens[key] = token @@ -38,11 +39,11 @@ def get_flow_run_id() -> Optional[UUID]: return None -def get_task_and_flow_run_ids() -> Tuple[Optional[UUID], Optional[UUID]]: +def get_task_and_flow_run_ids() -> tuple[Optional[UUID], Optional[UUID]]: """ Get the task run and flow run ids from the context, if available. Returns: - Tuple[Optional[UUID], Optional[UUID]]: a tuple of the task run id and flow run id + tuple[Optional[UUID], Optional[UUID]]: a tuple of the task run id and flow run id """ return get_task_run_id(), get_flow_run_id() diff --git a/src/prefect/utilities/dispatch.py b/src/prefect/utilities/dispatch.py index 599a9afafda4..603c24aecbd1 100644 --- a/src/prefect/utilities/dispatch.py +++ b/src/prefect/utilities/dispatch.py @@ -23,28 +23,39 @@ class Foo(Base): import abc import inspect import warnings -from typing import Any, Dict, Optional, Type, TypeVar +from typing import Any, Literal, Optional, TypeVar, overload -T = TypeVar("T", bound=Type) +T = TypeVar("T", bound=type[Any]) -_TYPE_REGISTRIES: Dict[Type, Dict[str, Type]] = {} +_TYPE_REGISTRIES: dict[Any, dict[str, Any]] = {} -def get_registry_for_type(cls: T) -> Optional[Dict[str, T]]: +def get_registry_for_type(cls: T) -> Optional[dict[str, T]]: """ Get the first matching registry for a class or any of its base classes. If not found, `None` is returned. """ return next( - filter( - lambda registry: registry is not None, - (_TYPE_REGISTRIES.get(cls) for cls in cls.mro()), - ), + (reg for cls in cls.mro() if (reg := _TYPE_REGISTRIES.get(cls)) is not None), None, ) +@overload +def get_dispatch_key( + cls_or_instance: Any, allow_missing: Literal[False] = False +) -> str: + ... + + +@overload +def get_dispatch_key( + cls_or_instance: Any, allow_missing: Literal[True] = ... +) -> Optional[str]: + ... + + def get_dispatch_key( cls_or_instance: Any, allow_missing: bool = False ) -> Optional[str]: @@ -89,14 +100,14 @@ def get_dispatch_key( @classmethod -def _register_subclass_of_base_type(cls, **kwargs): +def _register_subclass_of_base_type(cls: type[Any], **kwargs: Any) -> None: if hasattr(cls, "__init_subclass_original__"): cls.__init_subclass_original__(**kwargs) elif hasattr(cls, "__pydantic_init_subclass_original__"): cls.__pydantic_init_subclass_original__(**kwargs) # Do not register abstract base classes - if abc.ABC in getattr(cls, "__bases__", []): + if abc.ABC in cls.__bases__: return register_type(cls) @@ -123,7 +134,7 @@ def register_base_type(cls: T) -> T: cls.__pydantic_init_subclass__ = _register_subclass_of_base_type else: cls.__init_subclass_original__ = getattr(cls, "__init_subclass__") - cls.__init_subclass__ = _register_subclass_of_base_type + setattr(cls, "__init_subclass__", _register_subclass_of_base_type) return cls @@ -190,7 +201,7 @@ def lookup_type(cls: T, dispatch_key: str) -> T: Look up a dispatch key in the type registry for the given class. """ # Get the first matching registry for the class or one of its bases - registry = get_registry_for_type(cls) + registry = get_registry_for_type(cls) or {} # Look up this type in the registry subcls = registry.get(dispatch_key) diff --git a/src/prefect/utilities/dockerutils.py b/src/prefect/utilities/dockerutils.py index 72575d81d3a0..eb6bd18b024b 100644 --- a/src/prefect/utilities/dockerutils.py +++ b/src/prefect/utilities/dockerutils.py @@ -3,22 +3,12 @@ import shutil import sys import warnings +from collections.abc import Generator, Iterable, Iterator from contextlib import contextmanager from pathlib import Path, PurePosixPath from tempfile import TemporaryDirectory from types import TracebackType -from typing import ( - TYPE_CHECKING, - Any, - Generator, - Iterable, - List, - Optional, - TextIO, - Tuple, - Type, - Union, -) +from typing import TYPE_CHECKING, Any, Optional, TextIO, Union, cast from urllib.parse import urlsplit import pendulum @@ -29,6 +19,12 @@ from prefect.utilities.importtools import lazy_import from prefect.utilities.slugify import slugify +if TYPE_CHECKING: + import docker + import docker.errors + from docker import DockerClient + from docker.models.images import Image + CONTAINER_LABELS = { "io.prefect.version": prefect.__version__, } @@ -102,10 +98,7 @@ def silence_docker_warnings() -> Generator[None, None, None]: # want to have those popping up in various modules and test suites. Instead, # consolidate the imports we need here, and expose them via this module. with silence_docker_warnings(): - if TYPE_CHECKING: - import docker - from docker import DockerClient - else: + if not TYPE_CHECKING: docker = lazy_import("docker") @@ -123,7 +116,8 @@ def docker_client() -> Generator["DockerClient", None, None]: "This error is often thrown because Docker is not running. Please ensure Docker is running." ) from exc finally: - client is not None and client.close() + if client is not None: + client.close() # type: ignore # typing stub is not complete class BuildError(Exception): @@ -207,14 +201,15 @@ class ImageBuilder: base_directory: Path context: Optional[Path] platform: Optional[str] - dockerfile_lines: List[str] + dockerfile_lines: list[str] + temporary_directory: Optional[TemporaryDirectory[str]] def __init__( self, base_image: str, - base_directory: Path = None, + base_directory: Optional[Path] = None, platform: Optional[str] = None, - context: Path = None, + context: Optional[Path] = None, ): """Create an ImageBuilder @@ -250,7 +245,7 @@ def __enter__(self) -> Self: return self def __exit__( - self, exc: Type[BaseException], value: BaseException, traceback: TracebackType + self, exc: type[BaseException], value: BaseException, traceback: TracebackType ) -> None: if not self.temporary_directory: return @@ -267,7 +262,9 @@ def add_lines(self, lines: Iterable[str]) -> None: """Add lines to this image's Dockerfile""" self.dockerfile_lines.extend(lines) - def copy(self, source: Union[str, Path], destination: Union[str, PurePosixPath]): + def copy( + self, source: Union[str, Path], destination: Union[str, PurePosixPath] + ) -> None: """Copy a file to this image""" if not self.context: raise Exception("No context available") @@ -291,7 +288,7 @@ def copy(self, source: Union[str, Path], destination: Union[str, PurePosixPath]) self.add_line(f"COPY {source} {destination}") - def write_text(self, text: str, destination: Union[str, PurePosixPath]): + def write_text(self, text: str, destination: Union[str, PurePosixPath]) -> None: if not self.context: raise Exception("No context available") @@ -315,6 +312,7 @@ def build( Returns: The image ID """ + assert self.context is not None dockerfile_path: Path = self.context / "Dockerfile" with dockerfile_path.open("w") as dockerfile: @@ -436,9 +434,12 @@ def push_image( repository = f"{registry}/{name}" with docker_client() as client: - image: "docker.Image" = client.images.get(image_id) - image.tag(repository, tag=tag) - events = client.api.push(repository, tag=tag, stream=True, decode=True) + image: "Image" = client.images.get(image_id) + image.tag(repository, tag=tag) # type: ignore # typing stub is not complete + events = cast( + Iterator[dict[str, Any]], + client.api.push(repository, tag=tag, stream=True, decode=True), # type: ignore # typing stub is not complete + ) try: for event in events: if "status" in event: @@ -452,12 +453,12 @@ def push_image( elif "error" in event: raise PushError(event["error"]) finally: - client.api.remove_image(f"{repository}:{tag}", noprune=True) + client.api.remove_image(f"{repository}:{tag}", noprune=True) # type: ignore # typing stub is not complete return f"{repository}:{tag}" -def to_run_command(command: List[str]) -> str: +def to_run_command(command: list[str]) -> str: """ Convert a process-style list of command arguments to a single Dockerfile RUN instruction. @@ -481,7 +482,7 @@ def to_run_command(command: List[str]) -> str: return run_command -def parse_image_tag(name: str) -> Tuple[str, Optional[str]]: +def parse_image_tag(name: str) -> tuple[str, Optional[str]]: """ Parse Docker Image String @@ -519,7 +520,7 @@ def parse_image_tag(name: str) -> Tuple[str, Optional[str]]: return image_name, tag -def split_repository_path(repository_path: str) -> Tuple[Optional[str], str]: +def split_repository_path(repository_path: str) -> tuple[Optional[str], str]: """ Splits a Docker repository path into its namespace and repository components. @@ -550,7 +551,7 @@ def split_repository_path(repository_path: str) -> Tuple[Optional[str], str]: return namespace, repository -def format_outlier_version_name(version: str): +def format_outlier_version_name(version: str) -> str: """ Formats outlier docker version names to pass `packaging.version.parse` validation - Current cases are simple, but creates stub for more complicated formatting if eventually needed. @@ -580,7 +581,7 @@ def generate_default_dockerfile(context: Optional[Path] = None): """ if not context: context = Path.cwd() - lines = [] + lines: list[str] = [] base_image = get_prefect_image_name() lines.append(f"FROM {base_image}") dir_name = context.name diff --git a/src/prefect/utilities/engine.py b/src/prefect/utilities/engine.py index 4cd3bba35a0c..7551c0d5091f 100644 --- a/src/prefect/utilities/engine.py +++ b/src/prefect/utilities/engine.py @@ -1,40 +1,39 @@ import asyncio import contextlib -import inspect import os import signal import time +from collections.abc import Awaitable, Callable, Generator from functools import partial +from logging import Logger from typing import ( TYPE_CHECKING, Any, - Callable, - Dict, - Iterable, + NoReturn, Optional, - Set, TypeVar, Union, + cast, ) -from uuid import UUID, uuid4 +from uuid import UUID import anyio -from typing_extensions import Literal +from typing_extensions import TypeIs import prefect import prefect.context +import prefect.exceptions import prefect.plugins from prefect._internal.concurrency.cancellation import get_deadline from prefect.client.schemas import OrchestrationResult, TaskRun -from prefect.client.schemas.objects import ( - StateType, - TaskRunInput, - TaskRunResult, -) -from prefect.client.schemas.responses import SetStateStatus -from prefect.context import ( - FlowRunContext, +from prefect.client.schemas.objects import TaskRunInput, TaskRunResult +from prefect.client.schemas.responses import ( + SetStateStatus, + StateAbortDetails, + StateRejectDetails, + StateWaitDetails, ) +from prefect.context import FlowRunContext from prefect.events import Event, emit_event from prefect.exceptions import ( Pause, @@ -44,37 +43,26 @@ ) from prefect.flows import Flow from prefect.futures import PrefectFuture -from prefect.logging.loggers import ( - get_logger, - task_run_logger, -) +from prefect.logging.loggers import get_logger from prefect.results import BaseResult, ResultRecord, should_persist_result -from prefect.settings import ( - PREFECT_LOGGING_LOG_PRINTS, -) -from prefect.states import ( - State, - get_state_exception, -) +from prefect.settings import PREFECT_LOGGING_LOG_PRINTS +from prefect.states import State from prefect.tasks import Task from prefect.utilities.annotations import allow_failure, quote -from prefect.utilities.asyncutils import ( - gather, - run_coro_as_sync, -) +from prefect.utilities.asyncutils import run_coro_as_sync from prefect.utilities.collections import StopVisiting, visit_collection from prefect.utilities.text import truncated_to if TYPE_CHECKING: from prefect.client.orchestration import PrefectClient, SyncPrefectClient -API_HEALTHCHECKS = {} -UNTRACKABLE_TYPES = {bool, type(None), type(...), type(NotImplemented)} -engine_logger = get_logger("engine") +API_HEALTHCHECKS: dict[str, float] = {} +UNTRACKABLE_TYPES: set[type[Any]] = {bool, type(None), type(...), type(NotImplemented)} +engine_logger: Logger = get_logger("engine") T = TypeVar("T") -async def collect_task_run_inputs(expr: Any, max_depth: int = -1) -> Set[TaskRunInput]: +async def collect_task_run_inputs(expr: Any, max_depth: int = -1) -> set[TaskRunInput]: """ This function recurses through an expression to generate a set of any discernible task run inputs it finds in the data structure. It produces a set of all inputs @@ -87,14 +75,11 @@ async def collect_task_run_inputs(expr: Any, max_depth: int = -1) -> Set[TaskRun """ # TODO: This function needs to be updated to detect parameters and constants - inputs = set() - futures = set() + inputs: set[TaskRunInput] = set() - def add_futures_and_states_to_inputs(obj): + def add_futures_and_states_to_inputs(obj: Any) -> None: if isinstance(obj, PrefectFuture): - # We need to wait for futures to be submitted before we can get the task - # run id but we want to do so asynchronously - futures.add(obj) + inputs.add(TaskRunResult(id=obj.task_run_id)) elif isinstance(obj, State): if obj.state_details.task_run_id: inputs.add(TaskRunResult(id=obj.state_details.task_run_id)) @@ -113,16 +98,12 @@ def add_futures_and_states_to_inputs(obj): max_depth=max_depth, ) - await asyncio.gather(*[future._wait_for_submission() for future in futures]) - for future in futures: - inputs.add(TaskRunResult(id=future.task_run.id)) - return inputs def collect_task_run_inputs_sync( expr: Any, future_cls: Any = PrefectFuture, max_depth: int = -1 -) -> Set[TaskRunInput]: +) -> set[TaskRunInput]: """ This function recurses through an expression to generate a set of any discernible task run inputs it finds in the data structure. It produces a set of all inputs @@ -135,9 +116,9 @@ def collect_task_run_inputs_sync( """ # TODO: This function needs to be updated to detect parameters and constants - inputs = set() + inputs: set[TaskRunInput] = set() - def add_futures_and_states_to_inputs(obj): + def add_futures_and_states_to_inputs(obj: Any) -> None: if isinstance(obj, future_cls) and hasattr(obj, "task_run_id"): inputs.add(TaskRunResult(id=obj.task_run_id)) elif isinstance(obj, State): @@ -161,58 +142,9 @@ def add_futures_and_states_to_inputs(obj): return inputs -async def wait_for_task_runs_and_report_crashes( - task_run_futures: Iterable[PrefectFuture], client: "PrefectClient" -) -> Literal[True]: - crash_exceptions = [] - - # Gather states concurrently first - states = await gather(*(future._wait for future in task_run_futures)) - - for future, state in zip(task_run_futures, states): - logger = task_run_logger(future.task_run) - - if not state.type == StateType.CRASHED: - continue - - # We use this utility instead of `state.result` for type checking - exception = await get_state_exception(state) - - task_run = await client.read_task_run(future.task_run.id) - if not task_run.state.is_crashed(): - logger.info(f"Crash detected! {state.message}") - logger.debug("Crash details:", exc_info=exception) - - # Update the state of the task run - result = await client.set_task_run_state( - task_run_id=future.task_run.id, state=state, force=True - ) - if result.status == SetStateStatus.ACCEPT: - engine_logger.debug( - f"Reported crashed task run {future.name!r} successfully." - ) - else: - engine_logger.warning( - f"Failed to report crashed task run {future.name!r}. " - f"Orchestrator did not accept state: {result!r}" - ) - else: - # Populate the state details on the local state - future._final_state.state_details = task_run.state.state_details - - crash_exceptions.append(exception) - - # Now that we've finished reporting crashed tasks, reraise any exit exceptions - for exception in crash_exceptions: - if isinstance(exception, (KeyboardInterrupt, SystemExit)): - raise exception - - return True - - @contextlib.contextmanager -def capture_sigterm(): - def cancel_flow_run(*args): +def capture_sigterm() -> Generator[None, Any, None]: + def cancel_flow_run(*args: object) -> NoReturn: raise TerminationSignal(signal=signal.SIGTERM) original_term_handler = None @@ -241,8 +173,8 @@ def cancel_flow_run(*args): async def resolve_inputs( - parameters: Dict[str, Any], return_data: bool = True, max_depth: int = -1 -) -> Dict[str, Any]: + parameters: dict[str, Any], return_data: bool = True, max_depth: int = -1 +) -> dict[str, Any]: """ Resolve any `Quote`, `PrefectFuture`, or `State` types nested in parameters into data. @@ -254,24 +186,26 @@ async def resolve_inputs( UpstreamTaskError: If any of the upstream states are not `COMPLETED` """ - futures = set() - states = set() - result_by_state = {} + futures: set[PrefectFuture[Any]] = set() + states: set[State[Any]] = set() + result_by_state: dict[State[Any], Any] = {} if not parameters: return {} - def collect_futures_and_states(expr, context): + def collect_futures_and_states(expr: Any, context: dict[str, Any]) -> Any: # Expressions inside quotes should not be traversed if isinstance(context.get("annotation"), quote): raise StopVisiting() if isinstance(expr, PrefectFuture): - futures.add(expr) + fut: PrefectFuture[Any] = expr + futures.add(fut) if isinstance(expr, State): - states.add(expr) + state: State[Any] = expr + states.add(state) - return expr + return cast(Any, expr) visit_collection( parameters, @@ -281,32 +215,27 @@ def collect_futures_and_states(expr, context): context={}, ) - # Wait for all futures so we do not block when we retrieve the state in `resolve_input` - states.update(await asyncio.gather(*[future._wait() for future in futures])) - # Only retrieve the result if requested as it may be expensive if return_data: finished_states = [state for state in states if state.is_final()] - state_results = await asyncio.gather( - *[ - state.result(raise_on_failure=False, fetch=True) - for state in finished_states - ] - ) + state_results = [ + state.result(raise_on_failure=False, fetch=True) + for state in finished_states + ] for state, result in zip(finished_states, state_results): result_by_state[state] = result - def resolve_input(expr, context): - state = None + def resolve_input(expr: Any, context: dict[str, Any]) -> Any: + state: Optional[State[Any]] = None # Expressions inside quotes should not be modified if isinstance(context.get("annotation"), quote): raise StopVisiting() if isinstance(expr, PrefectFuture): - state = expr._final_state + state = expr.state elif isinstance(expr, State): state = expr else: @@ -329,7 +258,7 @@ def resolve_input(expr, context): return result_by_state.get(state) - resolved_parameters = {} + resolved_parameters: dict[str, Any] = {} for parameter, value in parameters.items(): try: resolved_parameters[parameter] = visit_collection( @@ -353,13 +282,21 @@ def resolve_input(expr, context): return resolved_parameters +def _is_base_result(data: Any) -> TypeIs[BaseResult[Any]]: + return isinstance(data, BaseResult) + + +def _is_result_record(data: Any) -> TypeIs[ResultRecord[Any]]: + return isinstance(data, ResultRecord) + + async def propose_state( client: "PrefectClient", - state: State[object], + state: State[Any], force: bool = False, task_run_id: Optional[UUID] = None, flow_run_id: Optional[UUID] = None, -) -> State[object]: +) -> State[Any]: """ Propose a new state for a flow run or task run, invoking Prefect orchestration logic. @@ -396,11 +333,12 @@ async def propose_state( # Handle task and sub-flow tracing if state.is_final(): - if isinstance(state.data, BaseResult) and state.data.has_cached_object(): + result: Any + if _is_base_result(state.data) and state.data.has_cached_object(): # Avoid fetching the result unless it is cached, otherwise we defeat # the purpose of disabling `cache_result_in_memory` - result = await state.result(raise_on_failure=False, fetch=True) - elif isinstance(state.data, ResultRecord): + result = state.result(raise_on_failure=False, fetch=True) + elif _is_result_record(state.data): result = state.data.result else: result = state.data @@ -409,9 +347,13 @@ async def propose_state( # Handle repeated WAITs in a loop instead of recursively, to avoid # reaching max recursion depth in extreme cases. - async def set_state_and_handle_waits(set_state_func) -> OrchestrationResult: + async def set_state_and_handle_waits( + set_state_func: Callable[[], Awaitable[OrchestrationResult[Any]]], + ) -> OrchestrationResult[Any]: response = await set_state_func() while response.status == SetStateStatus.WAIT: + if TYPE_CHECKING: + assert isinstance(response.details, StateWaitDetails) engine_logger.debug( f"Received wait instruction for {response.details.delay_seconds}s: " f"{response.details.reason}" @@ -436,6 +378,8 @@ async def set_state_and_handle_waits(set_state_func) -> OrchestrationResult: # Parse the response to return the new state if response.status == SetStateStatus.ACCEPT: # Update the state with the details if provided + if TYPE_CHECKING: + assert response.state is not None state.id = response.state.id state.timestamp = response.state.timestamp if response.state.state_details: @@ -443,9 +387,16 @@ async def set_state_and_handle_waits(set_state_func) -> OrchestrationResult: return state elif response.status == SetStateStatus.ABORT: + if TYPE_CHECKING: + assert isinstance(response.details, StateAbortDetails) + raise prefect.exceptions.Abort(response.details.reason) elif response.status == SetStateStatus.REJECT: + if TYPE_CHECKING: + assert response.state is not None + assert isinstance(response.details, StateRejectDetails) + if response.state.is_paused(): raise Pause(response.details.reason, state=response.state) return response.state @@ -458,11 +409,11 @@ async def set_state_and_handle_waits(set_state_func) -> OrchestrationResult: def propose_state_sync( client: "SyncPrefectClient", - state: State[object], + state: State[Any], force: bool = False, task_run_id: Optional[UUID] = None, flow_run_id: Optional[UUID] = None, -) -> State[object]: +) -> State[Any]: """ Propose a new state for a flow run or task run, invoking Prefect orchestration logic. @@ -499,13 +450,13 @@ def propose_state_sync( # Handle task and sub-flow tracing if state.is_final(): - if isinstance(state.data, BaseResult) and state.data.has_cached_object(): + if _is_base_result(state.data) and state.data.has_cached_object(): # Avoid fetching the result unless it is cached, otherwise we defeat # the purpose of disabling `cache_result_in_memory` result = state.result(raise_on_failure=False, fetch=True) if asyncio.iscoroutine(result): result = run_coro_as_sync(result) - elif isinstance(state.data, ResultRecord): + elif _is_result_record(state.data): result = state.data.result else: result = state.data @@ -514,9 +465,13 @@ def propose_state_sync( # Handle repeated WAITs in a loop instead of recursively, to avoid # reaching max recursion depth in extreme cases. - def set_state_and_handle_waits(set_state_func) -> OrchestrationResult: + def set_state_and_handle_waits( + set_state_func: Callable[[], OrchestrationResult[Any]], + ) -> OrchestrationResult[Any]: response = set_state_func() while response.status == SetStateStatus.WAIT: + if TYPE_CHECKING: + assert isinstance(response.details, StateWaitDetails) engine_logger.debug( f"Received wait instruction for {response.details.delay_seconds}s: " f"{response.details.reason}" @@ -540,6 +495,8 @@ def set_state_and_handle_waits(set_state_func) -> OrchestrationResult: # Parse the response to return the new state if response.status == SetStateStatus.ACCEPT: + if TYPE_CHECKING: + assert response.state is not None # Update the state with the details if provided state.id = response.state.id state.timestamp = response.state.timestamp @@ -548,9 +505,14 @@ def set_state_and_handle_waits(set_state_func) -> OrchestrationResult: return state elif response.status == SetStateStatus.ABORT: + if TYPE_CHECKING: + assert isinstance(response.details, StateAbortDetails) raise prefect.exceptions.Abort(response.details.reason) elif response.status == SetStateStatus.REJECT: + if TYPE_CHECKING: + assert response.state is not None + assert isinstance(response.details, StateRejectDetails) if response.state.is_paused(): raise Pause(response.details.reason, state=response.state) return response.state @@ -561,26 +523,6 @@ def set_state_and_handle_waits(set_state_func) -> OrchestrationResult: ) -def _dynamic_key_for_task_run( - context: FlowRunContext, task: Task, stable: bool = True -) -> Union[int, str]: - if ( - stable is False or context.detached - ): # this task is running on remote infrastructure - return str(uuid4()) - elif context.flow_run is None: # this is an autonomous task run - context.task_run_dynamic_keys[task.task_key] = getattr( - task, "dynamic_key", str(uuid4()) - ) - - elif task.task_key not in context.task_run_dynamic_keys: - context.task_run_dynamic_keys[task.task_key] = 0 - else: - context.task_run_dynamic_keys[task.task_key] += 1 - - return context.task_run_dynamic_keys[task.task_key] - - def get_state_for_result(obj: Any) -> Optional[State]: """ Get the state related to a result object. @@ -631,28 +573,29 @@ def link_state_to_result(state: State, result: Any) -> None: # Holding large user objects in memory can cause memory bloat linked_state = state.model_copy(update={"data": None}) - def link_if_trackable(obj: Any) -> None: - """Track connection between a task run result and its associated state if it has a unique ID. + if flow_run_context: - We cannot track booleans, Ellipsis, None, NotImplemented, or the integers from -5 to 256 - because they are singletons. + def link_if_trackable(obj: Any) -> None: + """Track connection between a task run result and its associated state if it has a unique ID. - This function will mutate the State if the object is an untrackable type by setting the value - for `State.state_details.untrackable_result` to `True`. + We cannot track booleans, Ellipsis, None, NotImplemented, or the integers from -5 to 256 + because they are singletons. - """ - if (type(obj) in UNTRACKABLE_TYPES) or ( - isinstance(obj, int) and (-5 <= obj <= 256) - ): - state.state_details.untrackable_result = True - return - flow_run_context.task_run_results[id(obj)] = linked_state + This function will mutate the State if the object is an untrackable type by setting the value + for `State.state_details.untrackable_result` to `True`. + + """ + if (type(obj) in UNTRACKABLE_TYPES) or ( + isinstance(obj, int) and (-5 <= obj <= 256) + ): + state.state_details.untrackable_result = True + return + flow_run_context.task_run_results[id(obj)] = linked_state - if flow_run_context: visit_collection(expr=result, visit_fn=link_if_trackable, max_depth=1) -def should_log_prints(flow_or_task: Union[Flow, Task]) -> bool: +def should_log_prints(flow_or_task: Union["Flow[..., Any]", "Task[..., Any]"]) -> bool: flow_run_context = FlowRunContext.get() if flow_or_task.log_prints is None: @@ -664,63 +607,7 @@ def should_log_prints(flow_or_task: Union[Flow, Task]) -> bool: return flow_or_task.log_prints -def _resolve_custom_flow_run_name(flow: Flow, parameters: Dict[str, Any]) -> str: - if callable(flow.flow_run_name): - flow_run_name = flow.flow_run_name() - if not isinstance(flow_run_name, str): - raise TypeError( - f"Callable {flow.flow_run_name} for 'flow_run_name' returned type" - f" {type(flow_run_name).__name__} but a string is required." - ) - elif isinstance(flow.flow_run_name, str): - flow_run_name = flow.flow_run_name.format(**parameters) - else: - raise TypeError( - "Expected string or callable for 'flow_run_name'; got" - f" {type(flow.flow_run_name).__name__} instead." - ) - - return flow_run_name - - -def _resolve_custom_task_run_name(task: Task, parameters: Dict[str, Any]) -> str: - if callable(task.task_run_name): - sig = inspect.signature(task.task_run_name) - - # If the callable accepts a 'parameters' kwarg, pass the entire parameters dict - if "parameters" in sig.parameters: - task_run_name = task.task_run_name(parameters=parameters) - else: - # If it doesn't expect parameters, call it without arguments - task_run_name = task.task_run_name() - - if not isinstance(task_run_name, str): - raise TypeError( - f"Callable {task.task_run_name} for 'task_run_name' returned type" - f" {type(task_run_name).__name__} but a string is required." - ) - elif isinstance(task.task_run_name, str): - task_run_name = task.task_run_name.format(**parameters) - else: - raise TypeError( - "Expected string or callable for 'task_run_name'; got" - f" {type(task.task_run_name).__name__} instead." - ) - - return task_run_name - - -def _get_hook_name(hook: Callable[..., Any]) -> str: # pyright: ignore[reportUnusedFunction] - return ( - hook.__name__ - if hasattr(hook, "__name__") - else ( - hook.func.__name__ if isinstance(hook, partial) else hook.__class__.__name__ - ) - ) - - -async def check_api_reachable(client: "PrefectClient", fail_message: str): +async def check_api_reachable(client: "PrefectClient", fail_message: str) -> None: # Do not perform a healthcheck if it exists and is not expired api_url = str(client.api_url) if api_url in API_HEALTHCHECKS: @@ -740,15 +627,15 @@ async def check_api_reachable(client: "PrefectClient", fail_message: str): def emit_task_run_state_change_event( task_run: TaskRun, - initial_state: Optional[State], - validated_state: State, + initial_state: Optional[State[Any]], + validated_state: State[Any], follows: Optional[Event] = None, -) -> Event: +) -> Optional[Event]: state_message_truncation_length = 100_000 - if isinstance(validated_state.data, ResultRecord) and should_persist_result(): + if _is_result_record(validated_state.data) and should_persist_result(): data = validated_state.data.metadata.model_dump(mode="json") - elif isinstance(validated_state.data, BaseResult): + elif _is_base_result(validated_state.data): data = validated_state.data.model_dump(mode="json") else: data = None @@ -830,20 +717,20 @@ def emit_task_run_state_change_event( ) -def resolve_to_final_result(expr, context): +def resolve_to_final_result(expr: Any, context: dict[str, Any]) -> Any: """ Resolve any `PrefectFuture`, or `State` types nested in parameters into data. Designed to be use with `visit_collection`. """ - state = None + state: Optional[State[Any]] = None # Expressions inside quotes should not be modified if isinstance(context.get("annotation"), quote): raise StopVisiting() if isinstance(expr, PrefectFuture): - upstream_task_run = context.get("current_task_run") - upstream_task = context.get("current_task") + upstream_task_run: Optional[TaskRun] = context.get("current_task_run") + upstream_task: Optional["Task[..., Any]"] = context.get("current_task") if ( upstream_task and upstream_task_run @@ -877,15 +764,15 @@ def resolve_to_final_result(expr, context): " 'COMPLETED' state." ) - _result = state.result(raise_on_failure=False, fetch=True) - if asyncio.iscoroutine(_result): - _result = run_coro_as_sync(_result) - return _result + result = state.result(raise_on_failure=False, fetch=True) + if asyncio.iscoroutine(result): + result = run_coro_as_sync(result) + return result def resolve_inputs_sync( - parameters: Dict[str, Any], return_data: bool = True, max_depth: int = -1 -) -> Dict[str, Any]: + parameters: dict[str, Any], return_data: bool = True, max_depth: int = -1 +) -> dict[str, Any]: """ Resolve any `Quote`, `PrefectFuture`, or `State` types nested in parameters into data. @@ -900,7 +787,7 @@ def resolve_inputs_sync( if not parameters: return {} - resolved_parameters = {} + resolved_parameters: dict[str, Any] = {} for parameter, value in parameters.items(): try: resolved_parameters[parameter] = visit_collection( diff --git a/src/prefect/utilities/filesystem.py b/src/prefect/utilities/filesystem.py index 68a87a410039..1301a1178c91 100644 --- a/src/prefect/utilities/filesystem.py +++ b/src/prefect/utilities/filesystem.py @@ -5,14 +5,16 @@ import os import pathlib import threading +from collections.abc import Iterable from contextlib import contextmanager from pathlib import Path, PureWindowsPath -from typing import Optional, Union, cast +from typing import AnyStr, Optional, Union, cast -import fsspec +# fsspec has no stubs, see https://github.com/fsspec/filesystem_spec/issues/625 +import fsspec # type: ignore import pathspec -from fsspec.core import OpenFile -from fsspec.implementations.local import LocalFileSystem +from fsspec.core import OpenFile # type: ignore +from fsspec.implementations.local import LocalFileSystem # type: ignore import prefect @@ -33,8 +35,10 @@ def create_default_ignore_file(path: str) -> bool: def filter_files( - root: str = ".", ignore_patterns: Optional[list] = None, include_dirs: bool = True -) -> set: + root: str = ".", + ignore_patterns: Optional[Iterable[AnyStr]] = None, + include_dirs: bool = True, +) -> set[str]: """ This function accepts a root directory path and a list of file patterns to ignore, and returns a list of files that excludes those that should be ignored. @@ -51,7 +55,7 @@ def filter_files( return included_files -chdir_lock = threading.Lock() +chdir_lock: threading.Lock = threading.Lock() def _normalize_path(path: Union[str, Path]) -> str: @@ -103,33 +107,32 @@ def tmpchdir(path: str): def filename(path: str) -> str: """Extract the file name from a path with remote file system support""" try: - of: OpenFile = cast(OpenFile, fsspec.open(path)) - sep = of.fs.sep + of: OpenFile = cast(OpenFile, fsspec.open(path)) # type: ignore # no typing stubs available + sep = cast(str, of.fs.sep) # type: ignore # no typing stubs available except (ImportError, AttributeError): sep = "\\" if "\\" in path else "/" return path.split(sep)[-1] -def is_local_path(path: Union[str, pathlib.Path, OpenFile]): +def is_local_path(path: Union[str, pathlib.Path, OpenFile]) -> bool: """Check if the given path points to a local or remote file system""" if isinstance(path, str): try: - of = fsspec.open(path) + of = cast(OpenFile, fsspec.open(path)) # type: ignore # no typing stubs available except ImportError: # The path is a remote file system that uses a lib that is not installed return False elif isinstance(path, pathlib.Path): return True - elif isinstance(path, OpenFile): - of = path else: - raise TypeError(f"Invalid path of type {type(path).__name__!r}") + of = path return isinstance(of.fs, LocalFileSystem) def to_display_path( - path: Union[pathlib.Path, str], relative_to: Union[pathlib.Path, str] = None + path: Union[pathlib.Path, str], + relative_to: Optional[Union[pathlib.Path, str]] = None, ) -> str: """ Convert a path to a displayable path. The absolute path or relative path to the diff --git a/src/prefect/utilities/hashing.py b/src/prefect/utilities/hashing.py index b31a60609164..f131e4898314 100644 --- a/src/prefect/utilities/hashing.py +++ b/src/prefect/utilities/hashing.py @@ -1,18 +1,14 @@ import hashlib -import sys from functools import partial from pathlib import Path from typing import Any, Callable, Optional, Union -import cloudpickle +import cloudpickle # type: ignore # no stubs available from prefect.exceptions import HashError from prefect.serializers import JSONSerializer -if sys.version_info[:2] >= (3, 9): - _md5 = partial(hashlib.md5, usedforsecurity=False) -else: - _md5 = hashlib.md5 +_md5 = partial(hashlib.md5, usedforsecurity=False) def stable_hash(*args: Union[str, bytes], hash_algo: Callable[..., Any] = _md5) -> str: diff --git a/src/prefect/utilities/importtools.py b/src/prefect/utilities/importtools.py index 7cbfce9d51de..d22deee762cb 100644 --- a/src/prefect/utilities/importtools.py +++ b/src/prefect/utilities/importtools.py @@ -5,20 +5,23 @@ import runpy import sys import warnings +from collections.abc import Iterable, Sequence from importlib.abc import Loader, MetaPathFinder from importlib.machinery import ModuleSpec +from io import TextIOWrapper +from logging import Logger from pathlib import Path from tempfile import NamedTemporaryFile from types import ModuleType -from typing import Any, Callable, Dict, Iterable, NamedTuple, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, NamedTuple, Optional, Union -import fsspec +import fsspec # type: ignore # no typing stubs available from prefect.exceptions import ScriptError from prefect.logging.loggers import get_logger from prefect.utilities.filesystem import filename, is_local_path, tmpchdir -logger = get_logger(__name__) +logger: Logger = get_logger(__name__) def to_qualified_name(obj: Any) -> str: @@ -70,7 +73,9 @@ def from_qualified_name(name: str) -> Any: return getattr(module, attr_name) -def objects_from_script(path: str, text: Union[str, bytes] = None) -> Dict[str, Any]: +def objects_from_script( + path: str, text: Optional[Union[str, bytes]] = None +) -> dict[str, Any]: """ Run a python script and return all the global variables @@ -97,7 +102,7 @@ def objects_from_script(path: str, text: Union[str, bytes] = None) -> Dict[str, ScriptError: if the script raises an exception during execution """ - def run_script(run_path: str): + def run_script(run_path: str) -> dict[str, Any]: # Cast to an absolute path before changing directories to ensure relative paths # are not broken abs_run_path = os.path.abspath(run_path) @@ -120,7 +125,9 @@ def run_script(run_path: str): else: if not is_local_path(path): # Remote paths need to be local to run - with fsspec.open(path) as f: + with fsspec.open(path) as f: # type: ignore # no typing stubs available + if TYPE_CHECKING: + assert isinstance(f, TextIOWrapper) contents = f.read() return objects_from_script(path, contents) else: @@ -156,6 +163,10 @@ def load_script_as_module(path: str) -> ModuleType: # Support explicit relative imports i.e. `from .foo import bar` submodule_search_locations=[parent_path, working_directory], ) + if TYPE_CHECKING: + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) sys.modules["__prefect_loader__"] = module @@ -189,7 +200,7 @@ def load_module(module_name: str) -> ModuleType: sys.path.remove(working_directory) -def import_object(import_path: str): +def import_object(import_path: str) -> Any: """ Load an object from an import path. @@ -228,22 +239,20 @@ class DelayedImportErrorModule(ModuleType): [1]: https://github.com/scientific-python/lazy_loader """ - def __init__(self, error_message, help_message, *args, **kwargs): + def __init__(self, error_message: str, help_message: Optional[str] = None) -> None: self.__error_message = error_message - self.__help_message = ( - help_message or "Import errors for this module are only reported when used." - ) - super().__init__(*args, **kwargs) + if not help_message: + help_message = "Import errors for this module are only reported when used." + super().__init__("DelayedImportErrorModule", help_message) - def __getattr__(self, attr): - if attr in ("__class__", "__file__", "__help_message"): - super().__getattr__(attr) - else: - raise ModuleNotFoundError(self.__error_message) + def __getattr__(self, attr: str) -> Any: + if attr == "__file__": # not set but should result in an attribute error? + return super().__getattr__(attr) + raise ModuleNotFoundError(self.__error_message) def lazy_import( - name: str, error_on_import: bool = False, help_message: str = "" + name: str, error_on_import: bool = False, help_message: Optional[str] = None ) -> ModuleType: """ Create a lazily-imported module to use in place of the module of the given name. @@ -282,13 +291,13 @@ def lazy_import( if error_on_import: raise ModuleNotFoundError(import_error_message) - return DelayedImportErrorModule( - import_error_message, help_message, "DelayedImportErrorModule" - ) + return DelayedImportErrorModule(import_error_message, help_message) module = importlib.util.module_from_spec(spec) sys.modules[name] = module + if TYPE_CHECKING: + assert spec.loader is not None loader = importlib.util.LazyLoader(spec.loader) loader.exec_module(module) @@ -317,13 +326,13 @@ def __init__(self, aliases: Iterable[AliasedModuleDefinition]): Aliases apply to all modules nested within an alias. """ - self.aliases = aliases + self.aliases: list[AliasedModuleDefinition] = list(aliases) def find_spec( self, fullname: str, - path=None, - target=None, + path: Optional[Sequence[str]] = None, + target: Optional[ModuleType] = None, ) -> Optional[ModuleSpec]: """ The fullname is the imported path, e.g. "foo.bar". If there is an alias "phi" @@ -334,6 +343,7 @@ def find_spec( if fullname.startswith(alias): # Retrieve the spec of the real module real_spec = importlib.util.find_spec(fullname.replace(alias, real, 1)) + assert real_spec is not None # Create a new spec for the alias return ModuleSpec( fullname, @@ -354,7 +364,7 @@ def __init__( self.callback = callback self.real_spec = real_spec - def exec_module(self, _: ModuleType) -> None: + def exec_module(self, module: ModuleType) -> None: root_module = importlib.import_module(self.real_spec.name) if self.callback is not None: self.callback(self.alias) @@ -363,7 +373,7 @@ def exec_module(self, _: ModuleType) -> None: def safe_load_namespace( source_code: str, filepath: Optional[str] = None -) -> Dict[str, Any]: +) -> dict[str, Any]: """ Safely load a namespace from source code, optionally handling relative imports. @@ -380,7 +390,7 @@ def safe_load_namespace( """ parsed_code = ast.parse(source_code) - namespace: Dict[str, Any] = {"__name__": "prefect_safe_namespace_loader"} + namespace: dict[str, Any] = {"__name__": "prefect_safe_namespace_loader"} # Remove the body of the if __name__ == "__main__": block new_body = [node for node in parsed_code.body if not _is_main_block(node)] @@ -427,6 +437,9 @@ def safe_load_namespace( try: if node.level > 0: # For relative imports, use the parent package to inform the import + if TYPE_CHECKING: + assert temp_module is not None + assert temp_module.__package__ is not None package_parts = temp_module.__package__.split(".") if len(package_parts) < node.level: raise ImportError( diff --git a/src/prefect/utilities/names.py b/src/prefect/utilities/names.py index 6be2b93ab707..baeb2b1b6475 100644 --- a/src/prefect/utilities/names.py +++ b/src/prefect/utilities/names.py @@ -1,6 +1,6 @@ from typing import Any -import coolname +import coolname # type: ignore # the version after coolname 2.2.0 should have stubs. OBFUSCATED_PREFIX = "****" @@ -42,7 +42,7 @@ def generate_slug(n_words: int) -> str: return "-".join(words) -def obfuscate(s: Any, show_tail=False) -> str: +def obfuscate(s: Any, show_tail: bool = False) -> str: """ Obfuscates any data type's string representation. See `obfuscate_string`. """ @@ -52,7 +52,7 @@ def obfuscate(s: Any, show_tail=False) -> str: return obfuscate_string(str(s), show_tail=show_tail) -def obfuscate_string(s: str, show_tail=False) -> str: +def obfuscate_string(s: str, show_tail: bool = False) -> str: """ Obfuscates a string by returning a new string of 8 characters. If the input string is longer than 10 characters and show_tail is True, then up to 4 of diff --git a/src/prefect/utilities/processutils.py b/src/prefect/utilities/processutils.py index aeb1a37e83ed..7951fac4cae5 100644 --- a/src/prefect/utilities/processutils.py +++ b/src/prefect/utilities/processutils.py @@ -4,28 +4,35 @@ import subprocess import sys import threading +from collections.abc import AsyncGenerator, Mapping from contextlib import asynccontextmanager from dataclasses import dataclass from functools import partial +from types import FrameType from typing import ( IO, + TYPE_CHECKING, Any, + AnyStr, Callable, - List, - Mapping, Optional, - Sequence, TextIO, - Tuple, Union, + cast, + overload, ) import anyio import anyio.abc from anyio.streams.text import TextReceiveStream, TextSendStream +from typing_extensions import TypeAlias, TypeVar -TextSink = Union[anyio.AsyncFile, TextIO, TextSendStream] +if TYPE_CHECKING: + from _typeshed import StrOrBytesPath +TextSink: TypeAlias = Union[anyio.AsyncFile[AnyStr], TextIO, TextSendStream] +PrintFn: TypeAlias = Callable[[str], object] +T = TypeVar("T", infer_variance=True) if sys.platform == "win32": from ctypes import WINFUNCTYPE, c_int, c_uint, windll @@ -33,7 +40,7 @@ _windows_process_group_pids = set() @WINFUNCTYPE(c_int, c_uint) - def _win32_ctrl_handler(dwCtrlType): + def _win32_ctrl_handler(dwCtrlType: object) -> int: """ A callback function for handling CTRL events cleanly on Windows. When called, this function will terminate all running win32 subprocesses the current @@ -125,16 +132,16 @@ def stderr(self) -> Union[anyio.abc.ByteReceiveStream, None]: return self._stderr async def _open_anyio_process( - command: Union[str, bytes, Sequence[Union[str, bytes]]], + command: Union[str, bytes, list["StrOrBytesPath"]], *, stdin: Union[int, IO[Any], None] = None, stdout: Union[int, IO[Any], None] = None, stderr: Union[int, IO[Any], None] = None, - cwd: Union[str, bytes, os.PathLike, None] = None, - env: Union[Mapping[str, str], None] = None, + cwd: Optional["StrOrBytesPath"] = None, + env: Optional[Mapping[str, str]] = None, start_new_session: bool = False, - **kwargs, - ): + **kwargs: Any, + ) -> Process: """ Open a subprocess and return a `Process` object. @@ -179,7 +186,9 @@ async def _open_anyio_process( @asynccontextmanager -async def open_process(command: List[str], **kwargs): +async def open_process( + command: list[str], **kwargs: Any +) -> AsyncGenerator[anyio.abc.Process, Any]: """ Like `anyio.open_process` but with: - Support for Windows command joining @@ -189,11 +198,12 @@ async def open_process(command: List[str], **kwargs): # Passing a string to open_process is equivalent to shell=True which is # generally necessary for Unix-like commands on Windows but otherwise should # be avoided - if not isinstance(command, list): - raise TypeError( - "The command passed to open process must be a list. You passed the command" - f"'{command}', which is type '{type(command)}'." - ) + if not TYPE_CHECKING: + if not isinstance(command, list): + raise TypeError( + "The command passed to open process must be a list. You passed the command" + f"'{command}', which is type '{type(command)}'." + ) if sys.platform == "win32": command = " ".join(command) @@ -222,7 +232,7 @@ async def open_process(command: List[str], **kwargs): finally: try: process.terminate() - if win32_process_group: + if sys.platform == "win32" and win32_process_group: _windows_process_group_pids.remove(process.pid) except OSError: @@ -236,13 +246,58 @@ async def open_process(command: List[str], **kwargs): await process.aclose() +@overload +async def run_process( + command: list[str], + *, + stream_output: Union[ + bool, tuple[Optional[TextSink[str]], Optional[TextSink[str]]] + ] = ..., + task_status: anyio.abc.TaskStatus[T] = ..., + task_status_handler: Callable[[anyio.abc.Process], T] = ..., + **kwargs: Any, +) -> anyio.abc.Process: + ... + + +@overload +async def run_process( + command: list[str], + *, + stream_output: Union[ + bool, tuple[Optional[TextSink[str]], Optional[TextSink[str]]] + ] = ..., + task_status: Optional[anyio.abc.TaskStatus[int]] = ..., + task_status_handler: None = None, + **kwargs: Any, +) -> anyio.abc.Process: + ... + + +@overload +async def run_process( + command: list[str], + *, + stream_output: Union[ + bool, tuple[Optional[TextSink[str]], Optional[TextSink[str]]] + ] = False, + task_status: Optional[anyio.abc.TaskStatus[T]] = None, + task_status_handler: Optional[Callable[[anyio.abc.Process], T]] = None, + **kwargs: Any, +) -> anyio.abc.Process: + ... + + async def run_process( - command: List[str], - stream_output: Union[bool, Tuple[Optional[TextSink], Optional[TextSink]]] = False, - task_status: Optional[anyio.abc.TaskStatus] = None, - task_status_handler: Optional[Callable[[anyio.abc.Process], Any]] = None, - **kwargs, -): + command: list[str], + *, + stream_output: Union[ + bool, tuple[Optional[TextSink[str]], Optional[TextSink[str]]] + ] = False, + task_status: Optional[anyio.abc.TaskStatus[T]] = None, + task_status_handler: Optional[Callable[[anyio.abc.Process], T]] = None, + **kwargs: Any, +) -> anyio.abc.Process: """ Like `anyio.run_process` but with: @@ -262,12 +317,10 @@ async def run_process( **kwargs, ) as process: if task_status is not None: - if not task_status_handler: - - def task_status_handler(process): - return process.pid - - task_status.started(task_status_handler(process)) + value: T = cast(T, process.pid) + if task_status_handler: + value = task_status_handler(process) + task_status.started(value) if stream_output: await consume_process_output( @@ -280,31 +333,36 @@ def task_status_handler(process): async def consume_process_output( - process, - stdout_sink: Optional[TextSink] = None, - stderr_sink: Optional[TextSink] = None, -): + process: anyio.abc.Process, + stdout_sink: Optional[TextSink[str]] = None, + stderr_sink: Optional[TextSink[str]] = None, +) -> None: async with anyio.create_task_group() as tg: - tg.start_soon( - stream_text, - TextReceiveStream(process.stdout), - stdout_sink, - ) - tg.start_soon( - stream_text, - TextReceiveStream(process.stderr), - stderr_sink, - ) + if process.stdout is not None: + tg.start_soon( + stream_text, + TextReceiveStream(process.stdout), + stdout_sink, + ) + if process.stderr is not None: + tg.start_soon( + stream_text, + TextReceiveStream(process.stderr), + stderr_sink, + ) -async def stream_text(source: TextReceiveStream, *sinks: TextSink): +async def stream_text( + source: TextReceiveStream, *sinks: Optional[TextSink[str]] +) -> None: wrapped_sinks = [ ( - anyio.wrap_file(sink) + anyio.wrap_file(cast(IO[str], sink)) if hasattr(sink, "write") and hasattr(sink, "flush") else sink ) for sink in sinks + if sink is not None ] async for item in source: for sink in wrapped_sinks: @@ -313,30 +371,32 @@ async def stream_text(source: TextReceiveStream, *sinks: TextSink): elif isinstance(sink, anyio.AsyncFile): await sink.write(item) await sink.flush() - elif sink is None: - pass # Consume the item but perform no action - else: - raise TypeError(f"Unsupported sink type {type(sink).__name__}") -def _register_signal(signum: int, handler: Callable): +def _register_signal( + signum: int, + handler: Optional[ + Union[Callable[[int, Optional[FrameType]], Any], int, signal.Handlers] + ], +) -> None: if threading.current_thread() is threading.main_thread(): signal.signal(signum, handler) def forward_signal_handler( - pid: int, signum: int, *signums: int, process_name: str, print_fn: Callable -): + pid: int, signum: int, *signums: int, process_name: str, print_fn: PrintFn +) -> None: """Forward subsequent signum events (e.g. interrupts) to respective signums.""" current_signal, future_signals = signums[0], signums[1:] # avoid RecursionError when setting up a direct signal forward to the same signal for the main pid + original_handler = None avoid_infinite_recursion = signum == current_signal and pid == os.getpid() if avoid_infinite_recursion: # store the vanilla handler so it can be temporarily restored below original_handler = signal.getsignal(current_signal) - def handler(*args): + def handler(*arg: Any) -> None: print_fn( f"Received {getattr(signum, 'name', signum)}. " f"Sending {getattr(current_signal, 'name', current_signal)} to" @@ -358,7 +418,9 @@ def handler(*args): _register_signal(signum, handler) -def setup_signal_handlers_server(pid: int, process_name: str, print_fn: Callable): +def setup_signal_handlers_server( + pid: int, process_name: str, print_fn: PrintFn +) -> None: """Handle interrupts of the server gracefully.""" setup_handler = partial( forward_signal_handler, pid, process_name=process_name, print_fn=print_fn @@ -375,7 +437,7 @@ def setup_signal_handlers_server(pid: int, process_name: str, print_fn: Callable setup_handler(signal.SIGTERM, signal.SIGTERM, signal.SIGKILL) -def setup_signal_handlers_agent(pid: int, process_name: str, print_fn: Callable): +def setup_signal_handlers_agent(pid: int, process_name: str, print_fn: PrintFn) -> None: """Handle interrupts of the agent gracefully.""" setup_handler = partial( forward_signal_handler, pid, process_name=process_name, print_fn=print_fn @@ -393,7 +455,9 @@ def setup_signal_handlers_agent(pid: int, process_name: str, print_fn: Callable) setup_handler(signal.SIGTERM, signal.SIGINT, signal.SIGKILL) -def setup_signal_handlers_worker(pid: int, process_name: str, print_fn: Callable): +def setup_signal_handlers_worker( + pid: int, process_name: str, print_fn: PrintFn +) -> None: """Handle interrupts of workers gracefully.""" setup_handler = partial( forward_signal_handler, pid, process_name=process_name, print_fn=print_fn diff --git a/src/prefect/utilities/pydantic.py b/src/prefect/utilities/pydantic.py index 8f09afaaa4e1..6086931fbbbd 100644 --- a/src/prefect/utilities/pydantic.py +++ b/src/prefect/utilities/pydantic.py @@ -1,18 +1,18 @@ -from functools import partial from typing import ( Any, Callable, - Dict, Generic, Optional, - Type, TypeVar, + Union, cast, get_origin, overload, ) -from jsonpatch import JsonPatch as JsonPatchBase +from jsonpatch import ( # type: ignore # no typing stubs available, see https://github.com/stefankoegl/python-json-patch/issues/158 + JsonPatch as JsonPatchBase, +) from pydantic import ( BaseModel, GetJsonSchemaHandler, @@ -33,7 +33,7 @@ T = TypeVar("T", bound=Any) -def _reduce_model(model: BaseModel): +def _reduce_model(self: BaseModel) -> tuple[Any, ...]: """ Helper for serializing a cythonized model with cloudpickle. @@ -43,31 +43,33 @@ def _reduce_model(model: BaseModel): return ( _unreduce_model, ( - to_qualified_name(type(model)), - model.model_dump_json(**getattr(model, "__reduce_kwargs__", {})), + to_qualified_name(type(self)), + self.model_dump_json(**getattr(self, "__reduce_kwargs__", {})), ), ) -def _unreduce_model(model_name, json): +def _unreduce_model(model_name: str, json: str) -> Any: """Helper for restoring model after serialization""" model = from_qualified_name(model_name) return model.model_validate_json(json) @overload -def add_cloudpickle_reduction(__model_cls: Type[M]) -> Type[M]: +def add_cloudpickle_reduction(__model_cls: type[M]) -> type[M]: ... @overload def add_cloudpickle_reduction( - **kwargs: Any, -) -> Callable[[Type[M]], Type[M]]: + __model_cls: None = None, **kwargs: Any +) -> Callable[[type[M]], type[M]]: ... -def add_cloudpickle_reduction(__model_cls: Optional[Type[M]] = None, **kwargs: Any): +def add_cloudpickle_reduction( + __model_cls: Optional[type[M]] = None, **kwargs: Any +) -> Union[type[M], Callable[[type[M]], type[M]]]: """ Adds a `__reducer__` to the given class that ensures it is cloudpickle compatible. @@ -85,25 +87,22 @@ def add_cloudpickle_reduction(__model_cls: Optional[Type[M]] = None, **kwargs: A """ if __model_cls: __model_cls.__reduce__ = _reduce_model - __model_cls.__reduce_kwargs__ = kwargs + setattr(__model_cls, "__reduce_kwargs__", kwargs) return __model_cls - else: - return cast( - Callable[[Type[M]], Type[M]], - partial( - add_cloudpickle_reduction, - **kwargs, - ), - ) + + def reducer_with_kwargs(__model_cls: type[M]) -> type[M]: + return add_cloudpickle_reduction(__model_cls, **kwargs) + + return reducer_with_kwargs -def get_class_fields_only(model: Type[BaseModel]) -> set: +def get_class_fields_only(model: type[BaseModel]) -> set[str]: """ Gets all the field names defined on the model class but not any parent classes. Any fields that are on the parent but redefined on the subclass are included. """ subclass_class_fields = set(model.__annotations__.keys()) - parent_class_fields = set() + parent_class_fields: set[str] = set() for base in model.__class__.__bases__: if issubclass(base, BaseModel): @@ -114,7 +113,7 @@ def get_class_fields_only(model: Type[BaseModel]) -> set: ) -def add_type_dispatch(model_cls: Type[M]) -> Type[M]: +def add_type_dispatch(model_cls: type[M]) -> type[M]: """ Extend a Pydantic model to add a 'type' field that is used as a discriminator field to dynamically determine the subtype that when deserializing models. @@ -149,7 +148,7 @@ def add_type_dispatch(model_cls: Type[M]) -> Type[M]: elif not defines_dispatch_key and defines_type_field: field_type_annotation = model_cls.model_fields["type"].annotation - if field_type_annotation != str: + if field_type_annotation != str and field_type_annotation is not None: raise TypeError( f"Model class {model_cls.__name__!r} defines a 'type' field with " f"type {field_type_annotation.__name__!r} but it must be 'str'." @@ -157,10 +156,10 @@ def add_type_dispatch(model_cls: Type[M]) -> Type[M]: # Set the dispatch key to retrieve the value from the type field @classmethod - def dispatch_key_from_type_field(cls): + def dispatch_key_from_type_field(cls: type[M]) -> str: return cls.model_fields["type"].default - model_cls.__dispatch_key__ = dispatch_key_from_type_field + setattr(model_cls, "__dispatch_key__", dispatch_key_from_type_field) else: raise ValueError( @@ -171,7 +170,7 @@ def dispatch_key_from_type_field(cls): cls_init = model_cls.__init__ cls_new = model_cls.__new__ - def __init__(__pydantic_self__, **data: Any) -> None: + def __init__(__pydantic_self__: M, **data: Any) -> None: type_string = ( get_dispatch_key(__pydantic_self__) if type(__pydantic_self__) != model_cls @@ -180,12 +179,16 @@ def __init__(__pydantic_self__, **data: Any) -> None: data.setdefault("type", type_string) cls_init(__pydantic_self__, **data) - def __new__(cls: Type[M], **kwargs: Any) -> M: + def __new__(cls: type[M], **kwargs: Any) -> M: if "type" in kwargs: try: subcls = lookup_type(cls, dispatch_key=kwargs["type"]) except KeyError as exc: - raise ValidationError(errors=[exc], model=cls) + raise ValidationError.from_exception_data( + title=cls.__name__, + line_errors=[{"type": str(exc), "input": kwargs["type"]}], + input_type="python", + ) return cls_new(subcls) else: return cls_new(cls) @@ -221,7 +224,7 @@ class PartialModel(Generic[M]): >>> model = partial_model.finalize(z=3.0) """ - def __init__(self, __model_cls: Type[M], **kwargs: Any) -> None: + def __init__(self, __model_cls: type[M], **kwargs: Any) -> None: self.fields = kwargs # Set fields first to avoid issues if `fields` is also set on the `model_cls` # in our custom `setattr` implementation. @@ -236,11 +239,11 @@ def finalize(self, **kwargs: Any) -> M: self.raise_if_not_in_model(name) return self.model_cls(**self.fields, **kwargs) - def raise_if_already_set(self, name): + def raise_if_already_set(self, name: str) -> None: if name in self.fields: raise ValueError(f"Field {name!r} has already been set.") - def raise_if_not_in_model(self, name): + def raise_if_not_in_model(self, name: str) -> None: if name not in self.model_cls.model_fields: raise ValueError(f"Field {name!r} is not present in the model.") @@ -290,7 +293,7 @@ def __get_pydantic_json_schema__( def custom_pydantic_encoder( - type_encoders: Optional[Dict[Any, Callable[[Type[Any]], Any]]], obj: Any + type_encoders: dict[Any, Callable[[type[Any]], Any]], obj: Any ) -> Any: # Check the class type and its superclasses for a matching encoder for base in obj.__class__.__mro__[:-1]: @@ -359,8 +362,10 @@ class ExampleModel(BaseModel): """ adapter = TypeAdapter(type_) - if get_origin(type_) is list and isinstance(data, dict): - data = next(iter(data.values())) + origin: Optional[Any] = get_origin(type_) + if origin is list and isinstance(data, dict): + values_dict: dict[Any, Any] = data + data = next(iter(values_dict.values())) parser: Callable[[Any], T] = getattr(adapter, f"validate_{mode}") diff --git a/src/prefect/utilities/render_swagger.py b/src/prefect/utilities/render_swagger.py index ac02ee985fc6..82008ed3192f 100644 --- a/src/prefect/utilities/render_swagger.py +++ b/src/prefect/utilities/render_swagger.py @@ -8,10 +8,13 @@ import string import urllib.parse from pathlib import Path +from typing import Any, Optional, cast from xml.sax.saxutils import escape import mkdocs.plugins -from mkdocs.structure.files import File +from mkdocs.config.defaults import MkDocsConfig +from mkdocs.structure.files import File, Files +from mkdocs.structure.pages import Page USAGE_MSG = ( "Usage: '!!swagger !!' or '!!swagger-http !!'. " @@ -50,7 +53,7 @@ TOKEN_HTTP = re.compile(r"!!swagger-http(?: (?Phttps?://[^\s]+))?!!") -def swagger_lib(config) -> dict: +def swagger_lib(config: MkDocsConfig) -> dict[str, Any]: """ Provides the actual swagger library used """ @@ -59,11 +62,14 @@ def swagger_lib(config) -> dict: "js": "https://unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js", } - extra_javascript = config.get("extra_javascript", []) - extra_css = config.get("extra_css", []) + extra_javascript = config.extra_javascript + extra_css = cast(list[str], config.extra_css) for lib in extra_javascript: - if os.path.basename(urllib.parse.urlparse(lib).path) == "swagger-ui-bundle.js": - lib_swagger["js"] = lib + if ( + os.path.basename(urllib.parse.urlparse(str(lib)).path) + == "swagger-ui-bundle.js" + ): + lib_swagger["js"] = str(lib) break for css in extra_css: @@ -73,8 +79,10 @@ def swagger_lib(config) -> dict: return lib_swagger -class SwaggerPlugin(mkdocs.plugins.BasePlugin): - def on_page_markdown(self, markdown, page, config, files): +class SwaggerPlugin(mkdocs.plugins.BasePlugin[Any]): + def on_page_markdown( + self, markdown: str, /, *, page: Page, config: MkDocsConfig, files: Files + ) -> Optional[str]: is_http = False match = TOKEN.search(markdown) @@ -88,7 +96,7 @@ def on_page_markdown(self, markdown, page, config, files): pre_token = markdown[: match.start()] post_token = markdown[match.end() :] - def _error(message): + def _error(message: str) -> str: return ( pre_token + escape(ERROR_TEMPLATE.substitute(error=message)) @@ -103,8 +111,10 @@ def _error(message): if is_http: url = path else: + base = page.file.abs_src_path + assert base is not None try: - api_file = Path(page.file.abs_src_path).with_name(path) + api_file = Path(base).with_name(path) except ValueError as exc: return _error(f"Invalid path. {exc.args[0]}") @@ -114,7 +124,7 @@ def _error(message): src_dir = api_file.parent dest_dir = Path(page.file.abs_dest_path).parent - new_file = File(api_file.name, src_dir, dest_dir, False) + new_file = File(api_file.name, str(src_dir), str(dest_dir), False) files.append(new_file) url = Path(new_file.abs_dest_path).name @@ -129,4 +139,4 @@ def _error(message): ) # If multiple swaggers exist. - return self.on_page_markdown(markdown, page, config, files) + return self.on_page_markdown(markdown, page=page, config=config, files=files) diff --git a/src/prefect/utilities/schema_tools/__init__.py b/src/prefect/utilities/schema_tools/__init__.py index 1e6e73fc372a..bfd382af2b5f 100644 --- a/src/prefect/utilities/schema_tools/__init__.py +++ b/src/prefect/utilities/schema_tools/__init__.py @@ -2,8 +2,8 @@ from .validation import ( CircularSchemaRefError, ValidationError, - validate, is_valid_schema, + validate, ) __all__ = [ @@ -12,5 +12,6 @@ "HydrationError", "ValidationError", "hydrate", + "is_valid_schema", "validate", ] diff --git a/src/prefect/utilities/schema_tools/hydration.py b/src/prefect/utilities/schema_tools/hydration.py index 49bd1bc8b33d..3cbb8e97804d 100644 --- a/src/prefect/utilities/schema_tools/hydration.py +++ b/src/prefect/utilities/schema_tools/hydration.py @@ -1,10 +1,12 @@ import json -from typing import Any, Callable, Dict, Optional +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from typing import Any, Optional, cast import jinja2 from pydantic import BaseModel, Field from sqlalchemy.ext.asyncio import AsyncSession -from typing_extensions import TypeAlias +from typing_extensions import Self, TypeAlias, TypeIs from prefect.server.utilities.user_templates import ( TemplateSecurityError, @@ -15,14 +17,14 @@ class HydrationContext(BaseModel): - workspace_variables: Dict[ + workspace_variables: dict[ str, StrictVariableValue, ] = Field(default_factory=dict) render_workspace_variables: bool = Field(default=False) raise_on_error: bool = Field(default=False) render_jinja: bool = Field(default=False) - jinja_context: Dict[str, Any] = Field(default_factory=dict) + jinja_context: dict[str, Any] = Field(default_factory=dict) @classmethod async def build( @@ -31,9 +33,11 @@ async def build( raise_on_error: bool = False, render_jinja: bool = False, render_workspace_variables: bool = False, - ) -> "HydrationContext": + ) -> Self: + from prefect.server.database.orm_models import Variable from prefect.server.models.variables import read_variables + variables: Sequence[Variable] if render_workspace_variables: variables = await read_variables( session=session, @@ -51,14 +55,14 @@ async def build( ) -Handler: TypeAlias = Callable[[dict, HydrationContext], Any] +Handler: TypeAlias = Callable[[dict[str, Any], HydrationContext], Any] PrefectKind: TypeAlias = Optional[str] -_handlers: Dict[PrefectKind, Handler] = {} +_handlers: dict[PrefectKind, Handler] = {} class Placeholder: - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, type(self)) @property @@ -70,11 +74,11 @@ class RemoveValue(Placeholder): pass -def _remove_value(value) -> bool: +def _remove_value(value: Any) -> TypeIs[RemoveValue]: return isinstance(value, RemoveValue) -class HydrationError(Placeholder, Exception): +class HydrationError(Placeholder, Exception, ABC): def __init__(self, detail: Optional[str] = None): self.detail = detail @@ -83,47 +87,49 @@ def is_error(self) -> bool: return True @property - def message(self): + @abstractmethod + def message(self) -> str: raise NotImplementedError("Must be implemented by subclass") - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, type(self)) and self.message == other.message - def __str__(self): + def __str__(self) -> str: return self.message class KeyNotFound(HydrationError): @property - def message(self): + def message(self) -> str: return f"Missing '{self.key}' key in __prefect object" @property + @abstractmethod def key(self) -> str: raise NotImplementedError("Must be implemented by subclass") class ValueNotFound(KeyNotFound): @property - def key(self): + def key(self) -> str: return "value" class TemplateNotFound(KeyNotFound): @property - def key(self): + def key(self) -> str: return "template" class VariableNameNotFound(KeyNotFound): @property - def key(self): + def key(self) -> str: return "variable_name" class InvalidJSON(HydrationError): @property - def message(self): + def message(self) -> str: message = "Invalid JSON" if self.detail: message += f": {self.detail}" @@ -132,7 +138,7 @@ def message(self): class InvalidJinja(HydrationError): @property - def message(self): + def message(self) -> str: message = "Invalid jinja" if self.detail: message += f": {self.detail}" @@ -146,29 +152,29 @@ def variable_name(self) -> str: return self.detail @property - def message(self): + def message(self) -> str: return f"Variable '{self.detail}' not found in workspace." class WorkspaceVariable(Placeholder): - def __init__(self, variable_name: str): + def __init__(self, variable_name: str) -> None: self.variable_name = variable_name - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return ( isinstance(other, type(self)) and self.variable_name == other.variable_name ) class ValidJinja(Placeholder): - def __init__(self, template: str): + def __init__(self, template: str) -> None: self.template = template - def __eq__(self, other): + def __eq__(self, other: Any) -> bool: return isinstance(other, type(self)) and self.template == other.template -def handler(kind: PrefectKind) -> Callable: +def handler(kind: PrefectKind) -> Callable[[Handler], Handler]: def decorator(func: Handler) -> Handler: _handlers[kind] = func return func @@ -176,9 +182,9 @@ def decorator(func: Handler) -> Handler: return decorator -def call_handler(kind: PrefectKind, obj: dict, ctx: HydrationContext) -> Any: +def call_handler(kind: PrefectKind, obj: dict[str, Any], ctx: HydrationContext) -> Any: if kind not in _handlers: - return (obj or {}).get("value", None) + return obj.get("value", None) res = _handlers[kind](obj, ctx) if ctx.raise_on_error and isinstance(res, HydrationError): @@ -187,7 +193,7 @@ def call_handler(kind: PrefectKind, obj: dict, ctx: HydrationContext) -> Any: @handler("none") -def null_handler(obj: dict, ctx: HydrationContext): +def null_handler(obj: dict[str, Any], ctx: HydrationContext): if "value" in obj: # null handler is a pass through, so we want to continue to hydrate return _hydrate(obj["value"], ctx) @@ -196,7 +202,7 @@ def null_handler(obj: dict, ctx: HydrationContext): @handler("json") -def json_handler(obj: dict, ctx: HydrationContext): +def json_handler(obj: dict[str, Any], ctx: HydrationContext): if "value" in obj: if isinstance(obj["value"], dict): dehydrated_json = _hydrate(obj["value"], ctx) @@ -222,7 +228,7 @@ def json_handler(obj: dict, ctx: HydrationContext): @handler("jinja") -def jinja_handler(obj: dict, ctx: HydrationContext): +def jinja_handler(obj: dict[str, Any], ctx: HydrationContext) -> Any: if "template" in obj: if isinstance(obj["template"], dict): dehydrated_jinja = _hydrate(obj["template"], ctx) @@ -247,7 +253,7 @@ def jinja_handler(obj: dict, ctx: HydrationContext): @handler("workspace_variable") -def workspace_variable_handler(obj: dict, ctx: HydrationContext): +def workspace_variable_handler(obj: dict[str, Any], ctx: HydrationContext) -> Any: if "variable_name" in obj: if isinstance(obj["variable_name"], dict): dehydrated_variable = _hydrate(obj["variable_name"], ctx) @@ -259,7 +265,7 @@ def workspace_variable_handler(obj: dict, ctx: HydrationContext): return dehydrated_variable if not ctx.render_workspace_variables: - return WorkspaceVariable(variable_name=obj["variable_name"]) + return WorkspaceVariable(variable_name=dehydrated_variable) if dehydrated_variable in ctx.workspace_variables: return ctx.workspace_variables[dehydrated_variable] @@ -277,35 +283,36 @@ def workspace_variable_handler(obj: dict, ctx: HydrationContext): return RemoveValue() -def hydrate(obj: dict, ctx: Optional[HydrationContext] = None): - res = _hydrate(obj, ctx) +def hydrate( + obj: dict[str, Any], ctx: Optional[HydrationContext] = None +) -> dict[str, Any]: + res: dict[str, Any] = _hydrate(obj, ctx) if _remove_value(res): - return {} + res = {} return res -def _hydrate(obj, ctx: Optional[HydrationContext] = None) -> Any: +def _hydrate(obj: Any, ctx: Optional[HydrationContext] = None) -> Any: if ctx is None: ctx = HydrationContext() - prefect_object = isinstance(obj, dict) and "__prefect_kind" in obj - - if prefect_object: - prefect_kind = obj.get("__prefect_kind") - return call_handler(prefect_kind, obj, ctx) + if isinstance(obj, dict) and "__prefect_kind" in obj: + obj_dict: dict[str, Any] = obj + prefect_kind = obj_dict["__prefect_kind"] + return call_handler(prefect_kind, obj_dict, ctx) else: if isinstance(obj, dict): return { key: hydrated_value - for key, value in obj.items() + for key, value in cast(dict[str, Any], obj).items() if not _remove_value(hydrated_value := _hydrate(value, ctx)) } elif isinstance(obj, list): return [ hydrated_element - for element in obj + for element in cast(list[Any], obj) if not _remove_value(hydrated_element := _hydrate(element, ctx)) ] else: diff --git a/src/prefect/utilities/schema_tools/validation.py b/src/prefect/utilities/schema_tools/validation.py index e94204331125..1dfdd7a3607b 100644 --- a/src/prefect/utilities/schema_tools/validation.py +++ b/src/prefect/utilities/schema_tools/validation.py @@ -1,14 +1,19 @@ from collections import defaultdict, deque +from collections.abc import Callable, Iterable, Iterator from copy import deepcopy -from typing import Any, Dict, List +from typing import TYPE_CHECKING, Any, cast import jsonschema from jsonschema.exceptions import ValidationError as JSONSchemaValidationError from jsonschema.validators import Draft202012Validator, create +from referencing.jsonschema import ObjectSchema, Schema from prefect.utilities.collections import remove_nested_keys from prefect.utilities.schema_tools.hydration import HydrationError, Placeholder +if TYPE_CHECKING: + from jsonschema.validators import _Validator # type: ignore + class CircularSchemaRefError(Exception): pass @@ -21,12 +26,16 @@ class ValidationError(Exception): PLACEHOLDERS_VALIDATOR_NAME = "_placeholders" -def _build_validator(): - def _applicable_validators(schema): +def _build_validator() -> type["_Validator"]: + def _applicable_validators(schema: Schema) -> Iterable[tuple[str, Any]]: # the default implementation returns `schema.items()` - return {**schema, PLACEHOLDERS_VALIDATOR_NAME: None}.items() + assert not isinstance(schema, bool) + schema = {**schema, PLACEHOLDERS_VALIDATOR_NAME: None} + return schema.items() - def _placeholders(validator, _, instance, schema): + def _placeholders( + _validator: "_Validator", _property: object, instance: Any, _schema: Schema + ) -> Iterator[JSONSchemaValidationError]: if isinstance(instance, HydrationError): yield JSONSchemaValidationError(instance.message) @@ -43,7 +52,9 @@ def _placeholders(validator, _, instance, schema): version="prefect", type_checker=Draft202012Validator.TYPE_CHECKER, format_checker=Draft202012Validator.FORMAT_CHECKER, - id_of=Draft202012Validator.ID_OF, + id_of=cast( # the stub for create() is wrong here; id_of accepts (Schema) -> str | None + Callable[[Schema], str], Draft202012Validator.ID_OF + ), applicable_validators=_applicable_validators, ) @@ -51,24 +62,23 @@ def _placeholders(validator, _, instance, schema): _VALIDATOR = _build_validator() -def is_valid_schema(schema: Dict, preprocess: bool = True): +def is_valid_schema(schema: ObjectSchema, preprocess: bool = True) -> None: if preprocess: schema = preprocess_schema(schema) try: - if schema is not None: - _VALIDATOR.check_schema(schema, format_checker=_VALIDATOR.FORMAT_CHECKER) + _VALIDATOR.check_schema(schema, format_checker=_VALIDATOR.FORMAT_CHECKER) except jsonschema.SchemaError as exc: raise ValueError(f"Invalid schema: {exc.message}") from exc def validate( - obj: Dict, - schema: Dict, + obj: dict[str, Any], + schema: ObjectSchema, raise_on_error: bool = False, preprocess: bool = True, ignore_required: bool = False, allow_none_with_default: bool = False, -) -> List[JSONSchemaValidationError]: +) -> list[JSONSchemaValidationError]: if preprocess: schema = preprocess_schema(schema, allow_none_with_default) @@ -93,32 +103,31 @@ def validate( else: try: validator = _VALIDATOR(schema, format_checker=_VALIDATOR.FORMAT_CHECKER) - errors = list(validator.iter_errors(obj)) + errors = list(validator.iter_errors(obj)) # type: ignore except RecursionError: raise CircularSchemaRefError return errors -def is_valid( - obj: Dict, - schema: Dict, -) -> bool: +def is_valid(obj: dict[str, Any], schema: ObjectSchema) -> bool: errors = validate(obj, schema) - return len(errors) == 0 + return not errors -def prioritize_placeholder_errors(errors): - errors_by_path = defaultdict(list) +def prioritize_placeholder_errors( + errors: list[JSONSchemaValidationError], +) -> list[JSONSchemaValidationError]: + errors_by_path: dict[str, list[JSONSchemaValidationError]] = defaultdict(list) for error in errors: path_str = "->".join(str(p) for p in error.relative_path) errors_by_path[path_str].append(error) - filtered_errors = [] - for path, grouped_errors in errors_by_path.items(): + filtered_errors: list[JSONSchemaValidationError] = [] + for grouped_errors in errors_by_path.values(): placeholders_errors = [ error for error in grouped_errors - if error.validator == PLACEHOLDERS_VALIDATOR_NAME + if error.validator == PLACEHOLDERS_VALIDATOR_NAME # type: ignore # typing stubs are incomplete ] if placeholders_errors: @@ -129,8 +138,8 @@ def prioritize_placeholder_errors(errors): return filtered_errors -def build_error_obj(errors: List[JSONSchemaValidationError]) -> Dict: - error_response: Dict[str, Any] = {"errors": []} +def build_error_obj(errors: list[JSONSchemaValidationError]) -> dict[str, Any]: + error_response: dict[str, Any] = {"errors": []} # If multiple errors are present for the same path and one of them # is a placeholder error, we want only want to use the placeholder error. @@ -145,11 +154,11 @@ def build_error_obj(errors: List[JSONSchemaValidationError]) -> Dict: # Required errors should be moved one level down to the property # they're associated with, so we add an extra level to the path. - if error.validator == "required": - required_field = error.message.split(" ")[0].strip("'") + if error.validator == "required": # type: ignore + required_field = error.message.partition(" ")[0].strip("'") path.append(required_field) - current = error_response["errors"] + current: list[Any] = error_response["errors"] # error at the root, just append the error message if not path: @@ -163,10 +172,10 @@ def build_error_obj(errors: List[JSONSchemaValidationError]) -> Dict: else: for entry in current: if entry.get("index") == part: - current = entry["errors"] + current = cast(list[Any], entry["errors"]) break else: - new_entry = {"index": part, "errors": []} + new_entry: dict[str, Any] = {"index": part, "errors": []} current.append(new_entry) current = new_entry["errors"] else: @@ -182,7 +191,7 @@ def build_error_obj(errors: List[JSONSchemaValidationError]) -> Dict: current.append(new_entry) current = new_entry["errors"] - valid = len(error_response["errors"]) == 0 + valid = not bool(error_response["errors"]) error_response["valid"] = valid return error_response @@ -190,10 +199,10 @@ def build_error_obj(errors: List[JSONSchemaValidationError]) -> Dict: def _fix_null_typing( key: str, - schema: Dict, - required_fields: List[str], + schema: dict[str, Any], + required_fields: list[str], allow_none_with_default: bool = False, -): +) -> None: """ Pydantic V1 does not generate a valid Draft2020-12 schema for null types. """ @@ -207,7 +216,7 @@ def _fix_null_typing( del schema["type"] -def _fix_tuple_items(schema: Dict): +def _fix_tuple_items(schema: dict[str, Any]) -> None: """ Pydantic V1 does not generate a valid Draft2020-12 schema for tuples. """ @@ -216,15 +225,15 @@ def _fix_tuple_items(schema: Dict): and isinstance(schema["items"], list) and not schema.get("prefixItems") ): - schema["prefixItems"] = deepcopy(schema["items"]) + schema["prefixItems"] = deepcopy(cast(list[Any], schema["items"])) del schema["items"] def process_properties( - properties: Dict, - required_fields: List[str], + properties: dict[str, dict[str, Any]], + required_fields: list[str], allow_none_with_default: bool = False, -): +) -> None: for key, schema in properties.items(): _fix_null_typing(key, schema, required_fields, allow_none_with_default) _fix_tuple_items(schema) @@ -235,9 +244,9 @@ def process_properties( def preprocess_schema( - schema: Dict, + schema: ObjectSchema, allow_none_with_default: bool = False, -): +) -> ObjectSchema: schema = deepcopy(schema) if "properties" in schema: @@ -247,7 +256,8 @@ def preprocess_schema( ) if "definitions" in schema: # Also process definitions for reused models - for definition in (schema["definitions"] or {}).values(): + definitions = cast(dict[str, Any], schema["definitions"]) + for definition in definitions.values(): if "properties" in definition: required_fields = definition.get("required", []) process_properties( diff --git a/src/prefect/utilities/services.py b/src/prefect/utilities/services.py index 60dd0cf61c8e..de183c73c555 100644 --- a/src/prefect/utilities/services.py +++ b/src/prefect/utilities/services.py @@ -1,9 +1,10 @@ -import sys import threading from collections import deque +from collections.abc import Coroutine +from logging import Logger from traceback import format_exception from types import TracebackType -from typing import Callable, Coroutine, Deque, Optional, Tuple +from typing import Any, Callable, Optional from wsgiref.simple_server import WSGIServer import anyio @@ -14,11 +15,11 @@ from prefect.utilities.collections import distinct from prefect.utilities.math import clamped_poisson_interval -logger = get_logger("utilities.services.critical_service_loop") +logger: Logger = get_logger("utilities.services.critical_service_loop") async def critical_service_loop( - workload: Callable[..., Coroutine], + workload: Callable[..., Coroutine[Any, Any, Any]], interval: float, memory: int = 10, consecutive: int = 3, @@ -26,7 +27,7 @@ async def critical_service_loop( printer: Callable[..., None] = print, run_once: bool = False, jitter_range: Optional[float] = None, -): +) -> None: """ Runs the given `workload` function on the specified `interval`, while being forgiving of intermittent issues like temporary HTTP errors. If more than a certain @@ -50,8 +51,8 @@ async def critical_service_loop( between `interval * (1 - range) < rv < interval * (1 + range)` """ - track_record: Deque[bool] = deque([True] * consecutive, maxlen=consecutive) - failures: Deque[Tuple[Exception, TracebackType]] = deque(maxlen=memory) + track_record: deque[bool] = deque([True] * consecutive, maxlen=consecutive) + failures: deque[tuple[Exception, Optional[TracebackType]]] = deque(maxlen=memory) backoff_count = 0 while True: @@ -78,7 +79,7 @@ async def critical_service_loop( # or Prefect Cloud is having an outage (which will be covered by the # exception clause below) track_record.append(False) - failures.append((exc, sys.exc_info()[-1])) + failures.append((exc, exc.__traceback__)) logger.debug( f"Run of {workload!r} failed with TransportError", exc_info=exc ) @@ -88,7 +89,7 @@ async def critical_service_loop( # likely to be temporary and transient. Don't quit over these unless # it is prolonged. track_record.append(False) - failures.append((exc, sys.exc_info()[-1])) + failures.append((exc, exc.__traceback__)) logger.debug( f"Run of {workload!r} failed with HTTPStatusError", exc_info=exc ) @@ -155,10 +156,10 @@ async def critical_service_loop( await anyio.sleep(sleep) -_metrics_server: Optional[Tuple[WSGIServer, threading.Thread]] = None +_metrics_server: Optional[tuple[WSGIServer, threading.Thread]] = None -def start_client_metrics_server(): +def start_client_metrics_server() -> None: """Start the process-wide Prometheus metrics server for client metrics (if enabled with `PREFECT_CLIENT_METRICS_ENABLED`) on the port `PREFECT_CLIENT_METRICS_PORT`.""" if not PREFECT_CLIENT_METRICS_ENABLED: @@ -173,7 +174,7 @@ def start_client_metrics_server(): _metrics_server = start_http_server(port=PREFECT_CLIENT_METRICS_PORT.value()) -def stop_client_metrics_server(): +def stop_client_metrics_server() -> None: """Start the process-wide Prometheus metrics server for client metrics, if it has previously been started""" global _metrics_server diff --git a/src/prefect/utilities/templating.py b/src/prefect/utilities/templating.py index 3e46337ba8f7..6e597a309786 100644 --- a/src/prefect/utilities/templating.py +++ b/src/prefect/utilities/templating.py @@ -1,17 +1,7 @@ import enum import os import re -from typing import ( - TYPE_CHECKING, - Any, - Dict, - NamedTuple, - Optional, - Set, - Type, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, TypeVar, Union, cast from prefect.client.utilities import inject_client from prefect.utilities.annotations import NotSet @@ -21,7 +11,7 @@ from prefect.client.orchestration import PrefectClient -T = TypeVar("T", str, int, float, bool, dict, list, None) +T = TypeVar("T", str, int, float, bool, dict[Any, Any], list[Any], None) PLACEHOLDER_CAPTURE_REGEX = re.compile(r"({{\s*([\w\.\-\[\]$]+)\s*}})") BLOCK_DOCUMENT_PLACEHOLDER_PREFIX = "prefect.blocks." @@ -62,7 +52,7 @@ def determine_placeholder_type(name: str) -> PlaceholderType: return PlaceholderType.STANDARD -def find_placeholders(template: T) -> Set[Placeholder]: +def find_placeholders(template: T) -> set[Placeholder]: """ Finds all placeholders in a template. @@ -72,8 +62,9 @@ def find_placeholders(template: T) -> Set[Placeholder]: Returns: A set of all placeholders in the template """ + seed: set[Placeholder] = set() if isinstance(template, (int, float, bool)): - return set() + return seed if isinstance(template, str): result = PLACEHOLDER_CAPTURE_REGEX.findall(template) return { @@ -81,18 +72,16 @@ def find_placeholders(template: T) -> Set[Placeholder]: for full_match, name in result } elif isinstance(template, dict): - return set().union( - *[find_placeholders(value) for key, value in template.items()] - ) + return seed.union(*[find_placeholders(value) for value in template.values()]) elif isinstance(template, list): - return set().union(*[find_placeholders(item) for item in template]) + return seed.union(*[find_placeholders(item) for item in template]) else: raise ValueError(f"Unexpected type: {type(template)}") def apply_values( - template: T, values: Dict[str, Any], remove_notset: bool = True -) -> Union[T, Type[NotSet]]: + template: T, values: dict[str, Any], remove_notset: bool = True +) -> Union[T, type[NotSet]]: """ Replaces placeholders in a template with values from a supplied dictionary. @@ -120,7 +109,7 @@ def apply_values( Returns: The template with the values applied """ - if isinstance(template, (int, float, bool, type(NotSet), type(None))): + if template in (NotSet, None) or isinstance(template, (int, float)): return template if isinstance(template, str): placeholders = find_placeholders(template) @@ -155,7 +144,7 @@ def apply_values( return template elif isinstance(template, dict): - updated_template = {} + updated_template: dict[str, Any] = {} for key, value in template.items(): updated_value = apply_values(value, values, remove_notset=remove_notset) if updated_value is not NotSet: @@ -163,22 +152,22 @@ def apply_values( elif not remove_notset: updated_template[key] = value - return updated_template + return cast(T, updated_template) elif isinstance(template, list): - updated_list = [] + updated_list: list[Any] = [] for value in template: updated_value = apply_values(value, values, remove_notset=remove_notset) if updated_value is not NotSet: updated_list.append(updated_value) - return updated_list + return cast(T, updated_list) else: raise ValueError(f"Unexpected template type {type(template).__name__!r}") @inject_client async def resolve_block_document_references( - template: T, client: "PrefectClient" = None -) -> Union[T, Dict[str, Any]]: + template: T, client: Optional["PrefectClient"] = None +) -> Union[T, dict[str, Any]]: """ Resolve block document references in a template by replacing each reference with the data of the block document. @@ -242,12 +231,17 @@ async def resolve_block_document_references( Returns: The template with block documents resolved """ + if TYPE_CHECKING: + # The @inject_client decorator takes care of providing the client, but + # the function signature must mark it as optional to callers. + assert client is not None + if isinstance(template, dict): block_document_id = template.get("$ref", {}).get("block_document_id") if block_document_id: block_document = await client.read_block_document(block_document_id) return block_document.data - updated_template = {} + updated_template: dict[str, Any] = {} for key, value in template.items(): updated_value = await resolve_block_document_references( value, client=client @@ -265,7 +259,7 @@ async def resolve_block_document_references( placeholder.type is PlaceholderType.BLOCK_DOCUMENT for placeholder in placeholders ) - if len(placeholders) == 0 or not has_block_document_placeholder: + if not (placeholders and has_block_document_placeholder): return template elif ( len(placeholders) == 1 @@ -274,31 +268,32 @@ async def resolve_block_document_references( ): # value_keypath will be a list containing a dot path if additional # attributes are accessed and an empty list otherwise. - block_type_slug, block_document_name, *value_keypath = ( - list(placeholders)[0] - .name.replace(BLOCK_DOCUMENT_PLACEHOLDER_PREFIX, "") - .split(".", 2) - ) + [placeholder] = placeholders + parts = placeholder.name.replace( + BLOCK_DOCUMENT_PLACEHOLDER_PREFIX, "" + ).split(".", 2) + block_type_slug, block_document_name, *value_keypath = parts block_document = await client.read_block_document_by_name( name=block_document_name, block_type_slug=block_type_slug ) - value = block_document.data + data = block_document.data + value: Union[T, dict[str, Any]] = data # resolving system blocks to their data for backwards compatibility - if len(value) == 1 and "value" in value: + if len(data) == 1 and "value" in data: # only resolve the value if the keypath is not already pointing to "value" - if len(value_keypath) == 0 or value_keypath[0][:5] != "value": - value = value["value"] + if not (value_keypath and value_keypath[0].startswith("value")): + data = value = value["value"] # resolving keypath/block attributes - if len(value_keypath) > 0: - value_keypath: str = value_keypath[0] - value = get_from_dict(value, value_keypath, default=NotSet) - if value is NotSet: + if value_keypath: + from_dict: Any = get_from_dict(data, value_keypath[0], default=NotSet) + if from_dict is NotSet: raise ValueError( f"Invalid template: {template!r}. Could not resolve the" " keypath in the block document data." ) + value = from_dict return value else: @@ -311,7 +306,7 @@ async def resolve_block_document_references( @inject_client -async def resolve_variables(template: T, client: Optional["PrefectClient"] = None): +async def resolve_variables(template: T, client: Optional["PrefectClient"] = None) -> T: """ Resolve variables in a template by replacing each variable placeholder with the value of the variable. @@ -326,6 +321,11 @@ async def resolve_variables(template: T, client: Optional["PrefectClient"] = Non Returns: The template with variables resolved """ + if TYPE_CHECKING: + # The @inject_client decorator takes care of providing the client, but + # the function signature must mark it as optional to callers. + assert client is not None + if isinstance(template, str): placeholders = find_placeholders(template) has_variable_placeholder = any( @@ -346,7 +346,7 @@ async def resolve_variables(template: T, client: Optional["PrefectClient"] = Non if variable is None: return "" else: - return variable.value + return cast(T, variable.value) else: for full_match, name, placeholder_type in placeholders: if placeholder_type is PlaceholderType.VARIABLE: @@ -355,7 +355,7 @@ async def resolve_variables(template: T, client: Optional["PrefectClient"] = Non if variable is None: template = template.replace(full_match, "") else: - template = template.replace(full_match, variable.value) + template = template.replace(full_match, str(variable.value)) return template elif isinstance(template, dict): return { diff --git a/src/prefect/utilities/text.py b/src/prefect/utilities/text.py index 14637c72f49b..3f37d2c719d2 100644 --- a/src/prefect/utilities/text.py +++ b/src/prefect/utilities/text.py @@ -1,5 +1,6 @@ import difflib -from typing import Iterable, Optional +from collections.abc import Iterable +from typing import Optional def truncated_to(length: int, value: Optional[str]) -> str: diff --git a/src/prefect/utilities/timeout.py b/src/prefect/utilities/timeout.py index 0074a4d337ae..596ad5c92568 100644 --- a/src/prefect/utilities/timeout.py +++ b/src/prefect/utilities/timeout.py @@ -1,6 +1,6 @@ from asyncio import CancelledError from contextlib import contextmanager -from typing import Optional, Type +from typing import Optional from prefect._internal.concurrency.cancellation import ( cancel_async_after, @@ -8,7 +8,7 @@ ) -def fail_if_not_timeout_error(timeout_exc_type: Type[Exception]) -> None: +def fail_if_not_timeout_error(timeout_exc_type: type[Exception]) -> None: if not issubclass(timeout_exc_type, TimeoutError): raise ValueError( "The `timeout_exc_type` argument must be a subclass of `TimeoutError`." @@ -17,7 +17,7 @@ def fail_if_not_timeout_error(timeout_exc_type: Type[Exception]) -> None: @contextmanager def timeout_async( - seconds: Optional[float] = None, timeout_exc_type: Type[TimeoutError] = TimeoutError + seconds: Optional[float] = None, timeout_exc_type: type[TimeoutError] = TimeoutError ): fail_if_not_timeout_error(timeout_exc_type) @@ -34,7 +34,7 @@ def timeout_async( @contextmanager def timeout( - seconds: Optional[float] = None, timeout_exc_type: Type[TimeoutError] = TimeoutError + seconds: Optional[float] = None, timeout_exc_type: type[TimeoutError] = TimeoutError ): fail_if_not_timeout_error(timeout_exc_type) diff --git a/src/prefect/utilities/urls.py b/src/prefect/utilities/urls.py index 7b99f645b648..eadaaa106426 100644 --- a/src/prefect/utilities/urls.py +++ b/src/prefect/utilities/urls.py @@ -2,6 +2,7 @@ import ipaddress import socket import urllib.parse +from logging import Logger from string import Formatter from typing import TYPE_CHECKING, Any, Literal, Optional, Union from urllib.parse import urlparse @@ -19,7 +20,7 @@ from prefect.futures import PrefectFuture from prefect.variables import Variable -logger = get_logger("utilities.urls") +logger: Logger = get_logger("utilities.urls") # The following objects are excluded from UI URL generation because we lack a # directly-addressable URL: @@ -64,7 +65,7 @@ RUN_TYPES = {"flow-run", "task-run"} -def validate_restricted_url(url: str): +def validate_restricted_url(url: str) -> None: """ Validate that the provided URL is safe for outbound requests. This prevents attacks like SSRF (Server Side Request Forgery), where an attacker can make @@ -123,7 +124,7 @@ def convert_class_to_name(obj: Any) -> str: def url_for( obj: Union[ - "PrefectFuture", + "PrefectFuture[Any]", "Block", "Variable", "Automation", @@ -163,6 +164,7 @@ def url_for( url_for("flow-run", obj_id="123e4567-e89b-12d3-a456-426614174000") """ from prefect.blocks.core import Block + from prefect.client.schemas.objects import WorkPool from prefect.events.schemas.automations import Automation from prefect.events.schemas.events import ReceivedEvent, Resource from prefect.futures import PrefectFuture @@ -228,8 +230,10 @@ def url_for( elif name == "block": # Blocks are client-side objects whose API representation is a # BlockDocument. - obj_id = obj._block_document_id + obj_id = getattr(obj, "_block_document_id") elif name in ("variable", "work-pool"): + if TYPE_CHECKING: + assert isinstance(obj, (Variable, WorkPool)) obj_id = obj.name elif isinstance(obj, Resource): obj_id = obj.id.rpartition(".")[2] @@ -244,6 +248,7 @@ def url_for( url_format = ( UI_URL_FORMATS.get(name) if url_type == "ui" else API_URL_FORMATS.get(name) ) + assert url_format is not None if isinstance(obj, ReceivedEvent): url = url_format.format( diff --git a/src/prefect/utilities/visualization.py b/src/prefect/utilities/visualization.py index 29349ac1c006..b149fa42806e 100644 --- a/src/prefect/utilities/visualization.py +++ b/src/prefect/utilities/visualization.py @@ -2,10 +2,12 @@ Utilities for working with Flow.visualize() """ +from collections.abc import Coroutine from functools import partial -from typing import Any, List, Optional +from typing import Any, Literal, Optional, Union, overload -import graphviz +import graphviz # type: ignore # no typing stubs available +from typing_extensions import Self from prefect._internal.concurrency.api import from_async @@ -30,16 +32,36 @@ class GraphvizExecutableNotFoundError(Exception): pass -def get_task_viz_tracker(): +def get_task_viz_tracker() -> Optional["TaskVizTracker"]: return TaskVizTrackerState.current +@overload +def track_viz_task( + is_async: Literal[True], + task_name: str, + parameters: dict[str, Any], + viz_return_value: Optional[Any] = None, +) -> Coroutine[Any, Any, Any]: + ... + + +@overload +def track_viz_task( + is_async: Literal[False], + task_name: str, + parameters: dict[str, Any], + viz_return_value: Optional[Any] = None, +) -> Any: + ... + + def track_viz_task( is_async: bool, task_name: str, parameters: dict[str, Any], viz_return_value: Optional[Any] = None, -): +) -> Union[Coroutine[Any, Any, Any], Any]: """Return a result if sync otherwise return a coroutine that returns the result""" if is_async: return from_async.wait_for_call_in_loop_thread( @@ -85,10 +107,10 @@ class VizTask: def __init__( self, name: str, - upstream_tasks: Optional[List["VizTask"]] = None, + upstream_tasks: Optional[list["VizTask"]] = None, ): self.name = name - self.upstream_tasks = upstream_tasks if upstream_tasks else [] + self.upstream_tasks: list[VizTask] = upstream_tasks if upstream_tasks else [] class TaskVizTracker: @@ -97,7 +119,7 @@ def __init__(self): self.dynamic_task_counter: dict[str, int] = {} self.object_id_to_task: dict[int, VizTask] = {} - def add_task(self, task: VizTask): + def add_task(self, task: VizTask) -> None: if task.name not in self.dynamic_task_counter: self.dynamic_task_counter[task.name] = 0 else: @@ -106,11 +128,11 @@ def add_task(self, task: VizTask): task.name = f"{task.name}-{self.dynamic_task_counter[task.name]}" self.tasks.append(task) - def __enter__(self): + def __enter__(self) -> Self: TaskVizTrackerState.current = self return self - def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any): + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: TaskVizTrackerState.current = None def link_viz_return_value_to_viz_task( @@ -129,7 +151,7 @@ def link_viz_return_value_to_viz_task( self.object_id_to_task[id(viz_return_value)] = viz_task -def build_task_dependencies(task_run_tracker: TaskVizTracker): +def build_task_dependencies(task_run_tracker: TaskVizTracker) -> graphviz.Digraph: """ Constructs a Graphviz directed graph object that represents the dependencies between tasks in the given TaskVizTracker. @@ -166,7 +188,7 @@ def build_task_dependencies(task_run_tracker: TaskVizTracker): ) -def visualize_task_dependencies(graph: graphviz.Digraph, flow_run_name: str): +def visualize_task_dependencies(graph: graphviz.Digraph, flow_run_name: str) -> None: """ Renders and displays a Graphviz directed graph representing task dependencies.