From edc42be059db53e6a3af74e1eb10bf4ce7d25582 Mon Sep 17 00:00:00 2001 From: Daniel Mesejo Date: Tue, 31 Oct 2023 08:54:03 +0100 Subject: [PATCH] feat(datafusion): add corr and covar --- ibis/backends/datafusion/compiler/values.py | 31 +++++++++++++++++++++ ibis/backends/tests/test_aggregation.py | 12 ++++---- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/ibis/backends/datafusion/compiler/values.py b/ibis/backends/datafusion/compiler/values.py index 32a76c4abcb1..0447e36da547 100644 --- a/ibis/backends/datafusion/compiler/values.py +++ b/ibis/backends/datafusion/compiler/values.py @@ -733,3 +733,34 @@ def array_concat(op, *, arg, **_): @translate_val.register(ops.ArrayPosition) def array_position(op, *, arg, other, **_): return F.coalesce(F.array_position(arg, other), 0) + + +@translate_val.register(ops.Covariance) +def covariance(op, *, left, right, how, where, **_): + x = op.left + if x.dtype.is_boolean(): + left = cast(left, dt.float64) + + y = op.right + if y.dtype.is_boolean(): + right = cast(right, dt.float64) + + if how == "sample": + return agg["covar_samp"](left, right, where=where) + elif how == "pop": + return agg["covar_pop"](left, right, where=where) + else: + raise ValueError(f"Unrecognized how = `{how}` value") + + +@translate_val.register(ops.Correlation) +def correlation(op, *, left, right, where, **_): + x = op.left + if x.dtype.is_boolean(): + left = cast(left, dt.float64) + + y = op.right + if y.dtype.is_boolean(): + right = cast(right, dt.float64) + + return agg["corr"](left, right, where=where) diff --git a/ibis/backends/tests/test_aggregation.py b/ibis/backends/tests/test_aggregation.py index 308ce55795dc..03fdabdd5dfd 100644 --- a/ibis/backends/tests/test_aggregation.py +++ b/ibis/backends/tests/test_aggregation.py @@ -902,7 +902,7 @@ def test_quantile( id="covar_pop", marks=[ pytest.mark.notimpl( - ["dask", "datafusion", "pandas", "polars", "druid"], + ["dask", "pandas", "polars", "druid"], raises=com.OperationNotDefinedError, ), pytest.mark.notyet( @@ -916,7 +916,7 @@ def test_quantile( id="covar_samp", marks=[ pytest.mark.notimpl( - ["dask", "datafusion", "pandas", "polars", "druid"], + ["dask", "pandas", "polars", "druid"], raises=com.OperationNotDefinedError, ), pytest.mark.notyet( @@ -931,7 +931,7 @@ def test_quantile( id="corr_pop", marks=[ pytest.mark.notimpl( - ["dask", "datafusion", "pandas", "druid"], + ["dask", "pandas", "druid"], raises=com.OperationNotDefinedError, ), pytest.mark.notyet( @@ -956,7 +956,7 @@ def test_quantile( id="corr_samp", marks=[ pytest.mark.notimpl( - ["dask", "datafusion", "pandas", "druid"], + ["dask", "pandas", "druid"], raises=com.OperationNotDefinedError, ), pytest.mark.notyet( @@ -985,7 +985,7 @@ def test_quantile( id="covar_pop_bool", marks=[ pytest.mark.notimpl( - ["dask", "datafusion", "pandas", "polars", "druid"], + ["dask", "pandas", "polars", "druid"], raises=com.OperationNotDefinedError, ), pytest.mark.notyet( @@ -1003,7 +1003,7 @@ def test_quantile( id="corr_pop_bool", marks=[ pytest.mark.notimpl( - ["dask", "datafusion", "pandas", "druid"], + ["dask", "pandas", "druid"], raises=com.OperationNotDefinedError, ), pytest.mark.notyet(