diff --git a/ibis/selectors.py b/ibis/selectors.py index e2b39540b1c6..b7764931a5d2 100644 --- a/ibis/selectors.py +++ b/ibis/selectors.py @@ -399,22 +399,25 @@ def expand(self, table: ir.Table) -> Sequence[ir.Value]: @public def across( - selector: Selector, + selector: Selector | Iterable[str] | str, func: Deferred | Callable[[ir.Value], ir.Value] | Mapping[str | None, Deferred | Callable[[ir.Value], ir.Value]], names: str | Callable[[str, str | None], str] | None = None, ) -> Across: - """Applies the same data transformation function across multiple columns. + """Applies data transformations across multiple columns. Parameters ---------- selector - An expression that selects columns on which the transformation function will be applied. + An expression that selects columns on which the transformation function + will be applied, an iterable of `str` column names or a single `str` + column name. func A function (or a dictionary of functions) to use to transform the data. names - A lambda function or a format string to name the columns created by the transformation function. + A lambda function or a format string to name the columns created by the + transformation function. Returns ------- @@ -455,6 +458,8 @@ def across( if names is None: names = lambda col, fn: "_".join(filter(None, (col, fn))) funcs = frozendict(func if isinstance(func, Mapping) else {None: func}) + if not isinstance(selector, Selector): + selector = c(*util.promote_list(selector)) return Across(selector=selector, funcs=funcs, names=names) diff --git a/ibis/tests/expr/test_selectors.py b/ibis/tests/expr/test_selectors.py index 46d7dca98f1e..2955886d7ed9 100644 --- a/ibis/tests/expr/test_selectors.py +++ b/ibis/tests/expr/test_selectors.py @@ -301,6 +301,18 @@ def test_across_group_by_agg_with_grouped_selectors(penguins, expr_func): assert expr.equals(expected) +def test_across_list(penguins): + expr = penguins.agg(s.across(["species", "island"], lambda c: c.count())) + expected = penguins.agg(species=_.species.count(), island=_.island.count()) + assert expr.equals(expected) + + +def test_across_str(penguins): + expr = penguins.agg(s.across("species", lambda c: c.count())) + expected = penguins.agg(species=_.species.count()) + assert expr.equals(expected) + + def test_if_all(penguins): expr = penguins.filter(s.if_all(s.numeric() & ~s.c("year"), _ > 5)) expected = penguins.filter(