Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(api): approximate quantiles #9881

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT
[
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[1],
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[2],
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[3]
] AS `qs`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS) AS `qs`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
approx_quantiles(`t0`.`double_col`, 2 IGNORE NULLS)[1] AS `qs`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
SELECT
[
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[2],
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[1],
approx_quantiles(`t0`.`double_col`, 4 IGNORE NULLS)[3]
] AS `qs`
FROM `functional_alltypes` AS `t0`
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SELECT
approx_quantiles(`t0`.`double_col`, 100000 IGNORE NULLS)[33333] AS `qs`
FROM `functional_alltypes` AS `t0`
16 changes: 16 additions & 0 deletions ibis/backends/bigquery/tests/unit/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,3 +677,19 @@ def test_time_from_hms_with_micros(snapshot):
literal = ibis.literal(datetime.time(12, 34, 56))
result = ibis.to_sql(literal, dialect="bigquery")
snapshot.assert_match(result, "no_micros.sql")


@pytest.mark.parametrize(
"quantiles",
[
param(0.5, id="scalar"),
param(1 / 3, id="tricky-scalar"),
param([0.25, 0.5, 0.75], id="array"),
param([0.5, 0.25, 0.75], id="shuffled-array"),
param([0, 0.25, 0.5, 0.75, 1], id="complete-array"),
],
)
def test_approx_quantiles(alltypes, quantiles, snapshot):
query = alltypes.double_col.approx_quantile(quantiles).name("qs")
result = ibis.to_sql(query, dialect="bigquery")
snapshot.assert_match(result, "out.sql")
2 changes: 1 addition & 1 deletion ibis/backends/clickhouse/tests/test_aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_reduction_invalid_where(alltypes, reduction):
),
(
lambda t, cond: t.int_col.approx_median(),
lambda df, cond: np.int32(df.int_col.median()),
lambda df, cond: df.int_col.median(),
),
(
lambda t, cond: t.double_col.min(),
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,7 @@ def execute_mode(op, **kw):


@translate.register(ops.Quantile)
@translate.register(ops.ApproxQuantile)
def execute_quantile(op, **kw):
arg = translate(op.arg, **kw)
quantile = translate(op.quantile, **kw)
Expand Down
37 changes: 37 additions & 0 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import decimal
import math
import re
from typing import TYPE_CHECKING, Any

Expand Down Expand Up @@ -392,6 +394,41 @@

return sge.GroupConcat(this=arg, separator=sep)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if not isinstance(op.quantile, ops.Literal):
raise com.UnsupportedOperationError(

Check warning on line 399 in ibis/backends/sql/compilers/bigquery/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery/__init__.py#L399

Added line #L399 was not covered by tests
"quantile must be a literal in BigQuery"
)

# BigQuery syntax is `APPROX_QUANTILES(col, resolution)` to return
# `resolution + 1` quantiles array. To handle this, we compute the
# resolution ourselves then restructure the output array as needed.
# To avoid excessive resolution we arbitrarily cap it at 100,000 -
# since these are approximate quantiles anyway this seems fine.
quantiles = util.promote_list(op.quantile.value)
fracs = [decimal.Decimal(str(q)).as_integer_ratio() for q in quantiles]
resolution = min(math.lcm(*(den for _, den in fracs)), 100_000)
indices = [(num * resolution) // den for num, den in fracs]

if where is not None:
arg = self.if_(where, arg, NULL)

Check warning on line 414 in ibis/backends/sql/compilers/bigquery/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery/__init__.py#L414

Added line #L414 was not covered by tests

if not op.arg.dtype.is_floating():
arg = self.cast(arg, dt.float64)

Check warning on line 417 in ibis/backends/sql/compilers/bigquery/__init__.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/bigquery/__init__.py#L417

Added line #L417 was not covered by tests

array = self.f.approx_quantiles(
arg, sge.IgnoreNulls(this=sge.convert(resolution))
)
if isinstance(op, ops.ApproxQuantile):
return array[indices[0]]

if indices == list(range(resolution + 1)):
return array
else:
return sge.Array(expressions=[array[i] for i in indices])

visit_ApproxMultiQuantile = visit_ApproxQuantile

def visit_FloorDivide(self, op, *, left, right):
return self.cast(self.f.floor(self.f.ieee_divide(left, right)), op.dtype)

Expand Down
22 changes: 14 additions & 8 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,18 +188,24 @@ def visit_CountStar(self, op, *, where, arg):
return self.f.countIf(where)
return sge.Count(this=STAR)

def visit_Quantile(self, op, *, arg, quantile, where):
if where is None:
return self.agg.quantile(arg, quantile, where=where)

func = "quantile" + "s" * isinstance(op, ops.MultiQuantile)
def _visit_quantile(self, func, arg, quantile, where):
return sge.ParameterizedAgg(
this=f"{func}If",
this=f"{func}If" if where is not None else func,
expressions=util.promote_list(quantile),
params=[arg, where],
params=[arg, where] if where is not None else [arg],
)

visit_MultiQuantile = visit_Quantile
def visit_Quantile(self, op, *, arg, quantile, where):
return self._visit_quantile("quantile", arg, quantile, where)

def visit_MultiQuantile(self, op, *, arg, quantile, where):
return self._visit_quantile("quantiles", arg, quantile, where)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
return self._visit_quantile("quantileTDigest", arg, quantile, where)

def visit_ApproxMultiQuantile(self, op, *, arg, quantile, where):
return self._visit_quantile("quantilesTDigest", arg, quantile, where)

def visit_Correlation(self, op, *, left, right, how, where):
if how == "pop":
Expand Down
1 change: 1 addition & 0 deletions ibis/backends/sql/compilers/datafusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class DataFusionCompiler(SQLGlotCompiler):
)

SIMPLE_OPS = {
ops.ApproxQuantile: "approx_percentile_cont",
ops.ApproxMedian: "approx_median",
ops.ArrayRemove: "array_remove_all",
ops.BitAnd: "bit_and",
Expand Down
7 changes: 7 additions & 0 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,13 @@ def visit_Quantile(self, op, *, arg, quantile, where):
def visit_MultiQuantile(self, op, *, arg, quantile, where):
return self.visit_Quantile(op, arg=arg, quantile=quantile, where=where)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if not op.arg.dtype.is_floating():
arg = self.cast(arg, dt.float64)
return self.agg.approx_quantile(arg, quantile, where=where)

visit_ApproxMultiQuantile = visit_ApproxQuantile

def visit_HexDigest(self, op, *, arg, how):
if how in ("md5", "sha256"):
return getattr(self.f, how)(arg)
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/exasol.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def visit_Quantile(self, op, *, arg, quantile, where):
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
)

visit_ApproxQuantile = visit_Quantile

def visit_TimestampTruncate(self, op, *, arg, unit):
short_name = unit.short
unit_mapping = {"W": "IW"}
Expand Down
8 changes: 8 additions & 0 deletions ibis/backends/sql/compilers/mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,14 @@ def visit_CountDistinct(self, op, *, arg, where):
arg = self.if_(where, arg, NULL)
return self.f.count(sge.Distinct(expressions=[arg]))

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if where is not None:
arg = self.if_(where, arg, NULL)
return sge.WithinGroup(
this=self.f.approx_percentile_cont(quantile),
expression=sge.Order(expressions=[sge.Ordered(this=arg, nulls_first=True)]),
)

def visit_DayOfWeekIndex(self, op, *, arg):
return self.f.datepart(self.v.weekday, arg) - 1

Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/sql/compilers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,15 @@ def visit_Quantile(self, op, *, arg, quantile, where):
)
return expr

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if where is not None:
arg = self.if_(where, arg)

return sge.WithinGroup(
this=self.f.approx_percentile(quantile),
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
)

def visit_CountDistinct(self, op, *, arg, where):
if where is not None:
arg = self.if_(where, arg)
Expand Down
2 changes: 2 additions & 0 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ def visit_Quantile(self, op, *, arg, quantile, where):
return expr

visit_MultiQuantile = visit_Quantile
visit_ApproxQuantile = visit_Quantile
visit_ApproxMultiQuantile = visit_Quantile

def visit_Correlation(self, op, *, left, right, how, where):
if how == "sample":
Expand Down
9 changes: 9 additions & 0 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,15 @@ def visit_Quantile(self, op, *, arg, quantile, where):

visit_MultiQuantile = visit_Quantile

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if not op.arg.dtype.is_floating():
arg = self.cast(arg, dt.float64)
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.approx_percentile(arg, quantile)

visit_ApproxMultiQuantile = visit_ApproxQuantile

def visit_Correlation(self, op, *, left, right, how, where):
if (left_type := op.left.dtype).is_boolean():
left = self.cast(left, dt.Int32(nullable=left_type.nullable))
Expand Down
15 changes: 14 additions & 1 deletion ibis/backends/sql/compilers/risingwave.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.backends.sql.compilers import PostgresCompiler
from ibis.backends.sql.compilers.base import ALL_OPERATIONS
from ibis.backends.sql.compilers.base import ALL_OPERATIONS, NULL
from ibis.backends.sql.datatypes import RisingWaveType
from ibis.backends.sql.dialects import RisingWave

Expand All @@ -22,6 +22,8 @@ class RisingWaveCompiler(PostgresCompiler):
ops.DateFromYMD,
ops.Mode,
ops.RandomUUID,
ops.MultiQuantile,
ops.ApproxMultiQuantile,
*(
op
for op in ALL_OPERATIONS
Expand Down Expand Up @@ -65,6 +67,17 @@ def visit_Correlation(self, op, *, left, right, how, where):
op, left=left, right=right, how=how, where=where
)

def visit_Quantile(self, op, *, arg, quantile, where):
if where is not None:
arg = self.if_(where, arg, NULL)
suffix = "cont" if op.arg.dtype.is_numeric() else "disc"
return sge.WithinGroup(
this=self.f[f"percentile_{suffix}"](quantile),
expression=sge.Order(expressions=[sge.Ordered(this=arg)]),
)

visit_ApproxQuantile = visit_Quantile

def visit_TimestampTruncate(self, op, *, arg, unit):
unit_mapping = {
"Y": "year",
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,12 @@
quantile = self.f.percentile_cont(quantile)
return sge.WithinGroup(this=quantile, expression=order_by)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if where is not None:
arg = self.if_(where, arg, NULL)

Check warning on line 613 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L613

Added line #L613 was not covered by tests

return self.f.approx_percentile(arg, quantile)

Check warning on line 615 in ibis/backends/sql/compilers/snowflake.py

View check run for this annotation

Codecov / codecov/patch

ibis/backends/sql/compilers/snowflake.py#L615

Added line #L615 was not covered by tests
jcrist marked this conversation as resolved.
Show resolved Hide resolved

def visit_CountStar(self, op, *, arg, where):
if where is None:
return super().visit_CountStar(op, arg=arg, where=where)
Expand Down
7 changes: 7 additions & 0 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def visit_Correlation(self, op, *, left, right, how, where):

return self.agg.corr(left, right, where=where)

def visit_ApproxQuantile(self, op, *, arg, quantile, where):
if not op.arg.dtype.is_floating():
arg = self.cast(arg, dt.float64)
return self.agg.approx_quantile(arg, quantile, where=where)

visit_ApproxMultiQuantile = visit_ApproxQuantile

def visit_BitXor(self, op, *, arg, where):
a, b = map(sg.to_identifier, "ab")
input_fn = combine_fn = sge.Lambda(
Expand Down
Loading