Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Do not cast .transform() output back to input dtype (closes #10972) #15256

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v0.20.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -502,7 +502,7 @@ Bug Fixes

- Bug in ``DataFrame.to_stata()`` and ``StataWriter`` which produces incorrectly formatted files to be produced for some locales (:issue:`13856`)


- Bug in ``DataFrame.groupby().transform()`` dtype of output is cast to dtype of input (:issue:`10972`)



Expand Down
39 changes: 24 additions & 15 deletions pandas/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,10 @@ def _python_agg_general(self, func, *args, **kwargs):
for name, obj in self._iterate_slices():
try:
result, counts = self.grouper.agg_series(obj, f)
output[name] = self._try_cast(result, obj)
if not obj.dtypes.kind in ['M', 'm']:
# don't cast back to datetime
result = self._try_cast(result, obj)
output[name] = result
except TypeError:
continue

Expand Down Expand Up @@ -2748,6 +2751,7 @@ def aggregate(self, func_or_funcs, *args, **kwargs):
if not _level and isinstance(ret, dict):
from pandas import concat
ret = concat(ret, axis=1)

return ret

agg = aggregate
Expand Down Expand Up @@ -2892,30 +2896,35 @@ def transform(self, func, *args, **kwargs):
return self._transform_fast(
lambda: getattr(self, func)(*args, **kwargs))

# reg transform
dtype = self._selected_obj.dtype
result = self._selected_obj.values.copy()
# we'll store the results of each group in a list
# and infer the output dtype
# then we'll create an empty array with the correct dtype to put the
# values inside
result_values = []
result_indexers = []
result_dtypes = []

wrapper = lambda x: func(x, *args, **kwargs)
for i, (name, group) in enumerate(self):
object.__setattr__(group, 'name', name)
res = wrapper(group)

if hasattr(res, 'values'):
res = res.values
result_values.append(res)
indexer = self._get_index(name)
result_indexers.append(indexer)
result_dtypes.append(np.array(res).dtype)

# may need to astype
try:
common_type = np.common_type(np.array(res), result)
if common_type != result.dtype:
result = result.astype(common_type)
except:
pass
dtype = np.find_common_type(array_types=result_dtypes, scalar_types=[])

indexer = self._get_index(name)
result = np.zeros(self._selected_obj.shape, dtype=dtype)
for indexer, res in zip(result_indexers, result_values):
result[indexer] = res

result = _possibly_downcast_to_dtype(result, dtype)
input_dtype = self._selected_obj.dtype
if not input_dtype.kind in ['M', 'm']: # don't cast back to datetime
result = _possibly_downcast_to_dtype(result, input_dtype)

return self._selected_obj.__class__(result,
index=self._selected_obj.index,
name=self._selected_obj.name)
Expand Down Expand Up @@ -3308,6 +3317,7 @@ def aggregate(self, arg, *args, **kwargs):

_level = kwargs.pop('_level', None)
result, how = self._aggregate(arg, _level=_level, *args, **kwargs)

if how is None:
return result

Expand Down Expand Up @@ -4245,7 +4255,6 @@ def fast_apply(self, f, names):

sdata = self._get_sorted_data()
results, mutated = lib.apply_frame_axis0(sdata, f, names, starts, ends)

return results, mutated

def _chop(self, sdata, slice_obj):
Expand Down
16 changes: 12 additions & 4 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,17 +1308,22 @@ def test_transform_coercion(self):
def test_with_na(self):
index = Index(np.arange(10))

for dtype in ['float64', 'float32', 'int64', 'int32', 'int16', 'int8']:
for dtype in ['float64', 'float32', 'int64', 'int32', 'int16', 'int8',
'datetime64[ns]']:
values = Series(np.ones(10), index, dtype=dtype)
labels = Series([nan, 'foo', 'bar', 'bar', nan, nan, 'bar',
'bar', nan, 'foo'], index=index)

# this SHOULD be an int
grouped = values.groupby(labels)
agged = grouped.agg(len)
expected = Series([4, 2], index=['bar', 'foo'])

assert_series_equal(agged, expected, check_dtype=False)
if dtype != "datetime64[ns]":
expected = Series([4, 2], index=['bar', 'foo'], dtype=dtype)
else:
expected = Series([4, 2], index=['bar', 'foo'])

assert_series_equal(agged, expected)

# self.assertTrue(issubclass(agged.dtype.type, np.integer))

Expand All @@ -1328,8 +1333,11 @@ def f(x):

agged = grouped.agg(f)
expected = Series([4, 2], index=['bar', 'foo'])

assert_series_equal(agged, expected, check_dtype=False)

if dtype == "datetime64[ns]":
continue

self.assertTrue(issubclass(agged.dtype.type, np.dtype(dtype).type))

def test_groupby_transform_with_int(self):
Expand Down
8 changes: 8 additions & 0 deletions pandas/tseries/tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,6 +1805,10 @@ def test_resample_median_bug_1688(self):
datetime(2012, 1, 1, 0, 5, 0)],
dtype=dtype)

result = df.resample("T").mean()
exp = df.asfreq('T')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why r u changing tests like this?

tm.assert_frame_equal(result, exp)

result = df.resample("T").apply(lambda x: x.mean())
exp = df.asfreq('T')
tm.assert_frame_equal(result, exp)
Expand All @@ -1813,6 +1817,10 @@ def test_resample_median_bug_1688(self):
exp = df.asfreq('T')
tm.assert_frame_equal(result, exp)

result = df.resample("T").apply(lambda x: x.median())
exp = df.asfreq('T')
tm.assert_frame_equal(result, exp)

def test_how_lambda_functions(self):

ts = _simple_ts('1/1/2000', '4/1/2000')
Expand Down