Skip to content

Commit

Permalink
Simplify Expr::map_children (#9876)
Browse files Browse the repository at this point in the history
* add map_until_stop_and_collect macro

* fix clippy

* simplify

* Update datafusion/common/src/tree_node.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

* add documentation

* fix macro

---------

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
peter-toth and alamb authored Apr 3, 2024
1 parent daf182d commit 2f55003
Show file tree
Hide file tree
Showing 2 changed files with 171 additions and 137 deletions.
82 changes: 68 additions & 14 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,20 @@ impl<T> Transformed<T> {
}
}

/// Transformation helper to process tree nodes that are siblings.
/// Transformation helper to process a sequence of iterable tree nodes that are siblings.
pub trait TransformedIterator: Iterator {
/// Apples `f` to each item in this iterator
///
/// Visits all items in the iterator unless
/// `f` returns an error or `f` returns TreeNodeRecursion::stop.
///
/// # Returns
/// Error if `f` returns an error
///
/// Ok(Transformed) such that:
/// 1. `transformed` is true if any return from `f` had transformed true
/// 2. `data` from the last invocation of `f`
/// 3. `tnr` from the last invocation of `f` or `Continue` if the iterator is empty
fn map_until_stop_and_collect<
F: FnMut(Self::Item) -> Result<Transformed<Self::Item>>,
>(
Expand All @@ -551,22 +563,64 @@ impl<I: Iterator> TransformedIterator for I {
) -> Result<Transformed<Vec<Self::Item>>> {
let mut tnr = TreeNodeRecursion::Continue;
let mut transformed = false;
let data = self
.map(|item| match tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
f(item).map(|result| {
tnr = result.tnr;
transformed |= result.transformed;
result.data
})
}
TreeNodeRecursion::Stop => Ok(item),
})
.collect::<Result<Vec<_>>>()?;
Ok(Transformed::new(data, transformed, tnr))
self.map(|item| match tnr {
TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => {
f(item).map(|result| {
tnr = result.tnr;
transformed |= result.transformed;
result.data
})
}
TreeNodeRecursion::Stop => Ok(item),
})
.collect::<Result<Vec<_>>>()
.map(|data| Transformed::new(data, transformed, tnr))
}
}

/// Transformation helper to process a heterogeneous sequence of tree node containing
/// expressions.
/// This macro is very similar to [TransformedIterator::map_until_stop_and_collect] to
/// process nodes that are siblings, but it accepts an initial transformation (`F0`) and
/// a sequence of pairs. Each pair is made of an expression (`EXPR`) and its
/// transformation (`F`).
///
/// The macro builds up a tuple that contains `Transformed.data` result of `F0` as the
/// first element and further elements from the sequence of pairs. An element from a pair
/// is either the value of `EXPR` or the `Transformed.data` result of `F`, depending on
/// the `Transformed.tnr` result of previous `F`s (`F0` initially).
///
/// # Returns
/// Error if any of the transformations returns an error
///
/// Ok(Transformed<(data0, ..., dataN)>) such that:
/// 1. `transformed` is true if any of the transformations had transformed true
/// 2. `(data0, ..., dataN)`, where `data0` is the `Transformed.data` from `F0` and
/// `data1` ... `dataN` are from either `EXPR` or the `Transformed.data` of `F`
/// 3. `tnr` from `F0` or the last invocation of `F`
#[macro_export]
macro_rules! map_until_stop_and_collect {
($F0:expr, $($EXPR:expr, $F:expr),*) => {{
$F0.and_then(|Transformed { data: data0, mut transformed, mut tnr }| {
let all_datas = (
data0,
$(
if tnr == TreeNodeRecursion::Continue || tnr == TreeNodeRecursion::Jump {
$F.map(|result| {
tnr = result.tnr;
transformed |= result.transformed;
result.data
})?
} else {
$EXPR
},
)*
);
Ok(Transformed::new(all_datas, transformed, tnr))
})
}}
}

/// Transformation helper to access [`Transformed`] fields in a [`Result`] easily.
pub trait TransformedResult<T> {
fn data(self) -> Result<T>;
Expand Down
226 changes: 103 additions & 123 deletions datafusion/expr/src/tree_node/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ use crate::{Expr, GetFieldAccess};
use datafusion_common::tree_node::{
Transformed, TransformedIterator, TreeNode, TreeNodeRecursion,
};
use datafusion_common::{handle_visit_recursion, internal_err, Result};
use datafusion_common::{
handle_visit_recursion, internal_err, map_until_stop_and_collect, Result,
};

impl TreeNode for Expr {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
Expand Down Expand Up @@ -167,58 +169,55 @@ impl TreeNode for Expr {
Expr::InSubquery(InSubquery::new(be, subquery, negated))
}),
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
transform_box(left, &mut f)?
.update_data(|new_left| (new_left, right))
.try_transform_node(|(new_left, right)| {
Ok(transform_box(right, &mut f)?
.update_data(|new_right| (new_left, new_right)))
})?
.update_data(|(new_left, new_right)| {
Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
})
map_until_stop_and_collect!(
transform_box(left, &mut f),
right,
transform_box(right, &mut f)
)?
.update_data(|(new_left, new_right)| {
Expr::BinaryExpr(BinaryExpr::new(new_left, op, new_right))
})
}
Expr::Like(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => transform_box(expr, &mut f)?
.update_data(|new_expr| (new_expr, pattern))
.try_transform_node(|(new_expr, pattern)| {
Ok(transform_box(pattern, &mut f)?
.update_data(|new_pattern| (new_expr, new_pattern)))
})?
.update_data(|(new_expr, new_pattern)| {
Expr::Like(Like::new(
negated,
new_expr,
new_pattern,
escape_char,
case_insensitive,
))
}),
}) => map_until_stop_and_collect!(
transform_box(expr, &mut f),
pattern,
transform_box(pattern, &mut f)
)?
.update_data(|(new_expr, new_pattern)| {
Expr::Like(Like::new(
negated,
new_expr,
new_pattern,
escape_char,
case_insensitive,
))
}),
Expr::SimilarTo(Like {
negated,
expr,
pattern,
escape_char,
case_insensitive,
}) => transform_box(expr, &mut f)?
.update_data(|new_expr| (new_expr, pattern))
.try_transform_node(|(new_expr, pattern)| {
Ok(transform_box(pattern, &mut f)?
.update_data(|new_pattern| (new_expr, new_pattern)))
})?
.update_data(|(new_expr, new_pattern)| {
Expr::SimilarTo(Like::new(
negated,
new_expr,
new_pattern,
escape_char,
case_insensitive,
))
}),
}) => map_until_stop_and_collect!(
transform_box(expr, &mut f),
pattern,
transform_box(pattern, &mut f)
)?
.update_data(|(new_expr, new_pattern)| {
Expr::SimilarTo(Like::new(
negated,
new_expr,
new_pattern,
escape_char,
case_insensitive,
))
}),
Expr::Not(expr) => transform_box(expr, &mut f)?.update_data(Expr::Not),
Expr::IsNotNull(expr) => {
transform_box(expr, &mut f)?.update_data(Expr::IsNotNull)
Expand Down Expand Up @@ -248,48 +247,38 @@ impl TreeNode for Expr {
negated,
low,
high,
}) => transform_box(expr, &mut f)?
.update_data(|new_expr| (new_expr, low, high))
.try_transform_node(|(new_expr, low, high)| {
Ok(transform_box(low, &mut f)?
.update_data(|new_low| (new_expr, new_low, high)))
})?
.try_transform_node(|(new_expr, new_low, high)| {
Ok(transform_box(high, &mut f)?
.update_data(|new_high| (new_expr, new_low, new_high)))
})?
.update_data(|(new_expr, new_low, new_high)| {
Expr::Between(Between::new(new_expr, negated, new_low, new_high))
}),
}) => map_until_stop_and_collect!(
transform_box(expr, &mut f),
low,
transform_box(low, &mut f),
high,
transform_box(high, &mut f)
)?
.update_data(|(new_expr, new_low, new_high)| {
Expr::Between(Between::new(new_expr, negated, new_low, new_high))
}),
Expr::Case(Case {
expr,
when_then_expr,
else_expr,
}) => transform_option_box(expr, &mut f)?
.update_data(|new_expr| (new_expr, when_then_expr, else_expr))
.try_transform_node(|(new_expr, when_then_expr, else_expr)| {
Ok(when_then_expr
.into_iter()
.map_until_stop_and_collect(|(when, then)| {
transform_box(when, &mut f)?
.update_data(|new_when| (new_when, then))
.try_transform_node(|(new_when, then)| {
Ok(transform_box(then, &mut f)?
.update_data(|new_then| (new_when, new_then)))
})
})?
.update_data(|new_when_then_expr| {
(new_expr, new_when_then_expr, else_expr)
}))
})?
.try_transform_node(|(new_expr, new_when_then_expr, else_expr)| {
Ok(transform_option_box(else_expr, &mut f)?.update_data(
|new_else_expr| (new_expr, new_when_then_expr, new_else_expr),
))
})?
.update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
}),
}) => map_until_stop_and_collect!(
transform_option_box(expr, &mut f),
when_then_expr,
when_then_expr
.into_iter()
.map_until_stop_and_collect(|(when, then)| {
map_until_stop_and_collect!(
transform_box(when, &mut f),
then,
transform_box(then, &mut f)
)
}),
else_expr,
transform_option_box(else_expr, &mut f)
)?
.update_data(|(new_expr, new_when_then_expr, new_else_expr)| {
Expr::Case(Case::new(new_expr, new_when_then_expr, new_else_expr))
}),
Expr::Cast(Cast { expr, data_type }) => transform_box(expr, &mut f)?
.update_data(|be| Expr::Cast(Cast::new(be, data_type))),
Expr::TryCast(TryCast { expr, data_type }) => transform_box(expr, &mut f)?
Expand Down Expand Up @@ -320,48 +309,39 @@ impl TreeNode for Expr {
order_by,
window_frame,
null_treatment,
}) => transform_vec(args, &mut f)?
.update_data(|new_args| (new_args, partition_by, order_by))
.try_transform_node(|(new_args, partition_by, order_by)| {
Ok(transform_vec(partition_by, &mut f)?.update_data(
|new_partition_by| (new_args, new_partition_by, order_by),
))
})?
.try_transform_node(|(new_args, new_partition_by, order_by)| {
Ok(
transform_vec(order_by, &mut f)?.update_data(|new_order_by| {
(new_args, new_partition_by, new_order_by)
}),
)
})?
.update_data(|(new_args, new_partition_by, new_order_by)| {
Expr::WindowFunction(WindowFunction::new(
fun,
new_args,
new_partition_by,
new_order_by,
window_frame,
null_treatment,
))
}),
}) => map_until_stop_and_collect!(
transform_vec(args, &mut f),
partition_by,
transform_vec(partition_by, &mut f),
order_by,
transform_vec(order_by, &mut f)
)?
.update_data(|(new_args, new_partition_by, new_order_by)| {
Expr::WindowFunction(WindowFunction::new(
fun,
new_args,
new_partition_by,
new_order_by,
window_frame,
null_treatment,
))
}),
Expr::AggregateFunction(AggregateFunction {
args,
func_def,
distinct,
filter,
order_by,
null_treatment,
}) => transform_vec(args, &mut f)?
.update_data(|new_args| (new_args, filter, order_by))
.try_transform_node(|(new_args, filter, order_by)| {
Ok(transform_option_box(filter, &mut f)?
.update_data(|new_filter| (new_args, new_filter, order_by)))
})?
.try_transform_node(|(new_args, new_filter, order_by)| {
Ok(transform_option_vec(order_by, &mut f)?
.update_data(|new_order_by| (new_args, new_filter, new_order_by)))
})?
.map_data(|(new_args, new_filter, new_order_by)| match func_def {
}) => map_until_stop_and_collect!(
transform_vec(args, &mut f),
filter,
transform_option_box(filter, &mut f),
order_by,
transform_option_vec(order_by, &mut f)
)?
.map_data(
|(new_args, new_filter, new_order_by)| match func_def {
AggregateFunctionDefinition::BuiltIn(fun) => {
Ok(Expr::AggregateFunction(AggregateFunction::new(
fun,
Expand All @@ -385,7 +365,8 @@ impl TreeNode for Expr {
AggregateFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be resolved.")
}
})?,
},
)?,
Expr::GroupingSet(grouping_set) => match grouping_set {
GroupingSet::Rollup(exprs) => transform_vec(exprs, &mut f)?
.update_data(|ve| Expr::GroupingSet(GroupingSet::Rollup(ve))),
Expand All @@ -402,15 +383,14 @@ impl TreeNode for Expr {
expr,
list,
negated,
}) => transform_box(expr, &mut f)?
.update_data(|new_expr| (new_expr, list))
.try_transform_node(|(new_expr, list)| {
Ok(transform_vec(list, &mut f)?
.update_data(|new_list| (new_expr, new_list)))
})?
.update_data(|(new_expr, new_list)| {
Expr::InList(InList::new(new_expr, new_list, negated))
}),
}) => map_until_stop_and_collect!(
transform_box(expr, &mut f),
list,
transform_vec(list, &mut f)
)?
.update_data(|(new_expr, new_list)| {
Expr::InList(InList::new(new_expr, new_list, negated))
}),
Expr::GetIndexedField(GetIndexedField { expr, field }) => {
transform_box(expr, &mut f)?.update_data(|be| {
Expr::GetIndexedField(GetIndexedField::new(be, field))
Expand Down

0 comments on commit 2f55003

Please sign in to comment.