Skip to content

Commit

Permalink
feat: Implement arithmetic operations for Null columns (#14107)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Jan 30, 2024
1 parent a648478 commit badc110
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
27 changes: 27 additions & 0 deletions crates/polars-core/src/series/implementations/null.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,23 @@ impl PrivateSeries for NullChunked {
ExplodeByOffsets::explode_by_offsets(self, offsets)
}

fn subtract(&self, _rhs: &Series) -> PolarsResult<Series> {
null_arithmetic(self, _rhs, "subtract")
}

fn add_to(&self, _rhs: &Series) -> PolarsResult<Series> {
null_arithmetic(self, _rhs, "add_to")
}
fn multiply(&self, _rhs: &Series) -> PolarsResult<Series> {
null_arithmetic(self, _rhs, "multiply")
}
fn divide(&self, _rhs: &Series) -> PolarsResult<Series> {
null_arithmetic(self, _rhs, "divide")
}
fn remainder(&self, _rhs: &Series) -> PolarsResult<Series> {
null_arithmetic(self, _rhs, "remainder")
}

#[cfg(feature = "algorithm_group_by")]
fn group_tuples(&self, _multithreaded: bool, _sorted: bool) -> PolarsResult<GroupsProxy> {
Ok(if self.is_empty() {
Expand All @@ -98,6 +115,16 @@ impl PrivateSeries for NullChunked {
}
}

fn null_arithmetic(lhs: &NullChunked, rhs: &Series, op: &str) -> PolarsResult<Series> {
let output_len = match (lhs.len(), rhs.len()) {
(1, len_r) => len_r,
(len_l, 1) => len_l,
(len_l, len_r) if len_l == len_r => len_l,
_ => polars_bail!(ComputeError: "Cannot {:?} two series of different lengths.", op),
};
Ok(NullChunked::new(lhs.name().into(), output_len).into_series())
}

impl SeriesTrait for NullChunked {
fn name(&self) -> &str {
self.name.as_ref()
Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/operations/arithmetic/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,3 +270,26 @@ def test_operator_arithmetic_with_nulls(op: Any) -> None:

assert_frame_equal(df_expected, op(df, None))
assert_series_equal(s_expected, op(s, None))


@pytest.mark.parametrize(
"op",
[
operator.add,
operator.mod,
operator.mul,
operator.sub,
],
)
def test_null_column_arithmetic(op: Any) -> None:
df = pl.DataFrame({"a": [None, None], "b": [None, None]})
expected_df = pl.DataFrame({"a": [None, None]})

output_df = df.select(op(pl.col("a"), pl.col("b")))
assert_frame_equal(expected_df, output_df)
# test broadcast right
output_df = df.select(op(pl.col("a"), pl.Series([None])))
assert_frame_equal(expected_df, output_df)
# test broadcast left
output_df = df.select(op(pl.Series("a", [None]), pl.col("a")))
assert_frame_equal(expected_df, output_df)

0 comments on commit badc110

Please sign in to comment.