Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CHORE] Add check for stateful UDF outside of project #2771

Merged
merged 5 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ def stateful_udf(
batch_size: int | None,
concurrency: int | None,
) -> PyExpr: ...
def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ...
def check_column_name_validity(name: str, schema: PySchema): ...
def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ...
def cosine_distance(expr: PyExpr, other: PyExpr) -> PyExpr: ...
def url_download(
Expand Down
11 changes: 4 additions & 7 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from daft.api_annotations import DataframePublicAPI
from daft.context import get_context
from daft.convert import InputListType
from daft.daft import FileFormat, IOConfig, JoinStrategy, JoinType, resolve_expr
from daft.daft import FileFormat, IOConfig, JoinStrategy, JoinType, check_column_name_validity
from daft.dataframe.preview import DataFramePreview
from daft.datatype import DataType
from daft.errors import ExpressionTypeError
Expand Down Expand Up @@ -1088,12 +1088,9 @@ def __getitem__(self, item: Union[slice, int, str, Iterable[Union[str, int]]]) -
return result
elif isinstance(item, str):
schema = self._builder.schema()
if (item == "*" or item.endswith(".*")) and item not in schema.column_names():
# does not account for weird column names
# like if struct "a" has a field named "*", then a.* will wrongly fail
raise ValueError("Wildcard expressions are not supported in DataFrame.__getitem__")
expr, _ = resolve_expr(col(item)._expr, schema._schema)
return Expression._from_pyexpr(expr)
check_column_name_validity(item, schema._schema)

return col(item)
elif isinstance(item, Iterable):
schema = self._builder.schema()

Expand Down
31 changes: 20 additions & 11 deletions src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use common_hashable_float_wrapper::FloatWrapper;
use common_treenode::TreeNode;
use daft_core::{
count_mode::CountMode,
datatypes::{try_mean_supertype, try_sum_supertype, DataType, Field, FieldID},
Expand All @@ -9,7 +10,9 @@ use itertools::Itertools;

use crate::{
functions::{
function_display, function_semantic_id, scalar_function_semantic_id,
function_display, function_semantic_id,
python::PythonUDF,
scalar_function_semantic_id,
sketch::{HashableVecPercentiles, SketchExpr},
struct_::StructExpr,
FunctionEvaluator, ScalarFunction,
Expand Down Expand Up @@ -965,16 +968,6 @@ impl Expr {
_ => None,
}
}

pub fn has_agg(&self) -> bool {
use Expr::*;

match self {
Agg(_) => true,
Column(_) | Literal(_) => false,
_ => self.children().into_iter().any(|e| e.has_agg()),
}
}
}

impl Display for Expr {
Expand Down Expand Up @@ -1122,6 +1115,22 @@ pub fn is_partition_compatible(a: &[ExprRef], b: &[ExprRef]) -> bool {
a == b
}

pub fn has_agg(expr: &ExprRef) -> bool {
expr.exists(|e| matches!(e.as_ref(), Expr::Agg(_)))
}

pub fn has_stateful_udf(expr: &ExprRef) -> bool {
expr.exists(|e| {
matches!(
e.as_ref(),
Expr::Function {
func: FunctionExpr::Python(PythonUDF::Stateful(_)),
..
}
)
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
13 changes: 7 additions & 6 deletions src/daft-dsl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ pub mod python;
mod resolve_expr;
mod treenode;
pub use common_treenode;
pub use expr::binary_op;
pub use expr::col;
pub use expr::is_partition_compatible;
pub use expr::{AggExpr, ApproxPercentileParams, Expr, ExprRef, Operator, SketchType};
pub use expr::{
binary_op, col, has_agg, has_stateful_udf, is_partition_compatible, AggExpr,
ApproxPercentileParams, Expr, ExprRef, Operator, SketchType,
};
pub use lit::{lit, null_lit, Literal, LiteralValue};
#[cfg(feature = "python")]
use pyo3::prelude::*;
pub use resolve_expr::{
resolve_aggexprs, resolve_exprs, resolve_single_aggexpr, resolve_single_expr,
check_column_name_validity, resolve_aggexprs, resolve_exprs, resolve_single_aggexpr,
resolve_single_expr,
};

#[cfg(feature = "python")]
Expand All @@ -39,7 +40,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_wrapped(wrap_pyfunction!(python::stateless_udf))?;
parent.add_wrapped(wrap_pyfunction!(python::stateful_udf))?;
parent.add_wrapped(wrap_pyfunction!(python::eq))?;
parent.add_wrapped(wrap_pyfunction!(python::resolve_expr))?;
parent.add_wrapped(wrap_pyfunction!(python::check_column_name_validity))?;

Ok(())
}
5 changes: 2 additions & 3 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,8 @@ pub fn eq(expr1: &PyExpr, expr2: &PyExpr) -> PyResult<bool> {
}

#[pyfunction]
pub fn resolve_expr(expr: &PyExpr, schema: &PySchema) -> PyResult<(PyExpr, PyField)> {
let (resolved_expr, field) = crate::resolve_single_expr(expr.expr.clone(), &schema.schema)?;
Ok((resolved_expr.into(), field.into()))
pub fn check_column_name_validity(name: &str, schema: &PySchema) -> PyResult<()> {
Ok(crate::check_column_name_validity(name, &schema.schema)?)
}

#[derive(FromPyObject)]
Expand Down
108 changes: 85 additions & 23 deletions src/daft-dsl/src/resolve_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
schema::Schema,
};

use crate::{col, AggExpr, ApproxPercentileParams, Expr, ExprRef};
use crate::{col, expr::has_agg, has_stateful_udf, AggExpr, ApproxPercentileParams, Expr, ExprRef};

use common_error::{DaftError, DaftResult};

Expand Down Expand Up @@ -262,14 +262,29 @@
}

/// Resolves and validates the expression with a schema, returning the new expression and its field.
/// Specifically, makes sure the expression does not contain aggregations or stateful UDFs when they are not allowed,
/// and resolves struct accessors and wildcards.
/// May return multiple expressions if the expr contains a wildcard.
fn resolve_expr(expr: ExprRef, schema: &Schema) -> DaftResult<Vec<ExprRef>> {
///
/// TODO: Use a builder pattern for this functionality
fn resolve_expr(
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved
expr: ExprRef,
schema: &Schema,
allow_stateful_udf: bool,
) -> DaftResult<Vec<ExprRef>> {
// TODO(Kevin): Support aggregation expressions everywhere
if expr.has_agg() {
if has_agg(&expr) {
return Err(DaftError::ValueError(format!(
"Aggregation expressions are currently only allowed in agg and pivot: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383",
)));
}

if !allow_stateful_udf && has_stateful_udf(&expr) {
return Err(DaftError::ValueError(format!(
"Stateful UDFs are only allowed in projections: {expr}"
)));

Check warning on line 285 in src/daft-dsl/src/resolve_expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/resolve_expr.rs#L283-L285

Added lines #L283 - L285 were not covered by tests
}

let struct_expr_map = calculate_struct_expr_map(schema);
expand_wildcards(expr, schema, &struct_expr_map)?
.into_iter()
Expand All @@ -278,8 +293,12 @@
}

// Resolve a single expression, erroring if any kind of expansion happens.
pub fn resolve_single_expr(expr: ExprRef, schema: &Schema) -> DaftResult<(ExprRef, Field)> {
let resolved_exprs = resolve_expr(expr.clone(), schema)?;
pub fn resolve_single_expr(
expr: ExprRef,
schema: &Schema,
allow_stateful_udf: bool,
) -> DaftResult<(ExprRef, Field)> {
let resolved_exprs = resolve_expr(expr.clone(), schema, allow_stateful_udf)?;
match resolved_exprs.as_slice() {
[resolved_expr] => Ok((resolved_expr.clone(), resolved_expr.to_field(schema)?)),
_ => Err(DaftError::ValueError(format!(
Expand All @@ -293,37 +312,54 @@
pub fn resolve_exprs(
exprs: Vec<ExprRef>,
schema: &Schema,
allow_stateful_udf: bool,
) -> DaftResult<(Vec<ExprRef>, Vec<Field>)> {
// can't flat map because we need to deal with errors
let resolved_exprs: DaftResult<Vec<Vec<ExprRef>>> =
exprs.into_iter().map(|e| resolve_expr(e, schema)).collect();
let resolved_exprs: DaftResult<Vec<Vec<ExprRef>>> = exprs
.into_iter()
.map(|e| resolve_expr(e, schema, allow_stateful_udf))
.collect();
let resolved_exprs: Vec<ExprRef> = resolved_exprs?.into_iter().flatten().collect();
let resolved_fields: DaftResult<Vec<Field>> =
resolved_exprs.iter().map(|e| e.to_field(schema)).collect();
Ok((resolved_exprs, resolved_fields?))
}

/// Resolves and validates the expression with a schema, returning the extracted aggregation expression and its field.
/// Specifically, makes sure the expression does not contain aggregationsnested or stateful UDFs,
/// and resolves struct accessors and wildcards.
/// May return multiple expressions if the expr contains a wildcard.
///
/// TODO: Use a builder pattern for this functionality
fn resolve_aggexpr(expr: ExprRef, schema: &Schema) -> DaftResult<Vec<AggExpr>> {
let struct_expr_map = calculate_struct_expr_map(schema);
expand_wildcards(expr, schema, &struct_expr_map)?.into_iter().map(|expr| {
let agg_expr = extract_agg_expr(&expr)?;
let has_nested_agg = extract_agg_expr(&expr)?.children().iter().any(has_agg);

let has_nested_agg = agg_expr.children().iter().any(|e| e.has_agg());
if has_nested_agg {
return Err(DaftError::ValueError(format!(
"Nested aggregation expressions are not supported: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383"
)));

Check warning on line 340 in src/daft-dsl/src/resolve_expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/resolve_expr.rs#L338-L340

Added lines #L338 - L340 were not covered by tests
}

if has_nested_agg {
return Err(DaftError::ValueError(format!(
"Nested aggregation expressions are not supported: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383"
)));
}
if has_stateful_udf(&expr) {
return Err(DaftError::ValueError(format!(
"Stateful UDFs are only allowed in projections: {expr}"
)));

Check warning on line 346 in src/daft-dsl/src/resolve_expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/resolve_expr.rs#L344-L346

Added lines #L344 - L346 were not covered by tests
}

let resolved_children = agg_expr
.children()
.into_iter()
.map(|e| transform_struct_gets(e, &struct_expr_map))
.collect::<DaftResult<Vec<_>>>()?;
Ok(agg_expr.with_new_children(resolved_children))
}).collect()
let struct_expr_map = calculate_struct_expr_map(schema);
expand_wildcards(expr, schema, &struct_expr_map)?
.into_iter()
.map(|expr| {
let agg_expr = extract_agg_expr(&expr)?;

let resolved_children = agg_expr
.children()
.into_iter()
.map(|e| transform_struct_gets(e, &struct_expr_map))
.collect::<DaftResult<Vec<_>>>()?;
Ok(agg_expr.with_new_children(resolved_children))
})
.collect()
}

pub fn resolve_single_aggexpr(expr: ExprRef, schema: &Schema) -> DaftResult<(AggExpr, Field)> {
Expand Down Expand Up @@ -353,6 +389,32 @@
Ok((resolved_exprs, resolved_fields?))
}

pub fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> {
let struct_expr_map = calculate_struct_expr_map(schema);

let names = if name.contains('*') {
if let Ok(names) = get_wildcard_matches(name, schema, &struct_expr_map) {
names

Check warning on line 397 in src/daft-dsl/src/resolve_expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/resolve_expr.rs#L396-L397

Added lines #L396 - L397 were not covered by tests
} else {
return Err(DaftError::ValueError(format!(
"Error matching wildcard `{name}` in schema: {schema}"
)));

Check warning on line 401 in src/daft-dsl/src/resolve_expr.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-dsl/src/resolve_expr.rs#L399-L401

Added lines #L399 - L401 were not covered by tests
}
} else {
vec![name.into()]
};

for n in names {
if !struct_expr_map.contains_key(&n) {
return Err(DaftError::ValueError(format!(
"Column `{n}` not found in schema: {schema}"
)));
}
}

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/actor_pool_project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub struct ActorPoolProject {
impl ActorPoolProject {
pub(crate) fn try_new(input: Arc<LogicalPlan>, projection: Vec<ExprRef>) -> Result<Self> {
let (projection, fields) =
resolve_exprs(projection, input.schema().as_ref()).context(CreationSnafu)?;
resolve_exprs(projection, input.schema().as_ref(), true).context(CreationSnafu)?;
kevinzwang marked this conversation as resolved.
Show resolved Hide resolved

let num_stateful_udf_exprs: usize = projection
.iter()
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl Aggregate {
) -> logical_plan::Result<Self> {
let upstream_schema = input.schema();
let (groupby, groupby_fields) =
resolve_exprs(groupby, &upstream_schema).context(CreationSnafu)?;
resolve_exprs(groupby, &upstream_schema, false).context(CreationSnafu)?;
let (aggregations, aggregation_fields) =
resolve_aggexprs(aggregations, &upstream_schema).context(CreationSnafu)?;

Expand Down
3 changes: 2 additions & 1 deletion src/daft-plan/src/logical_ops/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ impl Explode {
) -> logical_plan::Result<Self> {
let upstream_schema = input.schema();

let (to_explode, _) = resolve_exprs(to_explode, &upstream_schema).context(CreationSnafu)?;
let (to_explode, _) =
resolve_exprs(to_explode, &upstream_schema, false).context(CreationSnafu)?;

let explode_exprs = to_explode
.iter()
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct Filter {
impl Filter {
pub(crate) fn try_new(input: Arc<LogicalPlan>, predicate: ExprRef) -> Result<Self> {
let (predicate, field) =
resolve_single_expr(predicate, &input.schema()).context(CreationSnafu)?;
resolve_single_expr(predicate, &input.schema(), false).context(CreationSnafu)?;

if !matches!(field.dtype, DataType::Boolean) {
return Err(DaftError::ValueError(format!(
Expand Down
4 changes: 2 additions & 2 deletions src/daft-plan/src/logical_ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ impl Join {
join_strategy: Option<JoinStrategy>,
) -> logical_plan::Result<Self> {
let (left_on, left_fields) =
resolve_exprs(left_on, &left.schema()).context(CreationSnafu)?;
resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, right_fields) =
resolve_exprs(right_on, &right.schema()).context(CreationSnafu)?;
resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?;

for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] {
let on_schema = Schema::new(on_fields).context(CreationSnafu)?;
Expand Down
6 changes: 3 additions & 3 deletions src/daft-plan/src/logical_ops/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ impl Pivot {
) -> logical_plan::Result<Self> {
let upstream_schema = input.schema();
let (group_by, group_by_fields) =
resolve_exprs(group_by, &upstream_schema).context(CreationSnafu)?;
resolve_exprs(group_by, &upstream_schema, false).context(CreationSnafu)?;
let (pivot_column, _) =
resolve_single_expr(pivot_column, &upstream_schema).context(CreationSnafu)?;
resolve_single_expr(pivot_column, &upstream_schema, false).context(CreationSnafu)?;
let (value_column, value_col_field) =
resolve_single_expr(value_column, &upstream_schema).context(CreationSnafu)?;
resolve_single_expr(value_column, &upstream_schema, false).context(CreationSnafu)?;
let (aggregation, _) =
resolve_single_aggexpr(aggregation, &upstream_schema).context(CreationSnafu)?;

Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub struct Project {
impl Project {
pub(crate) fn try_new(input: Arc<LogicalPlan>, projection: Vec<ExprRef>) -> Result<Self> {
let (projection, fields) =
resolve_exprs(projection, &input.schema()).context(CreationSnafu)?;
resolve_exprs(projection, &input.schema(), true).context(CreationSnafu)?;

// Factor the projection and see if there are any substitutions to factor out.
let (factored_input, factored_projection) =
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/repartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Repartition {
) -> DaftResult<Self> {
let repartition_spec = match repartition_spec {
RepartitionSpec::Hash(HashRepartitionConfig { num_partitions, by }) => {
let (resolved_by, _) = resolve_exprs(by, &input.schema())?;
let (resolved_by, _) = resolve_exprs(by, &input.schema(), false)?;
RepartitionSpec::Hash(HashRepartitionConfig {
num_partitions,
by: resolved_by,
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl Sink {
let resolved_partition_cols = partition_cols
.clone()
.map(|cols| {
resolve_exprs(cols, &schema).map(|(resolved_cols, _)| resolved_cols)
resolve_exprs(cols, &schema, false).map(|(resolved_cols, _)| resolved_cols)
})
.transpose()?;

Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl Sort {
}

let (sort_by, sort_by_fields) =
resolve_exprs(sort_by, &input.schema()).context(CreationSnafu)?;
resolve_exprs(sort_by, &input.schema(), false).context(CreationSnafu)?;

let sort_by_resolved_schema = Schema::new(sort_by_fields).context(CreationSnafu)?;

Expand Down
Loading
Loading