Skip to content

Commit

Permalink
feat(api): exact median
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and gforsyth committed May 1, 2023
1 parent 5a6c8ca commit c53031c
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 18 deletions.
1 change: 1 addition & 0 deletions ibis/backends/clickhouse/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,6 +1016,7 @@ def formatter(op, **kw):
ops.RandomScalar: "randCanonical",
# Unary aggregates
ops.ApproxMedian: "median",
ops.Median: "quantileExactExclusive",
# TODO: there is also a `uniq` function which is the
# recommended way to approximate cardinality
ops.ApproxCountDistinct: "uniqHLL12",
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/duckdb/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def _map_merge(t, op):
ops.MapValues: _map_values,
ops.MapMerge: _map_merge,
ops.Hash: unary(sa.func.hash),
ops.Median: reduction(sa.func.median),
}
)

Expand Down
1 change: 1 addition & 0 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,7 @@ def struct_column(op):
ops.Sum: 'sum',
ops.Variance: 'var',
ops.CountDistinct: 'n_unique',
ops.Median: 'median',
}

for reduction in _reductions.keys():
Expand Down
5 changes: 3 additions & 2 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def _quantile(t, op):
)


def _approx_median(t, op):
def _median(t, op):
arg = op.arg
if (where := op.where) is not None:
arg = ops.Where(where, arg, None)
Expand Down Expand Up @@ -667,7 +667,8 @@ def _unnest(t, op):
ops.Correlation: _corr,
ops.BitwiseXor: _bitwise_op("#"),
ops.Mode: _mode,
ops.ApproxMedian: _approx_median,
ops.ApproxMedian: _median,
ops.Median: _median,
ops.Quantile: _quantile,
ops.MultiQuantile: _quantile,
ops.TimestampNow: lambda t, op: sa.literal_column(
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/snowflake/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def _group_concat(t, op):
ops.GroupConcat: _group_concat,
ops.Hash: unary(sa.func.hash),
ops.ApproxMedian: reduction(lambda x: sa.func.approx_percentile(x, 0.5)),
ops.Median: reduction(sa.func.median),
}
)

Expand Down
52 changes: 36 additions & 16 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def mean_udf(s):
"datafusion",
"impala",
"mysql",
"pyspark",
"mssql",
"pyspark",
"trino",
"druid",
],
Expand Down Expand Up @@ -437,20 +437,25 @@ def mean_and_std(v):
lambda t, where: (t.int_col % 3).mode(where=where),
lambda t, where: (t.int_col % 3)[where].mode().iloc[0],
id='mode',
marks=pytest.mark.notyet(
[
"bigquery",
"clickhouse",
"datafusion",
"impala",
"mysql",
"pyspark",
"mssql",
"trino",
"druid",
],
raises=com.OperationNotDefinedError,
),
marks=[
pytest.mark.notyet(
[
"bigquery",
"clickhouse",
"datafusion",
"impala",
"mysql",
"pyspark",
"mssql",
"trino",
"druid",
],
raises=com.OperationNotDefinedError,
),
pytest.mark.xfail_version(
pyspark=["pyspark<3.4.0"], raises=AttributeError
),
],
),
param(
lambda t, where: t.double_col.argmin(t.int_col, where=where),
Expand Down Expand Up @@ -1031,14 +1036,29 @@ def test_corr_cov(
@pytest.mark.broken(
["dask", "pandas"],
raises=AttributeError,
reason="'Series' object has no attribute 'approxmedian'",
reason="'Series' object has no attribute 'approx_median'",
)
def test_approx_median(alltypes):
expr = alltypes.double_col.approx_median()
result = expr.execute()
assert isinstance(result, float)


@pytest.mark.notimpl(
["datafusion", "druid", "sqlite"], raises=com.OperationNotDefinedError
)
@pytest.mark.notyet(
["impala", "mysql", "mssql", "druid", "pyspark", "trino"],
raises=com.OperationNotDefinedError,
)
@pytest.mark.notyet(["dask"], raises=NotImplementedError)
def test_median(alltypes, df):
expr = alltypes.double_col.median()
result = expr.execute()
expected = df.double_col.median()
assert result == expected


@mark.parametrize(
('result_fn', 'expected_fn'),
[
Expand Down
9 changes: 9 additions & 0 deletions ibis/expr/operations/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,15 @@ def output_dtype(self):
return dt.higher_precedence(dtype, dt.float64)


@public
class Median(Filterable, Reduction):
arg = rlz.column(rlz.numeric)

@attribute.default
def output_dtype(self):
return dt.higher_precedence(self.arg.output_dtype, dt.float64)


@public
class Quantile(Filterable, Reduction):
arg = rlz.any
Expand Down
16 changes: 16 additions & 0 deletions ibis/expr/types/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,22 @@ class NumericScalar(Scalar, NumericValue):

@public
class NumericColumn(Column, NumericValue):
def median(self, where: ir.BooleanValue | None = None) -> NumericScalar:
"""Return the median of the column.
Parameters
----------
where
Optional boolean expression. If given, only the values where
`where` evaluates to true will be considered for the median.
Returns
-------
NumericScalar
Median of the column
"""
return ops.Median(self, where=where).to_expr()

def quantile(
self,
quantile: Sequence[NumericValue | float],
Expand Down

0 comments on commit c53031c

Please sign in to comment.