Skip to content

Commit

Permalink
[SPARK-44842][SPARK-43812][PS] Support stat functions for pandas 2.0.…
Browse files Browse the repository at this point in the history
…0 and enabling tests

### What changes were proposed in this pull request?

This PR proposes to match the behavior with pandas 2.0.0 and above for stat functions, such as `sum`, `quantile`, `prod`, etc. See pandas-dev/pandas#41480 and pandas-dev/pandas#47500 for more detail.

### Why are the changes needed?

To match the behavior to latest pandas.

### Does this PR introduce _any_ user-facing change?

Yes, the behaviors for stat funcs are now matched with pandas 2.0.0 and above.

### How was this patch tested?

Enabling & updating the existing UTs.

Closes apache#42526 from itholic/pandas_stat.

Authored-by: itholic <haejoon.lee@databricks.com>
Signed-off-by: Ruifeng Zheng <ruifengz@apache.org>
  • Loading branch information
itholic authored and vpolet committed Aug 24, 2023
1 parent e6b9878 commit 49618b6
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 131 deletions.
10 changes: 3 additions & 7 deletions python/pyspark/pandas/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
DecimalType,
TimestampType,
TimestampNTZType,
NullType,
)
from pyspark.sql.window import Window

Expand Down Expand Up @@ -797,7 +798,7 @@ def _reduce_for_stat_function(
new_column_labels.append(label)

if len(exprs) == 1:
return Series([])
return Series([], dtype="float64")

sdf = self._internal.spark_frame.select(*exprs)

Expand Down Expand Up @@ -12128,11 +12129,6 @@ def quantile(
0.50 3.0 7.0
0.75 4.0 8.0
"""
warnings.warn(
"Default value of `numeric_only` will be changed to `False` "
"instead of `True` in 4.0.0.",
FutureWarning,
)
axis = validate_axis(axis)
if axis != 0:
raise NotImplementedError('axis should be either 0 or "index" currently.')
Expand All @@ -12155,7 +12151,7 @@ def quantile(
def quantile(psser: "Series") -> PySparkColumn:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
if isinstance(spark_type, (BooleanType, NumericType)):
if isinstance(spark_type, (BooleanType, NumericType, NullType)):
return F.percentile_approx(spark_column.cast(DoubleType()), qq, accuracy)
else:
raise TypeError(
Expand Down
5 changes: 0 additions & 5 deletions python/pyspark/pandas/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1419,11 +1419,6 @@ def product(
nan
"""
axis = validate_axis(axis)
warnings.warn(
"Default value of `numeric_only` will be changed to `False` "
"instead of `None` in 4.0.0.",
FutureWarning,
)

if numeric_only is None and axis == 0:
numeric_only = True
Expand Down
20 changes: 12 additions & 8 deletions python/pyspark/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,9 +614,10 @@ def mean(self, numeric_only: Optional[bool] = True) -> FrameLike:
Parameters
----------
numeric_only : bool, default False
numeric_only : bool, default True
Include only float, int, boolean columns. If None, will attempt to use
everything, then use only numeric data.
everything, then use only numeric data. False is not supported.
This parameter is mainly for pandas compatibility.
.. versionadded:: 3.4.0
Expand Down Expand Up @@ -646,11 +647,6 @@ def mean(self, numeric_only: Optional[bool] = True) -> FrameLike:
2 4.0 1.500000 1.000000
"""
self._validate_agg_columns(numeric_only=numeric_only, function_name="median")
warnings.warn(
"Default value of `numeric_only` will be changed to `False` "
"instead of `True` in 4.0.0.",
FutureWarning,
)

return self._reduce_for_stat_function(
F.mean, accepted_spark_types=(NumericType,), bool_to_numeric=True
Expand Down Expand Up @@ -920,7 +916,7 @@ def sum(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameL
)

# TODO: sync the doc.
def var(self, ddof: int = 1) -> FrameLike:
def var(self, ddof: int = 1, numeric_only: Optional[bool] = True) -> FrameLike:
"""
Compute variance of groups, excluding missing values.
Expand All @@ -935,6 +931,13 @@ def var(self, ddof: int = 1) -> FrameLike:
.. versionchanged:: 3.4.0
Supported including arbitary integers.
numeric_only : bool, default True
Include only float, int, boolean columns. If None, will attempt to use
everything, then use only numeric data. False is not supported.
This parameter is mainly for pandas compatibility.
.. versionadded:: 4.0.0
Examples
--------
>>> df = ps.DataFrame({"A": [1, 2, 1, 2], "B": [True, False, False, True],
Expand All @@ -961,6 +964,7 @@ def var(col: Column) -> Column:
var,
accepted_spark_types=(NumericType,),
bool_to_numeric=True,
numeric_only=numeric_only,
)

def skew(self) -> FrameLike:
Expand Down
16 changes: 5 additions & 11 deletions python/pyspark/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
Row,
StructType,
TimestampType,
NullType,
)
from pyspark.sql.window import Window
from pyspark.sql.utils import get_column_class, get_window_class
Expand Down Expand Up @@ -4024,7 +4025,7 @@ def quantile(
def quantile(psser: Series) -> PySparkColumn:
spark_type = psser.spark.data_type
spark_column = psser.spark.column
if isinstance(spark_type, (BooleanType, NumericType)):
if isinstance(spark_type, (BooleanType, NumericType, NullType)):
return F.percentile_approx(spark_column.cast(DoubleType()), q_float, accuracy)
else:
raise TypeError(
Expand Down Expand Up @@ -4059,7 +4060,8 @@ def rank(
ascending : boolean, default True
False for ranks by high (1) to low (N)
numeric_only : bool, optional
If set to True, rank numeric Series, or return an empty Series for non-numeric Series
If set to True, rank numeric Series, or raise TypeError for non-numeric Series.
False is not supported. This parameter is mainly for pandas compatibility.
Returns
-------
Expand Down Expand Up @@ -4127,18 +4129,10 @@ def rank(
y b
z c
Name: A, dtype: object
>>> s.rank(numeric_only=True)
Series([], Name: A, dtype: float64)
"""
warnings.warn(
"Default value of `numeric_only` will be changed to `False` "
"instead of `None` in 4.0.0.",
FutureWarning,
)
is_numeric = isinstance(self.spark.data_type, (NumericType, BooleanType))
if numeric_only and not is_numeric:
return ps.Series([], dtype="float64", name=self.name)
raise TypeError("Series.rank does not allow numeric_only=True with non-numeric dtype.")
else:
return self._rank(method, ascending).spark.analyzed

Expand Down
14 changes: 8 additions & 6 deletions python/pyspark/pandas/tests/computation/test_any_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ def df_pair(self):
psdf = ps.from_pandas(pdf)
return pdf, psdf

@unittest.skipIf(
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
"TODO(SPARK-43812): Enable DataFrameTests.test_all for pandas 2.0.0.",
)
def test_all(self):
pdf = pd.DataFrame(
{
Expand Down Expand Up @@ -105,9 +101,15 @@ def test_all(self):
self.assert_eq(psdf.all(skipna=True), pdf.all(skipna=True))
self.assert_eq(psdf.all(), pdf.all())
self.assert_eq(
ps.DataFrame([np.nan]).all(skipna=False), pd.DataFrame([np.nan]).all(skipna=False)
ps.DataFrame([np.nan]).all(skipna=False),
pd.DataFrame([np.nan]).all(skipna=False),
almost=True,
)
self.assert_eq(
ps.DataFrame([None]).all(skipna=True),
pd.DataFrame([None]).all(skipna=True),
almost=True,
)
self.assert_eq(ps.DataFrame([None]).all(skipna=True), pd.DataFrame([None]).all(skipna=True))

def test_any(self):
pdf = pd.DataFrame(
Expand Down
38 changes: 16 additions & 22 deletions python/pyspark/pandas/tests/computation/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,10 +283,6 @@ def test_nunique(self):
self.assert_eq(psdf.nunique(), pdf.nunique())
self.assert_eq(psdf.nunique(dropna=False), pdf.nunique(dropna=False))

@unittest.skipIf(
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
"TODO(SPARK-43810): Enable DataFrameSlowTests.test_quantile for pandas 2.0.0.",
)
def test_quantile(self):
pdf, psdf = self.df_pair

Expand Down Expand Up @@ -332,59 +328,57 @@ def test_quantile(self):
pdf = pd.DataFrame({"x": ["a", "b", "c"]})
psdf = ps.from_pandas(pdf)

self.assert_eq(psdf.quantile(0.5), pdf.quantile(0.5))
self.assert_eq(psdf.quantile([0.25, 0.5, 0.75]), pdf.quantile([0.25, 0.5, 0.75]))
self.assert_eq(psdf.quantile(0.5), pdf.quantile(0.5, numeric_only=True))
self.assert_eq(
psdf.quantile([0.25, 0.5, 0.75]), pdf.quantile([0.25, 0.5, 0.75], numeric_only=True)
)

with self.assertRaisesRegex(TypeError, "Could not convert object \\(string\\) to numeric"):
psdf.quantile(0.5, numeric_only=False)
with self.assertRaisesRegex(TypeError, "Could not convert object \\(string\\) to numeric"):
psdf.quantile([0.25, 0.5, 0.75], numeric_only=False)

@unittest.skipIf(
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
"TODO(SPARK-43558): Enable DataFrameSlowTests.test_product for pandas 2.0.0.",
)
def test_product(self):
pdf = pd.DataFrame(
{"A": [1, 2, 3, 4, 5], "B": [10, 20, 30, 40, 50], "C": ["a", "b", "c", "d", "e"]}
)
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.prod(), psdf.prod().sort_index())
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index())

# Named columns
pdf.columns.name = "Koalas"
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.prod(), psdf.prod().sort_index())
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index())

# MultiIndex columns
pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")])
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.prod(), psdf.prod().sort_index())
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index())

# Named MultiIndex columns
pdf.columns.names = ["Hello", "Koalas"]
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.prod(), psdf.prod().sort_index())
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index())

# No numeric columns
pdf = pd.DataFrame({"key": ["a", "b", "c"], "val": ["x", "y", "z"]})
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.prod(), psdf.prod().sort_index())
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index())

# No numeric named columns
pdf.columns.name = "Koalas"
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), almost=True)
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), almost=True)

# No numeric MultiIndex columns
pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y")])
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), almost=True)
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), almost=True)

# No numeric named MultiIndex columns
pdf.columns.names = ["Hello", "Koalas"]
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), almost=True)
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), almost=True)

# All NaN columns
pdf = pd.DataFrame(
Expand All @@ -395,22 +389,22 @@ def test_product(self):
}
)
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), check_exact=False)
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), check_exact=False)

# All NaN named columns
pdf.columns.name = "Koalas"
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), check_exact=False)
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), check_exact=False)

# All NaN MultiIndex columns
pdf.columns = pd.MultiIndex.from_tuples([("a", "x"), ("b", "y"), ("c", "z")])
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), check_exact=False)
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), check_exact=False)

# All NaN named MultiIndex columns
pdf.columns.names = ["Hello", "Koalas"]
psdf = ps.from_pandas(pdf)
self.assert_eq(pdf.prod(), psdf.prod().sort_index(), check_exact=False)
self.assert_eq(pdf.prod(numeric_only=True), psdf.prod().sort_index(), check_exact=False)


class FrameComputeTests(FrameComputeMixin, ComparisonTestBase, SQLTestUtils):
Expand Down
33 changes: 7 additions & 26 deletions python/pyspark/pandas/tests/groupby/test_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,10 @@ def _test_stat_func(self, func, check_exact=True):
check_exact=check_exact,
)

@unittest.skipIf(
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
"TODO(SPARK-43554): Enable GroupByTests.test_basic_stat_funcs for pandas 2.0.0.",
)
def test_basic_stat_funcs(self):
self._test_stat_func(lambda groupby_obj: groupby_obj.var(), check_exact=False)
self._test_stat_func(
lambda groupby_obj: groupby_obj.var(numeric_only=True), check_exact=False
)

pdf, psdf = self.pdf, self.psdf

Expand Down Expand Up @@ -102,30 +100,24 @@ def test_basic_stat_funcs(self):

self.assert_eq(
psdf.groupby("A").std().sort_index(),
pdf.groupby("A").std().sort_index(),
pdf.groupby("A").std(numeric_only=True).sort_index(),
check_exact=False,
)
self.assert_eq(
psdf.groupby("A").sem().sort_index(),
pdf.groupby("A").sem().sort_index(),
pdf.groupby("A").sem(numeric_only=True).sort_index(),
check_exact=False,
)

# TODO: fix bug of `sum` and re-enable the test below
# self._test_stat_func(lambda groupby_obj: groupby_obj.sum(), check_exact=False)
self.assert_eq(
psdf.groupby("A").sum().sort_index(),
pdf.groupby("A").sum().sort_index(),
pdf.groupby("A").sum(numeric_only=True).sort_index(),
check_exact=False,
)

@unittest.skipIf(
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
"TODO(SPARK-43706): Enable GroupByTests.test_mean " "for pandas 2.0.0.",
)
def test_mean(self):
self._test_stat_func(lambda groupby_obj: groupby_obj.mean())
self._test_stat_func(lambda groupby_obj: groupby_obj.mean(numeric_only=None))
self._test_stat_func(lambda groupby_obj: groupby_obj.mean(numeric_only=True))
psdf = self.psdf
with self.assertRaises(TypeError):
Expand Down Expand Up @@ -267,10 +259,6 @@ def test_nth(self):
with self.assertRaisesRegex(TypeError, "Invalid index"):
self.psdf.groupby("B").nth("x")

@unittest.skipIf(
LooseVersion(pd.__version__) >= LooseVersion("2.0.0"),
"TODO(SPARK-43551): Enable GroupByTests.test_prod for pandas 2.0.0.",
)
def test_prod(self):
pdf = pd.DataFrame(
{
Expand All @@ -286,19 +274,12 @@ def test_prod(self):
psdf = ps.from_pandas(pdf)

for n in [0, 1, 2, 128, -1, -2, -128]:
self._test_stat_func(
lambda groupby_obj: groupby_obj.prod(min_count=n), check_exact=False
)
self._test_stat_func(
lambda groupby_obj: groupby_obj.prod(numeric_only=None, min_count=n),
check_exact=False,
)
self._test_stat_func(
lambda groupby_obj: groupby_obj.prod(numeric_only=True, min_count=n),
check_exact=False,
)
self.assert_eq(
pdf.groupby("A").prod(min_count=n).sort_index(),
pdf.groupby("A").prod(min_count=n, numeric_only=True).sort_index(),
psdf.groupby("A").prod(min_count=n).sort_index(),
almost=True,
)
Expand Down
Loading

0 comments on commit 49618b6

Please sign in to comment.