-
Notifications
You must be signed in to change notification settings - Fork 609
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(api): add selectors for easier selection of columns
- Loading branch information
Showing
4 changed files
with
219 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
"""Construct column selectors.""" | ||
|
||
from __future__ import annotations | ||
|
||
import re | ||
from typing import Callable, Iterable, Sequence | ||
|
||
import ibis.expr.datatypes as dt | ||
import ibis.expr.types as ir | ||
from ibis import util | ||
|
||
|
||
class Selector: | ||
__slots__ = ("predicate",) | ||
|
||
def __init__(self, predicate: Callable[[ir.Column], bool]) -> None: | ||
"""Construct a `Selector` with `predicate`.""" | ||
self.predicate = predicate | ||
|
||
def expand(self, table: ir.Table) -> Sequence[ir.Column]: | ||
"""Evaluate `self.predicate` on every column of `table`.""" | ||
return [col for column in table.columns if self.predicate(col := table[column])] | ||
|
||
def __and__(self, other: Selector) -> Selector: | ||
"""Compute the conjunction of two `Selectors`.""" | ||
return self.__class__(lambda col: self.predicate(col) and other.predicate(col)) | ||
|
||
def __or__(self, other: Selector) -> Selector: | ||
"""Compute the disjunction of two `Selectors`.""" | ||
return self.__class__(lambda col: self.predicate(col) or other.predicate(col)) | ||
|
||
def __invert__(self) -> Selector: | ||
"""Compute the logical negation of two `Selectors`.""" | ||
return self.__class__(lambda col: not self.predicate(col)) | ||
|
||
|
||
def where(predicate: Callable[[ir.Value], bool]) -> Selector: | ||
"""Return columns that satisfy `predicate`. | ||
Examples | ||
-------- | ||
>>> t = ibis.table(dict(a="float32"), name="t") | ||
>>> t.select(s.where(lambda col: col.get_name() == "a")) | ||
r0 := UnboundTable: t | ||
a float32 | ||
Selection[r0] | ||
selections: | ||
a: r0.a | ||
""" | ||
return Selector(predicate) | ||
|
||
|
||
def numeric() -> Selector: | ||
"""Return numeric columns. | ||
Examples | ||
-------- | ||
>>> import ibis.selectors as s | ||
>>> t = ibis.table(dict(a="int", b="string", c="array<string>"), name="t") | ||
>>> t | ||
r0 := UnboundTable: t | ||
a int64 | ||
b string | ||
c array<string> | ||
>>> t.select(s.numeric()) # `a` has integer type, so it's numeric | ||
r0 := UnboundTable: t | ||
a int64 | ||
b string | ||
c array<string> | ||
Selection[r0] | ||
selections: | ||
a: r0.a | ||
""" | ||
return Selector(lambda col: col.type().is_numeric()) | ||
|
||
|
||
def of_type(dtype: dt.DataType | str | type[dt.DataType]) -> Selector: | ||
"""Select columns of type `dtype`.""" | ||
if isinstance(dtype, type): | ||
predicate = lambda col, dtype=dtype: isinstance(col.type(), dtype) | ||
else: | ||
dtype = dt.dtype(dtype) | ||
predicate = lambda col, dtype=dtype: col.type() == dtype | ||
return where(predicate) | ||
|
||
|
||
def startswith(prefixes: str | tuple[str, ...]) -> Selector: | ||
"""Select columns whose name starts with one of `prefixes`.""" | ||
return where(lambda col, prefixes=prefixes: col.get_name().startswith(prefixes)) | ||
|
||
|
||
def endswith(suffixes: str | tuple[str, ...]) -> Selector: | ||
"""Select columns whose name ends with one of `suffixes`.""" | ||
return where(lambda col, suffixes=suffixes: col.get_name().endswith(suffixes)) | ||
|
||
|
||
def contains( | ||
needles: str | tuple[str, ...], how: Callable[[Iterable[bool]], bool] = any | ||
) -> Selector: | ||
"""Return columns whose name contains `needles`.""" | ||
|
||
def predicate( | ||
col: ir.Column, | ||
needles: str | tuple[str, ...] = needles, | ||
how: Callable[[Iterable[bool]], bool] = how, | ||
) -> bool: | ||
name = col.get_name() | ||
return how(needle in name for needle in util.promote_list(needles)) | ||
|
||
return where(predicate) | ||
|
||
|
||
def matches(regex: str | re.Pattern) -> Selector: | ||
"""Return columns matching the regular expression `regex`.""" | ||
pattern = re.compile(regex) | ||
return where( | ||
lambda col, pattern=pattern: pattern.search(col.get_name()) is not None | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
import re | ||
|
||
import pytest | ||
|
||
import ibis | ||
import ibis.common.exceptions as exc | ||
import ibis.expr.datatypes as dt | ||
from ibis import selectors as s | ||
|
||
|
||
@pytest.fixture | ||
def t(): | ||
return ibis.table( | ||
dict( | ||
a="int", | ||
b="string", | ||
c="array<string>", | ||
d="struct<a: array<map<string, array<float>>>>", | ||
e="float", | ||
f="decimal(3, 1)", | ||
g="array<array<map<float, float>>>", | ||
ga="string", | ||
), | ||
name="t", | ||
) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"sel", | ||
[s.where(lambda _: False), s.startswith("X"), s.endswith("🙂")], | ||
ids=["false", "startswith", "endswith"], | ||
) | ||
def test_empty_selection(t, sel): | ||
with pytest.raises(exc.IbisError): | ||
t.select(sel) | ||
|
||
|
||
def test_where(t): | ||
assert t.select(s.where(lambda _: True)).equals(t.select(*t.columns)) | ||
|
||
|
||
def test_numeric(t): | ||
assert t.select(s.numeric()).equals(t.select("a", "e", "f")) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("obj", "expected"), | ||
[(dt.Array, ("c", "g")), ("float", ("e",)), (dt.Decimal(3, 1), ("f",))], | ||
ids=["type", "string", "instance"], | ||
) | ||
def test_of_type(t, obj, expected): | ||
assert t.select(s.of_type(obj)).equals(t.select(*expected)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("prefixes", "expected"), | ||
[("a", ("a",)), (("a", "e"), ("a", "e"))], | ||
ids=["string", "tuple"], | ||
) | ||
def test_startswith(t, prefixes, expected): | ||
assert t.select(s.startswith(prefixes)).equals(t.select(*expected)) | ||
|
||
|
||
def test_endswith(t): | ||
assert t.select(s.endswith(("a", "d"))).equals(t.select("a", "d", "ga")) | ||
|
||
|
||
def test_contains(t): | ||
assert t.select(s.contains("a")).equals(t.select("a", "ga")) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("rx", "expected"), | ||
[("e|f", ("e", "f")), (re.compile("e|f"), ("e", "f"))], | ||
ids=["string", "pattern"], | ||
) | ||
def test_matches(t, rx, expected): | ||
assert t.select(s.matches(rx)).equals(t.select(expected)) | ||
|
||
|
||
def test_compose_or(t): | ||
assert t.select(s.contains("a") | s.startswith("d")).equals( | ||
t.select("a", "d", "ga") | ||
) | ||
|
||
|
||
def test_compose_and(t): | ||
assert t.select(s.contains("a") & s.contains("g")).equals(t.select("ga")) | ||
|
||
|
||
def test_compose_not(t): | ||
assert t.select(~s.numeric()).equals(t.select("b", "c", "d", "g", "ga")) |