Skip to content

Commit

Permalink
Fix strictness check in single case
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Jun 20, 2024
1 parent 62463f8 commit 2dc0d1f
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
17 changes: 13 additions & 4 deletions crates/polars-ops/src/series/ops/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ fn replace_by_single(
/// Fast path for replacing by a single value in strict mode
fn replace_by_single_strict(s: &Series, old: &Series, new: &Series) -> PolarsResult<Series> {
let mask = get_replacement_mask(s, old)?;
ensure_all_replaced(&mask, s, old.null_count() > 0)?;
ensure_all_replaced(&mask, s, old.null_count() > 0, true)?;

let mut out = new.new_from_index(0, s.len());

Expand Down Expand Up @@ -224,7 +224,7 @@ fn replace_by_multiple_strict(s: &Series, old: Series, new: Series) -> PolarsRes
.unwrap()
.bool()
.unwrap();
ensure_all_replaced(mask, s, old_has_null)?;
ensure_all_replaced(mask, s, old_has_null, false)?;

Ok(replaced.clone())
}
Expand Down Expand Up @@ -254,12 +254,21 @@ fn validate_new(new: &Series, old: &Series) -> PolarsResult<()> {
}

/// Ensure that all values were replaced.
fn ensure_all_replaced(mask: &BooleanChunked, s: &Series, old_has_null: bool) -> PolarsResult<()> {
let all_replaced = if old_has_null {
fn ensure_all_replaced(
mask: &BooleanChunked,
s: &Series,
old_has_null: bool,
check_all: bool,
) -> PolarsResult<()> {
let nulls_check = if old_has_null {
mask.null_count() == 0
} else {
mask.null_count() == s.null_count()
};
// Checking booleans is only relevant for the 'replace_by_single' path.
let bools_check = !check_all || mask.all();

let all_replaced = bools_check && nulls_check;
polars_ensure!(
all_replaced,
InvalidOperation: "incomplete mapping specified for `replace_strict`\n\nHint: Pass a `default` value to set unmapped values."
Expand Down
42 changes: 42 additions & 0 deletions py-polars/tests/unit/operations/test_replace_strict.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,48 @@ def test_replace_strict_incomplete_mapping() -> None:
lf.select(pl.col("a").replace_strict({2: 200, 3: 300})).collect()


def test_replace_strict_incomplete_mapping_null_raises() -> None:
s = pl.Series("a", [1, 2, 2, None, None])
with pytest.raises(InvalidOperationError):
s.replace_strict({1: 10})


def test_replace_strict_mapping_null_not_specified() -> None:
s = pl.Series("a", [1, 2, 2, None, None])

result = s.replace_strict({1: 10, 2: 20})

expected = pl.Series("a", [10, 20, 20, None, None])
assert_series_equal(result, expected)


def test_replace_strict_mapping_null_specified() -> None:
s = pl.Series("a", [1, 2, 2, None, None])

result = s.replace_strict({1: 10, 2: 20, None: 0})

expected = pl.Series("a", [10, 20, 20, 0, 0])
assert_series_equal(result, expected)


def test_replace_strict_mapping_null_replace_by_null() -> None:
s = pl.Series("a", [1, 2, 2, None])

result = s.replace_strict({1: 10, 2: None, None: 0})

expected = pl.Series("a", [10, None, None, 0])
assert_series_equal(result, expected)


def test_replace_strict_mapping_null_with_default() -> None:
s = pl.Series("a", [1, 2, 2, None, None])

result = s.replace_strict({1: 10}, default=0)

expected = pl.Series("a", [10, 0, 0, 0, 0])
assert_series_equal(result, expected)


def test_replace_strict_empty() -> None:
lf = pl.LazyFrame({"a": [None, None]})
result = lf.select(pl.col("a").replace_strict({}))
Expand Down

0 comments on commit 2dc0d1f

Please sign in to comment.