Skip to content

Commit

Permalink
feat(postgres): implement ops.Arbitrary
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Mar 22, 2023
1 parent 9a19302 commit ee8dbab
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
32 changes: 32 additions & 0 deletions ibis/backends/postgres/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import contextlib
from typing import TYPE_CHECKING, Iterable, Literal

import sqlalchemy as sa
Expand Down Expand Up @@ -115,6 +116,37 @@ def do_connect(
alchemy_url, connect_args=connect_args, poolclass=sa.pool.StaticPool
)

# define first/last aggs for ops.Arbitrary
#
# ignore exceptions so the rest of ibis still works: a user may not
# have permissions to define funtions and/or aggregates
with engine.begin() as con, contextlib.suppress(Exception):
# adapted from https://wiki.postgresql.org/wiki/First/last_%28aggregate%29
con.exec_driver_sql(
"""\
CREATE OR REPLACE FUNCTION public._ibis_first_agg (anyelement, anyelement)
RETURNS anyelement
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE AS
'SELECT $1';
CREATE OR REPLACE AGGREGATE public._ibis_first (anyelement) (
SFUNC = public._ibis_first_agg,
STYPE = anyelement,
PARALLEL = safe
);
CREATE OR REPLACE FUNCTION public._ibis_last_agg (anyelement, anyelement)
RETURNS anyelement
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE AS
'SELECT $2';
CREATE OR REPLACE AGGREGATE public._ibis_last (anyelement) (
SFUNC = public._ibis_last_agg,
STYPE = anyelement,
PARALLEL = safe
);"""
)

@sa.event.listens_for(engine, "connect")
def connect(dbapi_connection, connection_record):
with dbapi_connection.cursor() as cur:
Expand Down
10 changes: 10 additions & 0 deletions ibis/backends/postgres/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,15 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
return translate


def _arbitrary(t, op):
if (how := op.how) == "heavy":
raise com.UnsupportedOperationError(
f"postgres backend doesn't support how={how!r} for the arbitrary() aggregate"
)
func = getattr(sa.func.public, f"_ibis_{op.how}")
return t._reduction(func, op)


operation_registry.update(
{
ops.Literal: _literal,
Expand Down Expand Up @@ -629,5 +638,6 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str:
ops.LStrip: unary(lambda arg: sa.func.ltrim(arg, string.whitespace)),
ops.RStrip: unary(lambda arg: sa.func.rtrim(arg, string.whitespace)),
ops.StartsWith: fixed_arity(lambda arg, prefix: arg.op("^@")(prefix), 2),
ops.Arbitrary: _arbitrary,
}
)
9 changes: 5 additions & 4 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,6 @@ def mean_and_std(v):
marks=pytest.mark.notimpl(
[
'impala',
'postgres',
'mysql',
'sqlite',
'polars',
Expand All @@ -571,7 +570,6 @@ def mean_and_std(v):
marks=pytest.mark.notimpl(
[
'impala',
'postgres',
'mysql',
'sqlite',
'polars',
Expand All @@ -590,7 +588,6 @@ def mean_and_std(v):
pytest.mark.notimpl(
[
'impala',
'postgres',
'mysql',
'sqlite',
'polars',
Expand Down Expand Up @@ -629,14 +626,18 @@ def mean_and_std(v):
"impala",
"mysql",
"pandas",
"postgres",
"sqlite",
"polars",
"mssql",
"druid",
],
raises=com.OperationNotDefinedError,
),
pytest.mark.notimpl(
["postgres"],
raises=com.UnsupportedOperationError,
reason="how='heavy' not supported in the postgres backend",
),
pytest.mark.notimpl(
["duckdb"],
raises=com.UnsupportedOperationError,
Expand Down

0 comments on commit ee8dbab

Please sign in to comment.