Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] (ACTORS-1) Add DAFT_ENABLE_ACTOR_POOL_PROJECTS=1 feature flag and specifying concurrency #2668

Merged
merged 4 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class DaftContext:
_daft_execution_config: PyDaftExecutionConfig = PyDaftExecutionConfig.from_env()

# Non-execution calls (e.g. creation of a dataframe, logical plan building etc) directly reference values in this config
_daft_planning_config: PyDaftPlanningConfig = PyDaftPlanningConfig()
_daft_planning_config: PyDaftPlanningConfig = PyDaftPlanningConfig.from_env()

_runner_config: _RunnerConfig | None = None
_disallow_set_runner: bool = False
Expand Down
6 changes: 6 additions & 0 deletions daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,7 @@ def stateful_udf(
resource_request: ResourceRequest | None,
init_args: tuple[tuple[Any, ...], dict[str, Any]] | None,
batch_size: int | None,
concurrency: int | None,
) -> PyExpr: ...
def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ...
def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ...
Expand Down Expand Up @@ -1767,12 +1768,17 @@ class PyDaftExecutionConfig:
def enable_native_executor(self) -> bool: ...

class PyDaftPlanningConfig:
@staticmethod
def from_env() -> PyDaftPlanningConfig: ...
def with_config_values(
self,
default_io_config: IOConfig | None = None,
enable_actor_pool_projections: bool | None = None,
) -> PyDaftPlanningConfig: ...
@property
def default_io_config(self) -> IOConfig: ...
@property
def enable_actor_pool_projections(self) -> bool: ...

def build_type() -> str: ...
def version() -> str: ...
Expand Down
2 changes: 2 additions & 0 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ def stateful_udf(
resource_request: ResourceRequest | None,
init_args: tuple[tuple[Any, ...], dict[builtins.str, Any]] | None,
batch_size: int | None,
concurrency: int | None,
) -> Expression:
return Expression._from_pyexpr(
_stateful_udf(
Expand All @@ -264,6 +265,7 @@ def stateful_udf(
resource_request,
init_args,
batch_size,
concurrency,
)
)

Expand Down
39 changes: 37 additions & 2 deletions daft/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Union

from daft.context import get_context
from daft.daft import PyDataType, ResourceRequest
from daft.datatype import DataType
from daft.expressions import Expression
Expand Down Expand Up @@ -198,9 +199,9 @@
memory_bytes: int | None = _UnsetMarker,
batch_size: int | None = _UnsetMarker,
) -> UDF:
"""Replace the resource requests for running each instance of your stateless UDF.
"""Replace the resource requests for running each instance of your UDF.

For instance, if your stateless UDF requires 4 CPUs to run, you can configure it like so:
For instance, if your UDF requires 4 CPUs to run, you can configure it like so:

>>> import daft
>>>
Expand Down Expand Up @@ -309,6 +310,7 @@
cls: type
return_dtype: DataType
init_args: tuple[tuple[Any, ...], dict[str, Any]] | None = None
concurrency: int | None = None

def __post_init__(self):
"""Analogous to the @functools.wraps(self.cls) pattern
Expand All @@ -319,6 +321,17 @@
functools.update_wrapper(self, self.cls)

def __call__(self, *args, **kwargs) -> Expression:
# Validate that the UDF has a concurrency set, if running with actor pool projections
if get_context().daft_planning_config.enable_actor_pool_projections:
if self.concurrency is None:
raise ValueError(

Check warning on line 327 in daft/udf.py

View check run for this annotation

Codecov / codecov/patch

daft/udf.py#L326-L327

Added lines #L326 - L327 were not covered by tests
"Cannot call StatefulUDF without supplying a concurrency argument. Daft needs to know how many instances of your StatefulUDF to run concurrently. Please parametrize your UDF using `.with_concurrency(N)` before invoking it!"
)
elif self.concurrency is not None:
raise ValueError(

Check warning on line 331 in daft/udf.py

View check run for this annotation

Codecov / codecov/patch

daft/udf.py#L331

Added line #L331 was not covered by tests
"StatefulUDF cannot be run with concurrency specified without the experimental DAFT_ENABLE_ACTOR_POOL_PROJECTIONS=1 flag set."
)

# Validate that initialization arguments are provided if the __init__ signature indicates that there are
# parameters without defaults
init_sig = inspect.signature(self.cls.__init__) # type: ignore
Expand All @@ -341,8 +354,29 @@
resource_request=self.resource_request,
init_args=self.init_args,
batch_size=self.batch_size,
concurrency=self.concurrency,
)

def with_concurrency(self, concurrency: int) -> StatefulUDF:
"""Override the concurrency of this StatefulUDF, which tells Daft how many instances of your StatefulUDF to run concurrently.

Example:

>>> import daft
>>>
>>> @daft.udf(return_dtype=daft.DataType.string(), num_gpus=1)
... class MyUDFThatNeedsAGPU:
... def __init__(self, text=" world"):
... self.text = text
...
... def __call__(self, data):
... return [x + self.text for x in data.to_pylist()]
>>>
>>> # New UDF that will have 8 concurrent running instances (will require 8 total GPUs)
>>> MyUDFThatNeedsAGPU_8_concurrency = MyUDFThatNeedsAGPU.with_concurrency(8)
"""
return dataclasses.replace(self, concurrency=concurrency)

Check warning on line 378 in daft/udf.py

View check run for this annotation

Codecov / codecov/patch

daft/udf.py#L378

Added line #L378 was not covered by tests

def with_init_args(self, *args, **kwargs) -> StatefulUDF:
"""Replace initialization arguments for the Stateful UDF when calling __init__ at runtime
on each instance of the UDF.
Expand Down Expand Up @@ -411,6 +445,7 @@
num_gpus: float | None = None,
memory_bytes: int | None = None,
batch_size: int | None = None,
_concurrency: int | None = None,
jaychia marked this conversation as resolved.
Show resolved Hide resolved
) -> Callable[[UserProvidedPythonFunction | type], StatelessUDF | StatefulUDF]:
"""Decorator to convert a Python function into a UDF

Expand Down
15 changes: 15 additions & 0 deletions src/common/daft-config/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,21 @@
#[derive(Clone, Serialize, Deserialize, Default)]
pub struct DaftPlanningConfig {
pub default_io_config: IOConfig,
pub enable_actor_pool_projections: bool,
}

impl DaftPlanningConfig {
pub fn from_env() -> Self {
let mut cfg = Self::default();

let enable_actor_pool_projections_env_var_name = "DAFT_ENABLE_ACTOR_POOL_PROJECTIONS";
if let Ok(val) = std::env::var(enable_actor_pool_projections_env_var_name)
&& matches!(val.trim().to_lowercase().as_str(), "1" | "true")
{
cfg.enable_actor_pool_projections = true;

Check warning on line 27 in src/common/daft-config/src/lib.rs

View check run for this annotation

Codecov / codecov/patch

src/common/daft-config/src/lib.rs#L25-L27

Added lines #L25 - L27 were not covered by tests
}
cfg
}
}

/// Configurations for Daft to use during the execution of a Dataframe
Expand Down
12 changes: 12 additions & 0 deletions src/common/daft-config/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ impl PyDaftPlanningConfig {
PyDaftPlanningConfig::default()
}

#[staticmethod]
pub fn from_env() -> Self {
PyDaftPlanningConfig {
config: Arc::new(DaftPlanningConfig::from_env()),
}
}

fn with_config_values(
&mut self,
default_io_config: Option<PyIOConfig>,
Expand All @@ -41,6 +48,11 @@ impl PyDaftPlanningConfig {
})
}

#[getter(enable_actor_pool_projections)]
fn enable_actor_pool_projections(&self) -> PyResult<bool> {
Ok(self.config.enable_actor_pool_projections)
}

fn __reduce__(&self, py: Python) -> PyResult<(PyObject, (Vec<u8>,))> {
let bin_data = bincode::serialize(self.config.as_ref())
.expect("DaftPlanningConfig should be serializable to bytes");
Expand Down
31 changes: 31 additions & 0 deletions src/daft-dsl/src/functions/python/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
#[cfg(feature = "python")]
pub init_args: Option<pyobj_serde::PyObjectWrapper>,
pub batch_size: Option<usize>,
pub concurrency: Option<usize>,
}

#[cfg(feature = "python")]
Expand Down Expand Up @@ -98,6 +99,7 @@
}

#[cfg(feature = "python")]
#[allow(clippy::too_many_arguments)]
pub fn stateful_udf(
name: &str,
py_stateful_partial_func: pyo3::PyObject,
Expand All @@ -106,6 +108,7 @@
resource_request: Option<ResourceRequest>,
init_args: Option<pyo3::PyObject>,
batch_size: Option<usize>,
concurrency: Option<usize>,
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF {
Expand All @@ -116,6 +119,7 @@
resource_request,
init_args: init_args.map(pyobj_serde::PyObjectWrapper),
batch_size,
concurrency,
})),
inputs: expressions.into(),
})
Expand All @@ -128,6 +132,7 @@
return_dtype: DataType,
resource_request: Option<ResourceRequest>,
batch_size: Option<usize>,
concurrency: Option<usize>,

Check warning on line 135 in src/daft-dsl/src/functions/python/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/python/mod.rs#L135

Added line #L135 was not covered by tests
) -> DaftResult<Expr> {
Ok(Expr::Function {
func: super::FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF {
Expand All @@ -136,6 +141,7 @@
return_dtype,
resource_request,
batch_size,
concurrency,

Check warning on line 144 in src/daft-dsl/src/functions/python/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/python/mod.rs#L144

Added line #L144 was not covered by tests
})),
inputs: expressions.into(),
})
Expand Down Expand Up @@ -188,3 +194,28 @@
))
}
}

/// Gets the concurrency from the first StatefulUDF encountered in a given slice of expressions
///
/// NOTE: This function panics if no StatefulUDF is found
pub fn get_concurrency(exprs: &[ExprRef]) -> usize {
let mut projection_concurrency = None;
for expr in exprs.iter() {
let mut found_stateful_udf = false;
expr.apply(|e| match e.as_ref() {

Check warning on line 205 in src/daft-dsl/src/functions/python/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/python/mod.rs#L201-L205

Added lines #L201 - L205 were not covered by tests
Expr::Function {
func: FunctionExpr::Python(PythonUDF::Stateful(StatefulPythonUDF{concurrency, ..})),
..
} => {
found_stateful_udf = true;
projection_concurrency = Some(concurrency.expect("Should have concurrency specified"));
Ok(common_treenode::TreeNodeRecursion::Stop)

Check warning on line 212 in src/daft-dsl/src/functions/python/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/python/mod.rs#L207-L212

Added lines #L207 - L212 were not covered by tests
}
_ => Ok(common_treenode::TreeNodeRecursion::Continue),
}).unwrap();
if found_stateful_udf {
break;
}

Check warning on line 218 in src/daft-dsl/src/functions/python/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/python/mod.rs#L214-L218

Added lines #L214 - L218 were not covered by tests
}
projection_concurrency.expect("get_concurrency expects one StatefulUDF")
}

Check warning on line 221 in src/daft-dsl/src/functions/python/mod.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/functions/python/mod.rs#L220-L221

Added lines #L220 - L221 were not covered by tests
2 changes: 2 additions & 0 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ pub fn stateful_udf(
resource_request: Option<ResourceRequest>,
init_args: Option<&PyAny>,
batch_size: Option<usize>,
concurrency: Option<usize>,
) -> PyResult<PyExpr> {
use crate::functions::python::stateful_udf;

Expand All @@ -224,6 +225,7 @@ pub fn stateful_udf(
resource_request,
init_args,
batch_size,
concurrency,
)?
.into(),
})
Expand Down
16 changes: 7 additions & 9 deletions src/daft-plan/src/logical_ops/actor_pool_project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
use daft_core::schema::{Schema, SchemaRef};
use daft_dsl::{
functions::{
python::{get_resource_request, PythonUDF, StatefulPythonUDF},
python::{get_concurrency, get_resource_request, PythonUDF, StatefulPythonUDF},
FunctionExpr,
},
resolve_exprs, Expr, ExprRef,
Expand All @@ -24,30 +24,28 @@
pub input: Arc<LogicalPlan>,
pub projection: Vec<ExprRef>,
pub projected_schema: SchemaRef,
pub num_actors: usize,
}

impl ActorPoolProject {
pub(crate) fn try_new(
input: Arc<LogicalPlan>,
projection: Vec<ExprRef>,
num_actors: usize,
) -> Result<Self> {
pub(crate) fn try_new(input: Arc<LogicalPlan>, projection: Vec<ExprRef>) -> Result<Self> {
let (projection, fields) =
resolve_exprs(projection, input.schema().as_ref()).context(CreationSnafu)?;
let projected_schema = Schema::new(fields).context(CreationSnafu)?.into();
Ok(ActorPoolProject {
input,
projection,
projected_schema,
num_actors,
})
}

pub fn resource_request(&self) -> Option<ResourceRequest> {
get_resource_request(self.projection.as_slice())
}

pub fn concurrency(&self) -> usize {
get_concurrency(self.projection.as_slice())
}

Check warning on line 47 in src/daft-plan/src/logical_ops/actor_pool_project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/actor_pool_project.rs#L45-L47

Added lines #L45 - L47 were not covered by tests

pub fn multiline_display(&self) -> Vec<String> {
let mut res = vec![];
res.push("ActorPoolProject:".to_string());
Expand Down Expand Up @@ -80,7 +78,7 @@
})
.join(", ")
));
res.push(format!("Num actors = {}", self.num_actors,));
res.push(format!("Concurrency = {}", self.concurrency()));

Check warning on line 81 in src/daft-plan/src/logical_ops/actor_pool_project.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_ops/actor_pool_project.rs#L81

Added line #L81 was not covered by tests
if let Some(resource_request) = self.resource_request() {
let multiline_display = resource_request.multiline_display();
res.push(format!(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,6 @@ impl PushDownProjection {
LogicalPlan::ActorPoolProject(ActorPoolProject::try_new(
upstream_actor_pool_projection.input.clone(),
pruned_upstream_projections,
upstream_actor_pool_projection.num_actors,
)?)
.arced()
};
Expand Down Expand Up @@ -819,6 +818,7 @@ mod tests {
return_dtype: DataType::Utf8,
resource_request: Some(ResourceRequest::default_cpu()),
batch_size: None,
concurrency: Some(8),
})),
inputs: vec![col("c")],
}
Expand All @@ -828,7 +828,6 @@ mod tests {
let actor_pool_project = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new(
scan_node.clone(),
vec![col("a"), col("b"), mock_stateful_udf.alias("udf_results")],
8,
)?)
.arced();
let project = LogicalPlan::Project(Project::try_new(
Expand All @@ -840,7 +839,6 @@ mod tests {
let expected_actor_pool_project = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new(
scan_node.clone(),
vec![mock_stateful_udf.alias("udf_results")],
8,
)?)
.arced();

Expand Down Expand Up @@ -872,6 +870,7 @@ mod tests {
return_dtype: DataType::Utf8,
resource_request: Some(ResourceRequest::default_cpu()),
batch_size: None,
concurrency: Some(8),
})),
inputs: vec![col("c")],
}
Expand All @@ -881,7 +880,6 @@ mod tests {
let actor_pool_project = LogicalPlan::ActorPoolProject(ActorPoolProject::try_new(
scan_node.clone(),
vec![col("a"), col("b"), mock_stateful_udf.alias("udf_results")],
8,
)?)
.arced();
let project =
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@
Self::Project(Project { projection, .. }) => Self::Project(Project::try_new(
input.clone(), projection.clone(),
).unwrap()),
Self::ActorPoolProject(ActorPoolProject {projection, num_actors, ..}) => Self::ActorPoolProject(ActorPoolProject::try_new(input.clone(), projection.clone(), *num_actors).unwrap()),
Self::ActorPoolProject(ActorPoolProject {projection, ..}) => Self::ActorPoolProject(ActorPoolProject::try_new(input.clone(), projection.clone()).unwrap()),

Check warning on line 202 in src/daft-plan/src/logical_plan.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-plan/src/logical_plan.rs#L202

Added line #L202 was not covered by tests
Self::Filter(Filter { predicate, .. }) => Self::Filter(Filter::try_new(input.clone(), predicate.clone()).unwrap()),
Self::Limit(Limit { limit, eager, .. }) => Self::Limit(Limit::new(input.clone(), *limit, *eager)),
Self::Explode(Explode { to_explode, .. }) => Self::Explode(Explode::try_new(input.clone(), to_explode.clone()).unwrap()),
Expand Down
Loading
Loading