Skip to content

Commit

Permalink
feat(mssql): implement ops.StandardDev, ops.Variance
Browse files Browse the repository at this point in the history
  • Loading branch information
krzysztof-kwitt authored and cpcloud committed Mar 16, 2023
1 parent a8a92dd commit e322f1d
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 24 deletions.
6 changes: 3 additions & 3 deletions ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ class substr(GenericFunction):
inherit_cache = True


def variance_reduction(func_name):
suffix = {'sample': 'samp', 'pop': 'pop'}
def variance_reduction(func_name, suffix=None):
suffix = suffix or {'sample': '_samp', 'pop': '_pop'}

def variance_compiler(t, op):
arg = op.arg

if arg.output_dtype.is_boolean():
arg = ops.Cast(op.arg, to=dt.int32)

func = getattr(sa.func, f'{func_name}_{suffix[op.how]}')
func = getattr(sa.func, f'{func_name}{suffix[op.how]}')

if op.where is not None:
arg = ops.Where(op.where, arg, None)
Expand Down
4 changes: 3 additions & 1 deletion ibis/backends/mssql/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
sqlalchemy_window_functions_registry,
unary,
)
from ibis.backends.base.sql.alchemy.registry import substr
from ibis.backends.base.sql.alchemy.registry import substr, variance_reduction


def _reduction(func, cast_type='int32'):
Expand Down Expand Up @@ -158,6 +158,8 @@ def _timestamp_truncate(t, op):
ops.Log: fixed_arity(lambda x, p: sa.func.log(x, p), 2),
ops.Log2: fixed_arity(lambda x: sa.func.log(x, 2), 1),
ops.Log10: fixed_arity(lambda x: sa.func.log(x, 10), 1),
ops.StandardDev: variance_reduction('stdev', {'sample': '', 'pop': 'p'}),
ops.Variance: variance_reduction('var', {'sample': '', 'pop': 'p'}),
# timestamp methods
ops.TimestampNow: fixed_arity(sa.func.GETDATE, 0),
ops.ExtractYear: _extract('year'),
Expand Down
20 changes: 0 additions & 20 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,11 +483,6 @@ def mean_and_std(v):
["datafusion"],
raises=com.OperationNotDefinedError,
),
mark.notimpl(
["mssql"],
raises=sa.exc.OperationalError,
reason="'stddev_samp' is not a recognized built-in function name.",
),
mark.notimpl(
["druid"],
raises=sa.exc.ProgrammingError,
Expand All @@ -504,11 +499,6 @@ def mean_and_std(v):
["datafusion"],
raises=com.OperationNotDefinedError,
),
mark.broken(
["mssql"],
raises=sa.exc.OperationalError,
reason="'var_samp' is not a recognized built-in function name.",
),
mark.notimpl(
["druid"],
raises=sa.exc.ProgrammingError,
Expand All @@ -530,11 +520,6 @@ def mean_and_std(v):
raises=sa.exc.ProgrammingError,
reason="No match found for function signature stddev_pop(<NUMERIC>)",
),
mark.broken(
["mssql"],
raises=sa.exc.OperationalError,
reason="'stddev_pop' is not a recognized built-in function name.",
),
],
),
param(
Expand All @@ -551,11 +536,6 @@ def mean_and_std(v):
raises=sa.exc.ProgrammingError,
reason="No match found for function signature var_pop(<NUMERIC>)",
),
pytest.mark.notimpl(
["mssql"],
raises=sa.exc.OperationalError,
reason="'var_pop' is not a recognized built-in function name.",
),
],
),
param(
Expand Down

0 comments on commit e322f1d

Please sign in to comment.