From c8fde41b0ee81cb3d08bf3ec22faa02243b4e789 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Tue, 13 Aug 2024 13:40:55 +0200 Subject: [PATCH] fix: Struct outer nullabillity (#18156) --- crates/polars-compute/src/if_then_else/mod.rs | 2 +- .../polars-core/src/chunked_array/ops/full.rs | 6 +- .../polars-core/src/chunked_array/ops/zip.rs | 105 +++++++++++++++++- .../src/series/implementations/struct__.rs | 12 +- crates/polars-core/src/series/ops/null.rs | 13 ++- py-polars/tests/unit/datatypes/test_struct.py | 26 ++++- py-polars/tests/unit/test_queries.py | 4 +- 7 files changed, 146 insertions(+), 22 deletions(-) diff --git a/crates/polars-compute/src/if_then_else/mod.rs b/crates/polars-compute/src/if_then_else/mod.rs index c6c752483330..8265422fb9de 100644 --- a/crates/polars-compute/src/if_then_else/mod.rs +++ b/crates/polars-compute/src/if_then_else/mod.rs @@ -100,7 +100,7 @@ impl IfThenElseKernel for PrimitiveArray { } } -fn if_then_else_validity( +pub fn if_then_else_validity( mask: &Bitmap, if_true: Option<&Bitmap>, if_false: Option<&Bitmap>, diff --git a/crates/polars-core/src/chunked_array/ops/full.rs b/crates/polars-core/src/chunked_array/ops/full.rs index fce46375666c..790e7a23e6ed 100644 --- a/crates/polars-core/src/chunked_array/ops/full.rs +++ b/crates/polars-core/src/chunked_array/ops/full.rs @@ -1,4 +1,4 @@ -use arrow::bitmap::MutableBitmap; +use arrow::bitmap::{Bitmap, MutableBitmap}; use crate::chunked_array::builder::get_list_builder; use crate::prelude::*; @@ -189,7 +189,9 @@ impl ListChunked { impl ChunkFullNull for StructChunked { fn full_null(name: &str, length: usize) -> StructChunked { let s = vec![Series::new_null("", length)]; - StructChunked::from_series(name, &s).unwrap() + StructChunked::from_series(name, &s) + .unwrap() + .with_outer_validity(Some(Bitmap::new_zeroed(length))) } } diff --git a/crates/polars-core/src/chunked_array/ops/zip.rs b/crates/polars-core/src/chunked_array/ops/zip.rs index 8319c81d9c3c..cf85266581e7 100644 --- a/crates/polars-core/src/chunked_array/ops/zip.rs +++ b/crates/polars-core/src/chunked_array/ops/zip.rs @@ -1,6 +1,6 @@ use arrow::bitmap::Bitmap; use arrow::compute::utils::{combine_validities_and, combine_validities_and_not}; -use polars_compute::if_then_else::IfThenElseKernel; +use polars_compute::if_then_else::{if_then_else_validity, IfThenElseKernel}; #[cfg(feature = "object")] use crate::chunked_array::object::ObjectArray; @@ -62,7 +62,7 @@ fn combine_validities_chunked< impl ChunkZip for ChunkedArray where - T: PolarsDataType, + T: PolarsDataType, T::Array: for<'a> IfThenElseKernel = T::Physical<'a>>, ChunkedArray: ChunkExpandAtIndex, { @@ -206,3 +206,104 @@ impl IfThenElseKernel for ObjectArray { .collect_arr() } } + +#[cfg(feature = "dtype-struct")] +impl ChunkZip for StructChunked { + fn zip_with( + &self, + mask: &BooleanChunked, + other: &ChunkedArray, + ) -> PolarsResult> { + let (l, r, mask) = align_chunks_ternary(self, other, mask); + + // Prepare the boolean arrays such that Null maps to false. + // This prevents every field doing that. + // # SAFETY + // We don't modify the length and update the null count. + let mut mask = mask.into_owned(); + unsafe { + for arr in mask.downcast_iter_mut() { + let bm = bool_null_to_false(arr); + *arr = BooleanArray::from_data_default(bm, None); + } + mask.set_null_count(0); + } + + // Zip all the fields. + let fields = l + .fields_as_series() + .iter() + .zip(r.fields_as_series()) + .map(|(lhs, rhs)| lhs.zip_with_same_type(&mask, &rhs)) + .collect::>>()?; + + let mut out = StructChunked::from_series(self.name(), &fields)?; + + // Zip the validities. + if (l.null_count + r.null_count) > 0 { + let validities = l + .chunks() + .iter() + .zip(r.chunks()) + .map(|(l, r)| (l.validity(), r.validity())); + + fn broadcast(v: Option<&Bitmap>, arr: &ArrayRef) -> Bitmap { + if v.unwrap().get(0).unwrap() { + Bitmap::new_with_value(true, arr.len()) + } else { + Bitmap::new_zeroed(arr.len()) + } + } + + // # SAFETY + // We don't modify the length and update the null count. + unsafe { + for ((arr, (lv, rv)), mask) in out + .chunks_mut() + .iter_mut() + .zip(validities) + .zip(mask.downcast_iter()) + { + // TODO! we can optimize this and use a kernel that is able to broadcast wo/ allocating. + let (lv, rv) = match (lv.map(|b| b.len()), rv.map(|b| b.len())) { + (Some(1), Some(1)) if arr.len() != 1 => { + let lv = broadcast(lv, arr); + let rv = broadcast(rv, arr); + (Some(lv), Some(rv)) + }, + (Some(a), Some(b)) if a == b => (lv.cloned(), rv.cloned()), + (Some(1), _) => { + let lv = broadcast(lv, arr); + (Some(lv), rv.cloned()) + }, + (_, Some(1)) => { + let rv = broadcast(rv, arr); + (lv.cloned(), Some(rv)) + }, + (None, Some(_)) | (Some(_), None) | (None, None) => { + (lv.cloned(), rv.cloned()) + }, + (Some(a), Some(b)) => { + polars_bail!(InvalidOperation: "got different sizes in 'zip' operation, got length: {a} and {b}") + }, + }; + + // broadcast mask + let validity = if mask.len() != arr.len() && mask.len() == 1 { + if mask.get(0).unwrap() { + lv + } else { + rv + } + } else { + if_then_else_validity(mask.values(), lv.as_ref(), rv.as_ref()) + }; + + *arr = arr.with_validity(validity); + } + } + out.compute_len(); + } + Ok(out) + } +} diff --git a/crates/polars-core/src/series/implementations/struct__.rs b/crates/polars-core/src/series/implementations/struct__.rs index a6c775a4245d..9c565bf43b49 100644 --- a/crates/polars-core/src/series/implementations/struct__.rs +++ b/crates/polars-core/src/series/implementations/struct__.rs @@ -56,15 +56,9 @@ impl PrivateSeries for SeriesWrap { #[cfg(feature = "zip_with")] fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult { - let other = other.struct_()?; - let fields = self - .0 - .fields_as_series() - .iter() - .zip(other.fields_as_series()) - .map(|(lhs, rhs)| lhs.zip_with_same_type(mask, &rhs)) - .collect::>>()?; - StructChunked::from_series(self.0.name(), &fields).map(|ca| ca.into_series()) + self.0 + .zip_with(mask, other.struct_()?) + .map(|ca| ca.into_series()) } #[cfg(feature = "algorithm_group_by")] diff --git a/crates/polars-core/src/series/ops/null.rs b/crates/polars-core/src/series/ops/null.rs index d13ce699cbad..0f46af8065bb 100644 --- a/crates/polars-core/src/series/ops/null.rs +++ b/crates/polars-core/src/series/ops/null.rs @@ -1,3 +1,5 @@ +use arrow::bitmap::Bitmap; + #[cfg(feature = "object")] use crate::chunked_array::object::registry::get_object_builder; use crate::prelude::*; @@ -53,9 +55,14 @@ impl Series { .iter() .map(|fld| Series::full_null(fld.name(), size, fld.data_type())) .collect::>(); - StructChunked::from_series(name, &fields) - .unwrap() - .into_series() + let ca = StructChunked::from_series(name, &fields).unwrap(); + + if !fields.is_empty() { + ca.with_outer_validity(Some(Bitmap::new_zeroed(size))) + .into_series() + } else { + ca.into_series() + } }, DataType::Null => Series::new_null(name, size), DataType::Unknown(kind) => { diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index 265cc71d07c4..1351af61d582 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -623,7 +623,7 @@ def test_struct_categorical_5843() -> None: def test_empty_struct() -> None: # List df = pl.DataFrame({"a": [[{}]]}) - assert df.to_dict(as_series=False) == {"a": [[{"": None}]]} + assert df.to_dict(as_series=False) == {"a": [[None]]} # Struct one not empty df = pl.DataFrame({"a": [[{}, {"a": 10}]]}) @@ -631,7 +631,7 @@ def test_empty_struct() -> None: # Empty struct df = pl.DataFrame({"a": [{}]}) - assert df.to_dict(as_series=False) == {"a": [{"": None}]} + assert df.to_dict(as_series=False) == {"a": [None]} @pytest.mark.parametrize( @@ -710,7 +710,7 @@ def test_struct_null_cast() -> None: .lazy() .select([pl.lit(None, dtype=pl.Null).cast(dtype, strict=True)]) .collect() - ).to_dict(as_series=False) == {"literal": [{"a": None, "b": None, "c": None}]} + ).to_dict(as_series=False) == {"literal": [None]} def test_nested_struct_in_lists_cast() -> None: @@ -976,3 +976,23 @@ def test_named_exprs() -> None: res = df.select(pl.struct(schema=schema, b=pl.col("a"))) assert res.to_dict(as_series=False) == {"b": [{"b": 1}]} assert res.schema["b"] == pl.Struct(schema) + + +def test_struct_outer_nullability_zip_18119() -> None: + df = pl.Series("int", [0, 1, 2, 3], dtype=pl.Int64).to_frame() + assert df.lazy().with_columns( + result=pl.when(pl.col("int") >= 1).then( + pl.struct( + a=pl.when(pl.col("int") % 2 == 1).then(True), + b=pl.when(pl.col("int") >= 2).then(False), + ) + ) + ).collect().to_dict(as_series=False) == { + "int": [0, 1, 2, 3], + "result": [ + None, + {"a": True, "b": None}, + {"a": None, "b": False}, + {"a": True, "b": False}, + ], + } diff --git a/py-polars/tests/unit/test_queries.py b/py-polars/tests/unit/test_queries.py index 5ddf46531840..4c50a183af49 100644 --- a/py-polars/tests/unit/test_queries.py +++ b/py-polars/tests/unit/test_queries.py @@ -241,8 +241,8 @@ def map_expr(name: str) -> pl.Expr: ).to_dict(as_series=False) == { "groups": [1, 2, 3, 4], "out": [ - {"sum": None, "count": None}, - {"sum": None, "count": None}, + None, + None, {"sum": 1, "count": 1}, {"sum": 2, "count": 1}, ],