diff --git a/ibis/expr/types/generic.py b/ibis/expr/types/generic.py index 6b4b66273c18..bd4f04ee556e 100644 --- a/ibis/expr/types/generic.py +++ b/ibis/expr/types/generic.py @@ -1074,9 +1074,7 @@ def collect(self, where: ir.BooleanValue | None = None) -> ir.ArrayScalar: │ b │ [4, 5] │ └────────┴──────────────────────┘ """ - return ops.ArrayCollect( - self, where=self._bind_reduction_filter(where) - ).to_expr() + return ops.ArrayCollect(self, where=self._bind_to_parent_table(where)).to_expr() def identical_to(self, other: Value) -> ir.BooleanValue: """Return whether this expression is identical to other. @@ -1153,7 +1151,7 @@ def group_concat( '39.1: 36.7' """ return ops.GroupConcat( - self, sep=sep, where=self._bind_reduction_filter(where) + self, sep=sep, where=self._bind_to_parent_table(where) ).to_expr() def __hash__(self) -> int: @@ -1485,24 +1483,34 @@ def as_table(self) -> ir.Table: "base table references to a projection" ) - def _bind_reduction_filter(self, where): - rels = self.op().relations - if isinstance(where, Deferred): - if len(rels) == 0: - raise com.IbisInputError( - "Unable to bind deferred expression to a table because " - "the expression doesn't depend on any tables" - ) - elif len(rels) == 1: - (table,) = rels - return where.resolve(table.to_expr()) - else: + def _bind_to_parent_table(self, value) -> Value | None: + """Bind an expr to the parent table of `self`.""" + if value is None: + return None + if isinstance(value, (Deferred, str)) or callable(value): + op = self.op() + if len(op.relations) != 1: + # TODO: I don't think this line can ever be hit by a valid + # expression, since it would require a column expression to + # directly depend on multiple tables. Currently some invalid + # expressions (like t1.a.argmin(t2.b)) aren't caught at + # construction time though, so we keep the check in for now. raise com.RelationError( - "Cannot bind deferred expression to a table because the " - "expression depends on multiple tables" + f"Unable to bind `{value!r}` - the current expression" + f"depends on multiple tables." ) - else: - return where + table = next(iter(op.relations)).to_expr() + + if isinstance(value, str): + return table[value] + elif isinstance(value, Deferred): + return value.resolve(table) + else: + value = value(table) + + if not isinstance(value, Value): + return literal(value) + return value def __deferred_repr__(self): return f"" @@ -1545,7 +1553,7 @@ def approx_nunique( 55 """ return ops.ApproxCountDistinct( - self, where=self._bind_reduction_filter(where) + self, where=self._bind_to_parent_table(where) ).to_expr() def approx_median( @@ -1585,9 +1593,7 @@ def approx_median( >>> t.body_mass_g.approx_median(where=t.species == "Chinstrap") 3700 """ - return ops.ApproxMedian( - self, where=self._bind_reduction_filter(where) - ).to_expr() + return ops.ApproxMedian(self, where=self._bind_to_parent_table(where)).to_expr() def mode(self, where: ir.BooleanValue | None = None) -> Scalar: """Return the mode of a column. @@ -1612,7 +1618,7 @@ def mode(self, where: ir.BooleanValue | None = None) -> Scalar: >>> t.body_mass_g.mode(where=(t.species == "Gentoo") & (t.sex == "male")) 5550 """ - return ops.Mode(self, where=self._bind_reduction_filter(where)).to_expr() + return ops.Mode(self, where=self._bind_to_parent_table(where)).to_expr() def max(self, where: ir.BooleanValue | None = None) -> Scalar: """Return the maximum of a column. @@ -1637,7 +1643,7 @@ def max(self, where: ir.BooleanValue | None = None) -> Scalar: >>> t.body_mass_g.max(where=t.species == "Chinstrap") 4800 """ - return ops.Max(self, where=self._bind_reduction_filter(where)).to_expr() + return ops.Max(self, where=self._bind_to_parent_table(where)).to_expr() def min(self, where: ir.BooleanValue | None = None) -> Scalar: """Return the minimum of a column. @@ -1662,7 +1668,7 @@ def min(self, where: ir.BooleanValue | None = None) -> Scalar: >>> t.body_mass_g.min(where=t.species == "Adelie") 2850 """ - return ops.Min(self, where=self._bind_reduction_filter(where)).to_expr() + return ops.Min(self, where=self._bind_to_parent_table(where)).to_expr() def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: """Return the value of `self` that maximizes `key`. @@ -1690,7 +1696,7 @@ def argmax(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: 'Chinstrap' """ return ops.ArgMax( - self, key=key, where=self._bind_reduction_filter(where) + self, key=key, where=self._bind_to_parent_table(where) ).to_expr() def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: @@ -1720,7 +1726,7 @@ def argmin(self, key: ir.Value, where: ir.BooleanValue | None = None) -> Scalar: 'Adelie' """ return ops.ArgMin( - self, key=key, where=self._bind_reduction_filter(where) + self, key=key, where=self._bind_to_parent_table(where) ).to_expr() def median(self, where: ir.BooleanValue | None = None) -> Scalar: @@ -1776,7 +1782,7 @@ def median(self, where: ir.BooleanValue | None = None) -> Scalar: │ Torgersen │ Adelie │ └───────────┴────────────────┘ """ - return ops.Median(self, where=where).to_expr() + return ops.Median(self, where=self._bind_to_parent_table(where)).to_expr() def quantile( self, @@ -1846,7 +1852,7 @@ def quantile( op = ops.MultiQuantile else: op = ops.Quantile - return op(self, quantile, where=where).to_expr() + return op(self, quantile, where=self._bind_to_parent_table(where)).to_expr() def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: """Compute the number of distinct rows in an expression. @@ -1872,7 +1878,7 @@ def nunique(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: 55 """ return ops.CountDistinct( - self, where=self._bind_reduction_filter(where) + self, where=self._bind_to_parent_table(where) ).to_expr() def topk( @@ -1938,7 +1944,7 @@ def arbitrary( removed_in="10.0", instead="call `first` or `last` explicitly", ) - return ops.Arbitrary(self, where=self._bind_reduction_filter(where)).to_expr() + return ops.Arbitrary(self, where=self._bind_to_parent_table(where)).to_expr() def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: """Compute the number of rows in an expression. @@ -1953,7 +1959,7 @@ def count(self, where: ir.BooleanValue | None = None) -> ir.IntegerScalar: IntegerScalar Number of elements in an expression """ - return ops.Count(self, where=self._bind_reduction_filter(where)).to_expr() + return ops.Count(self, where=self._bind_to_parent_table(where)).to_expr() def value_counts(self) -> ir.Table: """Compute a frequency table. @@ -2022,7 +2028,7 @@ def first(self, where: ir.BooleanValue | None = None) -> Value: >>> t.chars.first(where=t.chars != "a") 'b' """ - return ops.First(self, where=self._bind_reduction_filter(where)).to_expr() + return ops.First(self, where=self._bind_to_parent_table(where)).to_expr() def last(self, where: ir.BooleanValue | None = None) -> Value: """Return the last value of a column. @@ -2048,7 +2054,7 @@ def last(self, where: ir.BooleanValue | None = None) -> Value: >>> t.chars.last(where=t.chars != "d") 'c' """ - return ops.Last(self, where=self._bind_reduction_filter(where)).to_expr() + return ops.Last(self, where=self._bind_to_parent_table(where)).to_expr() def rank(self) -> ir.IntegerColumn: """Compute position of first element within each equal-value group in sorted order. diff --git a/ibis/expr/types/logical.py b/ibis/expr/types/logical.py index a0c289df3bf5..79262fb6c008 100644 --- a/ibis/expr/types/logical.py +++ b/ibis/expr/types/logical.py @@ -342,7 +342,7 @@ def resolve_exists_subquery(outer): if len(parents) == 2: return Deferred(Call(resolve_exists_subquery, _)) elif len(parents) == 1: - op = ops.Any(self, where=self._bind_reduction_filter(where)) + op = ops.Any(self, where=self._bind_to_parent_table(where)) else: raise NotImplementedError( f'Cannot compute "any" for expression of type {type(self)} ' @@ -407,7 +407,7 @@ def all(self, where: BooleanValue | None = None) -> BooleanScalar: False """ - return ops.All(self, where=self._bind_reduction_filter(where)).to_expr() + return ops.All(self, where=self._bind_to_parent_table(where)).to_expr() def notall(self, where: BooleanValue | None = None) -> BooleanScalar: """Return whether not all elements are `True`. diff --git a/ibis/expr/types/numeric.py b/ibis/expr/types/numeric.py index 56d5c2dea178..91609dcc2999 100644 --- a/ibis/expr/types/numeric.py +++ b/ibis/expr/types/numeric.py @@ -787,7 +787,7 @@ def std( Standard deviation of `arg` """ return ops.StandardDev( - self, how=how, where=self._bind_reduction_filter(where) + self, how=how, where=self._bind_to_parent_table(where) ).to_expr() def var( @@ -810,7 +810,7 @@ def var( Standard deviation of `arg` """ return ops.Variance( - self, how=how, where=self._bind_reduction_filter(where) + self, how=how, where=self._bind_to_parent_table(where) ).to_expr() def corr( @@ -836,7 +836,7 @@ def corr( The correlation of `left` and `right` """ return ops.Correlation( - self, right, how=how, where=self._bind_reduction_filter(where) + self, right, how=how, where=self._bind_to_parent_table(where) ).to_expr() def cov( @@ -862,7 +862,7 @@ def cov( The covariance of `self` and `right` """ return ops.Covariance( - self, right, how=how, where=self._bind_reduction_filter(where) + self, right, how=how, where=self._bind_to_parent_table(where) ).to_expr() def mean( @@ -883,7 +883,7 @@ def mean( """ # TODO(kszucs): remove the alias from the reduction method in favor # of default name generated by ops.Value operations - return ops.Mean(self, where=self._bind_reduction_filter(where)).to_expr() + return ops.Mean(self, where=self._bind_to_parent_table(where)).to_expr() def cummean(self, *, where=None, group_by=None, order_by=None) -> NumericColumn: """Return the cumulative mean of the input.""" @@ -907,7 +907,7 @@ def sum( NumericScalar The sum of the input expression """ - return ops.Sum(self, where=self._bind_reduction_filter(where)).to_expr() + return ops.Sum(self, where=self._bind_to_parent_table(where)).to_expr() def cumsum(self, *, where=None, group_by=None, order_by=None) -> NumericColumn: """Return the cumulative sum of the input.""" @@ -1161,15 +1161,15 @@ class IntegerScalar(NumericScalar, IntegerValue): class IntegerColumn(NumericColumn, IntegerValue): def bit_and(self, where: ir.BooleanValue | None = None) -> IntegerScalar: """Aggregate the column using the bitwise and operator.""" - return ops.BitAnd(self, where=self._bind_reduction_filter(where)).to_expr() + return ops.BitAnd(self, where=self._bind_to_parent_table(where)).to_expr() def bit_or(self, where: ir.BooleanValue | None = None) -> IntegerScalar: """Aggregate the column using the bitwise or operator.""" - return ops.BitOr(self, where=self._bind_reduction_filter(where)).to_expr() + return ops.BitOr(self, where=self._bind_to_parent_table(where)).to_expr() def bit_xor(self, where: ir.BooleanValue | None = None) -> IntegerScalar: """Aggregate the column using the bitwise exclusive or operator.""" - return ops.BitXor(self, where=self._bind_reduction_filter(where)).to_expr() + return ops.BitXor(self, where=self._bind_to_parent_table(where)).to_expr() @public diff --git a/ibis/tests/expr/test_aggregation.py b/ibis/tests/expr/test_aggregation.py new file mode 100644 index 000000000000..0192c27997aa --- /dev/null +++ b/ibis/tests/expr/test_aggregation.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import pytest + +import ibis +from ibis import _ + + +@pytest.fixture +def table(): + return ibis.table( + {"ints": "int", "floats": "float", "bools": "bool", "strings": "string"} + ) + + +@pytest.mark.parametrize( + "func", + [ + pytest.param(lambda t, **kws: t.strings.arbitrary(**kws), id="arbitrary"), + pytest.param(lambda t, **kws: t.strings.collect(**kws), id="collect"), + pytest.param(lambda t, **kws: t.strings.group_concat(**kws), id="group_concat"), + pytest.param( + lambda t, **kws: t.strings.approx_nunique(**kws), id="approx_nunique" + ), + pytest.param( + lambda t, **kws: t.strings.approx_median(**kws), id="approx_median" + ), + pytest.param(lambda t, **kws: t.strings.mode(**kws), id="mode"), + pytest.param(lambda t, **kws: t.strings.max(**kws), id="max"), + pytest.param(lambda t, **kws: t.strings.min(**kws), id="min"), + pytest.param(lambda t, **kws: t.strings.argmax(t.ints, **kws), id="argmax"), + pytest.param(lambda t, **kws: t.strings.argmin(t.ints, **kws), id="argmin"), + pytest.param(lambda t, **kws: t.strings.median(**kws), id="median"), + pytest.param(lambda t, **kws: t.strings.quantile(0.25, **kws), id="quantile"), + pytest.param( + lambda t, **kws: t.strings.quantile([0.25, 0.75], **kws), + id="multi-quantile", + ), + pytest.param(lambda t, **kws: t.strings.nunique(**kws), id="nunique"), + pytest.param(lambda t, **kws: t.strings.count(**kws), id="count"), + pytest.param(lambda t, **kws: t.strings.first(**kws), id="first"), + pytest.param(lambda t, **kws: t.strings.last(**kws), id="last"), + pytest.param(lambda t, **kws: t.ints.std(**kws), id="std"), + pytest.param(lambda t, **kws: t.ints.var(**kws), id="var"), + pytest.param(lambda t, **kws: t.ints.mean(**kws), id="mean"), + pytest.param(lambda t, **kws: t.ints.sum(**kws), id="sum"), + pytest.param(lambda t, **kws: t.ints.corr(t.floats, **kws), id="corr"), + pytest.param(lambda t, **kws: t.ints.cov(t.floats, **kws), id="cov"), + pytest.param(lambda t, **kws: t.ints.bit_and(**kws), id="bit_and"), + pytest.param(lambda t, **kws: t.ints.bit_xor(**kws), id="bit_xor"), + pytest.param(lambda t, **kws: t.ints.bit_or(**kws), id="bit_or"), + pytest.param(lambda t, **kws: t.bools.any(**kws), id="any"), + pytest.param(lambda t, **kws: t.bools.all(**kws), id="all"), + pytest.param(lambda t, **kws: t.count(**kws), id="table-count"), + pytest.param(lambda t, **kws: t.nunique(**kws), id="table-nunique"), + ], +) +def test_aggregation_where(table, func): + # No where + op = func(table).op() + assert op.where is None + + # Literal where + op = func(table, where=False).op() + assert op.where.equals(ibis.literal(False).op()) + + # Various ways to spell the same column expression + r1 = func(table, where=table.bools) + r2 = func(table, where=_.bools) + r3 = func(table, where=lambda t: t.bools) + r4 = func(table, where="bools") + assert r1.equals(r2) + assert r1.equals(r3) + assert r1.equals(r4) + assert r1.op().where.equals(table.bools.op())