From 9bcdf77ddafce75f0e5d8714d01dde81ed0b90f2 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Wed, 22 Mar 2023 07:01:44 -0400 Subject: [PATCH] feat(sqlite): implement `ops.Arbitrary` --- ibis/backends/sqlite/registry.py | 10 +++++++++ ibis/backends/sqlite/udf.py | 27 +++++++++++++++++++++++++ ibis/backends/tests/test_aggregation.py | 3 --- 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/ibis/backends/sqlite/registry.py b/ibis/backends/sqlite/registry.py index 41b244977f74..9ab55e22d47e 100644 --- a/ibis/backends/sqlite/registry.py +++ b/ibis/backends/sqlite/registry.py @@ -199,6 +199,15 @@ def _literal(t, op): return base_literal(t, op) +def _arbitrary(t, op): + if (how := op.how) == "heavy": + raise com.OperationNotDefinedError( + "how='heavy' not implemented for the SQLite backend" + ) + + return reduction(getattr(sa.func, f"_ibis_sqlite_arbitrary_{how}"))(t, op) + + operation_registry.update( { # TODO(kszucs): don't dispatch on op.arg since that should be always an @@ -322,5 +331,6 @@ def _literal(t, op): ops.RandomScalar: fixed_arity( lambda: 0.5 + sa.func.random() / sa.cast(-1 << 64, sa.REAL), 0 ), + ops.Arbitrary: _arbitrary, } ) diff --git a/ibis/backends/sqlite/udf.py b/ibis/backends/sqlite/udf.py index 84b1670dad39..56df54012361 100644 --- a/ibis/backends/sqlite/udf.py +++ b/ibis/backends/sqlite/udf.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc import functools import inspect import math @@ -350,6 +351,32 @@ def __init__(self): super().__init__(operator.xor) +class _ibis_sqlite_arbitrary(abc.ABC): + def __init__(self) -> None: + self.value = None + + @abc.abstractmethod + def step(self, value): + ... + + def finalize(self) -> int | None: + return self.value + + +@udaf +class _ibis_sqlite_arbitrary_first(_ibis_sqlite_arbitrary): + def step(self, value): + if self.value is None: + self.value = value + + +@udaf +class _ibis_sqlite_arbitrary_last(_ibis_sqlite_arbitrary): + def step(self, value): + if value is not None: + self.value = value + + def _number_of_arguments(callable): signature = inspect.signature(callable) parameters = signature.parameters.values() diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index e2ce2c1d125f..936099e7d915 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -554,7 +554,6 @@ def mean_and_std(v): [ 'impala', 'mysql', - 'sqlite', 'polars', 'datafusion', "mssql", @@ -571,7 +570,6 @@ def mean_and_std(v): [ 'impala', 'mysql', - 'sqlite', 'polars', 'datafusion', "mssql", @@ -589,7 +587,6 @@ def mean_and_std(v): [ 'impala', 'mysql', - 'sqlite', 'polars', 'datafusion', "mssql",