Skip to content

Commit

Permalink
refactor: Improve type-coercion (#15879)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Apr 25, 2024
1 parent fa37e57 commit 49873b9
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 71 deletions.
49 changes: 28 additions & 21 deletions crates/polars-core/src/datatypes/any_value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -490,20 +490,19 @@ impl<'a> AnyValue<'a> {

/// Cast `AnyValue` to the provided data type and return a new `AnyValue` with type `dtype`,
/// if possible.
///
pub fn strict_cast(&self, dtype: &'a DataType) -> PolarsResult<AnyValue<'a>> {
pub fn strict_cast(&self, dtype: &'a DataType) -> Option<AnyValue<'a>> {
let new_av = match (self, dtype) {
// to numeric
(av, DataType::UInt8) => AnyValue::UInt8(av.try_extract::<u8>()?),
(av, DataType::UInt16) => AnyValue::UInt16(av.try_extract::<u16>()?),
(av, DataType::UInt32) => AnyValue::UInt32(av.try_extract::<u32>()?),
(av, DataType::UInt64) => AnyValue::UInt64(av.try_extract::<u64>()?),
(av, DataType::Int8) => AnyValue::Int8(av.try_extract::<i8>()?),
(av, DataType::Int16) => AnyValue::Int16(av.try_extract::<i16>()?),
(av, DataType::Int32) => AnyValue::Int32(av.try_extract::<i32>()?),
(av, DataType::Int64) => AnyValue::Int64(av.try_extract::<i64>()?),
(av, DataType::Float32) => AnyValue::Float32(av.try_extract::<f32>()?),
(av, DataType::Float64) => AnyValue::Float64(av.try_extract::<f64>()?),
(av, DataType::UInt8) => AnyValue::UInt8(av.extract::<u8>()?),
(av, DataType::UInt16) => AnyValue::UInt16(av.extract::<u16>()?),
(av, DataType::UInt32) => AnyValue::UInt32(av.extract::<u32>()?),
(av, DataType::UInt64) => AnyValue::UInt64(av.extract::<u64>()?),
(av, DataType::Int8) => AnyValue::Int8(av.extract::<i8>()?),
(av, DataType::Int16) => AnyValue::Int16(av.extract::<i16>()?),
(av, DataType::Int32) => AnyValue::Int32(av.extract::<i32>()?),
(av, DataType::Int64) => AnyValue::Int64(av.extract::<i64>()?),
(av, DataType::Float32) => AnyValue::Float32(av.extract::<f32>()?),
(av, DataType::Float64) => AnyValue::Float64(av.extract::<f64>()?),

// to boolean
(AnyValue::UInt8(v), DataType::Boolean) => AnyValue::Boolean(*v != u8::default()),
Expand All @@ -519,7 +518,7 @@ impl<'a> AnyValue<'a> {

// to string
(av, DataType::String) => {
AnyValue::StringOwned(format_smartstring!("{}", av.try_extract::<i64>()?))
AnyValue::StringOwned(format_smartstring!("{}", av.extract::<i64>()?))
},

// to binary
Expand All @@ -528,7 +527,7 @@ impl<'a> AnyValue<'a> {
// to datetime
#[cfg(feature = "dtype-datetime")]
(av, DataType::Datetime(tu, tz)) if av.is_numeric() => {
AnyValue::Datetime(av.try_extract::<i64>()?, *tu, tz)
AnyValue::Datetime(av.extract::<i64>()?, *tu, tz)
},
#[cfg(all(feature = "dtype-datetime", feature = "dtype-date"))]
(AnyValue::Date(v), DataType::Datetime(tu, _)) => AnyValue::Datetime(
Expand Down Expand Up @@ -557,7 +556,7 @@ impl<'a> AnyValue<'a> {

// to date
#[cfg(feature = "dtype-date")]
(av, DataType::Date) if av.is_numeric() => AnyValue::Date(av.try_extract::<i32>()?),
(av, DataType::Date) if av.is_numeric() => AnyValue::Date(av.extract::<i32>()?),
#[cfg(all(feature = "dtype-date", feature = "dtype-datetime"))]
(AnyValue::Datetime(v, tu, _), DataType::Date) => AnyValue::Date(match tu {
TimeUnit::Nanoseconds => *v / NS_IN_DAY,
Expand All @@ -567,7 +566,7 @@ impl<'a> AnyValue<'a> {

// to time
#[cfg(feature = "dtype-time")]
(av, DataType::Time) if av.is_numeric() => AnyValue::Time(av.try_extract::<i64>()?),
(av, DataType::Time) if av.is_numeric() => AnyValue::Time(av.extract::<i64>()?),
#[cfg(all(feature = "dtype-time", feature = "dtype-datetime"))]
(AnyValue::Datetime(v, tu, _), DataType::Time) => AnyValue::Time(match tu {
TimeUnit::Nanoseconds => *v % NS_IN_DAY,
Expand All @@ -578,7 +577,7 @@ impl<'a> AnyValue<'a> {
// to duration
#[cfg(feature = "dtype-duration")]
(av, DataType::Duration(tu)) if av.is_numeric() => {
AnyValue::Duration(av.try_extract::<i64>()?, *tu)
AnyValue::Duration(av.extract::<i64>()?, *tu)
},
#[cfg(all(feature = "dtype-duration", feature = "dtype-time"))]
(AnyValue::Time(v), DataType::Duration(tu)) => AnyValue::Duration(
Expand Down Expand Up @@ -607,15 +606,23 @@ impl<'a> AnyValue<'a> {
// to self
(av, dtype) if av.dtype() == *dtype => self.clone(),

av => polars_bail!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", av, dtype),
_ => return None,
};
Ok(new_av)
Some(new_av)
}

/// Cast `AnyValue` to the provided data type and return a new `AnyValue` with type `dtype`,
/// if possible.
pub fn try_strict_cast(&self, dtype: &'a DataType) -> PolarsResult<AnyValue<'a>> {
self.strict_cast(dtype).ok_or_else(
|| polars_err!(ComputeError: "cannot cast any-value {:?} to dtype '{}'", self, dtype),
)
}

pub fn cast(&self, dtype: &'a DataType) -> AnyValue<'a> {
match self.strict_cast(dtype) {
Ok(av) => av,
Err(_) => AnyValue::Null,
Some(av) => av,
None => AnyValue::Null,
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,12 @@ pub fn to_alp_impl(
lp_arena: &mut Arena<IR>,
expr_arena: &mut Arena<AExpr>,
convert: &mut ConversionOpt,
name: &str,
) -> PolarsResult<Node> {
let lp_node = lp_arena.add(lp);
convert.coerce_types(expr_arena, lp_arena, lp_node)?;
convert
.coerce_types(expr_arena, lp_arena, lp_node)
.map_err(|e| e.context(format!("'{name}' failed").into()))?;

Ok(lp_node)
}
Expand Down Expand Up @@ -175,10 +178,9 @@ pub fn to_alp_impl(
let predicate = to_expr_ir(predicate, expr_arena);

convert.push_scratch(predicate.node(), expr_arena);
let lp_node = lp_arena.add(IR::Filter { input, predicate });
convert.coerce_types(expr_arena, lp_arena, lp_node)?;

return Ok(lp_node);
let lp = IR::Filter { input, predicate };
return run_conversion(lp, lp_arena, expr_arena, convert, "filter");
},
DslPlan::Slice { input, offset, len } => {
let input = to_alp_impl(owned(input), expr_arena, lp_arena, convert)
Expand Down Expand Up @@ -225,7 +227,7 @@ pub fn to_alp_impl(
options,
};

return run_conversion(lp, lp_arena, expr_arena, convert);
return run_conversion(lp, lp_arena, expr_arena, convert, "select");
},
DslPlan::Sort {
input,
Expand All @@ -246,7 +248,7 @@ pub fn to_alp_impl(
sort_options,
};

return run_conversion(lp, lp_arena, expr_arena, convert);
return run_conversion(lp, lp_arena, expr_arena, convert, "sort");
},
DslPlan::Cache {
input,
Expand Down Expand Up @@ -295,7 +297,7 @@ pub fn to_alp_impl(
options,
};

return run_conversion(lp, lp_arena, expr_arena, convert);
return run_conversion(lp, lp_arena, expr_arena, convert, "group_by");
},
DslPlan::Join {
input_left,
Expand Down Expand Up @@ -339,7 +341,7 @@ pub fn to_alp_impl(
right_on,
options,
};
return run_conversion(lp, lp_arena, expr_arena, convert);
return run_conversion(lp, lp_arena, expr_arena, convert, "join");
},
DslPlan::HStack {
input,
Expand All @@ -358,7 +360,7 @@ pub fn to_alp_impl(
schema,
options,
};
return run_conversion(lp, lp_arena, expr_arena, convert);
return run_conversion(lp, lp_arena, expr_arena, convert, "with_columns");
},
DslPlan::Distinct { input, options } => {
let input = to_alp_impl(owned(input), expr_arena, lp_arena, convert)
Expand Down Expand Up @@ -397,7 +399,7 @@ pub fn to_alp_impl(
..Default::default()
},
};
return run_conversion(lp, lp_arena, expr_arena, convert);
return run_conversion(lp, lp_arena, expr_arena, convert, "fill_nan");
},
DslFunction::Drop(to_drop) => {
let mut output_schema =
Expand Down Expand Up @@ -499,7 +501,7 @@ pub fn to_alp_impl(
..Default::default()
},
};
return run_conversion(lp, lp_arena, expr_arena, convert);
return run_conversion(lp, lp_arena, expr_arena, convert, "stats");
},
_ => {
let function = function.into_function_node(&input_schema)?;
Expand Down
51 changes: 23 additions & 28 deletions crates/polars-plan/src/logical_plan/conversion/stack_opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,39 +49,34 @@ impl ConversionOpt {
lp_arena: &Arena<IR>,
current_node: Node,
) -> PolarsResult<()> {
let mut changed = true;
while changed {
changed = false;
// Different from the stack-opt in the optimizer phase, this does a single pass until fixed point per expression.

// process the expressions on the stack and apply optimizations.
while let Some(current_expr_node) = self.scratch.pop() {
{
let expr = unsafe { expr_arena.get_unchecked(current_expr_node) };
if expr.is_leaf() {
continue;
}
// process the expressions on the stack and apply optimizations.
while let Some(current_expr_node) = self.scratch.pop() {
{
let expr = unsafe { expr_arena.get_unchecked(current_expr_node) };
if expr.is_leaf() {
continue;
}
if let Some(rule) = &mut self.simplify {
while let Some(x) =
rule.optimize_expr(expr_arena, current_expr_node, lp_arena, current_node)?
{
expr_arena.replace(current_expr_node, x);
changed = true;
}
}
if let Some(rule) = &mut self.simplify {
while let Some(x) =
rule.optimize_expr(expr_arena, current_expr_node, lp_arena, current_node)?
{
expr_arena.replace(current_expr_node, x);
}
if let Some(rule) = &mut self.coerce {
while let Some(x) =
rule.optimize_expr(expr_arena, current_expr_node, lp_arena, current_node)?
{
expr_arena.replace(current_expr_node, x);
changed = true;
}
}
if let Some(rule) = &mut self.coerce {
while let Some(x) =
rule.optimize_expr(expr_arena, current_expr_node, lp_arena, current_node)?
{
expr_arena.replace(current_expr_node, x);
}

let expr = unsafe { expr_arena.get_unchecked(current_expr_node) };
// traverse subexpressions and add to the stack
expr.nodes(&mut self.scratch)
}

let expr = unsafe { expr_arena.get_unchecked(current_expr_node) };
// traverse subexpressions and add to the stack
expr.nodes(&mut self.scratch)
}

Ok(())
Expand Down
13 changes: 6 additions & 7 deletions crates/polars-plan/src/logical_plan/optimizer/stack_opt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,18 @@ impl StackOptimizer {
) -> PolarsResult<Node> {
let mut changed = true;

let mut plans = Vec::with_capacity(32);

// nodes of expressions and lp node from which the expressions are a member of
let mut exprs = Vec::with_capacity(32);
// Nodes of expressions and lp node from which the expressions are a member of.
let mut plans = vec![];
let mut exprs = vec![];
let mut scratch = vec![];

// run loop until reaching fixed point
// Run loop until reaching fixed point.
while changed {
// recurse into sub plans and expressions and apply rules
// Recurse into sub plans and expressions and apply rules.
changed = false;
plans.push(lp_top);
while let Some(current_node) = plans.pop() {
// apply rules
// Apply rules
for rule in rules.iter_mut() {
// keep iterating over same rule
while let Some(x) = rule.optimize_plan(lp_arena, expr_arena, current_node) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ fn modify_supertype(
_ => {},
}

// TODO! This must be removed and dealt properly with dynamic str.
use DataType::*;
match (type_left, type_right, left, right) {
// if the we compare a categorical to a literal string we want to cast the literal to categorical
Expand Down Expand Up @@ -555,12 +556,12 @@ fn inline_or_prune_cast(
LiteralValue::Series(SpecialEq::new(s))
},
LiteralValue::StrCat(s) => {
let av = AnyValue::String(s).strict_cast(dtype).ok();
let av = AnyValue::String(s).strict_cast(dtype);
return Ok(av.map(|av| AExpr::Literal(av.try_into().unwrap())));
},
lv @ (LiteralValue::Int(_) | LiteralValue::Float(_)) => {
let av = lv.to_any_value().ok_or_else(|| polars_err!(InvalidOperation: "literal value: {:?} too large for Polars", lv))?;
let av = av.strict_cast(dtype).ok();
let av = av.strict_cast(dtype);
return Ok(av.map(|av| AExpr::Literal(av.try_into().unwrap())));
},
LiteralValue::Null => match dtype {
Expand Down Expand Up @@ -595,8 +596,8 @@ fn inline_or_prune_cast(
(av, _) => {
let out = {
match av.strict_cast(dtype) {
Ok(out) => out,
Err(_) => return Ok(None),
Some(out) => out,
None => return Ok(None),
}
};
out.try_into()?
Expand Down

0 comments on commit 49873b9

Please sign in to comment.