Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[typing] prefect.utilities #16298

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 46 additions & 47 deletions src/prefect/_internal/concurrency/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,50 +6,46 @@
import asyncio
import concurrent.futures
import contextlib
from typing import (
Any,
Awaitable,
Callable,
ContextManager,
Iterable,
Optional,
TypeVar,
Union,
)
from collections.abc import Awaitable, Iterable
from contextlib import AbstractContextManager
from typing import Any, Callable, Optional, Union, cast

from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, TypeAlias, TypeVar

from prefect._internal.concurrency.threads import (
WorkerThread,
get_global_loop,
in_global_loop,
)
from prefect._internal.concurrency.waiters import (
AsyncWaiter,
Call,
SyncWaiter,
)
from prefect._internal.concurrency.waiters import AsyncWaiter, Call, SyncWaiter

P = ParamSpec("P")
T = TypeVar("T")
T = TypeVar("T", infer_variance=True)
Future = Union[concurrent.futures.Future[T], asyncio.Future[T]]

_SyncOrAsyncCallable: TypeAlias = Callable[P, Union[T, Awaitable[T]]]

def create_call(__fn: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> Call[T]:

def create_call(
__fn: _SyncOrAsyncCallable[P, T], *args: P.args, **kwargs: P.kwargs
) -> Call[T]:
return Call[T].new(__fn, *args, **kwargs)


def _cast_to_call(call_like: Union[Callable[[], T], Call[T]]) -> Call[T]:
def cast_to_call(
call_like: Union["_SyncOrAsyncCallable[[], T]", Call[T]],
) -> Call[T]:
if isinstance(call_like, Call):
return call_like
return cast(Call[T], call_like)
else:
return create_call(call_like)


class _base(abc.ABC):
@abc.abstractstaticmethod
@staticmethod
@abc.abstractmethod
def wait_for_call_in_loop_thread(
__call: Union[Callable[[], T], Call[T]], # type: ignore[reportGeneralTypeIssues]
__call: Union["_SyncOrAsyncCallable[[], Any]", Call[T]],
timeout: Optional[float] = None,
done_callbacks: Optional[Iterable[Call[Any]]] = None,
) -> T:
Expand All @@ -60,9 +56,10 @@ def wait_for_call_in_loop_thread(
"""
raise NotImplementedError()

@abc.abstractstaticmethod
@staticmethod
@abc.abstractmethod
def wait_for_call_in_new_thread(
__call: Union[Callable[[], T], Call[T]], # type: ignore[reportGeneralTypeIssues]
__call: Union["_SyncOrAsyncCallable[[], T]", Call[T]],
timeout: Optional[float] = None,
done_callbacks: Optional[Iterable[Call[Any]]] = None,
) -> T:
Expand All @@ -75,30 +72,31 @@ def wait_for_call_in_new_thread(

@staticmethod
def call_soon_in_new_thread(
__call: Union[Callable[[], T], Call[T]], timeout: Optional[float] = None
__call: Union["_SyncOrAsyncCallable[[], T]", Call[T]],
timeout: Optional[float] = None,
) -> Call[T]:
"""
Schedule a call for execution in a new worker thread.

Returns the submitted call.
"""
call = _cast_to_call(__call)
call = cast_to_call(__call)
runner = WorkerThread(run_once=True)
call.set_timeout(timeout)
runner.submit(call)
return call

@staticmethod
def call_soon_in_loop_thread(
__call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]],
__call: Union["_SyncOrAsyncCallable[[], T]", Call[T]],
timeout: Optional[float] = None,
) -> Call[T]:
"""
Schedule a call for execution in the global event loop thread.

Returns the submitted call.
"""
call = _cast_to_call(__call)
call = cast_to_call(__call)
runner = get_global_loop()
call.set_timeout(timeout)
runner.submit(call)
Expand All @@ -117,7 +115,7 @@ def call_in_new_thread(

@staticmethod
def call_in_loop_thread(
__call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]],
__call: Union[Callable[[], Awaitable[T]], Call[T]],
timeout: Optional[float] = None,
) -> T:
"""
Expand All @@ -131,12 +129,12 @@ def call_in_loop_thread(
class from_async(_base):
@staticmethod
async def wait_for_call_in_loop_thread(
__call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]],
__call: Union[Callable[[], Awaitable[T]], Call[T]],
timeout: Optional[float] = None,
done_callbacks: Optional[Iterable[Call[Any]]] = None,
contexts: Optional[Iterable[ContextManager[Any]]] = None,
) -> Awaitable[T]:
call = _cast_to_call(__call)
contexts: Optional[Iterable[AbstractContextManager[Any]]] = None,
) -> T:
call = cast_to_call(__call)
waiter = AsyncWaiter(call)
for callback in done_callbacks or []:
waiter.add_done_callback(callback)
Expand All @@ -153,7 +151,7 @@ async def wait_for_call_in_new_thread(
timeout: Optional[float] = None,
done_callbacks: Optional[Iterable[Call[Any]]] = None,
) -> T:
call = _cast_to_call(__call)
call = cast_to_call(__call)
waiter = AsyncWaiter(call=call)
for callback in done_callbacks or []:
waiter.add_done_callback(callback)
Expand All @@ -170,7 +168,7 @@ def call_in_new_thread(

@staticmethod
def call_in_loop_thread(
__call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]],
__call: Union[Callable[[], Awaitable[T]], Call[T]],
timeout: Optional[float] = None,
) -> Awaitable[T]:
call = _base.call_soon_in_loop_thread(__call, timeout=timeout)
Expand All @@ -182,13 +180,13 @@ class from_sync(_base):
def wait_for_call_in_loop_thread(
__call: Union[
Callable[[], Awaitable[T]],
Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]],
Call[T],
],
timeout: Optional[float] = None,
done_callbacks: Optional[Iterable[Call]] = None,
contexts: Optional[Iterable[ContextManager]] = None,
) -> Awaitable[T]:
call = _cast_to_call(__call)
done_callbacks: Optional[Iterable[Call[T]]] = None,
contexts: Optional[Iterable[AbstractContextManager[Any]]] = None,
) -> T:
call = cast_to_call(__call)
waiter = SyncWaiter(call)
_base.call_soon_in_loop_thread(call, timeout=timeout)
for callback in done_callbacks or []:
Expand All @@ -203,9 +201,9 @@ def wait_for_call_in_loop_thread(
def wait_for_call_in_new_thread(
__call: Union[Callable[[], T], Call[T]],
timeout: Optional[float] = None,
done_callbacks: Optional[Iterable[Call]] = None,
) -> Call[T]:
call = _cast_to_call(__call)
done_callbacks: Optional[Iterable[Call[T]]] = None,
) -> T:
call = cast_to_call(__call)
waiter = SyncWaiter(call=call)
for callback in done_callbacks or []:
waiter.add_done_callback(callback)
Expand All @@ -215,20 +213,21 @@ def wait_for_call_in_new_thread(

@staticmethod
def call_in_new_thread(
__call: Union[Callable[[], T], Call[T]], timeout: Optional[float] = None
__call: Union["_SyncOrAsyncCallable[[], T]", Call[T]],
timeout: Optional[float] = None,
) -> T:
call = _base.call_soon_in_new_thread(__call, timeout=timeout)
return call.result()

@staticmethod
def call_in_loop_thread(
__call: Union[Callable[[], Awaitable[T]], Call[Awaitable[T]]],
__call: Union["_SyncOrAsyncCallable[[], T]", Call[T]],
timeout: Optional[float] = None,
) -> T:
) -> Union[Awaitable[T], T]:
if in_global_loop():
# Avoid deadlock where the call is submitted to the loop then the loop is
# blocked waiting for the call
call = _cast_to_call(__call)
call = cast_to_call(__call)
return call()

call = _base.call_soon_in_loop_thread(__call, timeout=timeout)
Expand Down
Loading
Loading