diff --git a/ibis/expr/types/relations.py b/ibis/expr/types/relations.py index b9374b2b6a80..39a7dbdabc97 100644 --- a/ibis/expr/types/relations.py +++ b/ibis/expr/types/relations.py @@ -209,12 +209,6 @@ def __polars_result__(self, df: pl.DataFrame) -> Any: return PolarsData.convert_table(df, self.schema()) - def _bind_reduction_filter(self, where): - if where is None or not isinstance(where, Deferred): - return where - - return where.resolve(self) - def bind(self, *args, **kwargs): # allow the first argument to be either a dictionary or a list of values if len(args) == 1: @@ -2527,9 +2521,9 @@ def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: >>> t.nunique(t.a != "foo") 1 """ - return ops.CountDistinctStar( - self, where=self._bind_reduction_filter(where) - ).to_expr() + if where is not None: + (where,) = bind(self, where) + return ops.CountDistinctStar(self, where=where).to_expr() def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: """Compute the number of rows in the table. @@ -2566,7 +2560,9 @@ def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: >>> type(t.count()) """ - return ops.CountStar(self, where=self._bind_reduction_filter(where)).to_expr() + if where is not None: + (where,) = bind(self, where) + return ops.CountStar(self, where=where).to_expr() def dropna( self, diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index 12d6c3d4ba04..29b2e01637b1 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -629,10 +629,26 @@ def test_invalid_slice(table, step): table[:5:step] -def test_table_count(table): - result = table.count() +@pytest.mark.parametrize( + "method, op_cls", [("count", ops.CountStar), ("nunique", ops.CountDistinctStar)] +) +def test_table_count_nunique(table, method, op_cls): + def f(t, **kwargs): + return getattr(t, method)(**kwargs) + + result = f(table) assert isinstance(result, ir.IntegerScalar) - assert isinstance(result.op(), ops.CountStar) + assert isinstance(result.op(), op_cls) + assert result.op().where is None + + r1 = f(table, where=table.h) + r2 = f(table, where="h") + r3 = f(table, where=_.h) + r4 = f(table, where=lambda t: t.h) + assert r1.equals(r2) + assert r1.equals(r3) + assert r1.equals(r4) + assert r1.op().where.equals(table.h.op()) def test_len_raises_expression_error(table):