Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Impl count_matches for array namespace #13675

Merged
merged 3 commits into from
Jan 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ sign = ["polars-plan/sign"]
timezones = ["polars-plan/timezones"]
list_gather = ["polars-ops/list_gather", "polars-plan/list_gather"]
list_count = ["polars-ops/list_count", "polars-plan/list_count"]

array_count = ["polars-ops/array_count", "polars-plan/array_count", "dtype-array"]
true_div = ["polars-plan/true_div"]
extract_jsonpath = ["polars-plan/extract_jsonpath", "polars-ops/extract_jsonpath"]

Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ chunked_ids = ["polars-core/chunked_ids"]
asof_join = ["polars-core/asof_join"]
semi_anti_join = []
array_any_all = ["dtype-array"]
array_count = ["dtype-array"]
list_gather = []
list_sets = []
list_any_all = []
Expand Down
45 changes: 45 additions & 0 deletions crates/polars-ops/src/chunked_array/array/count.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use arrow::array::{Array, BooleanArray};
use arrow::bitmap::utils::count_zeros;
use arrow::bitmap::Bitmap;
use arrow::legacy::utils::CustomIterTools;
use polars_core::prelude::arity::unary_mut_with_options;

use super::*;

pub fn array_count_matches(ca: &ArrayChunked, value: AnyValue) -> PolarsResult<Series> {
let value = Series::new("", [value]);

let ca = ca.apply_to_inner(&|s| {
ChunkCompare::<&Series>::equal_missing(&s, &value).map(|ca| ca.into_series())
})?;
let out = count_boolean_bits(&ca);
Ok(out.into_series())
}

fn count_boolean_bits(ca: &ArrayChunked) -> IdxCa {
unary_mut_with_options(ca, |arr| {
let inner_arr = arr.values();
let mask = inner_arr.as_any().downcast_ref::<BooleanArray>().unwrap();
assert_eq!(mask.null_count(), 0);
let out = count_bits_set(mask.values(), arr.len(), arr.size());
IdxArr::from_data_default(out.into(), arr.validity().cloned())
})
}

fn count_bits_set(values: &Bitmap, len: usize, width: usize) -> Vec<IdxSize> {
// Fast path where all bits are either set or unset.
if values.unset_bits() == values.len() {
return vec![0 as IdxSize; len];
} else if values.unset_bits() == 0 {
return vec![width as IdxSize; len];
}

let (bits, bitmap_offset, _) = values.as_slice();

(0..len)
.map(|i| {
let set_ones = width - count_zeros(bits, bitmap_offset + i * width, width);
set_ones as IdxSize
})
.collect_trusted()
}
2 changes: 2 additions & 0 deletions crates/polars-ops/src/chunked_array/array/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#[cfg(feature = "array_any_all")]
mod any_all;
#[cfg(feature = "array_count")]
mod count;
mod get;
mod join;
mod min_max;
Expand Down
8 changes: 8 additions & 0 deletions crates/polars-ops/src/chunked_array/array/namespace.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use super::min_max::AggType;
use super::*;
#[cfg(feature = "array_count")]
use crate::chunked_array::array::count::array_count_matches;
use crate::chunked_array::array::sum_mean::sum_with_nulls;
#[cfg(feature = "array_any_all")]
use crate::prelude::array::any_all::{array_all, array_any};
Expand Down Expand Up @@ -104,6 +106,12 @@ pub trait ArrayNameSpace: AsArray {
let ca = self.as_array();
array_join(ca, separator).map(|ok| ok.into_series())
}

#[cfg(feature = "array_count")]
fn array_count_matches(&self, element: AnyValue) -> PolarsResult<Series> {
let ca = self.as_array();
array_count_matches(ca, element)
}
}

impl ArrayNameSpace for ArrayChunked {}
1 change: 1 addition & 0 deletions crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ object = ["polars-core/object"]
date_offset = ["polars-time", "chrono"]
list_gather = ["polars-ops/list_gather"]
list_count = ["polars-ops/list_count"]
array_count = ["polars-ops/array_count", "dtype-array"]
trigonometry = []
sign = []
timezones = ["chrono-tz", "polars-time/timezones", "polars-core/timezones", "regex"]
Expand Down
18 changes: 18 additions & 0 deletions crates/polars-plan/src/dsl/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,4 +114,22 @@ impl ArrayNameSpace {
false,
)
}

#[cfg(feature = "array_count")]
/// Count how often the value produced by ``element`` occurs.
pub fn count_matches<E: Into<Expr>>(self, element: E) -> Expr {
let other = element.into();

self.0
.map_many_private(
FunctionExpr::ArrayExpr(ArrayFunction::CountMatches),
&[other],
false,
false,
)
.with_function_options(|mut options| {
options.input_wildcard_expansion = true;
options
})
}
}
21 changes: 21 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ pub enum ArrayFunction {
Join,
#[cfg(feature = "is_in")]
Contains,
#[cfg(feature = "array_count")]
CountMatches,
}

impl ArrayFunction {
Expand All @@ -42,6 +44,8 @@ impl ArrayFunction {
Join => mapper.with_dtype(DataType::String),
#[cfg(feature = "is_in")]
Contains => mapper.with_dtype(DataType::Boolean),
#[cfg(feature = "array_count")]
CountMatches => mapper.with_dtype(IDX_DTYPE),
}
}
}
Expand Down Expand Up @@ -75,6 +79,8 @@ impl Display for ArrayFunction {
Join => "join",
#[cfg(feature = "is_in")]
Contains => "contains",
#[cfg(feature = "array_count")]
CountMatches => "count_matches",
};
write!(f, "arr.{name}")
}
Expand All @@ -101,6 +107,8 @@ impl From<ArrayFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
Join => map_as_slice!(join),
#[cfg(feature = "is_in")]
Contains => map_as_slice!(contains),
#[cfg(feature = "array_count")]
CountMatches => map_as_slice!(count_matches),
}
}
}
Expand Down Expand Up @@ -177,3 +185,16 @@ pub(super) fn contains(s: &[Series]) -> PolarsResult<Series> {
let item = &s[1];
Ok(is_in(item, array)?.with_name(array.name()).into_series())
}

#[cfg(feature = "array_count")]
pub(super) fn count_matches(args: &[Series]) -> PolarsResult<Series> {
let s = &args[0];
let element = &args[1];
polars_ensure!(
element.len() == 1,
ComputeError: "argument expression in `arr.count_matches` must produce exactly one element, got {}",
element.len()
);
let ca = s.array()?;
ca.array_count_matches(element.get(0).unwrap())
}
4 changes: 2 additions & 2 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ impl Display for ListFunction {
#[cfg(feature = "list_gather")]
Gather(_) => "gather",
#[cfg(feature = "list_count")]
CountMatches => "count",
CountMatches => "count_matches",
Sum => "sum",
Min => "min",
Max => "max",
Expand Down Expand Up @@ -459,7 +459,7 @@ pub(super) fn count_matches(args: &[Series]) -> PolarsResult<Series> {
let element = &args[1];
polars_ensure!(
element.len() == 1,
ComputeError: "argument expression in `arr.count` must produce exactly one element, got {}",
ComputeError: "argument expression in `list.count_matches` must produce exactly one element, got {}",
element.len()
);
let ca = s.list()?;
Expand Down
1 change: 1 addition & 0 deletions crates/polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ is_unique = ["polars-lazy?/is_unique", "polars-ops/is_unique"]
regex = ["polars-lazy?/regex"]
list_any_all = ["polars-lazy?/list_any_all"]
list_count = ["polars-ops/list_count", "polars-lazy?/list_count"]
array_count = ["polars-ops/array_count", "polars-lazy?/array_count", "dtype-array"]
list_drop_nulls = ["polars-lazy?/list_drop_nulls"]
list_eval = ["polars-lazy?/list_eval"]
list_gather = ["polars-ops/list_gather", "polars-lazy?/list_gather"]
Expand Down
2 changes: 2 additions & 0 deletions py-polars/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ cse = ["polars/cse"]
merge_sorted = ["polars/merge_sorted"]
list_gather = ["polars/list_gather"]
list_count = ["polars/list_count"]
array_count = ["polars/array_count", "polars/dtype-array"]
binary_encoding = ["polars/binary_encoding"]
list_sets = ["polars-lazy/list_sets"]
list_any_all = ["polars/list_any_all"]
Expand All @@ -163,6 +164,7 @@ dtypes = [

operations = [
"array_any_all",
"array_count",
"is_in",
"repeat_by",
"trigonometry",
Expand Down
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/expressions/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ The following methods are available under the `expr.arr` attribute.
Expr.arr.last
Expr.arr.join
Expr.arr.contains
Expr.arr.count_matches
1 change: 1 addition & 0 deletions py-polars/docs/source/reference/series/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ The following methods are available under the `Series.arr` attribute.
Series.arr.last
Series.arr.join
Series.arr.contains
Series.arr.count_matches
31 changes: 30 additions & 1 deletion py-polars/polars/expr/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from datetime import date, datetime, time

from polars import Expr
from polars.type_aliases import IntoExprColumn
from polars.type_aliases import IntoExpr, IntoExprColumn


class ExprArrayNameSpace:
Expand Down Expand Up @@ -507,3 +507,32 @@ def contains(
"""
item = parse_as_expression(item, str_as_lit=True)
return wrap_expr(self._pyexpr.arr_contains(item))

def count_matches(self, element: IntoExpr) -> Expr:
"""
Count how often the value produced by `element` occurs.

Parameters
----------
element
An expression that produces a single value

Examples
--------
>>> df = pl.DataFrame(
... {"a": [[1, 2], [1, 1], [2, 2]]}, schema={"a": pl.Array(pl.Int64, 2)}
... )
>>> df.with_columns(number_of_twos=pl.col("a").arr.count_matches(2))
shape: (3, 2)
┌───────────────┬────────────────┐
│ a ┆ number_of_twos │
│ --- ┆ --- │
│ array[i64, 2] ┆ u32 │
╞═══════════════╪════════════════╡
│ [1, 2] ┆ 1 │
│ [1, 1] ┆ 0 │
│ [2, 2] ┆ 2 │
└───────────────┴────────────────┘
"""
element = parse_as_expression(element, str_as_lit=True)
return wrap_expr(self._pyexpr.arr_count_matches(element))
24 changes: 23 additions & 1 deletion py-polars/polars/series/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from polars import Series
from polars.polars import PySeries
from polars.type_aliases import IntoExprColumn
from polars.type_aliases import IntoExpr, IntoExprColumn


@expr_dispatch
Expand Down Expand Up @@ -404,3 +404,25 @@ def contains(
]

"""

def count_matches(self, element: IntoExpr) -> Series:
"""
Count how often the value produced by `element` occurs.

Parameters
----------
element
An expression that produces a single value

Examples
--------
>>> s = pl.Series("a", [[1, 2, 3], [2, 2, 2]], dtype=pl.Array(pl.Int64, 3))
>>> s.arr.count_matches(2)
shape: (2,)
Series: 'a' [u32]
[
1
3
]

"""
5 changes: 2 additions & 3 deletions py-polars/polars/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from polars import Expr, Series
from polars.polars import PySeries
from polars.type_aliases import (
IntoExpr,
IntoExprColumn,
NullBehavior,
ToStructStrategy,
Expand Down Expand Up @@ -692,9 +693,7 @@ def explode(self) -> Series:
]
"""

def count_matches(
self, element: float | str | bool | int | date | datetime | time | Expr
) -> Expr:
def count_matches(self, element: IntoExpr) -> Series:
"""
Count how often the value produced by `element` occurs.

Expand Down
5 changes: 5 additions & 0 deletions py-polars/src/expr/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,4 +72,9 @@ impl PyExpr {
fn arr_contains(&self, other: PyExpr) -> Self {
self.inner.clone().arr().contains(other.inner).into()
}

#[cfg(feature = "array_count")]
fn arr_count_matches(&self, expr: PyExpr) -> Self {
self.inner.clone().arr().count_matches(expr.inner).into()
}
}
3 changes: 0 additions & 3 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,9 +304,6 @@ def test_list_count_matches() -> None:
assert pl.DataFrame({"listcol": [[], [1], [1, 2, 3, 2], [1, 2, 1], [4, 4]]}).select(
pl.col("listcol").list.count_matches(2).alias("number_of_twos")
).to_dict(as_series=False) == {"number_of_twos": [0, 0, 2, 1, 0]}
assert pl.DataFrame({"listcol": [[], [1], [1, 2, 3, 2], [1, 2, 1], [4, 4]]}).select(
stinodego marked this conversation as resolved.
Show resolved Hide resolved
pl.col("listcol").list.count_matches(2).alias("number_of_twos")
).to_dict(as_series=False) == {"number_of_twos": [0, 0, 2, 1, 0]}


def test_list_sum_and_dtypes() -> None:
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/namespaces/array/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,3 +269,21 @@ def test_array_contains_literal(
out = df.select(contains=pl.col("array").arr.contains(data)).to_series()
expected_series = pl.Series("contains", expected)
assert_series_equal(out, expected_series)


@pytest.mark.parametrize(
("arr", "data", "expected", "dtype"),
[
([[1, 2], [3, None], None], 1, [1, 0, None], pl.Int64),
([[True, False], [True, None], None], True, [1, 1, None], pl.Boolean),
([["a", "b"], ["c", None], None], "a", [1, 0, None], pl.String),
([[b"a", b"b"], [b"c", None], None], b"a", [1, 0, None], pl.Binary),
],
)
def test_array_count_matches(
arr: list[list[Any] | None], data: Any, expected: list[Any], dtype: pl.DataType
) -> None:
df = pl.DataFrame({"arr": arr}, schema={"arr": pl.Array(dtype, 2)})
out = df.select(count_matches=pl.col("arr").arr.count_matches(data))
print(out)
assert out.to_dict(as_series=False) == {"count_matches": expected}
Loading