From beb7c97f21ec4e55e981de237b2f0c97c0c3292b Mon Sep 17 00:00:00 2001 From: "Richard T. Guy" Date: Wed, 3 Jul 2013 10:35:53 -0400 Subject: [PATCH] BUG: Fixes issue #3334: brittle margin computation in pivot_table Adds support for margin computation when all columns are used in rows and cols. --- doc/source/release.rst | 1 + doc/source/v0.13.0.txt | 4 ++ pandas/tools/pivot.py | 114 ++++++++++++++++++++++++------- pandas/tools/tests/test_pivot.py | 22 ++++++ 4 files changed, 117 insertions(+), 24 deletions(-) diff --git a/doc/source/release.rst b/doc/source/release.rst index 90f7585ba7ab9..ba1446d033010 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -102,6 +102,7 @@ pandas 0.13 set _ref_locs (:issue:`4403`) - Fixed an issue where hist subplots were being overwritten when they were called using the top level matplotlib API (:issue:`4408`) + - Fixed (:issue:`3334`) in pivot_table. Margins did not compute if values is the index. pandas 0.12 =========== diff --git a/doc/source/v0.13.0.txt b/doc/source/v0.13.0.txt index 0a62322fa2996..d849fa38f0783 100644 --- a/doc/source/v0.13.0.txt +++ b/doc/source/v0.13.0.txt @@ -52,6 +52,10 @@ Bug Fixes - Fixed bug in ``PeriodIndex.map`` where using ``str`` would return the str representation of the index (:issue:`4136`) + + - Fixed (:issue:`3334`) in pivot_table. Margins did not compute if values is the index. + + - Fixed test failure ``test_time_series_plot_color_with_empty_kwargs`` when using custom matplotlib default colors (:issue:`4345`) diff --git a/pandas/tools/pivot.py b/pandas/tools/pivot.py index effcc3ff7695f..df84aeef03f2a 100644 --- a/pandas/tools/pivot.py +++ b/pandas/tools/pivot.py @@ -2,7 +2,6 @@ from pandas import Series, DataFrame from pandas.core.index import MultiIndex -from pandas.core.reshape import _unstack_multiple from pandas.tools.merge import concat from pandas.tools.util import cartesian_product from pandas.compat import range, lrange, zip @@ -149,17 +148,64 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean', DataFrame.pivot_table = pivot_table -def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean): - grand_margin = {} - for k, v in compat.iteritems(data[values]): - try: - if isinstance(aggfunc, compat.string_types): - grand_margin[k] = getattr(v, aggfunc)() - else: - grand_margin[k] = aggfunc(v) - except TypeError: - pass +def _add_margins(table, data, values, rows, cols, aggfunc): + + grand_margin = _compute_grand_margin(data, values, aggfunc) + + if not values and isinstance(table, Series): + # If there are no values and the table is a series, then there is only + # one column in the data. Compute grand margin and return it. + row_key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All' + return table.append(Series({row_key: grand_margin['All']})) + + if values: + marginal_result_set = _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin) + if not isinstance(marginal_result_set, tuple): + return marginal_result_set + result, margin_keys, row_margin = marginal_result_set + else: + marginal_result_set = _generate_marginal_results_without_values(table, data, rows, cols, aggfunc) + if not isinstance(marginal_result_set, tuple): + return marginal_result_set + result, margin_keys, row_margin = marginal_result_set + + key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All' + + row_margin = row_margin.reindex(result.columns) + # populate grand margin + for k in margin_keys: + if isinstance(k, basestring): + row_margin[k] = grand_margin[k] + else: + row_margin[k] = grand_margin[k[0]] + margin_dummy = DataFrame(row_margin, columns=[key]).T + + row_names = result.index.names + result = result.append(margin_dummy) + result.index.names = row_names + + return result + + +def _compute_grand_margin(data, values, aggfunc): + + if values: + grand_margin = {} + for k, v in data[values].iteritems(): + try: + if isinstance(aggfunc, basestring): + grand_margin[k] = getattr(v, aggfunc)() + else: + grand_margin[k] = aggfunc(v) + except TypeError: + pass + return grand_margin + else: + return {'All': aggfunc(data.index)} + + +def _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin): if len(cols) > 0: # need to "interleave" the margins table_pieces = [] @@ -203,23 +249,43 @@ def _all_key(key): else: row_margin = Series(np.nan, index=result.columns) - key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All' + return result, margin_keys, row_margin - row_margin = row_margin.reindex(result.columns) - # populate grand margin - for k in margin_keys: - if len(cols) > 0: - row_margin[k] = grand_margin[k[0]] - else: - row_margin[k] = grand_margin[k] - margin_dummy = DataFrame(row_margin, columns=[key]).T +def _generate_marginal_results_without_values(table, data, rows, cols, aggfunc): + if len(cols) > 0: + # need to "interleave" the margins + margin_keys = [] - row_names = result.index.names - result = result.append(margin_dummy) - result.index.names = row_names + def _all_key(): + if len(cols) == 1: + return 'All' + return ('All', ) + ('', ) * (len(cols) - 1) - return result + if len(rows) > 0: + margin = data[rows].groupby(rows).apply(aggfunc) + all_key = _all_key() + table[all_key] = margin + result = table + margin_keys.append(all_key) + + else: + margin = data.groupby(level=0, axis=0).apply(aggfunc) + all_key = _all_key() + table[all_key] = margin + result = table + margin_keys.append(all_key) + return result + else: + result = table + margin_keys = table.columns + + if len(cols): + row_margin = data[cols].groupby(cols).apply(aggfunc) + else: + row_margin = Series(np.nan, index=result.columns) + + return result, margin_keys, row_margin def _convert_by(by): diff --git a/pandas/tools/tests/test_pivot.py b/pandas/tools/tests/test_pivot.py index 57e7d2f7f6ae9..935e7da69ffdd 100644 --- a/pandas/tools/tests/test_pivot.py +++ b/pandas/tools/tests/test_pivot.py @@ -296,6 +296,28 @@ def test_pivot_complex_aggfunc(self): tm.assert_frame_equal(result, expected) + def test_margins_no_values_no_cols(self): + # Regression test on pivot table: no values or cols passed. + result = self.data[['A', 'B']].pivot_table(rows=['A', 'B'], aggfunc=len, margins=True) + result_list = result.tolist() + self.assertEqual(sum(result_list[:-1]), result_list[-1]) + + def test_margins_no_values_two_rows(self): + # Regression test on pivot table: no values passed but rows are a multi-index + result = self.data[['A', 'B', 'C']].pivot_table(rows=['A', 'B'], cols='C', aggfunc=len, margins=True) + self.assertEqual(result.All.tolist(), [3.0, 1.0, 4.0, 3.0, 11.0]) + + def test_margins_no_values_one_row_one_col(self): + # Regression test on pivot table: no values passed but row and col defined + result = self.data[['A', 'B']].pivot_table(rows='A', cols='B', aggfunc=len, margins=True) + self.assertEqual(result.All.tolist(), [4.0, 7.0, 11.0]) + + def test_margins_no_values_two_row_two_cols(self): + # Regression test on pivot table: no values passed but rows and cols are multi-indexed + self.data['D'] = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k'] + result = self.data[['A', 'B', 'C', 'D']].pivot_table(rows=['A', 'B'], cols=['C', 'D'], aggfunc=len, margins=True) + self.assertEqual(result.All.tolist(), [3.0, 1.0, 4.0, 3.0, 11.0]) + class TestCrosstab(unittest.TestCase):