Skip to content

Commit

Permalink
fix(snowflake): manually construct quantile calls with WITHIN GROUP (
Browse files Browse the repository at this point in the history
  • Loading branch information
gforsyth authored Apr 12, 2024
1 parent fd8858d commit 261a544
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
15 changes: 14 additions & 1 deletion ibis/backends/snowflake/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,20 @@ def visit_Quantile(self, op, *, arg, quantile, where):
# the constant into an expression
if where is not None:
arg = self.if_(where, arg, NULL)
return self.f.percentile_cont(arg, quantile)

# The Snowflake SQLGlot dialect rewrites calls to `percentile_cont` to
# include WITHIN GROUP (ORDER BY ...)
# as per https://docs.snowflake.com/en/sql-reference/functions/percentile_cont
# using the rule `add_within_group_for_percentiles`
#
# If we have copy=False set in our call to `compile`, if there is more
# than one quantile, the rewrite rule fails on the second pass because
# of some mutation in the first pass. To avoid this error, we create the
# expression with the within group included already and skip the (now
# unneeded) rewrite rule.
order_by = sge.Order(expressions=[sge.Ordered(this=arg)])
quantile = self.f.percentile_cont(quantile)
return sge.WithinGroup(this=quantile, expression=order_by)

def visit_CountStar(self, op, *, arg, where):
if where is None:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
SELECT
PERCENTILE_CONT(0.25) WITHIN GROUP (ORDER BY
"t0"."ROW_COUNT") AS "quantile_0_25",
PERCENTILE_CONT(0.75) WITHIN GROUP (ORDER BY
"t0"."ROW_COUNT") AS "quantile_0_75"
FROM "t" AS "t0"
16 changes: 16 additions & 0 deletions ibis/backends/snowflake/tests/test_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from __future__ import annotations

import ibis
from ibis import _


def test_more_than_one_quantile(snapshot):
tables = ibis.table(name="t", schema={"ROW_COUNT": "int"})

expr = tables.aggregate(
quantile_0_25=_.ROW_COUNT.quantile(0.25),
quantile_0_75=_.ROW_COUNT.quantile(0.75),
)

sql = ibis.to_sql(expr, dialect="snowflake")
snapshot.assert_match(sql, "two_quantiles.sql")

0 comments on commit 261a544

Please sign in to comment.