Skip to content

Commit

Permalink
feat(bigquery): implement argmin and argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Dec 25, 2022
1 parent 4df9f8b commit 40c5f0d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
15 changes: 15 additions & 0 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import base64
import datetime
from typing import Literal

import numpy as np
from multipledispatch import Dispatcher
Expand Down Expand Up @@ -464,6 +465,18 @@ def _array_agg(t, op):
return f"ARRAY_AGG({t.translate(arg)} IGNORE NULLS)"


def _arg_min_max(sort_dir: Literal["ASC", "DESC"]):
def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
arg = op.arg
if (where := op.where) is not None:
arg = ops.Where(where, arg, None)
arg = t.translate(arg)
key = t.translate(op.key)
return f"ARRAY_AGG({arg} IGNORE NULLS ORDER BY {key} {sort_dir} LIMIT 1)[SAFE_OFFSET(0)]"

return translate


OPERATION_REGISTRY = {
**operation_registry,
# Literal
Expand Down Expand Up @@ -587,6 +600,8 @@ def _array_agg(t, op):
ops.FloorDivide: _floor_divide,
ops.IsNan: _is_nan,
ops.IsInf: _is_inf,
ops.ArgMin: _arg_min_max("ASC"),
ops.ArgMax: _arg_min_max("DESC"),
}

_invalid_operations = {
Expand Down
3 changes: 0 additions & 3 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def mean_udf(s):
]

argidx_not_grouped_marks = [
"bigquery",
"datafusion",
"impala",
"mysql",
Expand Down Expand Up @@ -305,7 +304,6 @@ def mean_and_std(v):
id='argmin',
marks=pytest.mark.notyet(
[
"bigquery",
"impala",
"mysql",
"postgres",
Expand All @@ -324,7 +322,6 @@ def mean_and_std(v):
id='argmax',
marks=pytest.mark.notyet(
[
"bigquery",
"impala",
"mysql",
"postgres",
Expand Down

0 comments on commit 40c5f0d

Please sign in to comment.