Skip to content

Commit

Permalink
fix(rust, python): improve recursive casting of nested data (#6897)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Feb 15, 2023
1 parent 7b98ccb commit cfca325
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 18 deletions.
61 changes: 43 additions & 18 deletions polars/polars-core/src/chunked_array/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,27 @@ fn cast_impl(name: &str, chunks: &[ArrayRef], dtype: &DataType) -> PolarsResult<
cast_impl_inner(name, chunks, dtype, true)
}

#[cfg(feature = "dtype-struct")]
fn cast_single_to_struct(
name: &str,
chunks: &[ArrayRef],
fields: &[Field],
) -> PolarsResult<Series> {
let mut new_fields = Vec::with_capacity(fields.len());
// cast to first field dtype
let mut fields = fields.iter();
let fld = fields.next().unwrap();
let s = cast_impl_inner(&fld.name, chunks, &fld.dtype, true)?;
let length = s.len();
new_fields.push(s);

for fld in fields {
new_fields.push(Series::full_null(&fld.name, length, &fld.dtype));
}

Ok(StructChunked::new_unchecked(name, &new_fields).into_series())
}

impl<T> ChunkedArray<T>
where
T: PolarsNumericType,
Expand All @@ -75,14 +96,7 @@ where
}
}
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => {
// cast to first field dtype
let fld = &fields[0];
let dtype = &fld.dtype;
let name = &fld.name;
let s = cast_impl_inner(name, &self.chunks, dtype, true)?;
Ok(StructChunked::new_unchecked(self.name(), &[s]).into_series())
}
DataType::Struct(fields) => cast_single_to_struct(self.name(), &self.chunks, fields),
_ => cast_impl_inner(self.name(), &self.chunks, data_type, checked).map(|mut s| {
// maintain sorted if data types remain signed
// this may still fail with overflow?
Expand Down Expand Up @@ -123,6 +137,8 @@ impl ChunkCast for Utf8Chunked {
let ca = builder.finish();
Ok(ca.into_series())
}
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => cast_single_to_struct(self.name(), &self.chunks, fields),
_ => cast_impl(self.name(), &self.chunks, data_type),
}
}
Expand All @@ -148,7 +164,11 @@ unsafe fn binary_to_utf8_unchecked(from: &BinaryArray<i64>) -> Utf8Array<i64> {
#[cfg(feature = "dtype-binary")]
impl ChunkCast for BinaryChunked {
fn cast(&self, data_type: &DataType) -> PolarsResult<Series> {
cast_impl(self.name(), &self.chunks, data_type)
match data_type {
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => cast_single_to_struct(self.name(), &self.chunks, fields),
_ => cast_impl(self.name(), &self.chunks, data_type),
}
}

fn cast_unchecked(&self, data_type: &DataType) -> PolarsResult<Series> {
Expand Down Expand Up @@ -177,12 +197,15 @@ fn boolean_to_utf8(ca: &BooleanChunked) -> Utf8Chunked {

impl ChunkCast for BooleanChunked {
fn cast(&self, data_type: &DataType) -> PolarsResult<Series> {
if matches!(data_type, DataType::Utf8) {
let mut ca = boolean_to_utf8(self);
ca.rename(self.name());
Ok(ca.into_series())
} else {
cast_impl(self.name(), &self.chunks, data_type)
match data_type {
DataType::Utf8 => {
let mut ca = boolean_to_utf8(self);
ca.rename(self.name());
Ok(ca.into_series())
}
#[cfg(feature = "dtype-struct")]
DataType::Struct(fields) => cast_single_to_struct(self.name(), &self.chunks, fields),
_ => cast_impl(self.name(), &self.chunks, data_type),
}
}

Expand Down Expand Up @@ -213,12 +236,12 @@ impl ChunkCast for ListChunked {
match (self.inner_dtype(), &**child_type) {
#[cfg(feature = "dtype-categorical")]
(Utf8, Categorical(_)) => {
let (arr, inner_dtype) = cast_list(self, child_type)?;
let (arr, child_type) = cast_list(self, child_type)?;
Ok(unsafe {
Series::from_chunks_and_dtype_unchecked(
self.name(),
vec![arr],
&List(Box::new(inner_dtype)),
&List(Box::new(child_type)),
)
})
}
Expand Down Expand Up @@ -251,14 +274,16 @@ impl ChunkCast for ListChunked {
}
}

// returns inner data type
// Returns inner data type. This is needed because a cast can instantiate the dtype inner
// values for instance with categoricals
fn cast_list(ca: &ListChunked, child_type: &DataType) -> PolarsResult<(ArrayRef, DataType)> {
let ca = ca.rechunk();
let arr = ca.downcast_iter().next().unwrap();
let s = Series::try_from(("", arr.values().clone())).unwrap();
let new_inner = s.cast(child_type)?;

let inner_dtype = new_inner.dtype().clone();
debug_assert_eq!(&inner_dtype, child_type);

let new_values = new_inner.array_ref(0).clone();

Expand Down
1 change: 1 addition & 0 deletions polars/polars-core/src/series/arithmetic/borrowed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ pub fn _struct_arithmetic<F: FnMut(&Series, &Series) -> Series>(
let rhs = rhs.struct_().unwrap();
let s_fields = s.fields();
let rhs_fields = rhs.fields();

match (s_fields.len(), rhs_fields.len()) {
(_, 1) => {
let rhs = &rhs.fields()[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ fn is_datetime_arithmetic(type_left: &DataType, type_right: &DataType, op: Opera
)
}

#[cfg(feature = "dtype-struct")]
fn is_struct_numeric_arithmetic(type_left: &DataType, type_right: &DataType, op: Operator) -> bool {
{
op.is_arithmetic() && (matches!(type_right, DataType::Struct(_)) && type_left.is_numeric())
|| (matches!(type_left, DataType::Struct(_)) && type_right.is_numeric())
}
}

fn is_list_arithmetic(type_left: &DataType, type_right: &DataType, op: Operator) -> bool {
op.is_arithmetic()
&& matches!(
Expand Down Expand Up @@ -124,6 +132,55 @@ fn process_list_arithmetic(
}
}

#[cfg(feature = "dtype-struct")]
// Ensure we don't cast to supertype
// otherwise we will fill a struct with null fields
fn process_struct_numeric_arithmetic(
type_left: DataType,
type_right: DataType,
node_left: Node,
node_right: Node,
op: Operator,
expr_arena: &mut Arena<AExpr>,
) -> PolarsResult<Option<AExpr>> {
match (&type_left, &type_right) {
(DataType::Struct(fields), _) => {
if let Some(first) = fields.first() {
let new_node_right = expr_arena.add(AExpr::Cast {
expr: node_right,
data_type: DataType::Struct(vec![first.clone()]),
strict: false,
});
Ok(Some(AExpr::BinaryExpr {
left: node_left,
op,
right: new_node_right,
}))
} else {
Ok(None)
}
}
(_, DataType::Struct(fields)) => {
if let Some(first) = fields.first() {
let new_node_left = expr_arena.add(AExpr::Cast {
expr: node_left,
data_type: DataType::Struct(vec![first.clone()]),
strict: false,
});

Ok(Some(AExpr::BinaryExpr {
left: new_node_left,
op,
right: node_right,
}))
} else {
Ok(None)
}
}
_ => unreachable!(),
}
}

pub(super) fn process_binary(
expr_arena: &mut Arena<AExpr>,
lp_arena: &Arena<ALogicalPlan>,
Expand Down Expand Up @@ -177,6 +234,17 @@ pub(super) fn process_binary(
);
}

#[cfg(feature = "dtype-struct")]
{
let is_struct_numeric_arithmetic =
is_struct_numeric_arithmetic(&type_left, &type_right, op);
if is_struct_numeric_arithmetic {
return process_struct_numeric_arithmetic(
type_left, type_right, node_left, node_right, op, expr_arena,
);
}
}

// All early return paths
if compare_cat_to_string
|| datetime_arithmetic
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -798,3 +798,16 @@ def test_struct_null_cast() -> None:
.select([pl.lit(None, dtype=pl.Null).cast(dtype, strict=True)])
.collect()
).to_dict(False) == {"literal": [{"a": None, "b": None, "c": None}]}


def test_nested_struct_in_lists_cast() -> None:
assert pl.DataFrame(
{
"node_groups": [
[{"nodes": [{"id": 1, "is_started": True}]}],
[{"nodes": []}],
]
}
).to_dict(False) == {
"node_groups": [[{"nodes": [{"id": 1, "is_started": True}]}], [{"nodes": []}]]
}

0 comments on commit cfca325

Please sign in to comment.