diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index a8290a9c01c1..c3ad0f780f6f 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -2464,32 +2464,47 @@ impl DataFrame { /// Aggregate the column horizontally to their sum values. pub fn sum_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult> { - let sum_fn = - |acc: &Series, s: &Series, null_strategy: NullStrategy| -> PolarsResult { - let mut acc = acc.clone(); - let mut s = s.clone(); + let apply_null_strategy = + |s: &Series, null_strategy: NullStrategy| -> PolarsResult { if let NullStrategy::Ignore = null_strategy { // if has nulls - if acc.has_validity() { - acc = acc.fill_null(FillNullStrategy::Zero)?; - } if s.has_validity() { - s = s.fill_null(FillNullStrategy::Zero)?; + return s.fill_null(FillNullStrategy::Zero); } } + Ok(s.clone()) + }; + + let sum_fn = + |acc: &Series, s: &Series, null_strategy: NullStrategy| -> PolarsResult { + let acc: Series = apply_null_strategy(acc, null_strategy)?; + let s = apply_null_strategy(s, null_strategy)?; Ok(&acc + &s) }; - match self.columns.len() { - 0 => Ok(None), - 1 => Ok(Some(self.columns[0].clone())), - 2 => sum_fn(&self.columns[0], &self.columns[1], null_strategy).map(Some), + let non_null_cols = self + .columns + .iter() + .filter(|x| x.dtype() != &DataType::Null) + .collect::>(); + + match non_null_cols.len() { + 0 => { + if self.columns.is_empty() { + Ok(None) + } else { + // all columns are null dtype, so result is null dtype + Ok(Some(self.columns[0].clone())) + } + }, + 1 => Ok(Some(apply_null_strategy(non_null_cols[0], null_strategy)?)), + 2 => sum_fn(non_null_cols[0], non_null_cols[1], null_strategy).map(Some), _ => { // the try_reduce_with is a bit slower in parallelism, // but I don't think it matters here as we parallelize over columns, not over elements POOL.install(|| { - self.columns - .par_iter() + non_null_cols + .into_par_iter() .map(|s| Ok(Cow::Borrowed(s))) .try_reduce_with(|l, r| sum_fn(&l, &r, null_strategy).map(Cow::Owned)) // we can unwrap the option, because we are certain there is a column diff --git a/py-polars/tests/unit/functions/aggregation/test_horizontal.py b/py-polars/tests/unit/functions/aggregation/test_horizontal.py index cc88e506919e..d5af947cb73f 100644 --- a/py-polars/tests/unit/functions/aggregation/test_horizontal.py +++ b/py-polars/tests/unit/functions/aggregation/test_horizontal.py @@ -248,6 +248,41 @@ def test_str_sum_horizontal() -> None: assert_series_equal(out["A"], pl.Series("A", ["af", "bg", "h", "c", ""])) +def test_sum_null_dtype() -> None: + df = pl.DataFrame( + { + "A": [5, None, 3, 2, 1], + "B": [5, 3, None, 2, 1], + "C": [None, None, None, None, None], + } + ) + + assert_series_equal( + df.select(pl.sum_horizontal("A", "B", "C")).to_series(), + pl.Series("A", [10, 3, 3, 4, 2]), + ) + assert_series_equal( + df.select(pl.sum_horizontal("C", "B")).to_series(), + pl.Series("C", [5, 3, 0, 2, 1]), + ) + assert_series_equal( + df.select(pl.sum_horizontal("C", "C")).to_series(), + pl.Series("C", [None, None, None, None, None]), + ) + + +def test_sum_single_col() -> None: + df = pl.DataFrame( + { + "A": [5, None, 3, None, 1], + } + ) + + assert_series_equal( + df.select(pl.sum_horizontal("A")).to_series(), pl.Series("A", [5, 0, 3, 0, 1]) + ) + + def test_cum_sum_horizontal() -> None: df = pl.DataFrame( {