Skip to content

Commit

Permalink
feat(api): more selectors
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Feb 27, 2023
1 parent b1b947a commit 5844304
Show file tree
Hide file tree
Showing 9 changed files with 1,012 additions and 78 deletions.
1 change: 1 addition & 0 deletions docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
* [Backend Operations Matrix](backends/support_matrix.md)
* [Releases](release_notes.md)
* Blog
* [Maximizing Productivity with Selectors](blog/selectors.md)
* [Ibis + DuckDB + Substrait](blog/ibis_substrait_to_duckdb.md)
* [Ibis v4.0.0](blog/ibis-version-4.0.0-release.md)
* [Analyzing Ibis's CI Data with Ibis](blog/rendered/ci-analysis.ipynb)
Expand Down
448 changes: 448 additions & 0 deletions docs/blog/selectors.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ibis/common/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def from_callable(cls, fn, validators=None, return_validator=None):

if validators is None:
validators = {}
elif isinstance(validators, list):
elif isinstance(validators, (list, tuple)):
# create a mapping of parameter name to validator
validators = dict(zip(sig.parameters.keys(), validators))
elif not isinstance(validators, dict):
Expand Down
3 changes: 0 additions & 3 deletions ibis/expr/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,9 +212,6 @@ def name_at_position(self, i: int) -> str:
>>> sch.name_at_position(1)
'b'
"""
upper = len(self.names) - 1
if not 0 <= i <= upper:
raise ValueError(f'Column index must be between 0 and {upper:d}, inclusive')
return self.names[i]

def apply_to(self, df: pd.DataFrame) -> pd.DataFrame:
Expand Down
203 changes: 184 additions & 19 deletions ibis/expr/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,31 +42,47 @@

from __future__ import annotations

import abc
import functools
import inspect
import operator
import re
from typing import Callable, Iterable, Sequence
from typing import Callable, Iterable, Mapping, Optional, Sequence, Union

from public import public
from typing_extensions import Annotated

import ibis.expr.datatypes as dt
import ibis.expr.rules as rlz
import ibis.expr.types as ir
from ibis import util
from ibis.common.annotations import attribute
from ibis.common.grounds import Concrete, Singleton
from ibis.common.validators import Coercible
from ibis.expr.deferred import Deferred


class Selector:
class Selector(Concrete):
"""A column selector."""

__slots__ = ("predicate",)

def __init__(self, predicate: Callable[[ir.Value], bool]) -> None:
"""Construct a `Selector` with `predicate`.
@abc.abstractmethod
def expand(self, table: ir.Table) -> Sequence[ir.Value]:
"""Expand `table` into a sequence of value expressions.
Parameters
----------
predicate
A callable that accepts an ibis value expression and returns a `bool`.
table
An ibis table expression
Returns
-------
Sequence[Value]
A sequence of value expressions
"""
self.predicate = predicate


class Predicate(Selector):
predicate: Callable[[ir.Value], bool]

def expand(self, table: ir.Table) -> Sequence[ir.Value]:
"""Evaluate `self.predicate` on every column of `table`.
Expand All @@ -78,7 +94,7 @@ def expand(self, table: ir.Table) -> Sequence[ir.Value]:
"""
return [col for column in table.columns if self.predicate(col := table[column])]

def __and__(self, other: Selector) -> Selector:
def __and__(self, other: Selector) -> Predicate:
"""Compute the conjunction of two `Selectors`.
Parameters
Expand All @@ -88,7 +104,7 @@ def __and__(self, other: Selector) -> Selector:
"""
return self.__class__(lambda col: self.predicate(col) and other.predicate(col))

def __or__(self, other: Selector) -> Selector:
def __or__(self, other: Selector) -> Predicate:
"""Compute the disjunction of two `Selectors`.
Parameters
Expand All @@ -98,13 +114,13 @@ def __or__(self, other: Selector) -> Selector:
"""
return self.__class__(lambda col: self.predicate(col) or other.predicate(col))

def __invert__(self) -> Selector:
def __invert__(self) -> Predicate:
"""Compute the logical negation of two `Selectors`."""
return self.__class__(lambda col: not self.predicate(col))


@public
def where(predicate: Callable[[ir.Value], bool]) -> Selector:
def where(predicate: Callable[[ir.Value], bool]) -> Predicate:
"""Return columns that satisfy `predicate`.
Use this selector when one of the other selectors does not meet your needs.
Expand All @@ -125,11 +141,11 @@ def where(predicate: Callable[[ir.Value], bool]) -> Selector:
selections:
a: r0.a
"""
return Selector(predicate)
return Predicate(predicate=predicate)


@public
def numeric() -> Selector:
def numeric() -> Predicate:
"""Return numeric columns.
Examples
Expand Down Expand Up @@ -159,7 +175,7 @@ def numeric() -> Selector:


@public
def of_type(dtype: dt.DataType | str | type[dt.DataType]) -> Selector:
def of_type(dtype: dt.DataType | str | type[dt.DataType]) -> Predicate:
"""Select columns of type `dtype`.
Parameters
Expand Down Expand Up @@ -194,7 +210,7 @@ def of_type(dtype: dt.DataType | str | type[dt.DataType]) -> Selector:


@public
def startswith(prefixes: str | tuple[str, ...]) -> Selector:
def startswith(prefixes: str | tuple[str, ...]) -> Predicate:
"""Select columns whose name starts with one of `prefixes`.
Parameters
Expand All @@ -215,7 +231,7 @@ def startswith(prefixes: str | tuple[str, ...]) -> Selector:


@public
def endswith(suffixes: str | tuple[str, ...]) -> Selector:
def endswith(suffixes: str | tuple[str, ...]) -> Predicate:
"""Select columns whose name ends with one of `suffixes`.
Parameters
Expand All @@ -233,7 +249,7 @@ def endswith(suffixes: str | tuple[str, ...]) -> Selector:
@public
def contains(
needles: str | tuple[str, ...], how: Callable[[Iterable[bool]], bool] = any
) -> Selector:
) -> Predicate:
"""Return columns whose name contains `needles`.
Parameters
Expand Down Expand Up @@ -284,3 +300,152 @@ def matches(regex: str | re.Pattern) -> Selector:
"""
pattern = re.compile(regex)
return where(lambda col: pattern.search(col.get_name()) is not None)


@public
def any_of(*predicates: Predicate) -> Predicate:
"""Include columns satisfying any of `predicates`."""
return functools.reduce(operator.or_, predicates)


@public
def all_of(*predicates: Predicate) -> Predicate:
"""Include columns satisfying all of `predicates`."""
return functools.reduce(operator.and_, predicates)


@public
def c(*names: str) -> Predicate:
"""Select specific column names."""
names = frozenset(names)
return where(lambda col: col.get_name() in names)


class Across(Selector):
selector: Selector
funcs: Union[
Deferred,
Callable[[ir.Value], ir.Value],
util.frozendict[Optional[str], Union[Deferred, Callable[[ir.Value], ir.Value]]],
]
names: Union[str, Callable[[str, Optional[str]], str]]

def expand(self, table: ir.Table) -> Sequence[ir.Value]:
expanded = []

names = self.names
cols = self.selector.expand(table)
for func_name, func in self.funcs.items():
resolve = func.resolve if isinstance(func, Deferred) else func
expanded.extend(
col.name(
names(orig_col.get_name(), func_name)
if callable(names)
else names.format(col=orig_col.get_name(), fn=func_name)
)
for col, orig_col in zip(map(resolve, cols), cols)
)

return expanded


@public
def across(selector, func, names=None) -> Across:
if names is None:
names = lambda col, fn: "_".join(filter(None, (col, fn)))
funcs = util.frozendict(func if isinstance(func, Mapping) else {None: func})
return Across(selector=selector, funcs=funcs, names=names)


class IfAnyAll(Selector):
selector: Selector
predicate: Union[Deferred, Callable[[ir.Value], ir.BooleanValue]]
summarizer: Callable[[ir.BooleanValue, ir.BooleanValue], ir.BooleanValue]

def expand(self, table: ir.Table) -> Sequence[ir.Value]:
func = self.predicate
resolve = func.resolve if isinstance(func, Deferred) else func
return [
functools.reduce(self.summarizer, map(resolve, self.selector.expand(table)))
]


@public
def if_any(selector: Selector, predicate: Deferred | Callable) -> IfAnyAll:
return IfAnyAll(selector=selector, predicate=predicate, summarizer=operator.or_)


@public
def if_all(selector: Selector, predicate: Deferred | Callable) -> IfAnyAll:
return IfAnyAll(selector=selector, predicate=predicate, summarizer=operator.and_)


class HashableSlice(Concrete, Coercible):
slice: slice

@classmethod
def __coerce__(cls, slice):
return cls(slice)

@property
def start(self):
return self.slice.start

@property
def stop(self):
return self.slice.stop

@property
def step(self):
return self.slice.step

@attribute.default
def __precomputed_hash__(self) -> int:
return hash((self.__class__, (self.start, self.stop, self.step)))


class RangeSelector(Selector):
key: Union[str, int, Annotated[slice, rlz.coerced_to(HashableSlice)]]

def expand(self, table: ir.Table) -> Sequence[ir.Value]:
key = self.key

if isinstance(key, (str, int)):
return [table[key]]

start = key.start or 0
stop = key.stop or len(table.columns)
step = key.step or 1

schema = table.schema()

if isinstance(start, str):
start = schema._name_locs[start]

if isinstance(stop, str):
stop = schema._name_locs[stop]

return [table[i] for i in range(start, stop, step)]


class Sliceable(Singleton):
def __getitem__(self, key: str | int | slice):
return RangeSelector(key=key)


r = Sliceable()


@public
def first() -> Selector:
return r[0]


@public
def last() -> Selector:
return r[-1]


@public
def all() -> Predicate:
return r[:]
Loading

0 comments on commit 5844304

Please sign in to comment.