Skip to content

Commit

Permalink
fix: Struct outer nullabillity (#18156)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Aug 13, 2024
1 parent 1b5fa4c commit c8fde41
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 22 deletions.
2 changes: 1 addition & 1 deletion crates/polars-compute/src/if_then_else/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl<T: NotSimdPrimitive> IfThenElseKernel for PrimitiveArray<T> {
}
}

fn if_then_else_validity(
pub fn if_then_else_validity(
mask: &Bitmap,
if_true: Option<&Bitmap>,
if_false: Option<&Bitmap>,
Expand Down
6 changes: 4 additions & 2 deletions crates/polars-core/src/chunked_array/ops/full.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use arrow::bitmap::MutableBitmap;
use arrow::bitmap::{Bitmap, MutableBitmap};

use crate::chunked_array::builder::get_list_builder;
use crate::prelude::*;
Expand Down Expand Up @@ -189,7 +189,9 @@ impl ListChunked {
impl ChunkFullNull for StructChunked {
fn full_null(name: &str, length: usize) -> StructChunked {
let s = vec![Series::new_null("", length)];
StructChunked::from_series(name, &s).unwrap()
StructChunked::from_series(name, &s)
.unwrap()
.with_outer_validity(Some(Bitmap::new_zeroed(length)))
}
}

Expand Down
105 changes: 103 additions & 2 deletions crates/polars-core/src/chunked_array/ops/zip.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use arrow::bitmap::Bitmap;
use arrow::compute::utils::{combine_validities_and, combine_validities_and_not};
use polars_compute::if_then_else::IfThenElseKernel;
use polars_compute::if_then_else::{if_then_else_validity, IfThenElseKernel};

#[cfg(feature = "object")]
use crate::chunked_array::object::ObjectArray;
Expand Down Expand Up @@ -62,7 +62,7 @@ fn combine_validities_chunked<

impl<T> ChunkZip<T> for ChunkedArray<T>
where
T: PolarsDataType,
T: PolarsDataType<IsStruct = FalseT>,
T::Array: for<'a> IfThenElseKernel<Scalar<'a> = T::Physical<'a>>,
ChunkedArray<T>: ChunkExpandAtIndex<T>,
{
Expand Down Expand Up @@ -206,3 +206,104 @@ impl<T: PolarsObject> IfThenElseKernel for ObjectArray<T> {
.collect_arr()
}
}

#[cfg(feature = "dtype-struct")]
impl ChunkZip<StructType> for StructChunked {
fn zip_with(
&self,
mask: &BooleanChunked,
other: &ChunkedArray<StructType>,
) -> PolarsResult<ChunkedArray<StructType>> {
let (l, r, mask) = align_chunks_ternary(self, other, mask);

// Prepare the boolean arrays such that Null maps to false.
// This prevents every field doing that.
// # SAFETY
// We don't modify the length and update the null count.
let mut mask = mask.into_owned();
unsafe {
for arr in mask.downcast_iter_mut() {
let bm = bool_null_to_false(arr);
*arr = BooleanArray::from_data_default(bm, None);
}
mask.set_null_count(0);
}

// Zip all the fields.
let fields = l
.fields_as_series()
.iter()
.zip(r.fields_as_series())
.map(|(lhs, rhs)| lhs.zip_with_same_type(&mask, &rhs))
.collect::<PolarsResult<Vec<_>>>()?;

let mut out = StructChunked::from_series(self.name(), &fields)?;

// Zip the validities.
if (l.null_count + r.null_count) > 0 {
let validities = l
.chunks()
.iter()
.zip(r.chunks())
.map(|(l, r)| (l.validity(), r.validity()));

fn broadcast(v: Option<&Bitmap>, arr: &ArrayRef) -> Bitmap {
if v.unwrap().get(0).unwrap() {
Bitmap::new_with_value(true, arr.len())
} else {
Bitmap::new_zeroed(arr.len())
}
}

// # SAFETY
// We don't modify the length and update the null count.
unsafe {
for ((arr, (lv, rv)), mask) in out
.chunks_mut()
.iter_mut()
.zip(validities)
.zip(mask.downcast_iter())
{
// TODO! we can optimize this and use a kernel that is able to broadcast wo/ allocating.
let (lv, rv) = match (lv.map(|b| b.len()), rv.map(|b| b.len())) {
(Some(1), Some(1)) if arr.len() != 1 => {
let lv = broadcast(lv, arr);
let rv = broadcast(rv, arr);
(Some(lv), Some(rv))
},
(Some(a), Some(b)) if a == b => (lv.cloned(), rv.cloned()),
(Some(1), _) => {
let lv = broadcast(lv, arr);
(Some(lv), rv.cloned())
},
(_, Some(1)) => {
let rv = broadcast(rv, arr);
(lv.cloned(), Some(rv))
},
(None, Some(_)) | (Some(_), None) | (None, None) => {
(lv.cloned(), rv.cloned())
},
(Some(a), Some(b)) => {
polars_bail!(InvalidOperation: "got different sizes in 'zip' operation, got length: {a} and {b}")
},
};

// broadcast mask
let validity = if mask.len() != arr.len() && mask.len() == 1 {
if mask.get(0).unwrap() {
lv
} else {
rv
}
} else {
if_then_else_validity(mask.values(), lv.as_ref(), rv.as_ref())
};

*arr = arr.with_validity(validity);
}
}
out.compute_len();
}
Ok(out)
}
}
12 changes: 3 additions & 9 deletions crates/polars-core/src/series/implementations/struct__.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,9 @@ impl PrivateSeries for SeriesWrap<StructChunked> {

#[cfg(feature = "zip_with")]
fn zip_with_same_type(&self, mask: &BooleanChunked, other: &Series) -> PolarsResult<Series> {
let other = other.struct_()?;
let fields = self
.0
.fields_as_series()
.iter()
.zip(other.fields_as_series())
.map(|(lhs, rhs)| lhs.zip_with_same_type(mask, &rhs))
.collect::<PolarsResult<Vec<_>>>()?;
StructChunked::from_series(self.0.name(), &fields).map(|ca| ca.into_series())
self.0
.zip_with(mask, other.struct_()?)
.map(|ca| ca.into_series())
}

#[cfg(feature = "algorithm_group_by")]
Expand Down
13 changes: 10 additions & 3 deletions crates/polars-core/src/series/ops/null.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use arrow::bitmap::Bitmap;

#[cfg(feature = "object")]
use crate::chunked_array::object::registry::get_object_builder;
use crate::prelude::*;
Expand Down Expand Up @@ -53,9 +55,14 @@ impl Series {
.iter()
.map(|fld| Series::full_null(fld.name(), size, fld.data_type()))
.collect::<Vec<_>>();
StructChunked::from_series(name, &fields)
.unwrap()
.into_series()
let ca = StructChunked::from_series(name, &fields).unwrap();

if !fields.is_empty() {
ca.with_outer_validity(Some(Bitmap::new_zeroed(size)))
.into_series()
} else {
ca.into_series()
}
},
DataType::Null => Series::new_null(name, size),
DataType::Unknown(kind) => {
Expand Down
26 changes: 23 additions & 3 deletions py-polars/tests/unit/datatypes/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,15 +623,15 @@ def test_struct_categorical_5843() -> None:
def test_empty_struct() -> None:
# List<struct>
df = pl.DataFrame({"a": [[{}]]})
assert df.to_dict(as_series=False) == {"a": [[{"": None}]]}
assert df.to_dict(as_series=False) == {"a": [[None]]}

# Struct one not empty
df = pl.DataFrame({"a": [[{}, {"a": 10}]]})
assert df.to_dict(as_series=False) == {"a": [[{"a": None}, {"a": 10}]]}

# Empty struct
df = pl.DataFrame({"a": [{}]})
assert df.to_dict(as_series=False) == {"a": [{"": None}]}
assert df.to_dict(as_series=False) == {"a": [None]}


@pytest.mark.parametrize(
Expand Down Expand Up @@ -710,7 +710,7 @@ def test_struct_null_cast() -> None:
.lazy()
.select([pl.lit(None, dtype=pl.Null).cast(dtype, strict=True)])
.collect()
).to_dict(as_series=False) == {"literal": [{"a": None, "b": None, "c": None}]}
).to_dict(as_series=False) == {"literal": [None]}


def test_nested_struct_in_lists_cast() -> None:
Expand Down Expand Up @@ -976,3 +976,23 @@ def test_named_exprs() -> None:
res = df.select(pl.struct(schema=schema, b=pl.col("a")))
assert res.to_dict(as_series=False) == {"b": [{"b": 1}]}
assert res.schema["b"] == pl.Struct(schema)


def test_struct_outer_nullability_zip_18119() -> None:
df = pl.Series("int", [0, 1, 2, 3], dtype=pl.Int64).to_frame()
assert df.lazy().with_columns(
result=pl.when(pl.col("int") >= 1).then(
pl.struct(
a=pl.when(pl.col("int") % 2 == 1).then(True),
b=pl.when(pl.col("int") >= 2).then(False),
)
)
).collect().to_dict(as_series=False) == {
"int": [0, 1, 2, 3],
"result": [
None,
{"a": True, "b": None},
{"a": None, "b": False},
{"a": True, "b": False},
],
}
4 changes: 2 additions & 2 deletions py-polars/tests/unit/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ def map_expr(name: str) -> pl.Expr:
).to_dict(as_series=False) == {
"groups": [1, 2, 3, 4],
"out": [
{"sum": None, "count": None},
{"sum": None, "count": None},
None,
None,
{"sum": 1, "count": 1},
{"sum": 2, "count": 1},
],
Expand Down

0 comments on commit c8fde41

Please sign in to comment.