From ee8dbabfa210257102010d76f7ae7f0cd4d01e05 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 20 Mar 2023 07:47:04 -0400 Subject: [PATCH] feat(postgres): implement `ops.Arbitrary` --- ibis/backends/postgres/__init__.py | 32 +++++++++++++++++++++++++ ibis/backends/postgres/registry.py | 10 ++++++++ ibis/backends/tests/test_aggregation.py | 9 +++---- 3 files changed, 47 insertions(+), 4 deletions(-) diff --git a/ibis/backends/postgres/__init__.py b/ibis/backends/postgres/__init__.py index 54d70e4d3846..06eacea39533 100644 --- a/ibis/backends/postgres/__init__.py +++ b/ibis/backends/postgres/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +import contextlib from typing import TYPE_CHECKING, Iterable, Literal import sqlalchemy as sa @@ -115,6 +116,37 @@ def do_connect( alchemy_url, connect_args=connect_args, poolclass=sa.pool.StaticPool ) + # define first/last aggs for ops.Arbitrary + # + # ignore exceptions so the rest of ibis still works: a user may not + # have permissions to define funtions and/or aggregates + with engine.begin() as con, contextlib.suppress(Exception): + # adapted from https://wiki.postgresql.org/wiki/First/last_%28aggregate%29 + con.exec_driver_sql( + """\ +CREATE OR REPLACE FUNCTION public._ibis_first_agg (anyelement, anyelement) +RETURNS anyelement +LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE AS +'SELECT $1'; + +CREATE OR REPLACE AGGREGATE public._ibis_first (anyelement) ( + SFUNC = public._ibis_first_agg, + STYPE = anyelement, + PARALLEL = safe +); + +CREATE OR REPLACE FUNCTION public._ibis_last_agg (anyelement, anyelement) +RETURNS anyelement +LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE AS +'SELECT $2'; + +CREATE OR REPLACE AGGREGATE public._ibis_last (anyelement) ( + SFUNC = public._ibis_last_agg, + STYPE = anyelement, + PARALLEL = safe +);""" + ) + @sa.event.listens_for(engine, "connect") def connect(dbapi_connection, connection_record): with dbapi_connection.cursor() as cur: diff --git a/ibis/backends/postgres/registry.py b/ibis/backends/postgres/registry.py index ea2a42bd78c5..d65c341890d1 100644 --- a/ibis/backends/postgres/registry.py +++ b/ibis/backends/postgres/registry.py @@ -479,6 +479,15 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str: return translate +def _arbitrary(t, op): + if (how := op.how) == "heavy": + raise com.UnsupportedOperationError( + f"postgres backend doesn't support how={how!r} for the arbitrary() aggregate" + ) + func = getattr(sa.func.public, f"_ibis_{op.how}") + return t._reduction(func, op) + + operation_registry.update( { ops.Literal: _literal, @@ -629,5 +638,6 @@ def translate(t, op: ops.ArgMin | ops.ArgMax) -> str: ops.LStrip: unary(lambda arg: sa.func.ltrim(arg, string.whitespace)), ops.RStrip: unary(lambda arg: sa.func.rtrim(arg, string.whitespace)), ops.StartsWith: fixed_arity(lambda arg, prefix: arg.op("^@")(prefix), 2), + ops.Arbitrary: _arbitrary, } ) diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 10bfcd1bc972..0bde3f741293 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -553,7 +553,6 @@ def mean_and_std(v): marks=pytest.mark.notimpl( [ 'impala', - 'postgres', 'mysql', 'sqlite', 'polars', @@ -571,7 +570,6 @@ def mean_and_std(v): marks=pytest.mark.notimpl( [ 'impala', - 'postgres', 'mysql', 'sqlite', 'polars', @@ -590,7 +588,6 @@ def mean_and_std(v): pytest.mark.notimpl( [ 'impala', - 'postgres', 'mysql', 'sqlite', 'polars', @@ -629,7 +626,6 @@ def mean_and_std(v): "impala", "mysql", "pandas", - "postgres", "sqlite", "polars", "mssql", @@ -637,6 +633,11 @@ def mean_and_std(v): ], raises=com.OperationNotDefinedError, ), + pytest.mark.notimpl( + ["postgres"], + raises=com.UnsupportedOperationError, + reason="how='heavy' not supported in the postgres backend", + ), pytest.mark.notimpl( ["duckdb"], raises=com.UnsupportedOperationError,