Skip to content

Commit

Permalink
fix(dask): don't call compute when executing cov/corr
Browse files Browse the repository at this point in the history
BREAKING CHANGE: the dask backend no longer supports `cov`/`corr` with `how="pop"`.
  • Loading branch information
jcrist committed Apr 18, 2024
1 parent 08a33e9 commit a876c47
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 23 deletions.
40 changes: 17 additions & 23 deletions ibis/backends/dask/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
PandasWindowFunction,
plan,
)
from ibis.common.exceptions import UnboundExpressionError
from ibis.common.exceptions import UnboundExpressionError, UnsupportedOperationError
from ibis.formats.pandas import PandasData, PandasType
from ibis.util import gen_name

Expand Down Expand Up @@ -213,37 +213,31 @@ def agg(df):

@classmethod
def visit(cls, op: ops.Correlation, left, right, where, how):
if where is None:
if how == "pop":
raise UnsupportedOperationError(
"Dask doesn't support `corr` with `how='pop'`"
)

def agg(df):
return df[left.name].corr(df[right.name])
else:
def agg(df):
if where is not None:
df = df.where(df[where.name])

def agg(df):
mask = df[where.name]
lhs = df[left.name][mask].compute()
rhs = df[right.name][mask].compute()
return lhs.corr(rhs)
return df[left.name].corr(df[right.name])

return agg

@classmethod
def visit(cls, op: ops.Covariance, left, right, where, how):
# TODO(kszucs): raise a warning about triggering compute()?
ddof = {"pop": 0, "sample": 1}[how]
if where is None:
if how == "pop":
raise UnsupportedOperationError(
"Dask doesn't support `cov` with `how='pop'`"
)

def agg(df):
lhs = df[left.name].compute()
rhs = df[right.name].compute()
return lhs.cov(rhs, ddof=ddof)
else:
def agg(df):
if where is not None:
df = df.where(df[where.name])

def agg(df):
mask = df[where.name]
lhs = df[left.name][mask].compute()
rhs = df[right.name][mask].compute()
return lhs.cov(rhs, ddof=ddof)
return df[left.name].cov(df[right.name])

return agg

Expand Down
20 changes: 20 additions & 0 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,11 @@ def test_quantile(
lambda t, where: t.G[where].cov(t.RBI[where], ddof=0),
id="covar_pop",
marks=[
pytest.mark.notyet(
["dask"],
reason="dask doesn't support `cov(ddof=0)` yet",
raises=com.UnsupportedOperationError,
),
pytest.mark.notimpl(
["polars", "druid"],
raises=com.OperationNotDefinedError,
Expand Down Expand Up @@ -914,6 +919,11 @@ def test_quantile(
lambda t, where: t.G[where].corr(t.RBI[where]),
id="corr_pop",
marks=[
pytest.mark.notyet(
["dask"],
raises=com.UnsupportedOperationError,
reason="dask doesn't support `corr(ddof=0)` yet",
),
pytest.mark.notimpl(
["druid"],
raises=com.OperationNotDefinedError,
Expand Down Expand Up @@ -978,6 +988,11 @@ def test_quantile(
lambda t, where: (t.G[where] > 34.0).cov(t.G[where] <= 34.0, ddof=0),
id="covar_pop_bool",
marks=[
pytest.mark.notyet(
["dask"],
raises=com.UnsupportedOperationError,
reason="dask doesn't support `cov(ddof=0)` yet",
),
pytest.mark.notimpl(
["polars", "druid"],
raises=com.OperationNotDefinedError,
Expand All @@ -1002,6 +1017,11 @@ def test_quantile(
lambda t, where: (t.G[where] > 34.0).corr(t.G[where] <= 34.0),
id="corr_pop_bool",
marks=[
pytest.mark.notyet(
["dask"],
raises=com.UnsupportedOperationError,
reason="dask doesn't support `corr(ddof=0)` yet",
),
pytest.mark.notimpl(
["druid"],
raises=com.OperationNotDefinedError,
Expand Down

0 comments on commit a876c47

Please sign in to comment.