Skip to content

Commit

Permalink
feat(datafusion): add corr and covar
Browse files Browse the repository at this point in the history
  • Loading branch information
mesejo authored and cpcloud committed Oct 31, 2023
1 parent 31f3497 commit edc42be
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 6 deletions.
31 changes: 31 additions & 0 deletions ibis/backends/datafusion/compiler/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,3 +733,34 @@ def array_concat(op, *, arg, **_):
@translate_val.register(ops.ArrayPosition)
def array_position(op, *, arg, other, **_):
return F.coalesce(F.array_position(arg, other), 0)


@translate_val.register(ops.Covariance)
def covariance(op, *, left, right, how, where, **_):
x = op.left
if x.dtype.is_boolean():
left = cast(left, dt.float64)

y = op.right
if y.dtype.is_boolean():
right = cast(right, dt.float64)

if how == "sample":
return agg["covar_samp"](left, right, where=where)
elif how == "pop":
return agg["covar_pop"](left, right, where=where)
else:
raise ValueError(f"Unrecognized how = `{how}` value")


@translate_val.register(ops.Correlation)
def correlation(op, *, left, right, where, **_):
x = op.left
if x.dtype.is_boolean():
left = cast(left, dt.float64)

y = op.right
if y.dtype.is_boolean():
right = cast(right, dt.float64)

return agg["corr"](left, right, where=where)
12 changes: 6 additions & 6 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ def test_quantile(
id="covar_pop",
marks=[
pytest.mark.notimpl(
["dask", "datafusion", "pandas", "polars", "druid"],
["dask", "pandas", "polars", "druid"],
raises=com.OperationNotDefinedError,
),
pytest.mark.notyet(
Expand All @@ -916,7 +916,7 @@ def test_quantile(
id="covar_samp",
marks=[
pytest.mark.notimpl(
["dask", "datafusion", "pandas", "polars", "druid"],
["dask", "pandas", "polars", "druid"],
raises=com.OperationNotDefinedError,
),
pytest.mark.notyet(
Expand All @@ -931,7 +931,7 @@ def test_quantile(
id="corr_pop",
marks=[
pytest.mark.notimpl(
["dask", "datafusion", "pandas", "druid"],
["dask", "pandas", "druid"],
raises=com.OperationNotDefinedError,
),
pytest.mark.notyet(
Expand All @@ -956,7 +956,7 @@ def test_quantile(
id="corr_samp",
marks=[
pytest.mark.notimpl(
["dask", "datafusion", "pandas", "druid"],
["dask", "pandas", "druid"],
raises=com.OperationNotDefinedError,
),
pytest.mark.notyet(
Expand Down Expand Up @@ -985,7 +985,7 @@ def test_quantile(
id="covar_pop_bool",
marks=[
pytest.mark.notimpl(
["dask", "datafusion", "pandas", "polars", "druid"],
["dask", "pandas", "polars", "druid"],
raises=com.OperationNotDefinedError,
),
pytest.mark.notyet(
Expand All @@ -1003,7 +1003,7 @@ def test_quantile(
id="corr_pop_bool",
marks=[
pytest.mark.notimpl(
["dask", "datafusion", "pandas", "druid"],
["dask", "pandas", "druid"],
raises=com.OperationNotDefinedError,
),
pytest.mark.notyet(
Expand Down

0 comments on commit edc42be

Please sign in to comment.