Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Fix actor pool initialization in ray client mode #3028

Merged
merged 14 commits into from
Oct 21, 2024
Merged
18 changes: 16 additions & 2 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class _RayRunnerConfig(_RunnerConfig):
name = "ray"
address: str | None
max_task_backlog: int | None
force_client_mode: bool


def _get_runner_config_from_env() -> _RunnerConfig:
Expand All @@ -43,10 +44,18 @@ def _get_runner_config_from_env() -> _RunnerConfig:
2. RayRunner: set DAFT_RUNNER=ray and optionally RAY_ADDRESS=ray://...
"""
runner_from_envvar = os.getenv("DAFT_RUNNER")

task_backlog_env = os.getenv("DAFT_DEVELOPER_RAY_MAX_TASK_BACKLOG")
task_backlog = int(task_backlog_env) if task_backlog_env is not None else None

use_thread_pool_env = os.getenv("DAFT_DEVELOPER_USE_THREAD_POOL")
use_thread_pool = bool(int(use_thread_pool_env)) if use_thread_pool_env is not None else None

ray_force_client_mode_env = os.getenv("DAFT_RAY_FORCE_CLIENT_MODE")
ray_force_client_mode = (
ray_force_client_mode_env.strip().lower() in ["1", "true"] if ray_force_client_mode_env else False
)

ray_is_initialized = False
in_ray_worker = False
try:
Expand All @@ -71,7 +80,8 @@ def _get_runner_config_from_env() -> _RunnerConfig:
ray_address = os.getenv("RAY_ADDRESS")
return _RayRunnerConfig(
address=ray_address,
max_task_backlog=int(task_backlog_env) if task_backlog_env else None,
max_task_backlog=task_backlog,
force_client_mode=ray_force_client_mode,
)
elif runner_from_envvar and runner_from_envvar.upper() == "PY":
return _PyRunnerConfig(use_thread_pool=use_thread_pool)
Expand All @@ -82,7 +92,8 @@ def _get_runner_config_from_env() -> _RunnerConfig:
elif ray_is_initialized and not in_ray_worker:
return _RayRunnerConfig(
address=None, # No address supplied, use the existing connection
max_task_backlog=int(task_backlog_env) if task_backlog_env else None,
max_task_backlog=task_backlog,
force_client_mode=ray_force_client_mode,
)

# Fall back on PyRunner
Expand Down Expand Up @@ -155,6 +166,7 @@ def _get_runner(self) -> Runner:
self._runner = RayRunner(
address=runner_config.address,
max_task_backlog=runner_config.max_task_backlog,
force_client_mode=runner_config.force_client_mode,
)
elif runner_config.name == "py":
from daft.runners.pyrunner import PyRunner
Expand Down Expand Up @@ -189,6 +201,7 @@ def set_runner_ray(
address: str | None = None,
noop_if_initialized: bool = False,
max_task_backlog: int | None = None,
force_client_mode: bool = False,
) -> DaftContext:
"""Set the runner for executing Daft dataframes to a Ray cluster

Expand Down Expand Up @@ -222,6 +235,7 @@ def set_runner_ray(
ctx._runner_config = _RayRunnerConfig(
address=address,
max_task_backlog=max_task_backlog,
force_client_mode=force_client_mode,
)
ctx._disallow_set_runner = True
return ctx
Expand Down
4 changes: 3 additions & 1 deletion daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1639,7 +1639,9 @@ class PhysicalPlanScheduler:
def repr_ascii(self, simple: bool) -> str: ...
def repr_mermaid(self, options: MermaidOptions) -> str: ...
def to_json_string(self) -> str: ...
def to_partition_tasks(self, psets: dict[str, list[PartitionT]]) -> physical_plan.InProgressPhysicalPlan: ...
def to_partition_tasks(
self, psets: dict[str, list[PartitionT]], actor_pool_manager: Any
) -> physical_plan.InProgressPhysicalPlan: ...
def run(self, psets: dict[str, list[PartitionT]]) -> Iterator[PyMicroPartition]: ...

class AdaptivePhysicalPlanScheduler:
Expand Down
3 changes: 1 addition & 2 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@
from daft.errors import ExpressionTypeError
from daft.expressions import Expression, ExpressionsProjection, col, lit
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.partitioning import PartitionCacheEntry, PartitionSet
from daft.runners.pyrunner import LocalPartitionSet
from daft.runners.partitioning import LocalPartitionSet, PartitionCacheEntry, PartitionSet
from daft.table import MicroPartition
from daft.viz import DataFrameDisplay

Expand Down
8 changes: 4 additions & 4 deletions daft/execution/native_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
if TYPE_CHECKING:
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.partitioning import (
LocalMaterializedResult,
MaterializedResult,
PartitionT,
)
from daft.runners.pyrunner import PyMaterializedResult


class NativeExecutor:
Expand All @@ -31,13 +31,13 @@ def run(
psets: dict[str, list[MaterializedResult[PartitionT]]],
daft_execution_config: PyDaftExecutionConfig,
results_buffer_size: int | None,
) -> Iterator[PyMaterializedResult]:
from daft.runners.pyrunner import PyMaterializedResult
) -> Iterator[LocalMaterializedResult]:
from daft.runners.partitioning import LocalMaterializedResult

psets_mp = {
part_id: [part.micropartition()._micropartition for part in parts] for part_id, parts in psets.items()
}
return (
PyMaterializedResult(MicroPartition._from_pymicropartition(part))
LocalMaterializedResult(MicroPartition._from_pymicropartition(part))
for part in self._executor.run(psets_mp, daft_execution_config, results_buffer_size)
)
31 changes: 30 additions & 1 deletion daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
from __future__ import annotations

import collections
import contextlib
import itertools
import logging
import math
from abc import abstractmethod
from collections import deque
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -205,9 +207,36 @@ def pipeline_instruction(
)


class ActorPoolManager:
@abstractmethod
@contextlib.contextmanager
def actor_pool_context(
self,
name: str,
actor_resource_request: ResourceRequest,
task_resource_request: ResourceRequest,
num_actors: int,
projection: ExpressionsProjection,
) -> Iterator[str]:
"""Creates a pool of actors which can execute work, and yield a context in which the pool can be used.

Also yields a `str` ID which clients can use to refer to the actor pool when submitting tasks.

Note that attempting to do work outside this context will result in errors!

Args:
name: Name of the actor pool for debugging/observability
resource_request: Requested amount of resources for each actor
num_actors: Number of actors to spin up
projection: Projection to be run on the incoming data (contains Stateful UDFs as well as other stateless expressions such as aliases)
"""
...


def actor_pool_project(
child_plan: InProgressPhysicalPlan[PartitionT],
projection: ExpressionsProjection,
actor_pool_manager: ActorPoolManager,
resource_request: execution_step.ResourceRequest,
num_actors: int,
) -> InProgressPhysicalPlan[PartitionT]:
Expand Down Expand Up @@ -238,7 +267,7 @@ def actor_pool_project(
num_gpus=resource_request.num_gpus, memory_bytes=resource_request.memory_bytes
)

with get_context().runner().actor_pool_context(
with actor_pool_manager.actor_pool_context(
actor_pool_name,
actor_resource_request,
task_resource_request,
Expand Down
2 changes: 2 additions & 0 deletions daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def project(
def actor_pool_project(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
projection: list[PyExpr],
actor_pool_manager: physical_plan.ActorPoolManager,
resource_request: ResourceRequest | None,
num_actors: int,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
Expand All @@ -93,6 +94,7 @@ def actor_pool_project(
return physical_plan.actor_pool_project(
child_plan=input,
projection=expr_projection,
actor_pool_manager=actor_pool_manager,
resource_request=resource_request,
num_actors=num_actors,
)
Expand Down
2 changes: 1 addition & 1 deletion daft/io/file_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from daft.daft import IOConfig
from daft.dataframe import DataFrame
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.pyrunner import LocalPartitionSet
from daft.runners.partitioning import LocalPartitionSet
from daft.table import MicroPartition


Expand Down
11 changes: 9 additions & 2 deletions daft/plan_scheduler/physical_plan_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,16 @@ def to_json_string(self) -> str:
return self._scheduler.to_json_string()

def to_partition_tasks(
self, psets: dict[str, list[PartitionT]], results_buffer_size: int | None
self,
psets: dict[str, list[PartitionT]],
actor_pool_manager: physical_plan.ActorPoolManager,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice

results_buffer_size: int | None,
) -> physical_plan.MaterializedPhysicalPlan:
return iter(physical_plan.Materialize(self._scheduler.to_partition_tasks(psets), results_buffer_size))
return iter(
physical_plan.Materialize(
self._scheduler.to_partition_tasks(psets, actor_pool_manager), results_buffer_size
)
)


class AdaptivePhysicalPlanScheduler:
Expand Down
88 changes: 87 additions & 1 deletion daft/runners/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from uuid import uuid4

from daft.datatype import TimeUnit
from daft.table import MicroPartition

if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa

from daft.expressions.expressions import Expression
from daft.logical.schema import Schema
from daft.table import MicroPartition

PartID = int

Expand Down Expand Up @@ -271,6 +271,92 @@ def wait(self) -> None:
raise NotImplementedError()


class LocalPartitionSet(PartitionSet[MicroPartition]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a pure movement of the class from pyrunner.py? Why did we decide to move it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was being imported by dataframe.py and ray_runner.py which was causing some circular import issues. I figured since it was being used outside of the pyrunner I would move it out to simplify the dependency tree.

_partitions: dict[PartID, MaterializedResult[MicroPartition]]

def __init__(self) -> None:
super().__init__()
self._partitions = {}

def items(self) -> list[tuple[PartID, MaterializedResult[MicroPartition]]]:
return sorted(self._partitions.items())

def _get_merged_micropartition(self) -> MicroPartition:
ids_and_partitions = self.items()
assert ids_and_partitions[0][0] == 0
assert ids_and_partitions[-1][0] + 1 == len(ids_and_partitions)
return MicroPartition.concat([part.partition() for id, part in ids_and_partitions])

def _get_preview_micropartitions(self, num_rows: int) -> list[MicroPartition]:
ids_and_partitions = self.items()
preview_parts = []
for _, mat_result in ids_and_partitions:
part: MicroPartition = mat_result.partition()
part_len = len(part)
if part_len >= num_rows: # if this part has enough rows, take what we need and break
preview_parts.append(part.slice(0, num_rows))
break
else: # otherwise, take the whole part and keep going
num_rows -= part_len
preview_parts.append(part)
return preview_parts

def get_partition(self, idx: PartID) -> MaterializedResult[MicroPartition]:
return self._partitions[idx]

def set_partition(self, idx: PartID, part: MaterializedResult[MicroPartition]) -> None:
self._partitions[idx] = part

def set_partition_from_table(self, idx: PartID, part: MicroPartition) -> None:
self._partitions[idx] = LocalMaterializedResult(part, PartitionMetadata.from_table(part))

def delete_partition(self, idx: PartID) -> None:
del self._partitions[idx]

def has_partition(self, idx: PartID) -> bool:
return idx in self._partitions

def __len__(self) -> int:
return sum(len(partition.partition()) for partition in self._partitions.values())

def size_bytes(self) -> int | None:
size_bytes_ = [partition.partition().size_bytes() for partition in self._partitions.values()]
size_bytes: list[int] = [size for size in size_bytes_ if size is not None]
if len(size_bytes) != len(size_bytes_):
return None
else:
return sum(size_bytes)

def num_partitions(self) -> int:
return len(self._partitions)

def wait(self) -> None:
pass


@dataclass
class LocalMaterializedResult(MaterializedResult[MicroPartition]):
_partition: MicroPartition
_metadata: PartitionMetadata | None = None

def partition(self) -> MicroPartition:
return self._partition

def micropartition(self) -> MicroPartition:
return self._partition

def metadata(self) -> PartitionMetadata:
if self._metadata is None:
self._metadata = PartitionMetadata.from_table(self._partition)
return self._metadata

def cancel(self) -> None:
return None

def _noop(self, _: MicroPartition) -> None:
return None


@dataclass(eq=False, repr=False)
class PartitionCacheEntry:
key: str
Expand Down
Loading
Loading