Skip to content

Commit

Permalink
[CHORE] Add check for stateful UDF outside of project (#2771)
Browse files Browse the repository at this point in the history
Add a parameter to `resolve_expr` and its related functions to allow for
stateful UDFs in the expression. This is set to false everywhere except
for projects.

Additionally, `df['column name']` used to call `resolve_exprs` to check
for validity and returned its output, which would be problematic if
there were wildcards in the expression. Now, I've created a function
that only does the validity check, and `DataFrame.__getitem__` would
just return `col(name)`, which will be actually resolved later in the
builder.
  • Loading branch information
kevinzwang authored Sep 4, 2024
1 parent a97d871 commit c5a4adc
Show file tree
Hide file tree
Showing 18 changed files with 137 additions and 79 deletions.
2 changes: 1 addition & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1202,7 +1202,7 @@ def stateful_udf(
batch_size: int | None,
concurrency: int | None,
) -> PyExpr: ...
def resolve_expr(expr: PyExpr, schema: PySchema) -> tuple[PyExpr, PyField]: ...
def check_column_name_validity(name: str, schema: PySchema): ...
def hash(expr: PyExpr, seed: Any | None = None) -> PyExpr: ...
def cosine_distance(expr: PyExpr, other: PyExpr) -> PyExpr: ...
def url_download(
Expand Down
11 changes: 4 additions & 7 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from daft.api_annotations import DataframePublicAPI
from daft.context import get_context
from daft.convert import InputListType
from daft.daft import FileFormat, IOConfig, JoinStrategy, JoinType, resolve_expr
from daft.daft import FileFormat, IOConfig, JoinStrategy, JoinType, check_column_name_validity
from daft.dataframe.preview import DataFramePreview
from daft.datatype import DataType
from daft.errors import ExpressionTypeError
Expand Down Expand Up @@ -1088,12 +1088,9 @@ def __getitem__(self, item: Union[slice, int, str, Iterable[Union[str, int]]]) -
return result
elif isinstance(item, str):
schema = self._builder.schema()
if (item == "*" or item.endswith(".*")) and item not in schema.column_names():
# does not account for weird column names
# like if struct "a" has a field named "*", then a.* will wrongly fail
raise ValueError("Wildcard expressions are not supported in DataFrame.__getitem__")
expr, _ = resolve_expr(col(item)._expr, schema._schema)
return Expression._from_pyexpr(expr)
check_column_name_validity(item, schema._schema)

return col(item)
elif isinstance(item, Iterable):
schema = self._builder.schema()

Expand Down
31 changes: 20 additions & 11 deletions src/daft-dsl/src/expr.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use common_hashable_float_wrapper::FloatWrapper;
use common_treenode::TreeNode;
use daft_core::{
count_mode::CountMode,
datatypes::{try_mean_supertype, try_sum_supertype, DataType, Field, FieldID},
Expand All @@ -9,7 +10,9 @@ use itertools::Itertools;

use crate::{
functions::{
function_display, function_semantic_id, scalar_function_semantic_id,
function_display, function_semantic_id,
python::PythonUDF,
scalar_function_semantic_id,
sketch::{HashableVecPercentiles, SketchExpr},
struct_::StructExpr,
FunctionEvaluator, ScalarFunction,
Expand Down Expand Up @@ -965,16 +968,6 @@ impl Expr {
_ => None,
}
}

pub fn has_agg(&self) -> bool {
use Expr::*;

match self {
Agg(_) => true,
Column(_) | Literal(_) => false,
_ => self.children().into_iter().any(|e| e.has_agg()),
}
}
}

impl Display for Expr {
Expand Down Expand Up @@ -1122,6 +1115,22 @@ pub fn is_partition_compatible(a: &[ExprRef], b: &[ExprRef]) -> bool {
a == b
}

pub fn has_agg(expr: &ExprRef) -> bool {
expr.exists(|e| matches!(e.as_ref(), Expr::Agg(_)))
}

pub fn has_stateful_udf(expr: &ExprRef) -> bool {
expr.exists(|e| {
matches!(
e.as_ref(),
Expr::Function {
func: FunctionExpr::Python(PythonUDF::Stateful(_)),
..
}
)
})
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
13 changes: 7 additions & 6 deletions src/daft-dsl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@ pub mod python;
mod resolve_expr;
mod treenode;
pub use common_treenode;
pub use expr::binary_op;
pub use expr::col;
pub use expr::is_partition_compatible;
pub use expr::{AggExpr, ApproxPercentileParams, Expr, ExprRef, Operator, SketchType};
pub use expr::{
binary_op, col, has_agg, has_stateful_udf, is_partition_compatible, AggExpr,
ApproxPercentileParams, Expr, ExprRef, Operator, SketchType,
};
pub use lit::{lit, null_lit, Literal, LiteralValue};
#[cfg(feature = "python")]
use pyo3::prelude::*;
pub use resolve_expr::{
resolve_aggexprs, resolve_exprs, resolve_single_aggexpr, resolve_single_expr,
check_column_name_validity, resolve_aggexprs, resolve_exprs, resolve_single_aggexpr,
resolve_single_expr,
};

#[cfg(feature = "python")]
Expand All @@ -39,7 +40,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> {
parent.add_wrapped(wrap_pyfunction!(python::stateless_udf))?;
parent.add_wrapped(wrap_pyfunction!(python::stateful_udf))?;
parent.add_wrapped(wrap_pyfunction!(python::eq))?;
parent.add_wrapped(wrap_pyfunction!(python::resolve_expr))?;
parent.add_wrapped(wrap_pyfunction!(python::check_column_name_validity))?;

Ok(())
}
5 changes: 2 additions & 3 deletions src/daft-dsl/src/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,8 @@ pub fn eq(expr1: &PyExpr, expr2: &PyExpr) -> PyResult<bool> {
}

#[pyfunction]
pub fn resolve_expr(expr: &PyExpr, schema: &PySchema) -> PyResult<(PyExpr, PyField)> {
let (resolved_expr, field) = crate::resolve_single_expr(expr.expr.clone(), &schema.schema)?;
Ok((resolved_expr.into(), field.into()))
pub fn check_column_name_validity(name: &str, schema: &PySchema) -> PyResult<()> {
Ok(crate::check_column_name_validity(name, &schema.schema)?)
}

#[derive(FromPyObject)]
Expand Down
108 changes: 85 additions & 23 deletions src/daft-dsl/src/resolve_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use daft_core::{
schema::Schema,
};

use crate::{col, AggExpr, ApproxPercentileParams, Expr, ExprRef};
use crate::{col, expr::has_agg, has_stateful_udf, AggExpr, ApproxPercentileParams, Expr, ExprRef};

use common_error::{DaftError, DaftResult};

Expand Down Expand Up @@ -262,14 +262,29 @@ fn extract_agg_expr(expr: &Expr) -> DaftResult<AggExpr> {
}

/// Resolves and validates the expression with a schema, returning the new expression and its field.
/// Specifically, makes sure the expression does not contain aggregations or stateful UDFs when they are not allowed,
/// and resolves struct accessors and wildcards.
/// May return multiple expressions if the expr contains a wildcard.
fn resolve_expr(expr: ExprRef, schema: &Schema) -> DaftResult<Vec<ExprRef>> {
///
/// TODO: Use a builder pattern for this functionality
fn resolve_expr(
expr: ExprRef,
schema: &Schema,
allow_stateful_udf: bool,
) -> DaftResult<Vec<ExprRef>> {
// TODO(Kevin): Support aggregation expressions everywhere
if expr.has_agg() {
if has_agg(&expr) {
return Err(DaftError::ValueError(format!(
"Aggregation expressions are currently only allowed in agg and pivot: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383",
)));
}

if !allow_stateful_udf && has_stateful_udf(&expr) {
return Err(DaftError::ValueError(format!(
"Stateful UDFs are only allowed in projections: {expr}"
)));
}

let struct_expr_map = calculate_struct_expr_map(schema);
expand_wildcards(expr, schema, &struct_expr_map)?
.into_iter()
Expand All @@ -278,8 +293,12 @@ fn resolve_expr(expr: ExprRef, schema: &Schema) -> DaftResult<Vec<ExprRef>> {
}

// Resolve a single expression, erroring if any kind of expansion happens.
pub fn resolve_single_expr(expr: ExprRef, schema: &Schema) -> DaftResult<(ExprRef, Field)> {
let resolved_exprs = resolve_expr(expr.clone(), schema)?;
pub fn resolve_single_expr(
expr: ExprRef,
schema: &Schema,
allow_stateful_udf: bool,
) -> DaftResult<(ExprRef, Field)> {
let resolved_exprs = resolve_expr(expr.clone(), schema, allow_stateful_udf)?;
match resolved_exprs.as_slice() {
[resolved_expr] => Ok((resolved_expr.clone(), resolved_expr.to_field(schema)?)),
_ => Err(DaftError::ValueError(format!(
Expand All @@ -293,37 +312,54 @@ pub fn resolve_single_expr(expr: ExprRef, schema: &Schema) -> DaftResult<(ExprRe
pub fn resolve_exprs(
exprs: Vec<ExprRef>,
schema: &Schema,
allow_stateful_udf: bool,
) -> DaftResult<(Vec<ExprRef>, Vec<Field>)> {
// can't flat map because we need to deal with errors
let resolved_exprs: DaftResult<Vec<Vec<ExprRef>>> =
exprs.into_iter().map(|e| resolve_expr(e, schema)).collect();
let resolved_exprs: DaftResult<Vec<Vec<ExprRef>>> = exprs
.into_iter()
.map(|e| resolve_expr(e, schema, allow_stateful_udf))
.collect();
let resolved_exprs: Vec<ExprRef> = resolved_exprs?.into_iter().flatten().collect();
let resolved_fields: DaftResult<Vec<Field>> =
resolved_exprs.iter().map(|e| e.to_field(schema)).collect();
Ok((resolved_exprs, resolved_fields?))
}

/// Resolves and validates the expression with a schema, returning the extracted aggregation expression and its field.
/// Specifically, makes sure the expression does not contain aggregationsnested or stateful UDFs,
/// and resolves struct accessors and wildcards.
/// May return multiple expressions if the expr contains a wildcard.
///
/// TODO: Use a builder pattern for this functionality
fn resolve_aggexpr(expr: ExprRef, schema: &Schema) -> DaftResult<Vec<AggExpr>> {
let struct_expr_map = calculate_struct_expr_map(schema);
expand_wildcards(expr, schema, &struct_expr_map)?.into_iter().map(|expr| {
let agg_expr = extract_agg_expr(&expr)?;
let has_nested_agg = extract_agg_expr(&expr)?.children().iter().any(has_agg);

let has_nested_agg = agg_expr.children().iter().any(|e| e.has_agg());
if has_nested_agg {
return Err(DaftError::ValueError(format!(
"Nested aggregation expressions are not supported: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383"
)));
}

if has_nested_agg {
return Err(DaftError::ValueError(format!(
"Nested aggregation expressions are not supported: {expr}\nIf you would like to have this feature, please see https://github.com/Eventual-Inc/Daft/issues/1979#issue-2170913383"
)));
}
if has_stateful_udf(&expr) {
return Err(DaftError::ValueError(format!(
"Stateful UDFs are only allowed in projections: {expr}"
)));
}

let resolved_children = agg_expr
.children()
.into_iter()
.map(|e| transform_struct_gets(e, &struct_expr_map))
.collect::<DaftResult<Vec<_>>>()?;
Ok(agg_expr.with_new_children(resolved_children))
}).collect()
let struct_expr_map = calculate_struct_expr_map(schema);
expand_wildcards(expr, schema, &struct_expr_map)?
.into_iter()
.map(|expr| {
let agg_expr = extract_agg_expr(&expr)?;

let resolved_children = agg_expr
.children()
.into_iter()
.map(|e| transform_struct_gets(e, &struct_expr_map))
.collect::<DaftResult<Vec<_>>>()?;
Ok(agg_expr.with_new_children(resolved_children))
})
.collect()
}

pub fn resolve_single_aggexpr(expr: ExprRef, schema: &Schema) -> DaftResult<(AggExpr, Field)> {
Expand Down Expand Up @@ -353,6 +389,32 @@ pub fn resolve_aggexprs(
Ok((resolved_exprs, resolved_fields?))
}

pub fn check_column_name_validity(name: &str, schema: &Schema) -> DaftResult<()> {
let struct_expr_map = calculate_struct_expr_map(schema);

let names = if name.contains('*') {
if let Ok(names) = get_wildcard_matches(name, schema, &struct_expr_map) {
names
} else {
return Err(DaftError::ValueError(format!(
"Error matching wildcard `{name}` in schema: {schema}"
)));
}
} else {
vec![name.into()]
};

for n in names {
if !struct_expr_map.contains_key(&n) {
return Err(DaftError::ValueError(format!(
"Column `{n}` not found in schema: {schema}"
)));
}
}

Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/actor_pool_project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub struct ActorPoolProject {
impl ActorPoolProject {
pub(crate) fn try_new(input: Arc<LogicalPlan>, projection: Vec<ExprRef>) -> Result<Self> {
let (projection, fields) =
resolve_exprs(projection, input.schema().as_ref()).context(CreationSnafu)?;
resolve_exprs(projection, input.schema().as_ref(), true).context(CreationSnafu)?;

let num_stateful_udf_exprs: usize = projection
.iter()
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl Aggregate {
) -> logical_plan::Result<Self> {
let upstream_schema = input.schema();
let (groupby, groupby_fields) =
resolve_exprs(groupby, &upstream_schema).context(CreationSnafu)?;
resolve_exprs(groupby, &upstream_schema, false).context(CreationSnafu)?;
let (aggregations, aggregation_fields) =
resolve_aggexprs(aggregations, &upstream_schema).context(CreationSnafu)?;

Expand Down
3 changes: 2 additions & 1 deletion src/daft-plan/src/logical_ops/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ impl Explode {
) -> logical_plan::Result<Self> {
let upstream_schema = input.schema();

let (to_explode, _) = resolve_exprs(to_explode, &upstream_schema).context(CreationSnafu)?;
let (to_explode, _) =
resolve_exprs(to_explode, &upstream_schema, false).context(CreationSnafu)?;

let explode_exprs = to_explode
.iter()
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub struct Filter {
impl Filter {
pub(crate) fn try_new(input: Arc<LogicalPlan>, predicate: ExprRef) -> Result<Self> {
let (predicate, field) =
resolve_single_expr(predicate, &input.schema()).context(CreationSnafu)?;
resolve_single_expr(predicate, &input.schema(), false).context(CreationSnafu)?;

if !matches!(field.dtype, DataType::Boolean) {
return Err(DaftError::ValueError(format!(
Expand Down
4 changes: 2 additions & 2 deletions src/daft-plan/src/logical_ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ impl Join {
join_strategy: Option<JoinStrategy>,
) -> logical_plan::Result<Self> {
let (left_on, left_fields) =
resolve_exprs(left_on, &left.schema()).context(CreationSnafu)?;
resolve_exprs(left_on, &left.schema(), false).context(CreationSnafu)?;
let (right_on, right_fields) =
resolve_exprs(right_on, &right.schema()).context(CreationSnafu)?;
resolve_exprs(right_on, &right.schema(), false).context(CreationSnafu)?;

for (on_exprs, on_fields) in [(&left_on, left_fields), (&right_on, right_fields)] {
let on_schema = Schema::new(on_fields).context(CreationSnafu)?;
Expand Down
6 changes: 3 additions & 3 deletions src/daft-plan/src/logical_ops/pivot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ impl Pivot {
) -> logical_plan::Result<Self> {
let upstream_schema = input.schema();
let (group_by, group_by_fields) =
resolve_exprs(group_by, &upstream_schema).context(CreationSnafu)?;
resolve_exprs(group_by, &upstream_schema, false).context(CreationSnafu)?;
let (pivot_column, _) =
resolve_single_expr(pivot_column, &upstream_schema).context(CreationSnafu)?;
resolve_single_expr(pivot_column, &upstream_schema, false).context(CreationSnafu)?;
let (value_column, value_col_field) =
resolve_single_expr(value_column, &upstream_schema).context(CreationSnafu)?;
resolve_single_expr(value_column, &upstream_schema, false).context(CreationSnafu)?;
let (aggregation, _) =
resolve_single_aggexpr(aggregation, &upstream_schema).context(CreationSnafu)?;

Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub struct Project {
impl Project {
pub(crate) fn try_new(input: Arc<LogicalPlan>, projection: Vec<ExprRef>) -> Result<Self> {
let (projection, fields) =
resolve_exprs(projection, &input.schema()).context(CreationSnafu)?;
resolve_exprs(projection, &input.schema(), true).context(CreationSnafu)?;

// Factor the projection and see if there are any substitutions to factor out.
let (factored_input, factored_projection) =
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/repartition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Repartition {
) -> DaftResult<Self> {
let repartition_spec = match repartition_spec {
RepartitionSpec::Hash(HashRepartitionConfig { num_partitions, by }) => {
let (resolved_by, _) = resolve_exprs(by, &input.schema())?;
let (resolved_by, _) = resolve_exprs(by, &input.schema(), false)?;
RepartitionSpec::Hash(HashRepartitionConfig {
num_partitions,
by: resolved_by,
Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl Sink {
let resolved_partition_cols = partition_cols
.clone()
.map(|cols| {
resolve_exprs(cols, &schema).map(|(resolved_cols, _)| resolved_cols)
resolve_exprs(cols, &schema, false).map(|(resolved_cols, _)| resolved_cols)
})
.transpose()?;

Expand Down
2 changes: 1 addition & 1 deletion src/daft-plan/src/logical_ops/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl Sort {
}

let (sort_by, sort_by_fields) =
resolve_exprs(sort_by, &input.schema()).context(CreationSnafu)?;
resolve_exprs(sort_by, &input.schema(), false).context(CreationSnafu)?;

let sort_by_resolved_schema = Schema::new(sort_by_fields).context(CreationSnafu)?;

Expand Down
Loading

0 comments on commit c5a4adc

Please sign in to comment.