From 5e11529d02a8f858353f4df958924d20a5cb2add Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 9 Aug 2023 09:45:53 -0400 Subject: [PATCH] feat(ux): promote lists of strings to `any_of` selectors --- ibis/expr/types/relations.py | 14 ++++++-------- ibis/selectors.py | 17 +++++++++++------ ibis/tests/expr/test_selectors.py | 15 +++++++++++++++ 3 files changed, 32 insertions(+), 14 deletions(-) diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index b221df4c0925..a89d8e35a656 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -1069,8 +1069,7 @@ def distinct( ) return ops.Distinct(self).to_expr() - if not isinstance(on, s.Selector): - on = s.c(*util.promote_list(on)) + on = s._to_selector(on) if keep is None: having = lambda t: t.count() == 1 @@ -1936,8 +1935,7 @@ def drop(self, *fields: str | Selector) -> Table: ): raise KeyError(f"Fields not in table: {sorted(missing_fields)}") - sels = (s.c(f) if isinstance(f, str) else f for f in fields) - return self.select(~s.any_of(*sels)) + return self.select(~s._to_selector(fields)) def filter( self, @@ -3111,7 +3109,7 @@ def pivot_longer( """ import ibis.selectors as s - pivot_sel = s.c(col) if isinstance(col, str) else col + pivot_sel = s._to_selector(col) pivot_cols = pivot_sel.expand(self) if not pivot_cols: @@ -3334,7 +3332,7 @@ def pivot_wider( >>> us_rent_income.pivot_wider( ... names_from="variable", ... names_sep=".", - ... values_from=s.c("estimate", "moe"), + ... values_from=("estimate", "moe"), ... ) ┏━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━┓ ┃ geoid ┃ name ┃ estimate.income ┃ moe.income ┃ … ┃ @@ -3540,8 +3538,8 @@ def pivot_wider( orig_names_from = util.promote_list(names_from) - names_from = s.any_of(*map(s._to_selector, orig_names_from)) - values_from = s.any_of(*map(s._to_selector, util.promote_list(values_from))) + names_from = s._to_selector(orig_names_from) + values_from = s._to_selector(values_from) if id_cols is None: id_cols = ~(names_from | values_from) diff --git a/ibis/selectors.py b/ibis/selectors.py index e73c5b031a8d..508da06700c8 100644 --- a/ibis/selectors.py +++ b/ibis/selectors.py @@ -352,15 +352,15 @@ def matches(regex: str | re.Pattern) -> Selector: @public -def any_of(*predicates: Predicate) -> Predicate: +def any_of(*predicates: str | Predicate) -> Predicate: """Include columns satisfying any of `predicates`.""" - return functools.reduce(operator.or_, predicates) + return functools.reduce(operator.or_, map(_to_selector, predicates)) @public -def all_of(*predicates: Predicate) -> Predicate: +def all_of(*predicates: str | Predicate) -> Predicate: """Include columns satisfying all of `predicates`.""" - return functools.reduce(operator.and_, predicates) + return functools.reduce(operator.and_, map(_to_selector, predicates)) @public @@ -654,6 +654,11 @@ def all() -> Predicate: return r[:] -def _to_selector(obj: str | Selector) -> Selector: +def _to_selector(obj: str | Selector | Sequence[str | Selector]) -> Selector: """Convert an object to a `Selector`.""" - return c(obj) if isinstance(obj, str) else obj + if isinstance(obj, Selector): + return obj + elif isinstance(obj, str): + return c(obj) + else: + return any_of(*obj) diff --git a/ibis/tests/expr/test_selectors.py b/ibis/tests/expr/test_selectors.py index 72df2adb9e7e..4081301dee08 100644 --- a/ibis/tests/expr/test_selectors.py +++ b/ibis/tests/expr/test_selectors.py @@ -432,7 +432,22 @@ def test_all_of(penguins): assert expr.equals(expected) +def test_all_of_string_list(penguins): + # a bit silly, but robust nonetheless + expr = penguins.select(s.all_of("year", "year")) + expected = penguins.select("year") + assert expr.equals(expected) + + def test_any_of(penguins): expr = penguins.select(s.any_of(s.startswith("bill"), s.c("year"))) expected = penguins.select("bill_length_mm", "bill_depth_mm", "year") assert expr.equals(expected) + + +def test_any_of_string_list(penguins): + expr = penguins.select(s.any_of("year", "body_mass_g", s.matches("length"))) + expected = penguins.select( + "bill_length_mm", "flipper_length_mm", "body_mass_g", "year" + ) + assert expr.equals(expected)