From 5fbce482239cd9d7495c93902253107d843602dd Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Thu, 18 Jan 2024 05:13:35 +0800 Subject: [PATCH] feat: Impl `count_matches` for array namespace (#13675) Co-authored-by: Stijn de Gooijer --- crates/polars-lazy/Cargo.toml | 2 +- crates/polars-ops/Cargo.toml | 1 + .../src/chunked_array/array/count.rs | 45 +++++++++++++++++++ .../polars-ops/src/chunked_array/array/mod.rs | 2 + .../src/chunked_array/array/namespace.rs | 8 ++++ crates/polars-plan/Cargo.toml | 1 + crates/polars-plan/src/dsl/array.rs | 18 ++++++++ .../src/dsl/function_expr/array.rs | 21 +++++++++ .../polars-plan/src/dsl/function_expr/list.rs | 4 +- crates/polars/Cargo.toml | 1 + py-polars/Cargo.toml | 2 + .../source/reference/expressions/array.rst | 1 + .../docs/source/reference/series/array.rst | 1 + py-polars/polars/expr/array.py | 31 ++++++++++++- py-polars/polars/series/array.py | 24 +++++++++- py-polars/polars/series/list.py | 5 +-- py-polars/src/expr/array.rs | 5 +++ py-polars/tests/unit/datatypes/test_list.py | 3 -- .../tests/unit/namespaces/array/test_array.py | 17 +++++++ 19 files changed, 181 insertions(+), 11 deletions(-) create mode 100644 crates/polars-ops/src/chunked_array/array/count.rs diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 04d03ff42a42..639d6431e39f 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -81,7 +81,7 @@ sign = ["polars-plan/sign"] timezones = ["polars-plan/timezones"] list_gather = ["polars-ops/list_gather", "polars-plan/list_gather"] list_count = ["polars-ops/list_count", "polars-plan/list_count"] - +array_count = ["polars-ops/array_count", "polars-plan/array_count", "dtype-array"] true_div = ["polars-plan/true_div"] extract_jsonpath = ["polars-plan/extract_jsonpath", "polars-ops/extract_jsonpath"] diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index cb5e247a5087..40f98833c1d2 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -110,6 +110,7 @@ chunked_ids = ["polars-core/chunked_ids"] asof_join = ["polars-core/asof_join"] semi_anti_join = [] array_any_all = ["dtype-array"] +array_count = ["dtype-array"] list_gather = [] list_sets = [] list_any_all = [] diff --git a/crates/polars-ops/src/chunked_array/array/count.rs b/crates/polars-ops/src/chunked_array/array/count.rs new file mode 100644 index 000000000000..249301b189df --- /dev/null +++ b/crates/polars-ops/src/chunked_array/array/count.rs @@ -0,0 +1,45 @@ +use arrow::array::{Array, BooleanArray}; +use arrow::bitmap::utils::count_zeros; +use arrow::bitmap::Bitmap; +use arrow::legacy::utils::CustomIterTools; +use polars_core::prelude::arity::unary_mut_with_options; + +use super::*; + +pub fn array_count_matches(ca: &ArrayChunked, value: AnyValue) -> PolarsResult { + let value = Series::new("", [value]); + + let ca = ca.apply_to_inner(&|s| { + ChunkCompare::<&Series>::equal_missing(&s, &value).map(|ca| ca.into_series()) + })?; + let out = count_boolean_bits(&ca); + Ok(out.into_series()) +} + +fn count_boolean_bits(ca: &ArrayChunked) -> IdxCa { + unary_mut_with_options(ca, |arr| { + let inner_arr = arr.values(); + let mask = inner_arr.as_any().downcast_ref::().unwrap(); + assert_eq!(mask.null_count(), 0); + let out = count_bits_set(mask.values(), arr.len(), arr.size()); + IdxArr::from_data_default(out.into(), arr.validity().cloned()) + }) +} + +fn count_bits_set(values: &Bitmap, len: usize, width: usize) -> Vec { + // Fast path where all bits are either set or unset. + if values.unset_bits() == values.len() { + return vec![0 as IdxSize; len]; + } else if values.unset_bits() == 0 { + return vec![width as IdxSize; len]; + } + + let (bits, bitmap_offset, _) = values.as_slice(); + + (0..len) + .map(|i| { + let set_ones = width - count_zeros(bits, bitmap_offset + i * width, width); + set_ones as IdxSize + }) + .collect_trusted() +} diff --git a/crates/polars-ops/src/chunked_array/array/mod.rs b/crates/polars-ops/src/chunked_array/array/mod.rs index 453559baf392..601a781ce0a5 100644 --- a/crates/polars-ops/src/chunked_array/array/mod.rs +++ b/crates/polars-ops/src/chunked_array/array/mod.rs @@ -1,5 +1,7 @@ #[cfg(feature = "array_any_all")] mod any_all; +#[cfg(feature = "array_count")] +mod count; mod get; mod join; mod min_max; diff --git a/crates/polars-ops/src/chunked_array/array/namespace.rs b/crates/polars-ops/src/chunked_array/array/namespace.rs index c4fac7ac6214..561d8f80f2b2 100644 --- a/crates/polars-ops/src/chunked_array/array/namespace.rs +++ b/crates/polars-ops/src/chunked_array/array/namespace.rs @@ -1,5 +1,7 @@ use super::min_max::AggType; use super::*; +#[cfg(feature = "array_count")] +use crate::chunked_array::array::count::array_count_matches; use crate::chunked_array::array::sum_mean::sum_with_nulls; #[cfg(feature = "array_any_all")] use crate::prelude::array::any_all::{array_all, array_any}; @@ -104,6 +106,12 @@ pub trait ArrayNameSpace: AsArray { let ca = self.as_array(); array_join(ca, separator).map(|ok| ok.into_series()) } + + #[cfg(feature = "array_count")] + fn array_count_matches(&self, element: AnyValue) -> PolarsResult { + let ca = self.as_array(); + array_count_matches(ca, element) + } } impl ArrayNameSpace for ArrayChunked {} diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 1b546a776656..e7831edaf5a2 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -80,6 +80,7 @@ object = ["polars-core/object"] date_offset = ["polars-time", "chrono"] list_gather = ["polars-ops/list_gather"] list_count = ["polars-ops/list_count"] +array_count = ["polars-ops/array_count", "dtype-array"] trigonometry = [] sign = [] timezones = ["chrono-tz", "polars-time/timezones", "polars-core/timezones", "regex"] diff --git a/crates/polars-plan/src/dsl/array.rs b/crates/polars-plan/src/dsl/array.rs index b8b49299c702..a981d04b1519 100644 --- a/crates/polars-plan/src/dsl/array.rs +++ b/crates/polars-plan/src/dsl/array.rs @@ -114,4 +114,22 @@ impl ArrayNameSpace { false, ) } + + #[cfg(feature = "array_count")] + /// Count how often the value produced by ``element`` occurs. + pub fn count_matches>(self, element: E) -> Expr { + let other = element.into(); + + self.0 + .map_many_private( + FunctionExpr::ArrayExpr(ArrayFunction::CountMatches), + &[other], + false, + false, + ) + .with_function_options(|mut options| { + options.input_wildcard_expansion = true; + options + }) + } } diff --git a/crates/polars-plan/src/dsl/function_expr/array.rs b/crates/polars-plan/src/dsl/function_expr/array.rs index 2bd9626b53f7..1e36e32bf9dd 100644 --- a/crates/polars-plan/src/dsl/function_expr/array.rs +++ b/crates/polars-plan/src/dsl/function_expr/array.rs @@ -23,6 +23,8 @@ pub enum ArrayFunction { Join, #[cfg(feature = "is_in")] Contains, + #[cfg(feature = "array_count")] + CountMatches, } impl ArrayFunction { @@ -42,6 +44,8 @@ impl ArrayFunction { Join => mapper.with_dtype(DataType::String), #[cfg(feature = "is_in")] Contains => mapper.with_dtype(DataType::Boolean), + #[cfg(feature = "array_count")] + CountMatches => mapper.with_dtype(IDX_DTYPE), } } } @@ -75,6 +79,8 @@ impl Display for ArrayFunction { Join => "join", #[cfg(feature = "is_in")] Contains => "contains", + #[cfg(feature = "array_count")] + CountMatches => "count_matches", }; write!(f, "arr.{name}") } @@ -101,6 +107,8 @@ impl From for SpecialEq> { Join => map_as_slice!(join), #[cfg(feature = "is_in")] Contains => map_as_slice!(contains), + #[cfg(feature = "array_count")] + CountMatches => map_as_slice!(count_matches), } } } @@ -177,3 +185,16 @@ pub(super) fn contains(s: &[Series]) -> PolarsResult { let item = &s[1]; Ok(is_in(item, array)?.with_name(array.name()).into_series()) } + +#[cfg(feature = "array_count")] +pub(super) fn count_matches(args: &[Series]) -> PolarsResult { + let s = &args[0]; + let element = &args[1]; + polars_ensure!( + element.len() == 1, + ComputeError: "argument expression in `arr.count_matches` must produce exactly one element, got {}", + element.len() + ); + let ca = s.array()?; + ca.array_count_matches(element.get(0).unwrap()) +} diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 302cfcb6fed0..544a867a05d4 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -128,7 +128,7 @@ impl Display for ListFunction { #[cfg(feature = "list_gather")] Gather(_) => "gather", #[cfg(feature = "list_count")] - CountMatches => "count", + CountMatches => "count_matches", Sum => "sum", Min => "min", Max => "max", @@ -459,7 +459,7 @@ pub(super) fn count_matches(args: &[Series]) -> PolarsResult { let element = &args[1]; polars_ensure!( element.len() == 1, - ComputeError: "argument expression in `arr.count` must produce exactly one element, got {}", + ComputeError: "argument expression in `list.count_matches` must produce exactly one element, got {}", element.len() ); let ca = s.list()?; diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 396457769b26..13bae210ae94 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -155,6 +155,7 @@ is_unique = ["polars-lazy?/is_unique", "polars-ops/is_unique"] regex = ["polars-lazy?/regex"] list_any_all = ["polars-lazy?/list_any_all"] list_count = ["polars-ops/list_count", "polars-lazy?/list_count"] +array_count = ["polars-ops/array_count", "polars-lazy?/array_count", "dtype-array"] list_drop_nulls = ["polars-lazy?/list_drop_nulls"] list_eval = ["polars-lazy?/list_eval"] list_gather = ["polars-ops/list_gather", "polars-lazy?/list_gather"] diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index 61a1be46193d..a6d8002ab558 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -137,6 +137,7 @@ cse = ["polars/cse"] merge_sorted = ["polars/merge_sorted"] list_gather = ["polars/list_gather"] list_count = ["polars/list_count"] +array_count = ["polars/array_count", "polars/dtype-array"] binary_encoding = ["polars/binary_encoding"] list_sets = ["polars-lazy/list_sets"] list_any_all = ["polars/list_any_all"] @@ -163,6 +164,7 @@ dtypes = [ operations = [ "array_any_all", + "array_count", "is_in", "repeat_by", "trigonometry", diff --git a/py-polars/docs/source/reference/expressions/array.rst b/py-polars/docs/source/reference/expressions/array.rst index 28f20f9ed2bf..98ca6304841e 100644 --- a/py-polars/docs/source/reference/expressions/array.rst +++ b/py-polars/docs/source/reference/expressions/array.rst @@ -25,3 +25,4 @@ The following methods are available under the `expr.arr` attribute. Expr.arr.last Expr.arr.join Expr.arr.contains + Expr.arr.count_matches diff --git a/py-polars/docs/source/reference/series/array.rst b/py-polars/docs/source/reference/series/array.rst index 1efe58127d64..e5534ac06e74 100644 --- a/py-polars/docs/source/reference/series/array.rst +++ b/py-polars/docs/source/reference/series/array.rst @@ -25,3 +25,4 @@ The following methods are available under the `Series.arr` attribute. Series.arr.last Series.arr.join Series.arr.contains + Series.arr.count_matches \ No newline at end of file diff --git a/py-polars/polars/expr/array.py b/py-polars/polars/expr/array.py index 6e89560c9b9b..249a685b636a 100644 --- a/py-polars/polars/expr/array.py +++ b/py-polars/polars/expr/array.py @@ -9,7 +9,7 @@ from datetime import date, datetime, time from polars import Expr - from polars.type_aliases import IntoExprColumn + from polars.type_aliases import IntoExpr, IntoExprColumn class ExprArrayNameSpace: @@ -509,3 +509,32 @@ def contains( """ item = parse_as_expression(item, str_as_lit=True) return wrap_expr(self._pyexpr.arr_contains(item)) + + def count_matches(self, element: IntoExpr) -> Expr: + """ + Count how often the value produced by `element` occurs. + + Parameters + ---------- + element + An expression that produces a single value + + Examples + -------- + >>> df = pl.DataFrame( + ... {"a": [[1, 2], [1, 1], [2, 2]]}, schema={"a": pl.Array(pl.Int64, 2)} + ... ) + >>> df.with_columns(number_of_twos=pl.col("a").arr.count_matches(2)) + shape: (3, 2) + ┌───────────────┬────────────────┐ + │ a ┆ number_of_twos │ + │ --- ┆ --- │ + │ array[i64, 2] ┆ u32 │ + ╞═══════════════╪════════════════╡ + │ [1, 2] ┆ 1 │ + │ [1, 1] ┆ 0 │ + │ [2, 2] ┆ 2 │ + └───────────────┴────────────────┘ + """ + element = parse_as_expression(element, str_as_lit=True) + return wrap_expr(self._pyexpr.arr_count_matches(element)) diff --git a/py-polars/polars/series/array.py b/py-polars/polars/series/array.py index 06e2ddf801ea..00597a719c83 100644 --- a/py-polars/polars/series/array.py +++ b/py-polars/polars/series/array.py @@ -9,7 +9,7 @@ from polars import Series from polars.polars import PySeries - from polars.type_aliases import IntoExprColumn + from polars.type_aliases import IntoExpr, IntoExprColumn @expr_dispatch @@ -406,3 +406,25 @@ def contains( ] """ + + def count_matches(self, element: IntoExpr) -> Series: + """ + Count how often the value produced by `element` occurs. + + Parameters + ---------- + element + An expression that produces a single value + + Examples + -------- + >>> s = pl.Series("a", [[1, 2, 3], [2, 2, 2]], dtype=pl.Array(pl.Int64, 3)) + >>> s.arr.count_matches(2) + shape: (2,) + Series: 'a' [u32] + [ + 1 + 3 + ] + + """ diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index f889032fe978..56a52064838b 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -16,6 +16,7 @@ from polars import Expr, Series from polars.polars import PySeries from polars.type_aliases import ( + IntoExpr, IntoExprColumn, NullBehavior, ToStructStrategy, @@ -694,9 +695,7 @@ def explode(self) -> Series: ] """ - def count_matches( - self, element: float | str | bool | int | date | datetime | time | Expr - ) -> Expr: + def count_matches(self, element: IntoExpr) -> Series: """ Count how often the value produced by `element` occurs. diff --git a/py-polars/src/expr/array.rs b/py-polars/src/expr/array.rs index 382b072f4043..70a395df5fba 100644 --- a/py-polars/src/expr/array.rs +++ b/py-polars/src/expr/array.rs @@ -73,4 +73,9 @@ impl PyExpr { fn arr_contains(&self, other: PyExpr) -> Self { self.inner.clone().arr().contains(other.inner).into() } + + #[cfg(feature = "array_count")] + fn arr_count_matches(&self, expr: PyExpr) -> Self { + self.inner.clone().arr().count_matches(expr.inner).into() + } } diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index 623aa7081dc6..021a8696a406 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -304,9 +304,6 @@ def test_list_count_matches() -> None: assert pl.DataFrame({"listcol": [[], [1], [1, 2, 3, 2], [1, 2, 1], [4, 4]]}).select( pl.col("listcol").list.count_matches(2).alias("number_of_twos") ).to_dict(as_series=False) == {"number_of_twos": [0, 0, 2, 1, 0]} - assert pl.DataFrame({"listcol": [[], [1], [1, 2, 3, 2], [1, 2, 1], [4, 4]]}).select( - pl.col("listcol").list.count_matches(2).alias("number_of_twos") - ).to_dict(as_series=False) == {"number_of_twos": [0, 0, 2, 1, 0]} def test_list_sum_and_dtypes() -> None: diff --git a/py-polars/tests/unit/namespaces/array/test_array.py b/py-polars/tests/unit/namespaces/array/test_array.py index c9c22507eb6a..c579a0882e24 100644 --- a/py-polars/tests/unit/namespaces/array/test_array.py +++ b/py-polars/tests/unit/namespaces/array/test_array.py @@ -280,3 +280,20 @@ def test_array_contains_literal( out = df.select(contains=pl.col("array").arr.contains(data)).to_series() expected_series = pl.Series("contains", expected) assert_series_equal(out, expected_series) + + +@pytest.mark.parametrize( + ("arr", "data", "expected", "dtype"), + [ + ([[1, 2], [3, None], None], 1, [1, 0, None], pl.Int64), + ([[True, False], [True, None], None], True, [1, 1, None], pl.Boolean), + ([["a", "b"], ["c", None], None], "a", [1, 0, None], pl.String), + ([[b"a", b"b"], [b"c", None], None], b"a", [1, 0, None], pl.Binary), + ], +) +def test_array_count_matches( + arr: list[list[Any] | None], data: Any, expected: list[Any], dtype: pl.DataType +) -> None: + df = pl.DataFrame({"arr": arr}, schema={"arr": pl.Array(dtype, 2)}) + out = df.select(count_matches=pl.col("arr").arr.count_matches(data)) + assert out.to_dict(as_series=False) == {"count_matches": expected}