Skip to content

Commit

Permalink
refactor(selectors): remove janky Predicate class and unify `Select…
Browse files Browse the repository at this point in the history
…or`s under a single interface (#9917)
  • Loading branch information
cpcloud authored Aug 27, 2024
1 parent fccd7ed commit c15a229
Show file tree
Hide file tree
Showing 4 changed files with 316 additions and 136 deletions.
86 changes: 84 additions & 2 deletions ibis/common/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@
import abc
from typing import TYPE_CHECKING

from ibis.common.bases import Abstract
from ibis.common.grounds import Concrete
from ibis.common.typing import VarTuple # noqa: TCH001

if TYPE_CHECKING:
from collections.abc import Sequence

import ibis.expr.types as ir


class Selector(Concrete):
"""A column selector."""
class Expandable(Abstract):
__slots__ = ()

@abc.abstractmethod
def expand(self, table: ir.Table) -> Sequence[ir.Value]:
Expand All @@ -29,3 +31,83 @@ def expand(self, table: ir.Table) -> Sequence[ir.Value]:
A sequence of value expressions that match the selector
"""


class Selector(Concrete, Expandable):
"""A column selector."""

@abc.abstractmethod
def expand_names(self, table: ir.Table) -> frozenset[str]:
"""Compute the set of column names that match the selector."""

def expand(self, table: ir.Table) -> Sequence[ir.Value]:
names = self.expand_names(table)
return list(map(table.__getitem__, filter(names.__contains__, table.columns)))

def __and__(self, other: Selector) -> Selector:
"""Compute the logical conjunction of two `Selector`s.
Parameters
----------
other
Another selector
"""
if not isinstance(other, Selector):
return NotImplemented
return And(self, other)

def __or__(self, other: Selector) -> Selector:
"""Compute the logical disjunction of two `Selector`s.
Parameters
----------
other
Another selector
"""
if not isinstance(other, Selector):
return NotImplemented
return Or(self, other)

def __invert__(self) -> Selector:
"""Compute the logical negation of a `Selector`."""
return Not(self)


class Or(Selector):
left: Selector
right: Selector

def expand_names(self, table: ir.Table) -> frozenset[str]:
return self.left.expand_names(table) | self.right.expand_names(table)


class And(Selector):
left: Selector
right: Selector

def expand_names(self, table: ir.Table) -> frozenset[str]:
return self.left.expand_names(table) & self.right.expand_names(table)


class Any(Selector):
selectors: VarTuple[Selector]

def expand_names(self, table: ir.Table) -> frozenset[str]:
names = (selector.expand_names(table) for selector in self.selectors)
return frozenset.union(*names)


class All(Selector):
selectors: VarTuple[Selector]

def expand_names(self, table: ir.Table) -> frozenset[str]:
names = (selector.expand_names(table) for selector in self.selectors)
return frozenset.intersection(*names)


class Not(Selector):
selector: Selector

def expand_names(self, table: ir.Table) -> frozenset[str]:
names = self.selector.expand_names(table)
return frozenset(col for col in table.columns if col not in names)
4 changes: 2 additions & 2 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import ibis.expr.schema as sch
from ibis import util
from ibis.common.deferred import Deferred, Resolver
from ibis.common.selectors import Selector
from ibis.common.selectors import Expandable, Selector
from ibis.expr.rewrites import DerefMap
from ibis.expr.types.core import Expr, _FixedTextJupyterMixin
from ibis.expr.types.generic import Value, literal
Expand Down Expand Up @@ -108,7 +108,7 @@ def bind(table: Table, value) -> Iterator[ir.Value]:
yield value.resolve(table)
elif isinstance(value, Resolver):
yield value.resolve({"_": table})
elif isinstance(value, Selector):
elif isinstance(value, Expandable):
yield from value.expand(table)
elif callable(value):
# rebind, otherwise the callable is required to return an expression
Expand Down
Loading

0 comments on commit c15a229

Please sign in to comment.