Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-43295][PS] Support string type columns for DataFrameGroupBy.sum #42798

Closed
wants to merge 9 commits into from
Closed
32 changes: 24 additions & 8 deletions python/pyspark/pandas/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,10 +857,10 @@ def sum(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameL
... "C": [3, 4, 3, 4], "D": ["a", "a", "b", "a"]})

>>> df.groupby("A").sum().sort_index()
B C
B C D
A
1 1 6
2 1 8
1 1 6 ab
2 1 8 aa

>>> df.groupby("D").sum().sort_index()
A B C
Expand Down Expand Up @@ -900,17 +900,17 @@ def sum(self, numeric_only: Optional[bool] = True, min_count: int = 0) -> FrameL
unsupported = [
col.name
for col in self._agg_columns
if not isinstance(col.spark.data_type, (NumericType, BooleanType))
if not isinstance(col.spark.data_type, (NumericType, BooleanType, StringType))
]
if len(unsupported) > 0:
log_advice(
"GroupBy.sum() can only support numeric and bool columns even if"
"GroupBy.sum() can only support numeric, bool and string columns even if"
f"numeric_only=False, skip unsupported columns: {unsupported}"
)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you gotta fix the log above too since not we support strings too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we should update. Thanks for catching this out!

return self._reduce_for_stat_function(
F.sum,
accepted_spark_types=(NumericType, BooleanType),
accepted_spark_types=(NumericType, BooleanType, StringType),
bool_to_numeric=True,
min_count=min_count,
)
Expand Down Expand Up @@ -3534,7 +3534,21 @@ def _reduce_for_stat_function(
for label in psdf._internal.column_labels:
psser = psdf._psser_for(label)
input_scol = psser._dtype_op.nan_to_null(psser).spark.column
output_scol = sfun(input_scol)
if sfun.__name__ == "sum" and isinstance(
psdf._internal.spark_type_for(label), StringType
):
input_scol_name = psser._internal.data_spark_column_names[0]
# Sort data with natural order column to ensure order of data
sorted_array = F.array_sort(
F.collect_list(F.struct(NATURAL_ORDER_COLUMN_NAME, input_scol))
)

# Using transform to extract strings
output_scol = F.concat_ws(
"", F.transform(sorted_array, lambda x: x.getField(input_scol_name))
)
else:
output_scol = sfun(input_scol)

if min_count > 0:
output_scol = F.when(
Expand Down Expand Up @@ -3591,7 +3605,9 @@ def _prepare_reduce(
):
agg_columns.append(psser)
sdf = self._psdf._internal.spark_frame.select(
*groupkey_scols, *[psser.spark.column for psser in agg_columns]
*groupkey_scols,
*[psser.spark.column for psser in agg_columns],
NATURAL_ORDER_COLUMN_NAME,
)
internal = InternalFrame(
spark_frame=sdf,
Expand Down
6 changes: 0 additions & 6 deletions python/pyspark/pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ def test_groupby_simple(self):
},
index=[0, 1, 3, 5, 6, 8, 9, 9, 9],
)
if LooseVersion(pd.__version__) >= LooseVersion("2.0.0"):
# TODO(SPARK-43295): Make DataFrameGroupBy.sum support for string type columns
pdf = pdf[["a", "b", "c", "e"]]
psdf = ps.from_pandas(pdf)

for as_index in [True, False]:
Expand Down Expand Up @@ -180,9 +177,6 @@ def sort(df):
index=[0, 1, 3, 5, 6, 8, 9, 9, 9],
)
psdf = ps.from_pandas(pdf)
if LooseVersion(pd.__version__) >= LooseVersion("2.0.0"):
# TODO(SPARK-43295): Make DataFrameGroupBy.sum support for string type columns
pdf = pdf[[10, 20, 30]]

for as_index in [True, False]:
if as_index:
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/pandas/tests/groupby/test_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def test_basic_stat_funcs(self):
# self._test_stat_func(lambda groupby_obj: groupby_obj.sum(), check_exact=False)
self.assert_eq(
psdf.groupby("A").sum().sort_index(),
pdf.groupby("A").sum(numeric_only=True).sort_index(),
pdf.groupby("A").sum().sort_index(),
check_exact=False,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def sort(df):

self.assert_eq(
sort(psdf1.groupby(psdf2.a, as_index=as_index).sum()),
sort(pdf1.groupby(pdf2.a, as_index=as_index).sum(numeric_only=True)),
sort(pdf1.groupby(pdf2.a, as_index=as_index).sum()),
almost=as_index,
)

Expand All @@ -93,7 +93,7 @@ def test_groupby_multiindex_columns(self):

self.assert_eq(
psdf1.groupby(psdf2[("x", "a")]).sum().sort_index(),
pdf1.groupby(pdf2[("x", "a")]).sum(numeric_only=True).sort_index(),
pdf1.groupby(pdf2[("x", "a")]).sum().sort_index(),
)

self.assert_eq(
Expand All @@ -102,7 +102,7 @@ def test_groupby_multiindex_columns(self):
.sort_values(("y", "c"))
.reset_index(drop=True),
pdf1.groupby(pdf2[("x", "a")], as_index=False)
.sum(numeric_only=True)
.sum()
.sort_values(("y", "c"))
.reset_index(drop=True),
)
Expand Down