Skip to content

Commit

Permalink
feat(api): add selectors for easier selection of columns
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Jan 24, 2023
1 parent ff34c7b commit 306bc88
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 1 deletion.
2 changes: 2 additions & 0 deletions ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis.backends.base import BaseBackend, connect
from ibis.expr import selectors
from ibis.expr.decompile import decompile
from ibis.expr.deferred import Deferred
from ibis.expr.schema import Schema
Expand Down Expand Up @@ -146,6 +147,7 @@
'rows_with_max_lookback',
'schema',
'Schema',
'selectors',
'sequence',
'set_backend',
'show_sql',
Expand Down
120 changes: 120 additions & 0 deletions ibis/expr/selectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""Construct column selectors."""

from __future__ import annotations

import re
from typing import Callable, Iterable, Sequence

import ibis.expr.datatypes as dt
import ibis.expr.types as ir
from ibis import util


class Selector:
__slots__ = ("predicate",)

def __init__(self, predicate: Callable[[ir.Column], bool]) -> None:
"""Construct a `Selector` with `predicate`."""
self.predicate = predicate

def expand(self, table: ir.Table) -> Sequence[ir.Column]:
"""Evaluate `self.predicate` on every column of `table`."""
return [col for column in table.columns if self.predicate(col := table[column])]

def __and__(self, other: Selector) -> Selector:
"""Compute the conjunction of two `Selectors`."""
return self.__class__(lambda col: self.predicate(col) and other.predicate(col))

def __or__(self, other: Selector) -> Selector:
"""Compute the disjunction of two `Selectors`."""
return self.__class__(lambda col: self.predicate(col) or other.predicate(col))

def __invert__(self) -> Selector:
"""Compute the logical negation of two `Selectors`."""
return self.__class__(lambda col: not self.predicate(col))


def where(predicate: Callable[[ir.Value], bool]) -> Selector:
"""Return columns that satisfy `predicate`.
Examples
--------
>>> t = ibis.table(dict(a="float32"), name="t")
>>> t.select(s.where(lambda col: col.get_name() == "a"))
r0 := UnboundTable: t
a float32
Selection[r0]
selections:
a: r0.a
"""
return Selector(predicate)


def numeric() -> Selector:
"""Return numeric columns.
Examples
--------
>>> import ibis.selectors as s
>>> t = ibis.table(dict(a="int", b="string", c="array<string>"), name="t")
>>> t
r0 := UnboundTable: t
a int64
b string
c array<string>
>>> t.select(s.numeric()) # `a` has integer type, so it's numeric
r0 := UnboundTable: t
a int64
b string
c array<string>
Selection[r0]
selections:
a: r0.a
"""
return Selector(lambda col: col.type().is_numeric())


def of_type(dtype: dt.DataType | str | type[dt.DataType]) -> Selector:
"""Select columns of type `dtype`."""
if isinstance(dtype, type):
predicate = lambda col, dtype=dtype: isinstance(col.type(), dtype)
else:
dtype = dt.dtype(dtype)
predicate = lambda col, dtype=dtype: col.type() == dtype
return where(predicate)


def startswith(prefixes: str | tuple[str, ...]) -> Selector:
"""Select columns whose name starts with one of `prefixes`."""
return where(lambda col, prefixes=prefixes: col.get_name().startswith(prefixes))


def endswith(suffixes: str | tuple[str, ...]) -> Selector:
"""Select columns whose name ends with one of `suffixes`."""
return where(lambda col, suffixes=suffixes: col.get_name().endswith(suffixes))


def contains(
needles: str | tuple[str, ...], how: Callable[[Iterable[bool]], bool] = any
) -> Selector:
"""Return columns whose name contains `needles`."""

def predicate(
col: ir.Column,
needles: str | tuple[str, ...] = needles,
how: Callable[[Iterable[bool]], bool] = how,
) -> bool:
name = col.get_name()
return how(needle in name for needle in util.promote_list(needles))

return where(predicate)


def matches(regex: str | re.Pattern) -> Selector:
"""Return columns matching the regular expression `regex`."""
pattern = re.compile(regex)
return where(
lambda col, pattern=pattern: pattern.search(col.get_name()) is not None
)
6 changes: 5 additions & 1 deletion ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import ibis.expr.operations as ops
from ibis import util
from ibis.expr.deferred import Deferred
from ibis.expr.selectors import Selector
from ibis.expr.types.core import Expr

if TYPE_CHECKING:
Expand Down Expand Up @@ -718,7 +719,10 @@ def select(

exprs = list(
itertools.chain(
itertools.chain.from_iterable(map(util.promote_list, exprs)),
itertools.chain.from_iterable(
util.promote_list(e.expand(self) if isinstance(e, Selector) else e)
for e in exprs
),
(
self._ensure_expr(expr).name(name)
for name, expr in named_exprs.items()
Expand Down
92 changes: 92 additions & 0 deletions ibis/tests/expr/test_selectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import re

import pytest

import ibis
import ibis.common.exceptions as exc
import ibis.expr.datatypes as dt
from ibis import selectors as s


@pytest.fixture
def t():
return ibis.table(
dict(
a="int",
b="string",
c="array<string>",
d="struct<a: array<map<string, array<float>>>>",
e="float",
f="decimal(3, 1)",
g="array<array<map<float, float>>>",
ga="string",
),
name="t",
)


@pytest.mark.parametrize(
"sel",
[s.where(lambda _: False), s.startswith("X"), s.endswith("🙂")],
ids=["false", "startswith", "endswith"],
)
def test_empty_selection(t, sel):
with pytest.raises(exc.IbisError):
t.select(sel)


def test_where(t):
assert t.select(s.where(lambda _: True)).equals(t.select(*t.columns))


def test_numeric(t):
assert t.select(s.numeric()).equals(t.select("a", "e", "f"))


@pytest.mark.parametrize(
("obj", "expected"),
[(dt.Array, ("c", "g")), ("float", ("e",)), (dt.Decimal(3, 1), ("f",))],
ids=["type", "string", "instance"],
)
def test_of_type(t, obj, expected):
assert t.select(s.of_type(obj)).equals(t.select(*expected))


@pytest.mark.parametrize(
("prefixes", "expected"),
[("a", ("a",)), (("a", "e"), ("a", "e"))],
ids=["string", "tuple"],
)
def test_startswith(t, prefixes, expected):
assert t.select(s.startswith(prefixes)).equals(t.select(*expected))


def test_endswith(t):
assert t.select(s.endswith(("a", "d"))).equals(t.select("a", "d", "ga"))


def test_contains(t):
assert t.select(s.contains("a")).equals(t.select("a", "ga"))


@pytest.mark.parametrize(
("rx", "expected"),
[("e|f", ("e", "f")), (re.compile("e|f"), ("e", "f"))],
ids=["string", "pattern"],
)
def test_matches(t, rx, expected):
assert t.select(s.matches(rx)).equals(t.select(expected))


def test_compose_or(t):
assert t.select(s.contains("a") | s.startswith("d")).equals(
t.select("a", "d", "ga")
)


def test_compose_and(t):
assert t.select(s.contains("a") & s.contains("g")).equals(t.select("ga"))


def test_compose_not(t):
assert t.select(~s.numeric()).equals(t.select("b", "c", "d", "g", "ga"))

0 comments on commit 306bc88

Please sign in to comment.