Skip to content

Commit

Permalink
fix: fix broadcasting in binary elementwise apply
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Nov 28, 2023
1 parent 38d016b commit df1b21d
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 3 deletions.
188 changes: 186 additions & 2 deletions crates/polars-core/src/chunked_array/ops/arity.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::error::Error;

use arrow::array::Array;
use arrow::legacy::utils::combine_validities_and;
use arrow::array::{new_null_array, Array};
use arrow::legacy::utils::{combine_validities_and, CustomIterTools};

use crate::datatypes::{ArrayCollectIterExt, ArrayFromIter, StaticArray};
use crate::prelude::{ChunkedArray, PolarsDataType};
Expand All @@ -27,6 +27,58 @@ impl<A1, A2, R, T: FnMut(A1, A2) -> R> BinaryFnMut<A1, A2> for T {
type Ret = R;
}

// SAFETY:
// At least one iterator has a length != 1, otherwise this will loop infinitely.
macro_rules! broadcast_apply {
(inner $func:ident, $iter_prev:ident, $iter_curr:ident) => {
{
let curr = $iter_prev;

if $iter_curr.len() == 1 {
let prev = curr;
let curr = prev.zip(std::iter::repeat($iter_curr.next().unwrap()));

$func!(curr)
} else {
let prev = curr;
let curr = prev.zip($iter_curr);

$func!(curr)
}
}
};

(inner $func:ident, $iter_prev:ident, $iter_curr:ident $(, $more:ident)+) => {
{
macro_rules! process_next
let curr = $iter_prev;

if $iter_curr.len() == 1 {
let prev = curr;
let curr = prev.zip(std::iter::repeat($iter_curr.next().unwrap()));

broadcast_apply!(inner $func, curr, $($more),*)
} else {
let prev = curr;
let curr = prev.zip($iter_curr);

broadcast_apply!(inner $func, curr, $($more),*)
}
}
};

($func:ident, $iter_first:ident $(, $more:ident)+) => {
{
if $iter_first.len() == 1 {
let $iter_first = std::iter::repeat($iter_first.next().unwrap());
broadcast_apply!(inner $func, $iter_first, $($more),*)
} else {
broadcast_apply!(inner $func, $iter_first, $($more),*)
}
}
};
}

#[inline]
pub fn binary_elementwise<T, U, V, F>(
lhs: &ChunkedArray<T>,
Expand All @@ -42,6 +94,35 @@ where
<F as BinaryFnMut<Option<T::Physical<'a>>, Option<U::Physical<'a>>>>::Ret,
>,
{
if lhs.len() != rhs.len() {
let broadcast_to = lhs.len().max(rhs.len());

debug_assert!(lhs.len() == 1 || lhs.len() == broadcast_to);
debug_assert!(rhs.len() == 1 || rhs.len() == broadcast_to);

macro_rules! ca_to_iter {
($ca:ident) => {
unsafe {
$ca.downcast_iter()
.flat_map(|arr| arr.iter())
.trust_my_length($ca.len())
}
};
}

macro_rules! collect_func {
($x:ident) => {
$x.map(|(a, b)| op(a, b)).collect_arr()
};
}

let mut a = ca_to_iter!(lhs);
let mut b = ca_to_iter!(rhs);
let out: V::Array = broadcast_apply!(collect_func, a, b);

return ChunkedArray::with_chunk(lhs.name(), out);
}

let (lhs, rhs) = align_chunks_binary(lhs, rhs);
let iter = lhs
.downcast_iter()
Expand Down Expand Up @@ -119,6 +200,35 @@ where
F: for<'a> FnMut(Option<T::Physical<'a>>, Option<U::Physical<'a>>) -> Result<Option<K>, E>,
V::Array: ArrayFromIter<Option<K>>,
{
if lhs.len() != rhs.len() {
let broadcast_to = lhs.len().max(rhs.len());

debug_assert!(lhs.len() == 1 || lhs.len() == broadcast_to);
debug_assert!(rhs.len() == 1 || rhs.len() == broadcast_to);

macro_rules! ca_to_iter {
($ca:ident) => {
unsafe {
$ca.downcast_iter()
.flat_map(|arr| arr.iter())
.trust_my_length($ca.len())
}
};
}

macro_rules! collect_func {
($x:ident) => {
$x.map(|(a, b)| op(a, b)).try_collect_arr()
};
}

let mut a = ca_to_iter!(lhs);
let mut b = ca_to_iter!(rhs);
let out: V::Array = broadcast_apply!(collect_func, a, b)?;

return Ok(ChunkedArray::with_chunk(lhs.name(), out));
}

let (lhs, rhs) = align_chunks_binary(lhs, rhs);
let iter = lhs
.downcast_iter()
Expand Down Expand Up @@ -146,6 +256,43 @@ where
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> K,
V::Array: ArrayFromIter<K>,
{
if lhs.len() != rhs.len() {
let broadcast_to = lhs.len().max(rhs.len());

debug_assert!(lhs.len() == 1 || lhs.len() == broadcast_to);
debug_assert!(rhs.len() == 1 || rhs.len() == broadcast_to);

if (lhs.len() == 1 && !lhs.downcast_iter().next().unwrap().is_valid(0))
|| (rhs.len() == 1 && !rhs.downcast_iter().next().unwrap().is_valid(0))
{
let arr = &*new_null_array(V::get_dtype().to_arrow(), broadcast_to);
let arr = unsafe { std::ptr::read(arr as *const dyn Array as *const V::Array) };
return ChunkedArray::with_chunk(lhs.name(), arr);
}

macro_rules! ca_to_iter {
($ca:ident) => {
unsafe {
$ca.downcast_iter()
.flat_map(|arr| arr.values_iter())
.trust_my_length($ca.len())
}
};
}

macro_rules! collect_func {
($x:ident) => {
$x.map(|(a, b)| op(a, b)).collect_arr()
};
}

let mut a = ca_to_iter!(lhs);
let mut b = ca_to_iter!(rhs);
let out: V::Array = broadcast_apply!(collect_func, a, b);

return ChunkedArray::with_chunk(lhs.name(), out);
}

let (lhs, rhs) = align_chunks_binary(lhs, rhs);

let iter = lhs
Expand Down Expand Up @@ -178,6 +325,43 @@ where
F: for<'a> FnMut(T::Physical<'a>, U::Physical<'a>) -> Result<K, E>,
V::Array: ArrayFromIter<K>,
{
if lhs.len() != rhs.len() {
let broadcast_to = lhs.len().max(rhs.len());

debug_assert!(lhs.len() == 1 || lhs.len() == broadcast_to);
debug_assert!(rhs.len() == 1 || rhs.len() == broadcast_to);

if (lhs.len() == 1 && !lhs.downcast_iter().next().unwrap().is_valid(0))
|| (rhs.len() == 1 && !rhs.downcast_iter().next().unwrap().is_valid(0))
{
let arr = &*new_null_array(V::get_dtype().to_arrow(), broadcast_to);
let arr = unsafe { std::ptr::read(arr as *const dyn Array as *const V::Array) };
return Ok(ChunkedArray::with_chunk(lhs.name(), arr));
}

macro_rules! ca_to_iter {
($ca:ident) => {
unsafe {
$ca.downcast_iter()
.flat_map(|arr| arr.values_iter())
.trust_my_length($ca.len())
}
};
}

macro_rules! collect_func {
($x:ident) => {
$x.map(|(a, b)| op(a, b)).try_collect_arr()
};
}

let mut a = ca_to_iter!(lhs);
let mut b = ca_to_iter!(rhs);
let out: V::Array = broadcast_apply!(collect_func, a, b)?;

return Ok(ChunkedArray::with_chunk(lhs.name(), out));
}

let (lhs, rhs) = align_chunks_binary(lhs, rhs);
let iter = lhs
.downcast_iter()
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ pub struct Flat;
///
/// The StaticArray and dtype return must be correct.
pub unsafe trait PolarsDataType: Send + Sync + Sized {
type Physical<'a>: std::fmt::Debug;
type Physical<'a>: std::fmt::Debug + Clone;
type ZeroablePhysical<'a>: Zeroable + From<Self::Physical<'a>>;
type Array: for<'a> StaticArray<
ValueT<'a> = Self::Physical<'a>,
Expand Down
32 changes: 32 additions & 0 deletions py-polars/tests/unit/test_arity.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,35 @@ def test_when_then_broadcast_nulls_12665() -> None:
assert df.select(
when=pl.when(pl.col("val") > pl.col("threshold")).then(1).otherwise(0),
).to_dict(as_series=False) == {"when": [0, 0, 0, 1]}


def test_broadcast_string_ops_12632() -> None:
df = pl.DataFrame(
[
{"name": "COMPANY A", "id": 1},
{"name": "COMPANY B", "id": 2},
{"name": "COMPANY C", "id": 3},
]
)

for needs_broadcast in (pl.lit("COMPANY A"), pl.col("name").head(1)):
for literal in (True, False):
assert df.select(
needs_broadcast.str.contains(pl.col("name"), literal=literal)
).to_series().to_list() == [True, False, False]

assert df.select(
needs_broadcast.str.starts_with(pl.col("name"))
).to_series().to_list() == [True, False, False]

assert df.select(
needs_broadcast.str.ends_with(pl.col("name"))
).to_series().to_list() == [True, False, False]

assert df.select(needs_broadcast.str.strip_chars(pl.col("name"))).height == 3
assert (
df.select(needs_broadcast.str.strip_chars_start(pl.col("name"))).height == 3
)
assert (
df.select(needs_broadcast.str.strip_chars_end(pl.col("name"))).height == 3
)

0 comments on commit df1b21d

Please sign in to comment.