Skip to content

Commit

Permalink
refactor(selectors): get rid of predicate jank
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Aug 25, 2024
1 parent 949fbea commit 1aa3a35
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 124 deletions.
73 changes: 72 additions & 1 deletion ibis/common/selectors.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from __future__ import annotations

import abc
from collections.abc import Sequence
from typing import TYPE_CHECKING

from ibis.common.grounds import Concrete
from ibis.common.typing import VarTuple # noqa: TCH001

if TYPE_CHECKING:
from collections.abc import Sequence
Expand All @@ -14,7 +16,6 @@
class Selector(Concrete):
"""A column selector."""

@abc.abstractmethod
def expand(self, table: ir.Table) -> Sequence[ir.Value]:
"""Expand `table` into value expressions that match the selector.
Expand All @@ -29,3 +30,73 @@ def expand(self, table: ir.Table) -> Sequence[ir.Value]:
A sequence of value expressions that match the selector
"""
names = self.expand_names(table)
return list(map(table.__getitem__, filter(names.__contains__, table.columns)))

@abc.abstractmethod
def expand_names(self, table: ir.Table) -> frozenset[str]:
"""Compute the set of column names that match the selector."""

def __and__(self, other: Selector) -> Selector:
"""Compute the logical conjunction of two `Selector`s.
Parameters
----------
other
Another selector
"""
return And(self, other)

def __or__(self, other: Selector) -> Selector:
"""Compute the logical disjunction of two `Selector`s.
Parameters
----------
other
Another selector
"""
return Or(self, other)

def __invert__(self) -> Selector:
"""Compute the logical negation of a `Selector`."""
return Not(self)


class Or(Selector):
left: Selector
right: Selector

def expand_names(self, table: ir.Table) -> frozenset[str]:
return self.left.expand_names(table) | self.right.expand_names(table)


class And(Selector):
left: Selector
right: Selector

def expand_names(self, table: ir.Table) -> frozenset[str]:
return self.left.expand_names(table) & self.right.expand_names(table)


class Any(Selector):
selectors: VarTuple[Selector]

def expand_names(self, table: ir.Table) -> frozenset[str]:
names = (selector.expand_names(table) for selector in self.selectors)
return frozenset.union(*names)


class All(Selector):
selectors: VarTuple[Selector]

def expand_names(self, table: ir.Table) -> frozenset[str]:
names = (selector.expand_names(table) for selector in self.selectors)
return frozenset.intersection(*names)


class Not(Selector):
selector: Selector

def expand_names(self, table: ir.Table) -> frozenset[str]:
names = self.selector.expand_names(table)
return frozenset(col for col in table.columns if col not in names)
Loading

0 comments on commit 1aa3a35

Please sign in to comment.