Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix panic when using fold in certain situations #17114

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/dsl/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ pub trait ExprEvalExtension: IntoExpr + Sized {

this.apply(
func,
GetOutput::map_field(move |f| eval_field_to_dtype(f, &expr2, false)),
GetOutput::map_field(move |f| Ok(eval_field_to_dtype(f, &expr2, false))),
)
.with_fmt("expanding_eval")
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ pub trait ListNameSpaceExtension: IntoListNameSpace + Sized {
this.0
.map(
func,
GetOutput::map_field(move |f| eval_field_to_dtype(f, &expr2, true)),
GetOutput::map_field(move |f| Ok(eval_field_to_dtype(f, &expr2, true))),
)
.with_fmt("eval")
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ impl ArrayNameSpace {
Field::from_owned(name, inner.as_ref().clone())
})
.collect();
DataType::Struct(fields)
Ok(DataType::Struct(fields))
}),
)
.with_fmt("arr.to_struct")
Expand Down
53 changes: 34 additions & 19 deletions crates/polars-plan/src/dsl/expr_dyn_fn.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::fmt::Formatter;
use std::ops::Deref;

use polars_core::utils::get_supertype;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")]
Expand Down Expand Up @@ -245,15 +244,20 @@ where
}

pub trait FunctionOutputField: Send + Sync {
fn get_field(&self, input_schema: &Schema, cntxt: Context, fields: &[Field]) -> Field;
fn get_field(
&self,
input_schema: &Schema,
cntxt: Context,
fields: &[Field],
) -> PolarsResult<Field>;
}

pub type GetOutput = SpecialEq<Arc<dyn FunctionOutputField>>;

impl Default for GetOutput {
fn default() -> Self {
SpecialEq::new(Arc::new(
|_input_schema: &Schema, _cntxt: Context, fields: &[Field]| fields[0].clone(),
|_input_schema: &Schema, _cntxt: Context, fields: &[Field]| Ok(fields[0].clone()),
))
}
}
Expand All @@ -265,67 +269,78 @@ impl GetOutput {

pub fn from_type(dt: DataType) -> Self {
SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
Field::new(flds[0].name(), dt.clone())
Ok(Field::new(flds[0].name(), dt.clone()))
}))
}

pub fn map_field<F: 'static + Fn(&Field) -> Field + Send + Sync>(f: F) -> Self {
pub fn map_field<F: 'static + Fn(&Field) -> PolarsResult<Field> + Send + Sync>(f: F) -> Self {
SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
f(&flds[0])
}))
}

pub fn map_fields<F: 'static + Fn(&[Field]) -> Field + Send + Sync>(f: F) -> Self {
pub fn map_fields<F: 'static + Fn(&[Field]) -> PolarsResult<Field> + Send + Sync>(
f: F,
) -> Self {
SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
f(flds)
}))
}

pub fn map_dtype<F: 'static + Fn(&DataType) -> DataType + Send + Sync>(f: F) -> Self {
pub fn map_dtype<F: 'static + Fn(&DataType) -> PolarsResult<DataType> + Send + Sync>(
f: F,
) -> Self {
SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
let mut fld = flds[0].clone();
let new_type = f(fld.data_type());
let new_type = f(fld.data_type())?;
fld.coerce(new_type);
fld
Ok(fld)
}))
}

pub fn float_type() -> Self {
Self::map_dtype(|dt| match dt {
DataType::Float32 => DataType::Float32,
_ => DataType::Float64,
Self::map_dtype(|dt| {
Ok(match dt {
DataType::Float32 => DataType::Float32,
_ => DataType::Float64,
})
})
}

pub fn super_type() -> Self {
Self::map_dtypes(|dtypes| {
let mut st = dtypes[0].clone();
for dt in &dtypes[1..] {
st = get_supertype(&st, dt).unwrap();
st = try_get_supertype(&st, dt)?;
}
st
Ok(st)
})
}

pub fn map_dtypes<F>(f: F) -> Self
where
F: 'static + Fn(&[&DataType]) -> DataType + Send + Sync,
F: 'static + Fn(&[&DataType]) -> PolarsResult<DataType> + Send + Sync,
{
SpecialEq::new(Arc::new(move |_: &Schema, _: Context, flds: &[Field]| {
let mut fld = flds[0].clone();
let dtypes = flds.iter().map(|fld| fld.data_type()).collect::<Vec<_>>();
let new_type = f(&dtypes);
let new_type = f(&dtypes)?;
fld.coerce(new_type);
fld
Ok(fld)
}))
}
}

impl<F> FunctionOutputField for F
where
F: Fn(&Schema, Context, &[Field]) -> Field + Send + Sync,
F: Fn(&Schema, Context, &[Field]) -> PolarsResult<Field> + Send + Sync,
{
fn get_field(&self, input_schema: &Schema, cntxt: Context, fields: &[Field]) -> Field {
fn get_field(
&self,
input_schema: &Schema,
cntxt: Context,
fields: &[Field],
) -> PolarsResult<Field> {
self(input_schema, cntxt, fields)
}
}
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/functions/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,15 @@ fn cum_fold_dtype() -> GetOutput {
for fld in &fields[1..] {
st = get_supertype(&st, &fld.dtype).unwrap();
}
Field::new(
Ok(Field::new(
&fields[0].name,
DataType::Struct(
fields
.iter()
.map(|fld| Field::new(fld.name(), st.clone()))
.collect(),
),
)
))
})
}

Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ impl ListNameSpace {
let out = out_dtype.read().unwrap();
match out.as_ref() {
// dtype already set
Some(dt) => dt.clone(),
Some(dt) => Ok(dt.clone()),
// dtype still unknown, set it
None => {
drop(out);
Expand All @@ -314,7 +314,7 @@ impl ListNameSpace {
let dt = DataType::Struct(fields);

*lock = Some(dt.clone());
dt
Ok(dt)
},
}
}),
Expand Down
24 changes: 13 additions & 11 deletions crates/polars-plan/src/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -799,13 +799,13 @@ impl Expr {
self.function_with_options(
move |s: Series| Some(s.product().map(|sc| sc.into_series(s.name()))).transpose(),
GetOutput::map_dtype(|dt| {
use DataType::*;
match dt {
Float32 => Float32,
Float64 => Float64,
UInt64 => UInt64,
_ => Int64,
}
use DataType as T;
Ok(match dt {
T::Float32 => T::Float32,
T::Float64 => T::Float64,
T::UInt64 => T::UInt64,
_ => T::Int64,
})
}),
options,
)
Expand Down Expand Up @@ -1468,10 +1468,12 @@ impl Expr {
Ok(Some(out))
}
},
GetOutput::map_field(|field| match field.data_type() {
DataType::Float64 => field.clone(),
DataType::Float32 => Field::new(field.name(), DataType::Float32),
_ => Field::new(field.name(), DataType::Float64),
GetOutput::map_field(|field| {
Ok(match field.data_type() {
DataType::Float64 => field.clone(),
DataType::Float32 => Field::new(field.name(), DataType::Float32),
_ => Field::new(field.name(), DataType::Float64),
})
}),
)
.with_fmt("rolling_map_float")
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl ExprNameNameSpace {
.iter()
.map(|fd| Field::new(&f(fd.name()), fd.data_type().clone()))
.collect();
DataType::Struct(fields)
Ok(DataType::Struct(fields))
},
_ => panic!("Only struct dtype is supported for `map_fields`."),
}),
Expand Down
32 changes: 18 additions & 14 deletions crates/polars-plan/src/dsl/python_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,15 @@ impl SeriesUdf for PythonUdfExpression {

fn get_output(&self) -> Option<GetOutput> {
let output_type = self.output_type.clone();
Some(GetOutput::map_field(move |fld| match output_type {
Some(ref dt) => Field::new(fld.name(), dt.clone()),
None => {
let mut fld = fld.clone();
fld.coerce(DataType::Unknown(Default::default()));
fld
},
Some(GetOutput::map_field(move |fld| {
Ok(match output_type {
Some(ref dt) => Field::new(fld.name(), dt.clone()),
None => {
let mut fld = fld.clone();
fld.coerce(DataType::Unknown(Default::default()));
fld
},
})
}))
}
}
Expand All @@ -239,13 +241,15 @@ impl Expr {

let returns_scalar = func.returns_scalar;
let return_dtype = func.output_type.clone();
let output_type = GetOutput::map_field(move |fld| match return_dtype {
Some(ref dt) => Field::new(fld.name(), dt.clone()),
None => {
let mut fld = fld.clone();
fld.coerce(DataType::Unknown(Default::default()));
fld
},
let output_type = GetOutput::map_field(move |fld| {
Ok(match return_dtype {
Some(ref dt) => Field::new(fld.name(), dt.clone()),
None => {
let mut fld = fld.clone();
fld.coerce(DataType::Unknown(Default::default()));
fld
},
})
});

Expr::AnonymousFunction {
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ impl AExpr {
let output_type = tmp.as_ref().unwrap_or(output_type);
let fields = func_args_to_fields(input, schema, arena, nested)?;
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", options.fmt_str);
Ok(output_type.get_field(schema, Context::Default, &fields))
output_type.get_field(schema, Context::Default, &fields)
},
Function {
function, input, ..
Expand Down
8 changes: 5 additions & 3 deletions py-polars/src/map/lazy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,11 @@ pub fn map_mul(

let exprs = pyexpr.iter().map(|pe| pe.clone().inner).collect::<Vec<_>>();

let output_map = GetOutput::map_field(move |fld| match output_type {
Some(ref dt) => Field::new(fld.name(), dt.0.clone()),
None => fld.clone(),
let output_map = GetOutput::map_field(move |fld| {
Ok(match output_type {
Some(ref dt) => Field::new(fld.name(), dt.0.clone()),
None => fld.clone(),
})
});
if map_groups {
polars::lazy::dsl::apply_multiple(function, exprs, output_map, returns_scalar).into()
Expand Down
Loading