Skip to content

Commit

Permalink
feat(postgres): implement argmin/argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Feb 2, 2023
1 parent 8b998a5 commit 82668ec
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
18 changes: 18 additions & 0 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,22 @@ def variance_compiler(t, op):
return variance_compiler


def _arg_min_max(sort_func):
def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
arg = t.translate(op.arg)
key = t.translate(op.key)

conditions = [arg != sa.null(), key != sa.null()]

agg = sa.func.array_agg(pg.aggregate_order_by(arg, sort_func(key)))

if (where := op.where) is not None:
conditions.append(t.translate(where))
return agg.filter(sa.and_(*conditions))[1]

return translate


operation_registry.update(
{
ops.Literal: _literal,
Expand Down Expand Up @@ -587,5 +603,7 @@ def variance_compiler(t, op):
ops.MapMerge: fixed_arity(operator.add, 2),
ops.MapLength: unary(lambda arg: sa.func.cardinality(arg.keys())),
ops.Map: fixed_arity(pg.hstore, 2),
ops.ArgMin: _arg_min_max(sa.asc),
ops.ArgMax: _arg_min_max(sa.desc),
}
)
21 changes: 2 additions & 19 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def mean_udf(s):
"datafusion",
"impala",
"mysql",
"postgres",
"sqlite",
"polars",
"mssql",
Expand Down Expand Up @@ -311,31 +310,15 @@ def mean_and_std(v):
lambda t, where: t.double_col[where].iloc[t.int_col[where].argmin()],
id='argmin',
marks=pytest.mark.notyet(
[
"impala",
"mysql",
"postgres",
"sqlite",
"polars",
"datafusion",
"mssql",
]
["impala", "mysql", "sqlite", "polars", "datafusion", "mssql"]
),
),
param(
lambda t, where: t.double_col.argmax(t.int_col, where=where),
lambda t, where: t.double_col[where].iloc[t.int_col[where].argmax()],
id='argmax',
marks=pytest.mark.notyet(
[
"impala",
"mysql",
"postgres",
"sqlite",
"polars",
"datafusion",
"mssql",
]
["impala", "mysql", "sqlite", "polars", "datafusion", "mssql"]
),
),
param(
Expand Down

0 comments on commit 82668ec

Please sign in to comment.