Skip to content

Commit

Permalink
BUG: Fixes issue pandas-dev#3334: brittle margin computation in pivot…
Browse files Browse the repository at this point in the history
…_table

Adds support for margin computation when all columns are used in rows and cols.
  • Loading branch information
guyrt committed Aug 1, 2013
1 parent 527db38 commit 50790c8
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 24 deletions.
3 changes: 3 additions & 0 deletions doc/source/v0.13.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ Bug Fixes

- Fixed bug in ``PeriodIndex.map`` where using ``str`` would return the str
representation of the index (:issue:`4136`)

- Fixed some edge cases in pivot_table where margins did not compute if values is the index.


- Fixed test failure ``test_time_series_plot_color_with_empty_kwargs`` when
using custom matplotlib default colors (:issue:`4345`)
Expand Down
114 changes: 90 additions & 24 deletions pandas/tools/pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from pandas import Series, DataFrame
from pandas.core.index import MultiIndex
from pandas.core.reshape import _unstack_multiple
from pandas.tools.merge import concat
from pandas.tools.util import cartesian_product
from pandas.compat import range, lrange, zip
Expand Down Expand Up @@ -149,17 +148,64 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
DataFrame.pivot_table = pivot_table


def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):
grand_margin = {}
for k, v in compat.iteritems(data[values]):
try:
if isinstance(aggfunc, compat.string_types):
grand_margin[k] = getattr(v, aggfunc)()
else:
grand_margin[k] = aggfunc(v)
except TypeError:
pass
def _add_margins(table, data, values, rows, cols, aggfunc):

grand_margin = _compute_grand_margin(data, values, aggfunc)

if not values and isinstance(table, Series):
# If there are no values and the table is a series, then there is only
# one column in the data. Compute grand margin and return it.
row_key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
return table.append(Series({row_key: grand_margin['All']}))

if values:
marginal_result_set = _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin)
if not isinstance(marginal_result_set, tuple):
return marginal_result_set
result, margin_keys, row_margin = marginal_result_set
else:
marginal_result_set = _generate_marginal_results_without_values(table, data, rows, cols, aggfunc)
if not isinstance(marginal_result_set, tuple):
return marginal_result_set
result, margin_keys, row_margin = marginal_result_set

key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'

row_margin = row_margin.reindex(result.columns)
# populate grand margin
for k in margin_keys:
if isinstance(k, basestring):
row_margin[k] = grand_margin[k]
else:
row_margin[k] = grand_margin[k[0]]

margin_dummy = DataFrame(row_margin, columns=[key]).T

row_names = result.index.names
result = result.append(margin_dummy)
result.index.names = row_names

return result


def _compute_grand_margin(data, values, aggfunc):

if values:
grand_margin = {}
for k, v in data[values].iteritems():
try:
if isinstance(aggfunc, basestring):
grand_margin[k] = getattr(v, aggfunc)()
else:
grand_margin[k] = aggfunc(v)
except TypeError:
pass
return grand_margin
else:
return {'All': aggfunc(data.index)}


def _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin):
if len(cols) > 0:
# need to "interleave" the margins
table_pieces = []
Expand Down Expand Up @@ -203,23 +249,43 @@ def _all_key(key):
else:
row_margin = Series(np.nan, index=result.columns)

key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
return result, margin_keys, row_margin

row_margin = row_margin.reindex(result.columns)
# populate grand margin
for k in margin_keys:
if len(cols) > 0:
row_margin[k] = grand_margin[k[0]]
else:
row_margin[k] = grand_margin[k]

margin_dummy = DataFrame(row_margin, columns=[key]).T
def _generate_marginal_results_without_values(table, data, rows, cols, aggfunc):
if len(cols) > 0:
# need to "interleave" the margins
margin_keys = []

row_names = result.index.names
result = result.append(margin_dummy)
result.index.names = row_names
def _all_key():
if len(cols) == 1:
return 'All'
return ('All', ) + ('', ) * (len(cols) - 1)

return result
if len(rows) > 0:
margin = data[rows].groupby(rows).apply(aggfunc)
all_key = _all_key()
table[all_key] = margin
result = table
margin_keys.append(all_key)

else:
margin = data.groupby(level=0, axis=0).apply(aggfunc)
all_key = _all_key()
table[all_key] = margin
result = table
margin_keys.append(all_key)
return result
else:
result = table
margin_keys = table.columns

if len(cols):
row_margin = data[cols].groupby(cols).apply(aggfunc)
else:
row_margin = Series(np.nan, index=result.columns)

return result, margin_keys, row_margin


def _convert_by(by):
Expand Down
22 changes: 22 additions & 0 deletions pandas/tools/tests/test_pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,28 @@ def test_pivot_complex_aggfunc(self):

tm.assert_frame_equal(result, expected)

def test_margins_no_values_no_cols(self):
# Regression test on pivot table: no values or cols passed.
result = self.data[['A', 'B']].pivot_table(rows=['A', 'B'], aggfunc=len, margins=True)
result_list = result.tolist()
self.assertEqual(sum(result_list[:-1]), result_list[-1])

def test_margins_no_values_two_rows(self):
# Regression test on pivot table: no values passed but rows are a multi-index
result = self.data[['A', 'B', 'C']].pivot_table(rows=['A', 'B'], cols='C', aggfunc=len, margins=True)
self.assertEqual(result.All.tolist(), [3.0, 1.0, 4.0, 3.0, 11.0])

def test_margins_no_values_one_row_one_col(self):
# Regression test on pivot table: no values passed but row and col defined
result = self.data[['A', 'B']].pivot_table(rows='A', cols='B', aggfunc=len, margins=True)
self.assertEqual(result.All.tolist(), [4.0, 7.0, 11.0])

def test_margins_no_values_two_row_two_cols(self):
# Regression test on pivot table: no values passed but rows and cols are multi-indexed
self.data['D'] = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k']
result = self.data[['A', 'B', 'C', 'D']].pivot_table(rows=['A', 'B'], cols=['C', 'D'], aggfunc=len, margins=True)
self.assertEqual(result.All.tolist(), [3.0, 1.0, 4.0, 3.0, 11.0])


class TestCrosstab(unittest.TestCase):

Expand Down

0 comments on commit 50790c8

Please sign in to comment.