Skip to content

Commit

Permalink
feat(pyspark): add option to treat nan as null in aggregations
Browse files Browse the repository at this point in the history
  • Loading branch information
cpcloud committed Aug 13, 2022
1 parent 012ad76 commit bf47250
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 4 deletions.
8 changes: 8 additions & 0 deletions ibis/backends/pyspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import pandas as pd
import pyspark
from pydantic import Field
from pyspark.sql import DataFrame
from pyspark.sql.column import Column

Expand All @@ -13,6 +14,7 @@
import ibis.expr.operations as ops

import ibis.common.exceptions as com
import ibis.config
import ibis.expr.schema as sch
import ibis.expr.types as types
import ibis.util as util
Expand Down Expand Up @@ -88,6 +90,12 @@ class Backend(BaseSQLBackend):
table_class = PySparkDatabaseTable
table_expr_class = PySparkTable

class Options(ibis.config.BaseModel):
treat_nan_as_null: bool = Field(
default=False,
description="Treat NaNs in floating point expressions as NULL.",
)

def do_connect(self, session: pyspark.sql.SparkSession) -> None:
"""Create a PySpark `Backend` for use with Ibis.
Expand Down
17 changes: 17 additions & 0 deletions ibis/backends/pyspark/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
combine_time_context,
filter_by_time_context,
)
from ibis.config import options
from ibis.expr.timecontext import adjust_context
from ibis.util import frozendict, guid

Expand Down Expand Up @@ -215,7 +216,22 @@ def compile_sort_key(t, expr, scope, timecontext, **kwargs):
return col.desc()


def compile_nan_as_null(compile_func):
@functools.wraps(compile_func)
def wrapper(t, expr, *args, **kwargs):
compiled = compile_func(t, expr, *args, **kwargs)
if options.pyspark.treat_nan_as_null and isinstance(
expr.type(), dtypes.Floating
):
return F.nanvl(compiled, F.lit(None))
else:
return compiled

return wrapper


@compiles(ops.TableColumn)
@compile_nan_as_null
def compile_column(t, expr, scope, timecontext, **kwargs):
op = expr.op()
table = t.translate(op.table, scope, timecontext)
Expand Down Expand Up @@ -357,6 +373,7 @@ def compile_subtract(t, expr, scope, timecontext, **kwargs):


@compiles(ops.Literal)
@compile_nan_as_null
def compile_literal(t, expr, scope, timecontext, raw=False, **kwargs):
"""If raw is True, don't wrap the result with F.lit()"""
value = expr.op().value
Expand Down
3 changes: 1 addition & 2 deletions ibis/backends/pyspark/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@ def get_common_spark_testing_client(data_directory, connect):

def get_pyspark_testing_client(data_directory):
return get_common_spark_testing_client(
data_directory,
lambda session: ibis.backends.pyspark.Backend().connect(session),
data_directory, ibis.pyspark.connect
)


Expand Down
16 changes: 14 additions & 2 deletions ibis/backends/pyspark/tests/test_aggregation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
import numpy as np
import pytest
from pytest import param

import ibis

pytest.importorskip("pyspark")


@pytest.fixture
def treat_nan_as_null():
treat_nan_as_null = ibis.options.pyspark.treat_nan_as_null
ibis.options.pyspark.treat_nan_as_null = True
try:
yield
finally:
ibis.options.pyspark.treat_nan_as_null = treat_nan_as_null


@pytest.mark.parametrize(
('result_fn', 'expected_fn'),
[
Expand All @@ -24,6 +35,7 @@ def test_aggregation_float_nulls(
client,
result_fn,
expected_fn,
treat_nan_as_null,
):
table = client.table('null_table')
df = table.compile().toPandas()
Expand All @@ -32,4 +44,4 @@ def test_aggregation_float_nulls(
result = expr.execute()

expected = expected_fn(df)
np.testing.assert_allclose(result, expected)
assert pytest.approx(expected) == result
1 change: 1 addition & 0 deletions ibis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class Options(BaseSettings):
dask: Optional[BaseModel] = None
impala: Optional[BaseModel] = None
pandas: Optional[BaseModel] = None
pyspark: Optional[BaseModel] = None

class Config:
validate_assignment = True
Expand Down

0 comments on commit bf47250

Please sign in to comment.