Skip to content

Commit

Permalink
feat(array): implement min, max, any, all, sum, mean (#9704)
Browse files Browse the repository at this point in the history
Co-authored-by: Phillip Cloud <417981+cpcloud@users.noreply.github.com>
  • Loading branch information
NickCrews and cpcloud authored Jul 30, 2024
1 parent fd61f2c commit 793efbc
Show file tree
Hide file tree
Showing 14 changed files with 717 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .codespellrc
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[codespell]
# local codespell matches `./docs`, pre-commit codespell matches `docs`
skip = *.lock,.direnv,.git,./docs/_freeze,./docs/_output/**,./docs/_inv/**,docs/_freeze/**,*.svg,*.css,*.html,*.js,ibis/backends/tests/tpc/queries/duckdb/ds/44.sql
ignore-regex = \b(i[if]f|I[IF]F|AFE)\b
ignore-regex = \b(i[if]f|I[IF]F|AFE|alls)\b
builtin = clear,rare,names
ignore-words-list = tim,notin,ang
2 changes: 2 additions & 0 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ quartodoc:
package: ibis.expr.types.numeric
- name: BooleanValue
package: ibis.expr.types.logical
- name: BooleanColumn
package: ibis.expr.types.logical
- name: and_
dynamic: true
signature_name: full
Expand Down
40 changes: 40 additions & 0 deletions ibis/backends/polars/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1350,3 +1350,43 @@ def execute_timestamp_range(op, **kw):
def execute_drop_columns(op, **kw):
parent = translate(op.parent, **kw)
return parent.drop(op.columns_to_drop)


@translate.register(ops.ArraySum)
def execute_array_agg(op, **kw):
arg = translate(op.arg, **kw)
# workaround polars annoying sum([]) == 0 behavior
#
# the polars behavior is consistent with math, but inconsistent
# with every other query engine every built.
no_nulls = arg.list.drop_nulls()
return pl.when(no_nulls.list.len() == 0).then(None).otherwise(no_nulls.list.sum())


@translate.register(ops.ArrayMean)
def execute_array_mean(op, **kw):
return translate(op.arg, **kw).list.mean()


@translate.register(ops.ArrayMin)
def execute_array_min(op, **kw):
return translate(op.arg, **kw).list.min()


@translate.register(ops.ArrayMax)
def execute_array_max(op, **kw):
return translate(op.arg, **kw).list.max()


@translate.register(ops.ArrayAny)
def execute_array_any(op, **kw):
arg = translate(op.arg, **kw)
no_nulls = arg.list.drop_nulls()
return pl.when(no_nulls.list.len() == 0).then(None).otherwise(no_nulls.list.any())


@translate.register(ops.ArrayAll)
def execute_array_all(op, **kw):
arg = translate(op.arg, **kw)
no_nulls = arg.list.drop_nulls()
return pl.when(no_nulls.list.len() == 0).then(None).otherwise(no_nulls.list.all())
64 changes: 64 additions & 0 deletions ibis/backends/snowflake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,70 @@
"returns": "ARRAY",
"source": """return Array(count).fill(value).flat();""",
},
"ibis_udfs.public.array_sum": {
"inputs": {"array": "ARRAY"},
"returns": "DOUBLE",
"source": """\
let total = 0.0;
let allNull = true;
for (val of array) {
if (val !== null) {
total += val;
allNull = false;
}
}
return !allNull ? total : null;""",
},
"ibis_udfs.public.array_avg": {
"inputs": {"array": "ARRAY"},
"returns": "DOUBLE",
"source": """\
let count = 0;
let total = 0.0;
for (val of array) {
if (val !== null) {
total += val;
++count;
}
}
return count !== 0 ? total / count : null;""",
},
"ibis_udfs.public.array_any": {
"inputs": {"array": "ARRAY"},
"returns": "BOOLEAN",
"source": """\
let count = 0;
for (val of array) {
if (val === true) {
return true;
} else if (val === false) {
++count;
}
}
return count !== 0 ? false : null;""",
},
"ibis_udfs.public.array_all": {
"inputs": {"array": "ARRAY"},
"returns": "BOOLEAN",
"source": """\
let count = 0;
for (val of array) {
if (val === false) {
return false;
} else if (val === true) {
++count;
}
}
return count !== 0 ? true : null;""",
},
}


Expand Down
26 changes: 26 additions & 0 deletions ibis/backends/sql/compilers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,3 +774,29 @@ def visit_TimestampBucket(self, op, *, arg, interval, offset):
origin = self.f.anon[f"{funcname}_add"](origin, offset)

return func(arg, interval, origin)

def _array_reduction(self, *, arg, reduction):
name = sg.to_identifier(util.gen_name(f"bq_arr_{reduction}"))
return (
sg.select(self.f[reduction](name))
.from_(self._unnest(arg, as_=name))
.subquery()
)

def visit_ArrayMin(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="min")

def visit_ArrayMax(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="max")

def visit_ArraySum(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="sum")

def visit_ArrayMean(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="avg")

def visit_ArrayAny(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="logical_or")

def visit_ArrayAll(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="logical_and")
21 changes: 21 additions & 0 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,3 +737,24 @@ def _cleanup_names(
value.as_(self._gen_valid_name(name), quoted=quoted, copy=False)
for name, value in exprs.items()
)

def _array_reduction(self, arg):
x = sg.to_identifier("x", quoted=self.quoted)
not_null = sge.Lambda(this=x.is_(sg.not_(NULL)), expressions=[x])
return self.f.arrayFilter(not_null, arg)

def visit_ArrayMin(self, op, *, arg):
return self.f.arrayReduce("min", self._array_reduction(arg))

visit_ArrayAll = visit_ArrayMin

def visit_ArrayMax(self, op, *, arg):
return self.f.arrayReduce("max", self._array_reduction(arg))

visit_ArrayAny = visit_ArrayMax

def visit_ArraySum(self, op, *, arg):
return self.f.arrayReduce("sum", self._array_reduction(arg))

def visit_ArrayMean(self, op, *, arg):
return self.f.arrayReduce("avg", self._array_reduction(arg))
6 changes: 6 additions & 0 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ class DuckDBCompiler(SQLGlotCompiler):
SIMPLE_OPS = {
ops.Arbitrary: "any_value",
ops.ArrayPosition: "list_indexof",
ops.ArrayMin: "list_min",
ops.ArrayMax: "list_max",
ops.ArrayAny: "list_bool_or",
ops.ArrayAll: "list_bool_and",
ops.ArraySum: "list_sum",
ops.ArrayMean: "list_avg",
ops.BitAnd: "bit_and",
ops.BitOr: "bit_or",
ops.BitXor: "bit_xor",
Expand Down
30 changes: 30 additions & 0 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,3 +670,33 @@ def visit_TableUnnest(
join_type="CROSS" if not keep_empty else "LEFT",
)
)

def _unnest(self, expression, *, as_):
alias = sge.TableAlias(columns=[sg.to_identifier(as_)])
return sge.Unnest(expressions=[expression], alias=alias)

def _array_reduction(self, *, arg, reduction):
name = sg.to_identifier(gen_name(f"pg_arr_{reduction}"))
return (
sg.select(self.f[reduction](name))
.from_(self._unnest(arg, as_=name))
.subquery()
)

def visit_ArrayMin(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="min")

def visit_ArrayMax(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="max")

def visit_ArraySum(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="sum")

def visit_ArrayMean(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="avg")

def visit_ArrayAny(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="bool_or")

def visit_ArrayAll(self, op, *, arg):
return self._array_reduction(arg=arg, reduction="bool_and")
44 changes: 44 additions & 0 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import calendar
import itertools
import operator
import re

import sqlglot as sg
Expand Down Expand Up @@ -68,6 +69,10 @@ class PySparkCompiler(SQLGlotCompiler):
ops.ArrayRemove: "array_remove",
ops.ArraySort: "array_sort",
ops.ArrayUnion: "array_union",
ops.ArrayMin: "array_min",
ops.ArrayMax: "array_max",
ops.ArrayAll: "array_min",
ops.ArrayAny: "array_max",
ops.EndsWith: "endswith",
ops.Hash: "hash",
ops.Log10: "log10",
Expand Down Expand Up @@ -589,3 +594,42 @@ def _format_window_interval(self, expression):
this = expression.this.this # avoid quoting the interval as a string literal

return f"{this}{unit}"

def _array_reduction(self, *, dtype, arg, output):
quoted = self.quoted
dot = lambda a, f: sge.Dot.build((a, sge.to_identifier(f, quoted=quoted)))
state_dtype = dt.Struct({"sum": dtype, "count": dt.int64})
initial_state = self.cast(
sge.Struct.from_arg_list([sge.convert(0), sge.convert(0)]), state_dtype
)

s = sg.to_identifier("s", quoted=quoted)
x = sg.to_identifier("x", quoted=quoted)

s_sum = dot(s, "sum")
s_count = dot(s, "count")

input_fn_body = self.cast(
sge.Struct.from_arg_list(
[
x + self.f.coalesce(s_sum, 0),
s_count + self.if_(x.is_(sg.not_(NULL)), 1, 0),
]
),
state_dtype,
)
input_fn = sge.Lambda(this=input_fn_body, expressions=[s, x])

output_fn_body = self.if_(s_count > 0, output(s_sum, s_count), NULL)
return self.f.aggregate(
arg,
initial_state,
input_fn,
sge.Lambda(this=output_fn_body, expressions=[s]),
)

def visit_ArraySum(self, op, *, arg):
return self._array_reduction(dtype=op.dtype, arg=arg, output=lambda sum, _: sum)

def visit_ArrayMean(self, op, *, arg):
return self._array_reduction(dtype=op.dtype, arg=arg, output=operator.truediv)
18 changes: 18 additions & 0 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,3 +727,21 @@ def visit_TableUnnest(
.from_(parent)
.join(unnest, join_type="CROSS" if not keep_empty else "LEFT")
)

def visit_ArrayMin(self, op, *, arg):
return self.cast(self.f.array_min(self.f.array_compact(arg)), op.dtype)

def visit_ArrayMax(self, op, *, arg):
return self.cast(self.f.array_max(self.f.array_compact(arg)), op.dtype)

def visit_ArrayAny(self, op, *, arg):
return self.f.udf.array_any(arg)

def visit_ArrayAll(self, op, *, arg):
return self.f.udf.array_all(arg)

def visit_ArraySum(self, op, *, arg):
return self.cast(self.f.udf.array_sum(arg), op.dtype)

def visit_ArrayMean(self, op, *, arg):
return self.cast(self.f.udf.array_avg(arg), op.dtype)
Loading

0 comments on commit 793efbc

Please sign in to comment.