Skip to content

Commit

Permalink
feat(api): add approx_quantiles for computing approximate quantiles
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Aug 23, 2024
1 parent 9d218d1 commit dcdb7a7
Show file tree
Hide file tree
Showing 23 changed files with 298 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT
[
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[1],
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[2],
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[3]
] AS `qs`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS) AS `qs`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
approx_quantiles(`t0`.`double_col`, 2 IGNORE NULLS)[1] AS `qs`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT
[
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[2],
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[1],
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[3]
] AS `qs`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
approx_quantiles(`t0`.`double_col`, 100000 IGNORE NULLS)[33333] AS `qs`
FROM `functional_alltypes` AS `t0`
16 changes: 16 additions & 0 deletions ibis/backends/bigquery/tests/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,3 +677,19 @@ def test_time_from_hms_with_micros(snapshot):
literal = ibis.literal(datetime.time(12, 34, 56))
result = ibis.to_sql(literal, dialect="bigquery")
snapshot.assert_match(result, "no_micros.sql")


@pytest.mark.parametrize(
"quantiles",
[
param(0.5, id="scalar"),
param(1 / 3, id="tricky-scalar"),
param([0.25, 0.5, 0.75], id="array"),
param([0.5, 0.25, 0.75], id="shuffled-array"),
param([0, 0.25, 0.5, 0.75, 1], id="complete-array"),
],
)
def test_approx_quantiles(alltypes, quantiles, snapshot):
query = alltypes.double_col.approx_quantile(quantiles).name("qs")
result = ibis.to_sql(query, dialect="bigquery")
snapshot.assert_match(result, "out.sql")
1 change: 1 addition & 0 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,7 @@ def execute_mode(op, **kw):


@translate.register(ops.Quantile)
@translate.register(ops.ApproxQuantile)
def execute_quantile(op, **kw):
arg = translate(op.arg, **kw)
quantile = translate(op.quantile, **kw)
Expand Down
37 changes: 37 additions & 0 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import decimal
import math
import re
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -392,6 +394,41 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by):

return sge.GroupConcat(this=arg, separator=sep)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if not isinstance(op.quantile, ops.Literal):
raise com.UnsupportedOperationError(
"quantile must be a literal in BigQuery"
)

# BigQuery syntax is `APPROX_QUANTILES(col, resolution)` to return
# `resolution + 1` quantiles array. To handle this, we compute the
# resolution ourselves then restructure the output array as needed.
# To avoid excessive resolution we arbitrarily cap it at 100,000 -
# since these are approximate quantiles anyway this seems fine.
quantiles = util.promote_list(op.quantile.value)
fracs = [decimal.Decimal(str(q)).as_integer_ratio() for q in quantiles]
resolution = min(math.lcm(*(den for _, den in fracs)), 100_000)
indices = [(num * resolution) // den for num, den in fracs]

if where is not None:
arg = self.if_(where, arg, NULL)

if not op.arg.dtype.is_floating():
arg = self.cast(arg, dt.float64)

array = self.f.approx_quantiles(
arg, sge.IgnoreNulls(this=sge.convert(resolution))
)
if isinstance(op, ops.ApproxQuantile):
return array[indices[0]]

if indices == list(range(resolution + 1)):
return array
else:
return sge.Array(expressions=[array[i] for i in indices])

visit_ApproxMultiQuantile = visit_ApproxQuantile

def visit_FloorDivide(self, op, *, left, right):
return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype)

Expand Down
22 changes: 14 additions & 8 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,18 +188,24 @@ def visit_CountStar(self, op, *, where, arg):
return self.f.countIf(where)
return sge.Count(this=STAR)

def visit_Quantile(self, op, *, arg, quantile, where):
if where is None:
return self.agg.quantile(arg, quantile, where=where)

func = "quantile" + "s" * isinstance(op, ops.MultiQuantile)
def _visit_quantile(self, func, arg, quantile, where):
return sge.ParameterizedAgg(
this=f"{func}If",
this=f"{func}If" if where is not None else func,
expressions=util.promote_list(quantile),
params=[arg, where],
params=[arg, where] if where is not None else [arg],
)

visit_MultiQuantile = visit_Quantile
def visit_Quantile(self, op, *, arg, quantile, where):
return self._visit_quantile("quantile", arg, quantile, where)

def visit_MultiQuantile(self, op, *, arg, quantile, where):
return self._visit_quantile("quantiles", arg, quantile, where)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
return self._visit_quantile("quantileTDigest", arg, quantile, where)

def visit_ApproxMultiQuantile(self, op, *, arg, quantile, where):
return self._visit_quantile("quantilesTDigest", arg, quantile, where)

def visit_Correlation(self, op, *, left, right, how, where):
if how == "pop":
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class DataFusionCompiler(SQLGlotCompiler):
)

SIMPLE_OPS = {
ops.ApproxQuantile: "approx_percentile_cont",
ops.ApproxMedian: "approx_median",
ops.ArrayRemove: "array_remove_all",
ops.BitAnd: "bit_and",
Expand Down
7 changes: 7 additions & 0 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,13 @@ def visit_Quantile(self, op, *, arg, quantile, where):
def visit_MultiQuantile(self, op, *, arg, quantile, where):
return self.visit_Quantile(op, arg=arg, quantile=quantile, where=where)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if not op.arg.dtype.is_floating():
arg = self.cast(arg, dt.float64)
return self.agg.approx_quantile(arg, quantile, where=where)

visit_ApproxMultiQuantile = visit_ApproxQuantile

def visit_HexDigest(self, op, *, arg, how):
if how in ("md5", "sha256"):
return getattr(self.f, how)(arg)
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def visit_Quantile(self, op, *, arg, quantile, where):
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
)

visit_ApproxQuantile = visit_Quantile

def visit_TimestampTruncate(self, op, *, arg, unit):
short_name = unit.short
unit_mapping = {"W": "IW"}
Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,14 @@ def visit_CountDistinct(self, op, *, arg, where):
arg = self.if_(where, arg, NULL)
return self.f.count(sge.Distinct(expressions=[arg]))

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if where is not None:
arg = self.if_(where, arg, NULL)
return sge.WithinGroup(
this=self.f.approx_percentile_cont(quantile),
expression=sge.Order(expressions=[sge.Ordered(this=arg, nulls_first=True)]),
)

def visit_DayOfWeekIndex(self, op, *, arg):
return self.f.datepart(self.v.weekday, arg) - 1

Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,15 @@ def visit_Quantile(self, op, *, arg, quantile, where):
)
return expr

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if where is not None:
arg = self.if_(where, arg)

return sge.WithinGroup(
this=self.f.approx_percentile(quantile),
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
)

def visit_CountDistinct(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg)
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ def visit_Quantile(self, op, *, arg, quantile, where):
return expr

visit_MultiQuantile = visit_Quantile
visit_ApproxQuantile = visit_Quantile
visit_ApproxMultiQuantile = visit_Quantile

def visit_Correlation(self, op, *, left, right, how, where):
if how == "sample":
Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,15 @@ def visit_Quantile(self, op, *, arg, quantile, where):

visit_MultiQuantile = visit_Quantile

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if not op.arg.dtype.is_floating():
arg = self.cast(arg, dt.float64)
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.approx_percentile(arg, quantile)

visit_ApproxMultiQuantile = visit_ApproxQuantile

def visit_Correlation(self, op, *, left, right, how, where):
if (left_type := op.left.dtype).is_boolean():
left = self.cast(left, dt.Int32(nullable=left_type.nullable))
Expand Down
15 changes: 14 additions & 1 deletion ibis/backends/sql/compilers/risingwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compilers import PostgresCompiler
from ibis.backends.sql.compilers.base import ALL_OPERATIONS
from ibis.backends.sql.compilers.base import ALL_OPERATIONS, NULL
from ibis.backends.sql.datatypes import RisingWaveType
from ibis.backends.sql.dialects import RisingWave

Expand All @@ -22,6 +22,8 @@ class RisingWaveCompiler(PostgresCompiler):
ops.DateFromYMD,
ops.Mode,
ops.RandomUUID,
ops.MultiQuantile,
ops.ApproxMultiQuantile,
*(
op
for op in ALL_OPERATIONS
Expand Down Expand Up @@ -65,6 +67,17 @@ def visit_Correlation(self, op, *, left, right, how, where):
op, left=left, right=right, how=how, where=where
)

def visit_Quantile(self, op, *, arg, quantile, where):
if where is not None:
arg = self.if_(where, arg, NULL)
suffix = "cont" if op.arg.dtype.is_numeric() else "disc"
return sge.WithinGroup(
this=self.f[f"percentile_{suffix}"](quantile),
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
)

visit_ApproxQuantile = visit_Quantile

def visit_TimestampTruncate(self, op, *, arg, unit):
unit_mapping = {
"Y": "year",
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,12 @@ def visit_Quantile(self, op, *, arg, quantile, where):
quantile = self.f.percentile_cont(quantile)
return sge.WithinGroup(this=quantile, expression=order_by)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if where is not None:
arg = self.if_(where, arg, NULL)

return self.f.approx_percentile(arg, quantile)

def visit_CountStar(self, op, *, arg, where):
if where is None:
return super().visit_CountStar(op, arg=arg, where=where)
Expand Down
7 changes: 7 additions & 0 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def visit_Correlation(self, op, *, left, right, how, where):

return self.agg.corr(left, right, where=where)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if not op.arg.dtype.is_floating():
arg = self.cast(arg, dt.float64)
return self.agg.approx_quantile(arg, quantile, where=where)

visit_ApproxMultiQuantile = visit_ApproxQuantile

def visit_BitXor(self, op, *, arg, where):
a, b = map(sg.to_identifier, "ab")
input_fn = combine_fn = sge.Lambda(
Expand Down
77 changes: 63 additions & 14 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,7 +825,7 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond):
raises=com.UnsupportedBackendType,
),
pytest.mark.notyet(
["snowflake"],
["snowflake", "risingwave"],
reason="backend doesn't implement array of quantiles as input",
raises=com.OperationNotDefinedError,
),
Expand All @@ -844,11 +844,6 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond):
reason="backend doesn't implement approximate quantiles yet",
raises=com.OperationNotDefinedError,
),
pytest.mark.notimpl(
["risingwave"],
reason="Invalid input syntax: direct arg in `percentile_cont` must be castable to float64",
raises=PsycoPg2InternalError,
),
],
),
],
Expand All @@ -862,14 +857,7 @@ def test_count_distinct_star(alltypes, df, ibis_cond, pandas_cond):
lambda t: t.string_col.isin(["1", "7"]),
id="is_in",
marks=[
pytest.mark.notimpl(
["datafusion"], raises=com.OperationNotDefinedError
),
pytest.mark.notimpl(
"risingwave",
raises=PsycoPg2InternalError,
reason="probably incorrect filter syntax but not sure",
),
pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)
],
),
],
Expand All @@ -888,6 +876,67 @@ def test_quantile(
assert pytest.approx(result) == expected


@pytest.mark.parametrize(
"filtered",
[
False,
param(
True,
marks=[
pytest.mark.notyet(
["datafusion"],
raises=Exception,
reason="datafusion 38.0.1 has a bug in FILTER handling that causes this test to fail",
strict=False,
)
],
),
],
)
@pytest.mark.parametrize(
"multi",
[
False,
param(
True,
marks=[
pytest.mark.notimpl(
["datafusion", "oracle", "snowflake", "polars", "risingwave"],
raises=com.OperationNotDefinedError,
reason="multi-quantile not yet implemented",
),
pytest.mark.notyet(
["mssql", "exasol"],
raises=com.UnsupportedBackendType,
reason="array types not supported",
),
],
),
],
)
@pytest.mark.notyet(
["druid", "flink", "impala", "mysql", "sqlite"],
raises=(com.OperationNotDefinedError, com.UnsupportedBackendType),
reason="quantiles (approximate or otherwise) not supported",
)
def test_approx_quantile(con, filtered, multi):
t = ibis.memtable({"x": [0, 25, 25, 50, 75, 75, 100, 125, 125, 150, 175, 175, 200]})
where = t.x <= 100 if filtered else None
q = [0.25, 0.75] if multi else 0.25
res = con.execute(t.x.approx_quantile(q, where=where))
if multi:
assert isinstance(res, list)
assert all(pd.api.types.is_float(r) for r in res)
sol = [25, 75] if filtered else [50, 150]
else:
assert pd.api.types.is_float(res)
sol = 25 if filtered else 50

# Give pretty wide bounds for approximation - we're mostly testing that
# the call is valid and the filtering logic is applied properly.
assert res == pytest.approx(sol, abs=10)


@pytest.mark.parametrize(
("result_fn", "expected_fn"),
[
Expand Down
Loading

0 comments on commit dcdb7a7

Please sign in to comment.