Skip to content

Commit

Permalink
feat(bigquery): add support for correlation
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Dec 25, 2022
1 parent 94152a3 commit 4df9f8b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 30 deletions.
53 changes: 29 additions & 24 deletions ibis/backends/bigquery/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,32 +380,36 @@ def compiles_approx(translator, op):
return f"APPROX_QUANTILES({translator.translate(arg)}, 2)[OFFSET(1)]"


def compiles_covar(translator, op):
left = op.left
right = op.right
where = op.where
how = op.how
def compiles_covar_corr(func):
def translate(translator, op):
left = op.left
right = op.right

if op.how == "sample":
how = "SAMP"
elif op.how == "pop":
how = "POP"
else:
raise ValueError(f"Covariance with how={how!r} is not supported.")
if (where := op.where) is not None:
left = ops.Where(where, left, None)
right = ops.Where(where, right, None)

if where is not None:
left = ops.Where(where, left, ibis.NA)
right = ops.Where(where, right, ibis.NA)
left = translator.translate(
ops.Cast(left, dt.int64) if left.output_dtype.is_boolean() else left
)
right = translator.translate(
ops.Cast(right, dt.int64) if right.output_dtype.is_boolean() else right
)
return f"{func}({left}, {right})"

left = translator.translate(
ops.Cast(left, dt.int64) if isinstance(left.output_dtype, dt.Boolean) else left
)
right = translator.translate(
ops.Cast(right, dt.int64)
if isinstance(right.output_dtype, dt.Boolean)
else right
)
return f"COVAR_{how}({left}, {right})"
return translate


def _covar(translator, op):
how = op.how[:4].upper()
assert how in ("POP", "SAMP"), 'how not in ("POP", "SAMP")'
return compiles_covar_corr(f"COVAR_{how}")(translator, op)


def _corr(translator, op):
if (how := op.how) == "sample":
raise ValueError(f"Correlation with how={how!r} is not supported.")
return compiles_covar_corr("CORR")(translator, op)


def bigquery_compile_any(translator, op):
Expand Down Expand Up @@ -474,7 +478,8 @@ def _array_agg(t, op):
ops.NotAll: bigquery_compile_notall,
# Math
ops.CMSMedian: compiles_approx,
ops.Covariance: compiles_covar,
ops.Covariance: _covar,
ops.Correlation: _corr,
ops.Divide: bigquery_compiles_divide,
ops.Floor: compiles_floor,
ops.Modulus: fixed_arity("MOD", 2),
Expand Down
8 changes: 2 additions & 6 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,9 +559,7 @@ def test_reduction_ops(
lambda t, where: t.G[where].corr(t.RBI[where]),
id='corr_pop',
marks=[
pytest.mark.notimpl(
["bigquery", "dask", "datafusion", "pandas", "polars"]
),
pytest.mark.notimpl(["dask", "datafusion", "pandas", "polars"]),
pytest.mark.notyet(
["clickhouse", "impala", "mysql", "pyspark", "sqlite"]
),
Expand Down Expand Up @@ -610,9 +608,7 @@ def test_reduction_ops(
lambda t, where: (t.G[where] > 34.0).corr(t.G[where] <= 34.0),
id='corr_pop_bool',
marks=[
pytest.mark.notimpl(
["bigquery", "dask", "datafusion", "pandas", "polars"]
),
pytest.mark.notimpl(["dask", "datafusion", "pandas", "polars"]),
pytest.mark.notyet(
["clickhouse", "impala", "mysql", "pyspark", "sqlite"]
),
Expand Down

0 comments on commit 4df9f8b

Please sign in to comment.