Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

call FlowRunContext.model_rebuild() in hydrate_context #16628

Merged
merged 5 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 49 additions & 46 deletions src/integrations/prefect-ray/prefect_ray/task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,23 @@ def count_to(highest_number):
#9
```
"""
from __future__ import annotations

import asyncio # noqa: I001
from typing import (
TYPE_CHECKING,
Any,
Callable,
Coroutine,
Dict,
Iterable,
Optional,
Set,
TypeVar,
overload,
)
from typing_extensions import ParamSpec
from uuid import UUID, uuid4

from typing_extensions import ParamSpec, Self

from prefect.client.schemas.objects import TaskRunInput
from prefect.context import serialize_context
from prefect.futures import PrefectFuture, PrefectFutureList, PrefectWrappedFuture
Expand All @@ -108,12 +109,12 @@ def count_to(highest_number):

P = ParamSpec("P")
T = TypeVar("T")
F = TypeVar("F", bound=PrefectFuture)
F = TypeVar("F", bound=PrefectFuture[Any])
R = TypeVar("R")


class PrefectRayFuture(PrefectWrappedFuture[R, "ray.ObjectRef"]):
def wait(self, timeout: Optional[float] = None) -> None:
def wait(self, timeout: float | None = None) -> None:
try:
result = ray.get(self.wrapped_future, timeout=timeout)
except ray.exceptions.GetTimeoutError:
Expand All @@ -125,7 +126,7 @@ def wait(self, timeout: Optional[float] = None) -> None:

def result(
self,
timeout: Optional[float] = None,
timeout: float | None = None,
raise_on_failure: bool = True,
) -> R:
if not self._final_state:
Expand All @@ -150,10 +151,10 @@ def result(
_result = run_coro_as_sync(_result)
return _result

def add_done_callback(self, fn):
def add_done_callback(self, fn: Callable[["PrefectRayFuture[R]"], Any]):
if not self._final_state:

def call_with_self(future):
def call_with_self(future: "PrefectRayFuture[R]"):
"""Call the callback with self as the argument, this is necessary to ensure we remove the future from the pending set"""
fn(self)

Expand All @@ -162,25 +163,27 @@ def call_with_self(future):
fn(self)

def __del__(self):
# If we already have a final state, skip
if self._final_state:
return

try:
ray.get(self.wrapped_future, timeout=0)
return
except ray.exceptions.GetTimeoutError:
pass

# logging in __del__ can also fail at shutdown
try:
local_logger = get_run_logger()
except Exception:
local_logger = logger
local_logger.warning(
"A future was garbage collected before it resolved."
" Please call `.wait()` or `.result()` on futures to ensure they resolve.",
" Please call `.wait()` or `.result()` on futures to ensure they resolve."
)


class RayTaskRunner(TaskRunner[PrefectRayFuture]):
class RayTaskRunner(TaskRunner[PrefectRayFuture[R]]):
"""
A parallel task_runner that submits tasks to `ray`.
By default, a temporary Ray cluster is created for the duration of the flow run.
Expand Down Expand Up @@ -209,7 +212,7 @@ def my_flow():
def __init__(
self,
address: Optional[str] = None,
init_kwargs: Optional[Dict] = None,
init_kwargs: dict[str, Any] | None = None,
):
# Store settings
self.address = address
Expand Down Expand Up @@ -243,28 +246,28 @@ def __eq__(self, other: object) -> bool:
def submit(
self,
task: "Task[P, Coroutine[Any, Any, R]]",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
parameters: dict[str, Any],
wait_for: Iterable[PrefectFuture[Any]] | None = None,
dependencies: dict[str, set[TaskRunInput]] | None = None,
) -> PrefectRayFuture[R]:
...

@overload
def submit(
self,
task: "Task[Any, R]",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
task: "Task[P, R]",
parameters: dict[str, Any],
wait_for: Iterable[PrefectFuture[Any]] | None = None,
dependencies: dict[str, set[TaskRunInput]] | None = None,
) -> PrefectRayFuture[R]:
...

def submit(
self,
task: Task,
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
task: Task[P, R],
parameters: dict[str, Any],
wait_for: Iterable[PrefectFuture[Any]] | None = None,
dependencies: dict[str, set[TaskRunInput]] | None = None,
):
if not self._started:
raise RuntimeError(
Expand Down Expand Up @@ -296,40 +299,40 @@ def submit(
context=context,
)
)
return PrefectRayFuture(task_run_id=task_run_id, wrapped_future=object_ref)
return PrefectRayFuture[R](task_run_id=task_run_id, wrapped_future=object_ref)

@overload
def map(
self,
task: "Task[P, Coroutine[Any, Any, R]]",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
parameters: dict[str, Any],
wait_for: Iterable[PrefectFuture[Any]] | None = None,
) -> PrefectFutureList[PrefectRayFuture[R]]:
...

@overload
def map(
self,
task: "Task[Any, R]",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
parameters: dict[str, Any],
wait_for: Iterable[PrefectFuture[Any]] | None = None,
) -> PrefectFutureList[PrefectRayFuture[R]]:
...

def map(
self,
task: "Task",
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
):
task: "Task[P, R]",
parameters: dict[str, Any],
wait_for: Iterable[PrefectFuture[Any]] | None = None,
) -> PrefectFutureList[PrefectRayFuture[R]]:
return super().map(task, parameters, wait_for)

def _exchange_prefect_for_ray_futures(self, kwargs_prefect_futures):
def _exchange_prefect_for_ray_futures(self, kwargs_prefect_futures: dict[str, Any]):
"""Exchanges Prefect futures for Ray futures."""

upstream_ray_obj_refs = []
upstream_ray_obj_refs: list[Any] = []

def exchange_prefect_for_ray_future(expr):
def exchange_prefect_for_ray_future(expr: Any):
"""Exchanges Prefect future for Ray future."""
if isinstance(expr, PrefectRayFuture):
ray_future = expr.wrapped_future
Expand All @@ -347,14 +350,14 @@ def exchange_prefect_for_ray_future(expr):

@staticmethod
def _run_prefect_task(
*upstream_ray_obj_refs,
task: Task,
*upstream_ray_obj_refs: Any,
task: Task[P, R],
task_run_id: UUID,
context: Dict[str, Any],
parameters: Dict[str, Any],
wait_for: Optional[Iterable[PrefectFuture]] = None,
dependencies: Optional[Dict[str, Set[TaskRunInput]]] = None,
):
context: dict[str, Any],
parameters: dict[str, Any],
wait_for: Iterable[PrefectFuture[Any]] | None = None,
dependencies: dict[str, set[TaskRunInput]] | None = None,
) -> Any:
"""Resolves Ray futures before calling the actual Prefect task function.

Passing upstream_ray_obj_refs directly as args enables Ray to wait for
Expand All @@ -364,7 +367,7 @@ def _run_prefect_task(
"""

# Resolve Ray futures to ensure that the task function receives the actual values
def resolve_ray_future(expr):
def resolve_ray_future(expr: Any):
"""Resolves Ray future."""
if isinstance(expr, ray.ObjectRef):
return ray.get(expr)
Expand All @@ -374,7 +377,7 @@ def resolve_ray_future(expr):
parameters, visit_fn=resolve_ray_future, return_data=True
)

run_task_kwargs = {
run_task_kwargs: dict[str, Any] = {
"task": task,
"task_run_id": task_run_id,
"parameters": parameters,
Expand All @@ -391,7 +394,7 @@ def resolve_ray_future(expr):
else:
return run_task_sync(**run_task_kwargs)

def __enter__(self):
def __enter__(self) -> Self:
super().__enter__()

if ray.is_initialized():
Expand Down Expand Up @@ -424,7 +427,7 @@ def __enter__(self):

return self

def __exit__(self, *exc_info):
def __exit__(self, *exc_info: Any):
"""
Shuts down the driver/cluster.
"""
Expand Down
26 changes: 24 additions & 2 deletions src/integrations/prefect-ray/tests/test_task_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,20 @@ def machine_ray_instance():
"""
Starts a ray instance for the current machine
"""
# First ensure any existing Ray processes are stopped
try:
subprocess.run(
["ray", "stop"],
check=True,
capture_output=True,
cwd=str(prefect.__development_base_path__),
)
except subprocess.CalledProcessError:
# It's okay if ray stop fails - it might not be running
pass

try:
# Start Ray with clean session
subprocess.check_output(
[
"ray",
Expand All @@ -70,9 +83,18 @@ def machine_ray_instance():
)
yield "ray://127.0.0.1:10001"
except subprocess.CalledProcessError as exc:
pytest.fail(f"Failed to start ray: {exc.stderr}")
pytest.fail(f"Failed to start ray: {exc.stderr or exc}")
finally:
subprocess.run(["ray", "stop"])
# Always try to stop Ray in the cleanup
try:
subprocess.run(
["ray", "stop"],
check=True,
capture_output=True,
cwd=str(prefect.__development_base_path__),
)
except subprocess.CalledProcessError:
pass # Best effort cleanup


@pytest.fixture
Expand Down
16 changes: 15 additions & 1 deletion src/prefect/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def serialize_context() -> dict[str, Any]:
"""
Serialize the current context for use in a remote execution environment.
"""

flow_run_context = EngineContext.get()
task_run_context = TaskRunContext.get()
tags_context = TagsContext.get()
Expand All @@ -71,6 +70,21 @@ def hydrated_context(
serialized_context: Optional[dict[str, Any]] = None,
client: Union[PrefectClient, SyncPrefectClient, None] = None,
) -> Generator[None, Any, None]:
# We need to rebuild the models because we might be hydrating in a remote
# environment where the models are not available.
# TODO: Remove this once we have fixed our circular imports and we don't need to rebuild models any more.
from prefect.flows import Flow
from prefect.results import ResultRecordMetadata
from prefect.tasks import Task

_types: dict[str, Any] = dict(
Flow=Flow,
Task=Task,
ResultRecordMetadata=ResultRecordMetadata,
)
FlowRunContext.model_rebuild(_types_namespace=_types)
TaskRunContext.model_rebuild(_types_namespace=_types)

with ExitStack() as stack:
if serialized_context:
# Set up settings context
Expand Down
Loading
Loading