From ec52ab4015cf046ed752de887063d0c587ff4755 Mon Sep 17 00:00:00 2001 From: Nicolas Bonnotte Date: Sat, 28 Jan 2017 17:40:24 +0100 Subject: [PATCH] Do not cast .transform() output back to input dtype: datetime input --- doc/source/whatsnew/v0.20.0.txt | 2 +- pandas/core/groupby.py | 39 ++++++++++++++++----------- pandas/tests/groupby/test_groupby.py | 16 ++++++++--- pandas/tseries/tests/test_resample.py | 8 ++++++ 4 files changed, 45 insertions(+), 20 deletions(-) diff --git a/doc/source/whatsnew/v0.20.0.txt b/doc/source/whatsnew/v0.20.0.txt index afe2758db9ab1..fcca834f68ab4 100644 --- a/doc/source/whatsnew/v0.20.0.txt +++ b/doc/source/whatsnew/v0.20.0.txt @@ -502,7 +502,7 @@ Bug Fixes - Bug in ``DataFrame.to_stata()`` and ``StataWriter`` which produces incorrectly formatted files to be produced for some locales (:issue:`13856`) - +- Bug in ``DataFrame.groupby().transform()`` dtype of output is cast to dtype of input (:issue:`10972`) diff --git a/pandas/core/groupby.py b/pandas/core/groupby.py index 99220232114ce..c0f5e6021c52e 100644 --- a/pandas/core/groupby.py +++ b/pandas/core/groupby.py @@ -829,7 +829,10 @@ def _python_agg_general(self, func, *args, **kwargs): for name, obj in self._iterate_slices(): try: result, counts = self.grouper.agg_series(obj, f) - output[name] = self._try_cast(result, obj) + if not obj.dtypes.kind in ['M', 'm']: + # don't cast back to datetime + result = self._try_cast(result, obj) + output[name] = result except TypeError: continue @@ -2748,6 +2751,7 @@ def aggregate(self, func_or_funcs, *args, **kwargs): if not _level and isinstance(ret, dict): from pandas import concat ret = concat(ret, axis=1) + return ret agg = aggregate @@ -2892,30 +2896,35 @@ def transform(self, func, *args, **kwargs): return self._transform_fast( lambda: getattr(self, func)(*args, **kwargs)) - # reg transform - dtype = self._selected_obj.dtype - result = self._selected_obj.values.copy() + # we'll store the results of each group in a list + # and infer the output dtype + # then we'll create an empty array with the correct dtype to put the + # values inside + result_values = [] + result_indexers = [] + result_dtypes = [] wrapper = lambda x: func(x, *args, **kwargs) for i, (name, group) in enumerate(self): object.__setattr__(group, 'name', name) res = wrapper(group) - if hasattr(res, 'values'): res = res.values + result_values.append(res) + indexer = self._get_index(name) + result_indexers.append(indexer) + result_dtypes.append(np.array(res).dtype) - # may need to astype - try: - common_type = np.common_type(np.array(res), result) - if common_type != result.dtype: - result = result.astype(common_type) - except: - pass + dtype = np.find_common_type(array_types=result_dtypes, scalar_types=[]) - indexer = self._get_index(name) + result = np.zeros(self._selected_obj.shape, dtype=dtype) + for indexer, res in zip(result_indexers, result_values): result[indexer] = res - result = _possibly_downcast_to_dtype(result, dtype) + input_dtype = self._selected_obj.dtype + if not input_dtype.kind in ['M', 'm']: # don't cast back to datetime + result = _possibly_downcast_to_dtype(result, input_dtype) + return self._selected_obj.__class__(result, index=self._selected_obj.index, name=self._selected_obj.name) @@ -3308,6 +3317,7 @@ def aggregate(self, arg, *args, **kwargs): _level = kwargs.pop('_level', None) result, how = self._aggregate(arg, _level=_level, *args, **kwargs) + if how is None: return result @@ -4245,7 +4255,6 @@ def fast_apply(self, f, names): sdata = self._get_sorted_data() results, mutated = lib.apply_frame_axis0(sdata, f, names, starts, ends) - return results, mutated def _chop(self, sdata, slice_obj): diff --git a/pandas/tests/groupby/test_groupby.py b/pandas/tests/groupby/test_groupby.py index ffb6025163a6b..e1f031954c1eb 100644 --- a/pandas/tests/groupby/test_groupby.py +++ b/pandas/tests/groupby/test_groupby.py @@ -1308,7 +1308,8 @@ def test_transform_coercion(self): def test_with_na(self): index = Index(np.arange(10)) - for dtype in ['float64', 'float32', 'int64', 'int32', 'int16', 'int8']: + for dtype in ['float64', 'float32', 'int64', 'int32', 'int16', 'int8', + 'datetime64[ns]']: values = Series(np.ones(10), index, dtype=dtype) labels = Series([nan, 'foo', 'bar', 'bar', nan, nan, 'bar', 'bar', nan, 'foo'], index=index) @@ -1316,9 +1317,13 @@ def test_with_na(self): # this SHOULD be an int grouped = values.groupby(labels) agged = grouped.agg(len) - expected = Series([4, 2], index=['bar', 'foo']) - assert_series_equal(agged, expected, check_dtype=False) + if dtype != "datetime64[ns]": + expected = Series([4, 2], index=['bar', 'foo'], dtype=dtype) + else: + expected = Series([4, 2], index=['bar', 'foo']) + + assert_series_equal(agged, expected) # self.assertTrue(issubclass(agged.dtype.type, np.integer)) @@ -1328,8 +1333,11 @@ def f(x): agged = grouped.agg(f) expected = Series([4, 2], index=['bar', 'foo']) - assert_series_equal(agged, expected, check_dtype=False) + + if dtype == "datetime64[ns]": + continue + self.assertTrue(issubclass(agged.dtype.type, np.dtype(dtype).type)) def test_groupby_transform_with_int(self): diff --git a/pandas/tseries/tests/test_resample.py b/pandas/tseries/tests/test_resample.py index 56953541265a6..45e217283af08 100755 --- a/pandas/tseries/tests/test_resample.py +++ b/pandas/tseries/tests/test_resample.py @@ -1805,6 +1805,10 @@ def test_resample_median_bug_1688(self): datetime(2012, 1, 1, 0, 5, 0)], dtype=dtype) + result = df.resample("T").mean() + exp = df.asfreq('T') + tm.assert_frame_equal(result, exp) + result = df.resample("T").apply(lambda x: x.mean()) exp = df.asfreq('T') tm.assert_frame_equal(result, exp) @@ -1813,6 +1817,10 @@ def test_resample_median_bug_1688(self): exp = df.asfreq('T') tm.assert_frame_equal(result, exp) + result = df.resample("T").apply(lambda x: x.median()) + exp = df.asfreq('T') + tm.assert_frame_equal(result, exp) + def test_how_lambda_functions(self): ts = _simple_ts('1/1/2000', '4/1/2000')