Skip to content

Commit

Permalink
BUG: Fixes issue pandas-dev#3334: brittle margin in pivot_table.
Browse files Browse the repository at this point in the history
Adds support for margin computation when all columns are used in rows and cols
  • Loading branch information
guyrt committed Aug 1, 2013
1 parent 527db38 commit 1feaf7b
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 82 deletions.
1 change: 1 addition & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ pandas 0.13
set _ref_locs (:issue:`4403`)
- Fixed an issue where hist subplots were being overwritten when they were
called using the top level matplotlib API (:issue:`4408`)
- Fixed (:issue:`3334`). Margins did not compute if values is the index.

pandas 0.12
===========
Expand Down
40 changes: 2 additions & 38 deletions doc/source/v0.13.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,35 +11,12 @@ API changes

- ``read_excel`` now supports an integer in its ``sheetname`` argument giving
the index of the sheet to read in (:issue:`4301`).
- Text parser now treats anything that reads like inf ("inf", "Inf", "-Inf",
"iNf", etc.) as infinity. (:issue:`4220`, :issue:`4219`), affecting
``read_table``, ``read_csv``, etc.
- ``pandas`` now is Python 2/3 compatible without the need for 2to3 thanks to
@jtratner. As a result, pandas now uses iterators more extensively. This
also led to the introduction of substantive parts of the Benjamin
Peterson's ``six`` library into compat. (:issue:`4384`, :issue:`4375`,
:issue:`4372`)
- ``pandas.util.compat`` and ``pandas.util.py3compat`` have been merged into
``pandas.compat``. ``pandas.compat`` now includes many functions allowing
2/3 compatibility. It contains both list and iterator versions of range,
filter, map and zip, plus other necessary elements for Python 3
compatibility. ``lmap``, ``lzip``, ``lrange`` and ``lfilter`` all produce
lists instead of iterators, for compatibility with ``numpy``, subscripting
and ``pandas`` constructors.(:issue:`4384`, :issue:`4375`, :issue:`4372`)
- deprecated ``iterkv``, which will be removed in a future release (was just
an alias of iteritems used to get around ``2to3``'s changes).
(:issue:`4384`, :issue:`4375`, :issue:`4372`)
- ``Series.get`` with negative indexers now returns the same as ``[]`` (:issue:`4390`)

Enhancements
~~~~~~~~~~~~

- ``read_html`` now raises a ``URLError`` instead of catching and raising a
``ValueError`` (:issue:`4303`, :issue:`4305`)
- Added a test for ``read_clipboard()`` and ``to_clipboard()`` (:issue:`4282`)
- Clipboard functionality now works with PySide (:issue:`4282`)
- Added a more informative error message when plot arguments contain
overlapping color and style arguments (:issue:`4402`)

Bug Fixes
~~~~~~~~~
Expand All @@ -52,22 +29,9 @@ Bug Fixes

- Fixed bug in ``PeriodIndex.map`` where using ``str`` would return the str
representation of the index (:issue:`4136`)

- Fixed (:issue:`3334`). Margins did not compute if values is the index.

- Fixed test failure ``test_time_series_plot_color_with_empty_kwargs`` when
using custom matplotlib default colors (:issue:`4345`)

- Fix running of stata IO tests. Now uses temporary files to write
(:issue:`4353`)

- Fixed an issue where ``DataFrame.sum`` was slower than ``DataFrame.mean``
for integer valued frames (:issue:`4365`)

- ``read_html`` tests now work with Python 2.6 (:issue:`4351`)

- Fixed bug where ``network`` testing was throwing ``NameError`` because a
local variable was undefined (:issue:`4381`)

- Suppressed DeprecationWarning associated with internal calls issued by repr() (:issue:`4391`)

See the :ref:`full release notes
<release>` or issue tracker
Expand Down
118 changes: 91 additions & 27 deletions pandas/tools/pivot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@

from pandas import Series, DataFrame
from pandas.core.index import MultiIndex
from pandas.core.reshape import _unstack_multiple
from pandas.tools.merge import concat
from pandas.tools.util import cartesian_product
from pandas.compat import range, lrange, zip
from pandas import compat
import pandas.core.common as com
import numpy as np

Expand Down Expand Up @@ -149,17 +146,64 @@ def pivot_table(data, values=None, rows=None, cols=None, aggfunc='mean',
DataFrame.pivot_table = pivot_table


def _add_margins(table, data, values, rows=None, cols=None, aggfunc=np.mean):
grand_margin = {}
for k, v in compat.iteritems(data[values]):
try:
if isinstance(aggfunc, compat.string_types):
grand_margin[k] = getattr(v, aggfunc)()
else:
grand_margin[k] = aggfunc(v)
except TypeError:
pass
def _add_margins(table, data, values, rows, cols, aggfunc):

grand_margin = _compute_grand_margin(data, values, aggfunc)

if not values and isinstance(table, Series):
# If there are no values and the table is a series, then there is only
# one column in the data. Compute grand margin and return it.
row_key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
return table.append(Series({row_key: grand_margin['All']}))

if values:
marginal_result_set = _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin)
if not isinstance(marginal_result_set, tuple):
return marginal_result_set
result, margin_keys, row_margin = marginal_result_set
else:
marginal_result_set = _generate_marginal_results_without_values(table, data, rows, cols, aggfunc)
if not isinstance(marginal_result_set, tuple):
return marginal_result_set
result, margin_keys, row_margin = marginal_result_set

key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'

row_margin = row_margin.reindex(result.columns)
# populate grand margin
for k in margin_keys:
if isinstance(k, basestring):
row_margin[k] = grand_margin[k]
else:
row_margin[k] = grand_margin[k[0]]

margin_dummy = DataFrame(row_margin, columns=[key]).T

row_names = result.index.names
result = result.append(margin_dummy)
result.index.names = row_names

return result


def _compute_grand_margin(data, values, aggfunc):

if values:
grand_margin = {}
for k, v in data[values].iteritems():
try:
if isinstance(aggfunc, basestring):
grand_margin[k] = getattr(v, aggfunc)()
else:
grand_margin[k] = aggfunc(v)
except TypeError:
pass
return grand_margin
else:
return {'All': aggfunc(data.index)}


def _generate_marginal_results(table, data, values, rows, cols, aggfunc, grand_margin):
if len(cols) > 0:
# need to "interleave" the margins
table_pieces = []
Expand Down Expand Up @@ -198,28 +242,48 @@ def _all_key(key):
row_margin = row_margin.stack()

# slight hack
new_order = [len(cols)] + lrange(len(cols))
new_order = [len(cols)] + range(len(cols))
row_margin.index = row_margin.index.reorder_levels(new_order)
else:
row_margin = Series(np.nan, index=result.columns)

key = ('All',) + ('',) * (len(rows) - 1) if len(rows) > 1 else 'All'
return result, margin_keys, row_margin

row_margin = row_margin.reindex(result.columns)
# populate grand margin
for k in margin_keys:
if len(cols) > 0:
row_margin[k] = grand_margin[k[0]]
else:
row_margin[k] = grand_margin[k]

margin_dummy = DataFrame(row_margin, columns=[key]).T
def _generate_marginal_results_without_values(table, data, rows, cols, aggfunc):
if len(cols) > 0:
# need to "interleave" the margins
margin_keys = []

row_names = result.index.names
result = result.append(margin_dummy)
result.index.names = row_names
def _all_key():
if len(cols) == 1:
return 'All'
return ('All', ) + ('', ) * (len(cols) - 1)

return result
if len(rows) > 0:
margin = data[rows].groupby(rows).apply(aggfunc)
all_key = _all_key()
table[all_key] = margin
result = table
margin_keys.append(all_key)

else:
margin = data.groupby(level=0, axis=0).apply(aggfunc)
all_key = _all_key()
table[all_key] = margin
result = table
margin_keys.append(all_key)
return result
else:
result = table
margin_keys = table.columns

if len(cols):
row_margin = data[cols].groupby(cols).apply(aggfunc)
else:
row_margin = Series(np.nan, index=result.columns)

return result, margin_keys, row_margin


def _convert_by(by):
Expand Down
50 changes: 33 additions & 17 deletions pandas/tools/tests/test_pivot.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import datetime
import unittest

import numpy as np
from numpy.testing import assert_equal

import pandas
from pandas import DataFrame, Series, Index, MultiIndex
from pandas.tools.merge import concat
from pandas.tools.pivot import pivot_table, crosstab
from pandas.compat import range, u, product
import pandas.util.testing as tm


Expand Down Expand Up @@ -75,18 +72,9 @@ def test_pivot_table_dropna(self):
pv_col = df.pivot_table('quantity', 'month', ['customer', 'product'], dropna=False)
pv_ind = df.pivot_table('quantity', ['customer', 'product'], 'month', dropna=False)

m = MultiIndex.from_tuples([(u('A'), u('a')),
(u('A'), u('b')),
(u('A'), u('c')),
(u('A'), u('d')),
(u('B'), u('a')),
(u('B'), u('b')),
(u('B'), u('c')),
(u('B'), u('d')),
(u('C'), u('a')),
(u('C'), u('b')),
(u('C'), u('c')),
(u('C'), u('d'))])
m = MultiIndex.from_tuples([(u'A', u'a'), (u'A', u'b'), (u'A', u'c'), (u'A', u'd'),
(u'B', u'a'), (u'B', u'b'), (u'B', u'c'), (u'B', u'd'),
(u'C', u'a'), (u'C', u'b'), (u'C', u'c'), (u'C', u'd')])

assert_equal(pv_col.columns.values, m.values)
assert_equal(pv_ind.index.values, m.values)
Expand Down Expand Up @@ -211,17 +199,20 @@ def _check_output(res, col, rows=['A', 'B'], cols=['C']):
# no rows
rtable = self.data.pivot_table(cols=['AA', 'BB'], margins=True,
aggfunc=np.mean)
tm.assert_isinstance(rtable, Series)
self.assert_(isinstance(rtable, Series))
for item in ['DD', 'EE', 'FF']:
gmarg = table[item]['All', '']
self.assertEqual(gmarg, self.data[item].mean())

def test_pivot_integer_columns(self):
# caused by upstream bug in unstack
from pandas.util.compat import product
import datetime
import pandas

d = datetime.date.min
data = list(product(['foo', 'bar'], ['A', 'B', 'C'], ['x1', 'x2'],
[d + datetime.timedelta(i) for i in range(20)], [1.0]))
[d + datetime.timedelta(i) for i in xrange(20)], [1.0]))
df = pandas.DataFrame(data)
table = df.pivot_table(values=4, rows=[0, 1, 3], cols=[2])

Expand All @@ -245,6 +236,9 @@ def test_pivot_no_level_overlap(self):
tm.assert_frame_equal(table, expected)

def test_pivot_columns_lexsorted(self):
import datetime
import numpy as np
import pandas

n = 10000

Expand Down Expand Up @@ -296,6 +290,28 @@ def test_pivot_complex_aggfunc(self):

tm.assert_frame_equal(result, expected)

def test_margins_no_values_no_cols(self):
# Regression test on pivot table: no values or cols passed.
result = self.data[['A', 'B']].pivot_table(rows=['A', 'B'], aggfunc=len, margins=True)
result_list = result.tolist()
self.assertEqual(sum(result_list[:-1]), result_list[-1])

def test_margins_no_values_two_rows(self):
# Regression test on pivot table: no values passed but rows are a multi-index
result = self.data[['A', 'B', 'C']].pivot_table(rows=['A', 'B'], cols='C', aggfunc=len, margins=True)
self.assertEqual(result.All.tolist(), [3.0, 1.0, 4.0, 3.0, 11.0])

def test_margins_no_values_one_row_one_col(self):
# Regression test on pivot table: no values passed but row and col defined
result = self.data[['A', 'B']].pivot_table(rows='A', cols='B', aggfunc=len, margins=True)
self.assertEqual(result.All.tolist(), [4.0, 7.0, 11.0])

def test_margins_no_values_two_row_two_cols(self):
# Regression test on pivot table: no values passed but rows and cols are multi-indexed
self.data['D'] = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k']
result = self.data[['A', 'B', 'C', 'D']].pivot_table(rows=['A', 'B'], cols=['C', 'D'], aggfunc=len, margins=True)
self.assertEqual(result.All.tolist(), [3.0, 1.0, 4.0, 3.0, 11.0])


class TestCrosstab(unittest.TestCase):

Expand Down

0 comments on commit 1feaf7b

Please sign in to comment.