Skip to content

Commit

Permalink
feat(pyspark): implement covariance and correlation
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud authored and kszucs committed Jun 16, 2022
1 parent 335f6ba commit ae818fb
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
38 changes: 38 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,44 @@ def compile_variance(t, expr, scope, timecontext, context=None, **kwargs):
)


@compiles(ops.Covariance)
def compile_covariance(t, expr, scope, timecontext, context=None, **kwargs):
op = expr.op()
how = op.how

fn = {"sample": F.covar_samp, "pop": F.covar_pop}[how]

pyspark_double_type = ibis_dtype_to_spark_dtype(dtypes.double)
expr = op.__class__(
left=op.left.cast(pyspark_double_type),
right=op.right.cast(pyspark_double_type),
how=how,
where=op.where,
).to_expr()
return compile_aggregator(
t, expr, scope, timecontext, fn=fn, context=context
)


@compiles(ops.Correlation)
def compile_correlation(t, expr, scope, timecontext, context=None, **kwargs):
op = expr.op()

if (how := op.how) == "pop":
raise ValueError("PySpark only implements sample correlation")

pyspark_double_type = ibis_dtype_to_spark_dtype(dtypes.double)
expr = op.__class__(
left=op.left.cast(pyspark_double_type),
right=op.right.cast(pyspark_double_type),
how=how,
where=op.where,
).to_expr()
return compile_aggregator(
t, expr, scope, timecontext, fn=F.corr, context=context
)


@compiles(ops.Arbitrary)
def compile_arbitrary(t, expr, scope, timecontext, context=None, **kwargs):
how = expr.op().how
Expand Down
4 changes: 0 additions & 4 deletions ibis/backends/tests/test_aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,6 @@ def test_aggregate_grouped(
"impala",
"mysql",
"pandas",
"postgres",
"pyspark",
"sqlite",
]
)
Expand All @@ -220,8 +218,6 @@ def test_aggregate_grouped(
"impala",
"mysql",
"pandas",
"postgres",
"pyspark",
"sqlite",
]
)
Expand Down

0 comments on commit ae818fb

Please sign in to comment.