Skip to content

Commit

Permalink
feat(api): support wider range of types in where arg to table reduc…
Browse files Browse the repository at this point in the history
…tions
  • Loading branch information
jcrist committed May 24, 2024
1 parent 10afc98 commit 7aba385
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
16 changes: 6 additions & 10 deletions ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,6 @@ def __polars_result__(self, df: pl.DataFrame) -> Any:

return PolarsData.convert_table(df, self.schema())

def _bind_reduction_filter(self, where):
if where is None or not isinstance(where, Deferred):
return where

return where.resolve(self)

def bind(self, *args, **kwargs):
# allow the first argument to be either a dictionary or a list of values
if len(args) == 1:
Expand Down Expand Up @@ -2527,9 +2521,9 @@ def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
>>> t.nunique(t.a != "foo")
1
"""
return ops.CountDistinctStar(
self, where=self._bind_reduction_filter(where)
).to_expr()
if where is not None:
(where,) = bind(self, where)
return ops.CountDistinctStar(self, where=where).to_expr()

def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
"""Compute the number of rows in the table.
Expand Down Expand Up @@ -2566,7 +2560,9 @@ def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar:
>>> type(t.count())
<class 'ibis.expr.types.numeric.IntegerScalar'>
"""
return ops.CountStar(self, where=self._bind_reduction_filter(where)).to_expr()
if where is not None:
(where,) = bind(self, where)
return ops.CountStar(self, where=where).to_expr()

def dropna(
self,
Expand Down
22 changes: 19 additions & 3 deletions ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,10 +629,26 @@ def test_invalid_slice(table, step):
table[:5:step]


def test_table_count(table):
result = table.count()
@pytest.mark.parametrize(
"method, op_cls", [("count", ops.CountStar), ("nunique", ops.CountDistinctStar)]
)
def test_table_count_nunique(table, method, op_cls):
def f(t, **kwargs):
return getattr(t, method)(**kwargs)

result = f(table)
assert isinstance(result, ir.IntegerScalar)
assert isinstance(result.op(), ops.CountStar)
assert isinstance(result.op(), op_cls)
assert result.op().where is None

r1 = f(table, where=table.h)
r2 = f(table, where="h")
r3 = f(table, where=_.h)
r4 = f(table, where=lambda t: t.h)
assert r1.equals(r2)
assert r1.equals(r3)
assert r1.equals(r4)
assert r1.op().where.equals(table.h.op())


def test_len_raises_expression_error(table):
Expand Down

0 comments on commit 7aba385

Please sign in to comment.