diff --git a/src/integrations/prefect-ray/prefect_ray/task_runners.py b/src/integrations/prefect-ray/prefect_ray/task_runners.py index c8bcd9a62621..d8bef2ac5dcc 100644 --- a/src/integrations/prefect-ray/prefect_ray/task_runners.py +++ b/src/integrations/prefect-ray/prefect_ray/task_runners.py @@ -70,22 +70,23 @@ def count_to(highest_number): #9 ``` """ +from __future__ import annotations import asyncio # noqa: I001 from typing import ( TYPE_CHECKING, Any, + Callable, Coroutine, - Dict, Iterable, Optional, - Set, TypeVar, overload, ) -from typing_extensions import ParamSpec from uuid import UUID, uuid4 +from typing_extensions import ParamSpec, Self + from prefect.client.schemas.objects import TaskRunInput from prefect.context import serialize_context from prefect.futures import PrefectFuture, PrefectFutureList, PrefectWrappedFuture @@ -108,12 +109,12 @@ def count_to(highest_number): P = ParamSpec("P") T = TypeVar("T") -F = TypeVar("F", bound=PrefectFuture) +F = TypeVar("F", bound=PrefectFuture[Any]) R = TypeVar("R") class PrefectRayFuture(PrefectWrappedFuture[R, "ray.ObjectRef"]): - def wait(self, timeout: Optional[float] = None) -> None: + def wait(self, timeout: float | None = None) -> None: try: result = ray.get(self.wrapped_future, timeout=timeout) except ray.exceptions.GetTimeoutError: @@ -125,7 +126,7 @@ def wait(self, timeout: Optional[float] = None) -> None: def result( self, - timeout: Optional[float] = None, + timeout: float | None = None, raise_on_failure: bool = True, ) -> R: if not self._final_state: @@ -150,10 +151,10 @@ def result( _result = run_coro_as_sync(_result) return _result - def add_done_callback(self, fn): + def add_done_callback(self, fn: Callable[["PrefectRayFuture[R]"], Any]): if not self._final_state: - def call_with_self(future): + def call_with_self(future: "PrefectRayFuture[R]"): """Call the callback with self as the argument, this is necessary to ensure we remove the future from the pending set""" fn(self) @@ -162,25 +163,27 @@ def call_with_self(future): fn(self) def __del__(self): + # If we already have a final state, skip if self._final_state: return + try: ray.get(self.wrapped_future, timeout=0) - return except ray.exceptions.GetTimeoutError: pass + # logging in __del__ can also fail at shutdown try: local_logger = get_run_logger() except Exception: local_logger = logger local_logger.warning( "A future was garbage collected before it resolved." - " Please call `.wait()` or `.result()` on futures to ensure they resolve.", + " Please call `.wait()` or `.result()` on futures to ensure they resolve." ) -class RayTaskRunner(TaskRunner[PrefectRayFuture]): +class RayTaskRunner(TaskRunner[PrefectRayFuture[R]]): """ A parallel task_runner that submits tasks to `ray`. By default, a temporary Ray cluster is created for the duration of the flow run. @@ -209,7 +212,7 @@ def my_flow(): def __init__( self, address: Optional[str] = None, - init_kwargs: Optional[Dict] = None, + init_kwargs: dict[str, Any] | None = None, ): # Store settings self.address = address @@ -243,28 +246,28 @@ def __eq__(self, other: object) -> bool: def submit( self, task: "Task[P, Coroutine[Any, Any, R]]", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + dependencies: dict[str, set[TaskRunInput]] | None = None, ) -> PrefectRayFuture[R]: ... @overload def submit( self, - task: "Task[Any, R]", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + task: "Task[P, R]", + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + dependencies: dict[str, set[TaskRunInput]] | None = None, ) -> PrefectRayFuture[R]: ... def submit( self, - task: Task, - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + task: Task[P, R], + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + dependencies: dict[str, set[TaskRunInput]] | None = None, ): if not self._started: raise RuntimeError( @@ -296,14 +299,14 @@ def submit( context=context, ) ) - return PrefectRayFuture(task_run_id=task_run_id, wrapped_future=object_ref) + return PrefectRayFuture[R](task_run_id=task_run_id, wrapped_future=object_ref) @overload def map( self, task: "Task[P, Coroutine[Any, Any, R]]", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, ) -> PrefectFutureList[PrefectRayFuture[R]]: ... @@ -311,25 +314,25 @@ def map( def map( self, task: "Task[Any, R]", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, ) -> PrefectFutureList[PrefectRayFuture[R]]: ... def map( self, - task: "Task", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - ): + task: "Task[P, R]", + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + ) -> PrefectFutureList[PrefectRayFuture[R]]: return super().map(task, parameters, wait_for) - def _exchange_prefect_for_ray_futures(self, kwargs_prefect_futures): + def _exchange_prefect_for_ray_futures(self, kwargs_prefect_futures: dict[str, Any]): """Exchanges Prefect futures for Ray futures.""" - upstream_ray_obj_refs = [] + upstream_ray_obj_refs: list[Any] = [] - def exchange_prefect_for_ray_future(expr): + def exchange_prefect_for_ray_future(expr: Any): """Exchanges Prefect future for Ray future.""" if isinstance(expr, PrefectRayFuture): ray_future = expr.wrapped_future @@ -347,14 +350,14 @@ def exchange_prefect_for_ray_future(expr): @staticmethod def _run_prefect_task( - *upstream_ray_obj_refs, - task: Task, + *upstream_ray_obj_refs: Any, + task: Task[P, R], task_run_id: UUID, - context: Dict[str, Any], - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, - ): + context: dict[str, Any], + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + dependencies: dict[str, set[TaskRunInput]] | None = None, + ) -> Any: """Resolves Ray futures before calling the actual Prefect task function. Passing upstream_ray_obj_refs directly as args enables Ray to wait for @@ -364,7 +367,7 @@ def _run_prefect_task( """ # Resolve Ray futures to ensure that the task function receives the actual values - def resolve_ray_future(expr): + def resolve_ray_future(expr: Any): """Resolves Ray future.""" if isinstance(expr, ray.ObjectRef): return ray.get(expr) @@ -374,7 +377,7 @@ def resolve_ray_future(expr): parameters, visit_fn=resolve_ray_future, return_data=True ) - run_task_kwargs = { + run_task_kwargs: dict[str, Any] = { "task": task, "task_run_id": task_run_id, "parameters": parameters, @@ -391,7 +394,7 @@ def resolve_ray_future(expr): else: return run_task_sync(**run_task_kwargs) - def __enter__(self): + def __enter__(self) -> Self: super().__enter__() if ray.is_initialized(): @@ -424,7 +427,7 @@ def __enter__(self): return self - def __exit__(self, *exc_info): + def __exit__(self, *exc_info: Any): """ Shuts down the driver/cluster. """ diff --git a/src/integrations/prefect-ray/tests/test_task_runners.py b/src/integrations/prefect-ray/tests/test_task_runners.py index 06d963a107d0..7e5642f623e8 100644 --- a/src/integrations/prefect-ray/tests/test_task_runners.py +++ b/src/integrations/prefect-ray/tests/test_task_runners.py @@ -56,7 +56,20 @@ def machine_ray_instance(): """ Starts a ray instance for the current machine """ + # First ensure any existing Ray processes are stopped try: + subprocess.run( + ["ray", "stop"], + check=True, + capture_output=True, + cwd=str(prefect.__development_base_path__), + ) + except subprocess.CalledProcessError: + # It's okay if ray stop fails - it might not be running + pass + + try: + # Start Ray with clean session subprocess.check_output( [ "ray", @@ -70,9 +83,18 @@ def machine_ray_instance(): ) yield "ray://127.0.0.1:10001" except subprocess.CalledProcessError as exc: - pytest.fail(f"Failed to start ray: {exc.stderr}") + pytest.fail(f"Failed to start ray: {exc.stderr or exc}") finally: - subprocess.run(["ray", "stop"]) + # Always try to stop Ray in the cleanup + try: + subprocess.run( + ["ray", "stop"], + check=True, + capture_output=True, + cwd=str(prefect.__development_base_path__), + ) + except subprocess.CalledProcessError: + pass # Best effort cleanup @pytest.fixture diff --git a/src/prefect/context.py b/src/prefect/context.py index 31298b0b1b43..6e460ff2f2fa 100644 --- a/src/prefect/context.py +++ b/src/prefect/context.py @@ -52,7 +52,6 @@ def serialize_context() -> dict[str, Any]: """ Serialize the current context for use in a remote execution environment. """ - flow_run_context = EngineContext.get() task_run_context = TaskRunContext.get() tags_context = TagsContext.get() @@ -71,6 +70,21 @@ def hydrated_context( serialized_context: Optional[dict[str, Any]] = None, client: Union[PrefectClient, SyncPrefectClient, None] = None, ) -> Generator[None, Any, None]: + # We need to rebuild the models because we might be hydrating in a remote + # environment where the models are not available. + # TODO: Remove this once we have fixed our circular imports and we don't need to rebuild models any more. + from prefect.flows import Flow + from prefect.results import ResultRecordMetadata + from prefect.tasks import Task + + _types: dict[str, Any] = dict( + Flow=Flow, + Task=Task, + ResultRecordMetadata=ResultRecordMetadata, + ) + FlowRunContext.model_rebuild(_types_namespace=_types) + TaskRunContext.model_rebuild(_types_namespace=_types) + with ExitStack() as stack: if serialized_context: # Set up settings context diff --git a/src/prefect/task_runners.py b/src/prefect/task_runners.py index 497e34c1fbf5..9e13c873d9d3 100644 --- a/src/prefect/task_runners.py +++ b/src/prefect/task_runners.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc import asyncio import sys @@ -14,7 +16,6 @@ Iterable, List, Optional, - Set, overload, ) @@ -44,7 +45,7 @@ P = ParamSpec("P") T = TypeVar("T") R = TypeVar("R") -F = TypeVar("F", bound=PrefectFuture, default=PrefectConcurrentFuture) +F = TypeVar("F", bound=PrefectFuture[Any], default=PrefectConcurrentFuture[Any]) class TaskRunner(abc.ABC, Generic[F]): @@ -76,10 +77,10 @@ def duplicate(self) -> Self: @abc.abstractmethod def submit( self, - task: "Task", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + task: "Task[P, R]", + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + dependencies: dict[str, set[TaskRunInput]] | None = None, ) -> F: """ Submit a task to the task run engine. @@ -98,7 +99,7 @@ def submit( def map( self, task: "Task[P, R]", - parameters: Dict[str, Any], + parameters: dict[str, Any], wait_for: Optional[Iterable[PrefectFuture[R]]] = None, ) -> PrefectFutureList[F]: """ @@ -138,9 +139,9 @@ def map( # Ensure that any parameters in kwargs are expanded before this check parameters = explode_variadic_parameter(task.fn, parameters) - iterable_parameters = {} - static_parameters = {} - annotated_parameters = {} + iterable_parameters: dict[str, Any] = {} + static_parameters: dict[str, Any] = {} + annotated_parameters: dict[str, Any] = {} for key, val in parameters.items(): if isinstance(val, (allow_failure, quote)): # Unwrap annotated parameters to determine if they are iterable @@ -172,9 +173,9 @@ def map( map_length = list(lengths)[0] - futures: List[PrefectFuture] = [] + futures: List[PrefectFuture[Any]] = [] for i in range(map_length): - call_parameters = { + call_parameters: dict[str, Any] = { key: value[i] for key, value in iterable_parameters.items() } call_parameters.update( @@ -212,12 +213,12 @@ def __enter__(self): self._started = True return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): self.logger.debug("Stopping task runner") self._started = False -class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture]): +class ThreadPoolTaskRunner(TaskRunner[PrefectConcurrentFuture[Any]]): def __init__(self, max_workers: Optional[int] = None): super().__init__() self._executor: Optional[ThreadPoolExecutor] = None @@ -235,9 +236,9 @@ def duplicate(self) -> "ThreadPoolTaskRunner": def submit( self, task: "Task[P, Coroutine[Any, Any, R]]", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + dependencies: dict[str, set[TaskRunInput]] | None = None, ) -> PrefectConcurrentFuture[R]: ... @@ -245,19 +246,19 @@ def submit( def submit( self, task: "Task[Any, R]", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + dependencies: dict[str, set[TaskRunInput]] | None = None, ) -> PrefectConcurrentFuture[R]: ... def submit( self, - task: "Task", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, - ): + task: "Task[P, R]", + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + dependencies: dict[str, set[TaskRunInput]] | None = None, + ) -> PrefectConcurrentFuture[R]: """ Submit a task to the task run engine running in a separate thread. @@ -289,7 +290,7 @@ def submit( else: self.logger.debug(f"Submitting task {task.name} to thread pool executor...") - submit_kwargs = dict( + submit_kwargs: dict[str, Any] = dict( task=task, task_run_id=task_run_id, parameters=parameters, @@ -322,8 +323,8 @@ def submit( def map( self, task: "Task[P, Coroutine[Any, Any, R]]", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, ) -> PrefectFutureList[PrefectConcurrentFuture[R]]: ... @@ -331,17 +332,17 @@ def map( def map( self, task: "Task[Any, R]", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, ) -> PrefectFutureList[PrefectConcurrentFuture[R]]: ... def map( self, - task: "Task", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - ): + task: "Task[P, R]", + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + ) -> PrefectFutureList[PrefectConcurrentFuture[R]]: return super().map(task, parameters, wait_for) def cancel_all(self): @@ -358,7 +359,7 @@ def __enter__(self): self._executor = ThreadPoolExecutor(max_workers=self._max_workers) return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any): self.cancel_all() if self._executor is not None: self._executor.shutdown(cancel_futures=True) @@ -375,7 +376,7 @@ def __eq__(self, value: object) -> bool: ConcurrentTaskRunner = ThreadPoolTaskRunner -class PrefectTaskRunner(TaskRunner[PrefectDistributedFuture]): +class PrefectTaskRunner(TaskRunner[PrefectDistributedFuture[R]]): def __init__(self): super().__init__() @@ -386,9 +387,9 @@ def duplicate(self) -> "PrefectTaskRunner": def submit( self, task: "Task[P, Coroutine[Any, Any, R]]", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + dependencies: dict[str, set[TaskRunInput]] | None = None, ) -> PrefectDistributedFuture[R]: ... @@ -396,19 +397,19 @@ def submit( def submit( self, task: "Task[Any, R]", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + dependencies: dict[str, set[TaskRunInput]] | None = None, ) -> PrefectDistributedFuture[R]: ... def submit( self, - task: "Task", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None, - ): + task: "Task[P, R]", + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + dependencies: dict[str, set[TaskRunInput]] | None = None, + ) -> PrefectDistributedFuture[R]: """ Submit a task to the task run engine running in a separate thread. @@ -443,8 +444,8 @@ def submit( def map( self, task: "Task[P, Coroutine[Any, Any, R]]", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, ) -> PrefectFutureList[PrefectDistributedFuture[R]]: ... @@ -452,15 +453,15 @@ def map( def map( self, task: "Task[Any, R]", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, ) -> PrefectFutureList[PrefectDistributedFuture[R]]: ... def map( self, - task: "Task", - parameters: Dict[str, Any], - wait_for: Optional[Iterable[PrefectFuture]] = None, - ): + task: "Task[P, R]", + parameters: dict[str, Any], + wait_for: Iterable[PrefectFuture[Any]] | None = None, + ) -> PrefectFutureList[PrefectDistributedFuture[R]]: return super().map(task, parameters, wait_for) diff --git a/src/prefect/telemetry/services.py b/src/prefect/telemetry/services.py index a9825094f0e0..e2fcd928af69 100644 --- a/src/prefect/telemetry/services.py +++ b/src/prefect/telemetry/services.py @@ -53,6 +53,8 @@ def __init__(self, endpoint: str, headers: tuple[tuple[str, str]]): def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult: for item in spans: + if self._stopped: + break self.send(item) return SpanExportResult.SUCCESS @@ -65,4 +67,6 @@ def __init__(self, endpoint: str, headers: tuple[tuple[str, str]]) -> None: def export(self, batch: Sequence[LogData]) -> None: for item in batch: + if self._stopped: + break self.send(item)