diff --git a/doc/source/whatsnew/v1.5.1.rst b/doc/source/whatsnew/v1.5.1.rst index a3c82111c630a..6a8fae61e23ba 100644 --- a/doc/source/whatsnew/v1.5.1.rst +++ b/doc/source/whatsnew/v1.5.1.rst @@ -76,6 +76,7 @@ Fixed regressions - Fixed regression in :meth:`DataFrame.plot` ignoring invalid ``colormap`` for ``kind="scatter"`` (:issue:`48726`) - Fixed performance regression in :func:`factorize` when ``na_sentinel`` is not ``None`` and ``sort=False`` (:issue:`48620`) - Fixed Regression in :meth:`DataFrameGroupBy.apply` when user defined function is called on an empty dataframe (:issue:`47985`) +- Fixed :meth:`.DataFrameGroupBy.size` not returning a Series when ``axis=1`` (:issue:`48738`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index d5f22b816f908..46ab1f6a86329 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1869,11 +1869,13 @@ def _wrap_transform_fast_result(self, result: NDFrameT) -> NDFrameT: out = algorithms.take_nd(result._values, ids) output = obj._constructor(out, index=obj.index, name=obj.name) else: + # `.size()` gives Series output on DataFrame input, need axis 0 + axis = 0 if result.ndim == 1 else self.axis # GH#46209 # Don't convert indices: negative indices need to give rise # to null values in the result - output = result._take(ids, axis=self.axis, convert_indices=False) - output = output.set_axis(obj._get_axis(self.axis), axis=self.axis) + output = result._take(ids, axis=axis, convert_indices=False) + output = output.set_axis(obj._get_axis(self.axis), axis=axis) return output # ----------------------------------------------------------------- @@ -2397,13 +2399,6 @@ def size(self) -> DataFrame | Series: """ result = self.grouper.size() - if self.axis == 1: - return DataFrame( - data=np.tile(result.values, (self.obj.shape[0], 1)), - columns=result.index, - index=self.obj.index, - ) - # GH28330 preserve subclassed Series/DataFrames through calls if isinstance(self.obj, Series): result = self._obj_1d_constructor(result, name=self.obj.name) diff --git a/pandas/tests/groupby/test_size.py b/pandas/tests/groupby/test_size.py index a614cf7abd684..92012436f6b47 100644 --- a/pandas/tests/groupby/test_size.py +++ b/pandas/tests/groupby/test_size.py @@ -33,12 +33,12 @@ def test_size_axis_1(df, axis_1, by, sort, dropna): counts = {key: sum(value == key for value in by) for key in dict.fromkeys(by)} if dropna: counts = {key: value for key, value in counts.items() if key is not None} - expected = DataFrame(counts, index=df.index) + expected = Series(counts) if sort: - expected = expected.sort_index(axis=1) + expected = expected.sort_index() grouped = df.groupby(by=by, axis=axis_1, sort=sort, dropna=dropna) result = grouped.size() - tm.assert_frame_equal(result, expected) + tm.assert_series_equal(result, expected) @pytest.mark.parametrize("by", ["A", "B", ["A", "B"]]) diff --git a/pandas/tests/groupby/transform/test_transform.py b/pandas/tests/groupby/transform/test_transform.py index 113bd4a0c4c65..8a2bd64a3deb0 100644 --- a/pandas/tests/groupby/transform/test_transform.py +++ b/pandas/tests/groupby/transform/test_transform.py @@ -213,13 +213,9 @@ def test_transform_axis_1_reducer(request, reduction_func): df = DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}, index=["x", "y"]) with tm.assert_produces_warning(warn, match=msg): result = df.groupby([0, 0, 1], axis=1).transform(reduction_func) - if reduction_func == "size": - # size doesn't behave in the same manner; hardcode expected result - expected = DataFrame(2 * [[2, 2, 1]], index=df.index, columns=df.columns) - else: - warn = FutureWarning if reduction_func == "mad" else None - with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"): - expected = df.T.groupby([0, 0, 1]).transform(reduction_func).T + warn = FutureWarning if reduction_func == "mad" else None + with tm.assert_produces_warning(warn, match="The 'mad' method is deprecated"): + expected = df.T.groupby([0, 0, 1]).transform(reduction_func).T tm.assert_equal(result, expected)