Skip to content

Commit

Permalink
fix(rust, python): treat null columns as zero in sum_horizontal (po…
Browse files Browse the repository at this point in the history
  • Loading branch information
edavisau authored and r-brink committed Jan 24, 2024
1 parent e4c1e05 commit fc3761d
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 14 deletions.
43 changes: 29 additions & 14 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2464,32 +2464,47 @@ impl DataFrame {

/// Aggregate the column horizontally to their sum values.
pub fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Series>> {
let sum_fn =
|acc: &Series, s: &Series, null_strategy: NullStrategy| -> PolarsResult<Series> {
let mut acc = acc.clone();
let mut s = s.clone();
let apply_null_strategy =
|s: &Series, null_strategy: NullStrategy| -> PolarsResult<Series> {
if let NullStrategy::Ignore = null_strategy {
// if has nulls
if acc.has_validity() {
acc = acc.fill_null(FillNullStrategy::Zero)?;
}
if s.has_validity() {
s = s.fill_null(FillNullStrategy::Zero)?;
return s.fill_null(FillNullStrategy::Zero);
}
}
Ok(s.clone())
};

let sum_fn =
|acc: &Series, s: &Series, null_strategy: NullStrategy| -> PolarsResult<Series> {
let acc: Series = apply_null_strategy(acc, null_strategy)?;
let s = apply_null_strategy(s, null_strategy)?;
Ok(&acc + &s)
};

match self.columns.len() {
0 => Ok(None),
1 => Ok(Some(self.columns[0].clone())),
2 => sum_fn(&self.columns[0], &self.columns[1], null_strategy).map(Some),
let non_null_cols = self
.columns
.iter()
.filter(|x| x.dtype() != &DataType::Null)
.collect::<Vec<_>>();

match non_null_cols.len() {
0 => {
if self.columns.is_empty() {
Ok(None)
} else {
// all columns are null dtype, so result is null dtype
Ok(Some(self.columns[0].clone()))
}
},
1 => Ok(Some(apply_null_strategy(non_null_cols[0], null_strategy)?)),
2 => sum_fn(non_null_cols[0], non_null_cols[1], null_strategy).map(Some),
_ => {
// the try_reduce_with is a bit slower in parallelism,
// but I don't think it matters here as we parallelize over columns, not over elements
POOL.install(|| {
self.columns
.par_iter()
non_null_cols
.into_par_iter()
.map(|s| Ok(Cow::Borrowed(s)))
.try_reduce_with(|l, r| sum_fn(&l, &r, null_strategy).map(Cow::Owned))
// we can unwrap the option, because we are certain there is a column
Expand Down
35 changes: 35 additions & 0 deletions py-polars/tests/unit/functions/aggregation/test_horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,41 @@ def test_str_sum_horizontal() -> None:
assert_series_equal(out["A"], pl.Series("A", ["af", "bg", "h", "c", ""]))


def test_sum_null_dtype() -> None:
df = pl.DataFrame(
{
"A": [5, None, 3, 2, 1],
"B": [5, 3, None, 2, 1],
"C": [None, None, None, None, None],
}
)

assert_series_equal(
df.select(pl.sum_horizontal("A", "B", "C")).to_series(),
pl.Series("A", [10, 3, 3, 4, 2]),
)
assert_series_equal(
df.select(pl.sum_horizontal("C", "B")).to_series(),
pl.Series("C", [5, 3, 0, 2, 1]),
)
assert_series_equal(
df.select(pl.sum_horizontal("C", "C")).to_series(),
pl.Series("C", [None, None, None, None, None]),
)


def test_sum_single_col() -> None:
df = pl.DataFrame(
{
"A": [5, None, 3, None, 1],
}
)

assert_series_equal(
df.select(pl.sum_horizontal("A")).to_series(), pl.Series("A", [5, 0, 3, 0, 1])
)


def test_cum_sum_horizontal() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit fc3761d

Please sign in to comment.