From 18bdb7e6957f83439d8f3670174177a904e788e6 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 11 Dec 2023 06:22:48 -0500 Subject: [PATCH] fix(datafusion): support computed group by when the aggregation is count distinct --- .../backends/datafusion/compiler/relations.py | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/ibis/backends/datafusion/compiler/relations.py b/ibis/backends/datafusion/compiler/relations.py index 108f7e465fab..e023e3fb15ac 100644 --- a/ibis/backends/datafusion/compiler/relations.py +++ b/ibis/backends/datafusion/compiler/relations.py @@ -139,13 +139,37 @@ def _limit(op: ops.Limit, *, table, n, offset, **_): def _aggregation( op: ops.Aggregation, *, table, metrics, by, having, predicates, sort_keys, **_ ): - selections = (by + metrics) or (STAR,) + if by: + # datafusion doesn't support count distinct aggregations alongside + # computed grouping keys so create a projection of the key and all + # existing columns first, followed by the usual group by + # + # analogous to a user calling mutate -> group_by + by_names = frozenset(b.alias_or_name for b in by) + cols = [ + sg.column( + name, + table=sg.to_identifier(table.alias_or_name, quoted=True), + quoted=True, + ) + for name in op.table.schema.keys() - by_names + ] + table = sg.select(*cols, *by).from_(table).subquery() + + # datafusion lower cases all column names internally unless quoted so + # quoted=True is required here for correctness + by_names_quoted = tuple( + sg.column(b.alias_or_name, table=getattr(b, "table", None), quoted=True) + for b in by + ) + selections = by_names_quoted + metrics + else: + selections = metrics or (STAR,) + sel = sg.select(*selections).from_(table) if by: - sel = sel.group_by( - *(key.this if isinstance(key, sg.exp.Alias) else key for key in by) - ) + sel = sel.group_by(*by_names_quoted) if predicates: sel = sel.where(*predicates)