Skip to content

Commit

Permalink
chore: clean up OfType selector
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Aug 25, 2024
1 parent 1aa3a35 commit 6bd0a97
Showing 1 changed file with 30 additions and 17 deletions.
47 changes: 30 additions & 17 deletions ibis/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,15 @@ def numeric() -> Selector:
return of_type(dt.Numeric)


class OfType(Selector):
predicate: Callable[[dt.DataType], bool]

def expand_names(self, table: ir.Table) -> frozenset[str]:
return frozenset(
name for name, typ in table.schema().items() if self.predicate(typ)
)


@public
def of_type(dtype: dt.DataType | str | type[dt.DataType]) -> Selector:
"""Select columns of type `dtype`.
Expand Down Expand Up @@ -171,32 +180,36 @@ def of_type(dtype: dt.DataType | str | type[dt.DataType]) -> Selector:
[`numeric`](#ibis.selectors.numeric)
"""

abstract = {
"array": dt.Array,
"decimal": dt.Decimal,
"floating": dt.Floating,
"geospatial": dt.GeoSpatial,
"integer": dt.Integer,
"map": dt.Map,
"numeric": dt.Numeric,
"struct": dt.Struct,
"temporal": dt.Temporal,
}

if isinstance(dtype, str):
# A mapping of abstract or parametric types, to allow selecting all
# subclasses/parametrizations of these types, rather than only a
# specific instance.
abstract = {
"array": dt.Array,
"decimal": dt.Decimal,
"floating": dt.Floating,
"geospatial": dt.GeoSpatial,
"integer": dt.Integer,
"map": dt.Map,
"numeric": dt.Numeric,
"struct": dt.Struct,
"temporal": dt.Temporal,
}
if cls := abstract.get(dtype.lower()):
predicate = lambda col: isinstance(col.type(), cls)
if dtype_cls := abstract.get(dtype.lower()):
predicate = lambda typ, dtype_cls=dtype_cls: isinstance(typ, dtype_cls)
else:
dtype = dt.dtype(dtype)
predicate = lambda col: col.type() == dtype
predicate = lambda typ, dtype=dtype: typ == dtype

elif inspect.isclass(dtype) and issubclass(dtype, dt.DataType):
predicate = lambda col: isinstance(col.type(), dtype)
predicate = lambda typ, dtype_cls=dtype: isinstance(typ, dtype_cls)
else:
dtype = dt.dtype(dtype)
predicate = lambda col: col.type() == dtype
return where(predicate)
predicate = lambda typ, dtype=dtype: typ == dtype

return OfType(predicate)


class StartsWith(Selector):
Expand Down

0 comments on commit 6bd0a97

Please sign in to comment.