Skip to content

Commit

Permalink
feat(sqlite): implement ops.Arbitrary
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Mar 22, 2023
1 parent c816f00 commit 9bcdf77
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 3 deletions.
10 changes: 10 additions & 0 deletions ibis/backends/sqlite/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ def _literal(t, op):
return base_literal(t, op)


def _arbitrary(t, op):
if (how := op.how) == "heavy":
raise com.OperationNotDefinedError(
"how='heavy' not implemented for the SQLite backend"
)

return reduction(getattr(sa.func, f"_ibis_sqlite_arbitrary_{how}"))(t, op)


operation_registry.update(
{
# TODO(kszucs): don't dispatch on op.arg since that should be always an
Expand Down Expand Up @@ -322,5 +331,6 @@ def _literal(t, op):
ops.RandomScalar: fixed_arity(
lambda: 0.5 + sa.func.random() / sa.cast(-1 << 64, sa.REAL), 0
),
ops.Arbitrary: _arbitrary,
}
)
27 changes: 27 additions & 0 deletions ibis/backends/sqlite/udf.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import abc
import functools
import inspect
import math
Expand Down Expand Up @@ -350,6 +351,32 @@ def __init__(self):
super().__init__(operator.xor)


class _ibis_sqlite_arbitrary(abc.ABC):
def __init__(self) -> None:
self.value = None

@abc.abstractmethod
def step(self, value):
...

def finalize(self) -> int | None:
return self.value


@udaf
class _ibis_sqlite_arbitrary_first(_ibis_sqlite_arbitrary):
def step(self, value):
if self.value is None:
self.value = value


@udaf
class _ibis_sqlite_arbitrary_last(_ibis_sqlite_arbitrary):
def step(self, value):
if value is not None:
self.value = value


def _number_of_arguments(callable):
signature = inspect.signature(callable)
parameters = signature.parameters.values()
Expand Down
3 changes: 0 additions & 3 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,6 @@ def mean_and_std(v):
[
'impala',
'mysql',
'sqlite',
'polars',
'datafusion',
"mssql",
Expand All @@ -571,7 +570,6 @@ def mean_and_std(v):
[
'impala',
'mysql',
'sqlite',
'polars',
'datafusion',
"mssql",
Expand All @@ -589,7 +587,6 @@ def mean_and_std(v):
[
'impala',
'mysql',
'sqlite',
'polars',
'datafusion',
"mssql",
Expand Down

0 comments on commit 9bcdf77

Please sign in to comment.