From 6e7e4de53348085ebe4ecd38b96e7de62f613c83 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Sat, 3 Aug 2024 08:04:58 -0400 Subject: [PATCH] fix(snowflake): bring back `where` filter support in `group_concat`; fix `array_agg` ordering (#9758) --- ibis/backends/sql/compilers/snowflake.py | 26 ++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index 00af8fc01a27..23d73f384824 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -361,12 +361,27 @@ def visit_TimestampFromUNIX(self, op, *, arg, unit): timestamp_units_to_scale = {"s": 0, "ms": 3, "us": 6, "ns": 9} return self.f.to_timestamp(arg, timestamp_units_to_scale[unit.short]) + def _array_collect(self, *, arg, where, order_by): + if where is not None: + arg = self.if_(where, arg, NULL) + + out = self.f.array_agg(arg) + + if order_by: + out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by)) + + return out + + def visit_ArrayCollect(self, op, *, arg, where, order_by): + return self._array_collect(arg=arg, where=where, order_by=order_by) + def visit_First(self, op, *, arg, where, order_by): - return self.f.get(self.agg.array_agg(arg, where=where, order_by=order_by), 0) + out = self._array_collect(arg=arg, where=where, order_by=order_by) + return self.f.get(out, 0) def visit_Last(self, op, *, arg, where, order_by): - expr = self.agg.array_agg(arg, where=where, order_by=order_by) - return self.f.get(expr, self.f.array_size(expr) - 1) + out = self._array_collect(arg=arg, where=where, order_by=order_by) + return self.f.get(out, self.f.array_size(out) - 1) def visit_GroupConcat(self, op, *, arg, where, sep, order_by): if where is not None: @@ -377,7 +392,10 @@ def visit_GroupConcat(self, op, *, arg, where, sep, order_by): if order_by: out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by)) - return out + if where is None: + return out + + return self.if_(self.f.count_if(where) > 0, out, NULL) def visit_TimestampBucket(self, op, *, arg, interval, offset): if offset is not None: