diff --git a/ibis/expr/api.py b/ibis/expr/api.py index e3a05fce6f13..5b7a2232094f 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -21,6 +21,7 @@ import ibis.expr.schema as sch import ibis.expr.types as ir from ibis.backends.base import BaseBackend, connect +from ibis.expr import selectors from ibis.expr.decompile import decompile from ibis.expr.deferred import Deferred from ibis.expr.schema import Schema @@ -146,6 +147,7 @@ 'rows_with_max_lookback', 'schema', 'Schema', + 'selectors', 'sequence', 'set_backend', 'show_sql', diff --git a/ibis/expr/selectors.py b/ibis/expr/selectors.py new file mode 100644 index 000000000000..606f16bc7008 --- /dev/null +++ b/ibis/expr/selectors.py @@ -0,0 +1,120 @@ +"""Construct column selectors.""" + +from __future__ import annotations + +import re +from typing import Callable, Iterable, Sequence + +import ibis.expr.datatypes as dt +import ibis.expr.types as ir +from ibis import util + + +class Selector: + __slots__ = ("predicate",) + + def __init__(self, predicate: Callable[[ir.Column], bool]) -> None: + """Construct a `Selector` with `predicate`.""" + self.predicate = predicate + + def expand(self, table: ir.Table) -> Sequence[ir.Column]: + """Evaluate `self.predicate` on every column of `table`.""" + return [col for column in table.columns if self.predicate(col := table[column])] + + def __and__(self, other: Selector) -> Selector: + """Compute the conjunction of two `Selectors`.""" + return self.__class__(lambda col: self.predicate(col) and other.predicate(col)) + + def __or__(self, other: Selector) -> Selector: + """Compute the disjunction of two `Selectors`.""" + return self.__class__(lambda col: self.predicate(col) or other.predicate(col)) + + def __invert__(self) -> Selector: + """Compute the logical negation of two `Selectors`.""" + return self.__class__(lambda col: not self.predicate(col)) + + +def where(predicate: Callable[[ir.Value], bool]) -> Selector: + """Return columns that satisfy `predicate`. + + Examples + -------- + >>> t = ibis.table(dict(a="float32"), name="t") + >>> t.select(s.where(lambda col: col.get_name() == "a")) + r0 := UnboundTable: t + a float32 + + Selection[r0] + selections: + a: r0.a + """ + return Selector(predicate) + + +def numeric() -> Selector: + """Return numeric columns. + + Examples + -------- + >>> import ibis.selectors as s + >>> t = ibis.table(dict(a="int", b="string", c="array"), name="t") + >>> t + r0 := UnboundTable: t + a int64 + b string + c array + >>> t.select(s.numeric()) # `a` has integer type, so it's numeric + r0 := UnboundTable: t + a int64 + b string + c array + + Selection[r0] + selections: + a: r0.a + """ + return Selector(lambda col: col.type().is_numeric()) + + +def of_type(dtype: dt.DataType | str | type[dt.DataType]) -> Selector: + """Select columns of type `dtype`.""" + if isinstance(dtype, type): + predicate = lambda col, dtype=dtype: isinstance(col.type(), dtype) + else: + dtype = dt.dtype(dtype) + predicate = lambda col, dtype=dtype: col.type() == dtype + return where(predicate) + + +def startswith(prefixes: str | tuple[str, ...]) -> Selector: + """Select columns whose name starts with one of `prefixes`.""" + return where(lambda col, prefixes=prefixes: col.get_name().startswith(prefixes)) + + +def endswith(suffixes: str | tuple[str, ...]) -> Selector: + """Select columns whose name ends with one of `suffixes`.""" + return where(lambda col, suffixes=suffixes: col.get_name().endswith(suffixes)) + + +def contains( + needles: str | tuple[str, ...], how: Callable[[Iterable[bool]], bool] = any +) -> Selector: + """Return columns whose name contains `needles`.""" + + def predicate( + col: ir.Column, + needles: str | tuple[str, ...] = needles, + how: Callable[[Iterable[bool]], bool] = how, + ) -> bool: + name = col.get_name() + return how(needle in name for needle in util.promote_list(needles)) + + return where(predicate) + + +def matches(regex: str | re.Pattern) -> Selector: + """Return columns matching the regular expression `regex`.""" + pattern = re.compile(regex) + return where( + lambda col, pattern=pattern: pattern.search(col.get_name()) is not None + ) diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index bf2cfdd213f9..38f3ae2b6a62 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -26,6 +26,7 @@ import ibis.expr.operations as ops from ibis import util from ibis.expr.deferred import Deferred +from ibis.expr.selectors import Selector from ibis.expr.types.core import Expr if TYPE_CHECKING: @@ -718,7 +719,10 @@ def select( exprs = list( itertools.chain( - itertools.chain.from_iterable(map(util.promote_list, exprs)), + itertools.chain.from_iterable( + util.promote_list(e.expand(self) if isinstance(e, Selector) else e) + for e in exprs + ), ( self._ensure_expr(expr).name(name) for name, expr in named_exprs.items() diff --git a/ibis/tests/expr/test_selectors.py b/ibis/tests/expr/test_selectors.py new file mode 100644 index 000000000000..92c854c9650f --- /dev/null +++ b/ibis/tests/expr/test_selectors.py @@ -0,0 +1,92 @@ +import re + +import pytest + +import ibis +import ibis.common.exceptions as exc +import ibis.expr.datatypes as dt +from ibis import selectors as s + + +@pytest.fixture +def t(): + return ibis.table( + dict( + a="int", + b="string", + c="array", + d="struct>>>", + e="float", + f="decimal(3, 1)", + g="array>>", + ga="string", + ), + name="t", + ) + + +@pytest.mark.parametrize( + "sel", + [s.where(lambda _: False), s.startswith("X"), s.endswith("🙂")], + ids=["false", "startswith", "endswith"], +) +def test_empty_selection(t, sel): + with pytest.raises(exc.IbisError): + t.select(sel) + + +def test_where(t): + assert t.select(s.where(lambda _: True)).equals(t.select(*t.columns)) + + +def test_numeric(t): + assert t.select(s.numeric()).equals(t.select("a", "e", "f")) + + +@pytest.mark.parametrize( + ("obj", "expected"), + [(dt.Array, ("c", "g")), ("float", ("e",)), (dt.Decimal(3, 1), ("f",))], + ids=["type", "string", "instance"], +) +def test_of_type(t, obj, expected): + assert t.select(s.of_type(obj)).equals(t.select(*expected)) + + +@pytest.mark.parametrize( + ("prefixes", "expected"), + [("a", ("a",)), (("a", "e"), ("a", "e"))], + ids=["string", "tuple"], +) +def test_startswith(t, prefixes, expected): + assert t.select(s.startswith(prefixes)).equals(t.select(*expected)) + + +def test_endswith(t): + assert t.select(s.endswith(("a", "d"))).equals(t.select("a", "d", "ga")) + + +def test_contains(t): + assert t.select(s.contains("a")).equals(t.select("a", "ga")) + + +@pytest.mark.parametrize( + ("rx", "expected"), + [("e|f", ("e", "f")), (re.compile("e|f"), ("e", "f"))], + ids=["string", "pattern"], +) +def test_matches(t, rx, expected): + assert t.select(s.matches(rx)).equals(t.select(expected)) + + +def test_compose_or(t): + assert t.select(s.contains("a") | s.startswith("d")).equals( + t.select("a", "d", "ga") + ) + + +def test_compose_and(t): + assert t.select(s.contains("a") & s.contains("g")).equals(t.select("ga")) + + +def test_compose_not(t): + assert t.select(~s.numeric()).equals(t.select("b", "c", "d", "g", "ga"))