Skip to content

Commit

Permalink
Refactor geom.draw_group to take a dataframe
Browse files Browse the repository at this point in the history
Finally got rid of ``geom._make_pinfos`. Had to create a wrapper
`groupby_with_null` around `DataFrame.groupby` to allow grouping
on columns with Null values. Almost at the same time a PR [1]
popped up to probably solve this issue.

---

[1] pandas-dev/pandas#12607
  • Loading branch information
has2k1 committed Mar 17, 2016
1 parent 499e61b commit 95c2a29
Show file tree
Hide file tree
Showing 24 changed files with 648 additions and 592 deletions.
96 changes: 27 additions & 69 deletions ggplot/geoms/geom.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,6 @@ class geom(object):
# not implemented
legend_geom = 'point'

# A matplotlib plot function may require that an aethestic have a
# single unique value. e.g. linestyle='dashed' and not
# linestyle=['dashed', 'dotted', ...].
# A single call to such a function can only plot lines with the
# same linestyle. However, if the plot we want has more than one
# line with different linestyles, we need to group the lines with
# the same linestyle and plot them as one unit.
#
# geoms should fill out this set with such aesthetics so that the
# plot information they receive can be plotted in a single call.
# See: geom_point
_units = set()

# Whether to divide the distance between any two points into
# multiple segments. This is done during coord.transform time
_munch = False
Expand Down Expand Up @@ -167,12 +154,35 @@ def draw_panel(self, data, panel_scales, coord, ax, **params):
"""
data = coord.transform(data, panel_scales, self._munch)
for _, gdata in data.groupby('group'):
pinfos = self._make_pinfos(gdata, params)
for pinfo in pinfos:
self.draw_group(pinfo, panel_scales, coord, ax, **params)
gdata.reset_index(inplace=True, drop=True)
gdata.is_copy = None
self.draw_group(gdata, panel_scales, coord, ax, **params)

@staticmethod
def draw_group(data, panel_scales, coord, ax, **params):
"""
Plot data
"""
msg = "The geom should implement this method."
raise NotImplementedError(msg)

@staticmethod
def draw_group(pinfo, panel_scales, coord, ax, **params):
def draw_unit(data, panel_scales, coord, ax, **params):
"""
Plot data
A matplotlib plot function may require that an aethestic
have a single unique value. e.g. linestyle='dashed' and
not linestyle=['dashed', 'dotted', ...].
A single call to such a function can only plot lines with
the same linestyle. However, if the plot we want has more
than one line with different linestyles, we need to group
the lines with the same linestyle and plot them as one
unit. In this case, draw_group calls this function to do
the plotting.
See: geom_point
"""
msg = "The geom should implement this method."
raise NotImplementedError(msg)

Expand Down Expand Up @@ -284,55 +294,3 @@ def verify_arguments(self, kwargs):
if unknown:
msg = 'Unknown parameters {}'
raise GgplotError(msg.format(unknown))

def _make_pinfos(self, data, params):
units = []
for col in data.columns:
if col in self._units:
units.append(col)

shrinkable = {'alpha', 'fill', 'color', 'size', 'linetype',
'shape'}

def prep(pinfo):
"""
Reduce shrinkable parameters & append zorder
"""
# If it is the same value in the list make it a scalar
# This can help the matplotlib functions draw faster
for ae in set(pinfo) & shrinkable:
with suppress(TypeError, IndexError):
if all(pinfo[ae][0] == v for v in pinfo[ae]):
pinfo[ae] = pinfo[ae][0]
pinfo['zorder'] = params['zorder']
return pinfo

out = []
if units:
# Currently groupby does not like None values in any of
# the columns that participate in the grouping. These
# Nones come in when the default aesthetics are added to
# the data. We drop these columns and after turning the
# the dataframe into a dictionary insert a None for that
# aesthetic
_units = []
_none_units = []
for unit in units:
if data[unit].iloc[0] is None:
_none_units.append(unit)
del data[unit]
else:
_units.append(unit)

for name, _data in data.groupby(_units):
pinfo = _data.to_dict('list')
for ae in _units:
pinfo[ae] = pinfo[ae][0]
for ae in _none_units:
pinfo[ae] = None
out.append(prep(pinfo))
else:
pinfo = data.to_dict('list')
out.append(prep(pinfo))

return out
8 changes: 4 additions & 4 deletions ggplot/geoms/geom_abline.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def draw_panel(self, data, panel_scales, coord, ax, **params):
data = data.drop_duplicates()

for _, gdata in data.groupby('group'):
pinfos = self._make_pinfos(gdata, params)
for pinfo in pinfos:
geom_segment.draw_group(pinfo, panel_scales,
coord, ax, **params)
gdata.reset_index(inplace=True)
gdata.is_copy = None
geom_segment.draw_group(gdata, panel_scales,
coord, ax, **params)
78 changes: 42 additions & 36 deletions ggplot/geoms/geom_boxplot.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
from copy import deepcopy

import numpy as np
import pandas as pd
import matplotlib.lines as mlines
from matplotlib.patches import Rectangle

from ..scales.utils import resolution
from ..utils import make_iterable_ntimes, to_rgba
from ..utils import make_iterable_ntimes, to_rgba, copy_missing_columns
from .geom_point import geom_point
from .geom_segment import geom_segment
from .geom_crossbar import geom_crossbar
Expand All @@ -22,7 +23,8 @@ class geom_boxplot(geom):
'outlier_alpha': 1, 'outlier_color': None,
'outlier_shape': 'o', 'outlier_size': 5,
'outlier_stroke': 0, 'notch': False,
'varwidth': False, 'notchwidth': 0.5}
'varwidth': False, 'notchwidth': 0.5,
'fatten': 2}

def setup_data(self, data):
if 'width' not in data:
Expand Down Expand Up @@ -50,59 +52,63 @@ def setup_data(self, data):
return data

@staticmethod
def draw_group(pinfo, panel_scales, coord, ax, **params):
def draw_group(data, panel_scales, coord, ax, **params):
def flat(*args):
"""Flatten list-likes"""
return np.hstack(args)

def subdict(keys):
d = {}
for key in keys:
d[key] = deepcopy(pinfo[key])
return d

common = subdict(('color', 'size', 'linetype',
'fill', 'group', 'alpha',
'zorder'))

whiskers = subdict(('x',))
whiskers.update(deepcopy(common))
whiskers['x'] = whiskers['x'] * 2
common_columns = ['color', 'size', 'linetype',
'fill', 'group', 'alpha', 'shape']
# whiskers
whiskers = pd.DataFrame({
'x': flat(data['x'], data['x']),
'y': flat(data['upper'], data['lower']),
'yend': flat(data['ymax'], data['ymin'])})
whiskers['xend'] = whiskers['x']
whiskers['y'] = pinfo['upper'] + pinfo['lower']
whiskers['yend'] = pinfo['ymax'] + pinfo['ymin']

box = subdict(('xmin', 'xmax', 'lower', 'middle', 'upper'))
box.update(deepcopy(common))
box['ymin'] = box.pop('lower')
box['y'] = box.pop('middle')
box['ymax'] = box.pop('upper')
box['notchwidth'] = params['notchwidth']
copy_missing_columns(whiskers, data[common_columns])

# box
box_columns = ['xmin', 'xmax', 'lower', 'middle', 'upper']
box = data[common_columns + box_columns].copy()
box.rename(columns={'lower': 'ymin',
'middle': 'y',
'upper': 'ymax'},
inplace=True)

# notch
if params['notch']:
box['ynotchlower'] = pinfo['notchlower']
box['ynotchupper'] = pinfo['notchupper']
box['ynotchlower'] = data['notchlower']
box['ynotchupper'] = data['notchupper']

if 'outliers' in pinfo and len(pinfo['outliers'][0]):
outliers = subdict(('alpha', 'zorder'))
# outliers
try:
num_outliers = len(data['outliers'].iloc[0])
except KeyError:
num_outliers = 0

if num_outliers:
def outlier_value(param):
oparam = 'outlier_{}'.format(param)
if params[oparam] is not None:
return params[oparam]
return pinfo[param]
return data[param].iloc[0]

outliers['y'] = pinfo['outliers'][0]
outliers['x'] = make_iterable_ntimes(pinfo['x'][0],
len(outliers['y']))
outliers = pd.DataFrame({
'y': data['outliers'].iloc[0],
'x': make_iterable_ntimes(data['x'][0],
num_outliers),
'fill': None})
outliers['alpha'] = outlier_value('alpha')
outliers['color'] = outlier_value('color')
outliers['fill'] = None
outliers['shape'] = outlier_value('shape')
outliers['size'] = outlier_value('size')
outliers['stroke'] = outlier_value('stroke')
geom_point.draw_group(outliers, panel_scales,
coord, ax, **params)

# plot
geom_segment.draw_group(whiskers, panel_scales,
coord, ax, **params)
params['fatten'] = geom_crossbar.DEFAULT_PARAMS['fatten']
geom_crossbar.draw_group(box, panel_scales,
coord, ax, **params)

Expand Down
63 changes: 29 additions & 34 deletions ggplot/geoms/geom_crossbar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
unicode_literals)

import numpy as np
import pandas as pd
import matplotlib.lines as mlines
from matplotlib.patches import Rectangle

from ..scales.utils import resolution
from ..utils.exceptions import gg_warn
from ..utils import to_rgba
from ..utils import copy_missing_columns
from .geom import geom
from .geom_polygon import geom_polygon
from .geom_segment import geom_segment
Expand All @@ -33,64 +34,58 @@ def setup_data(self, data):
return data

@staticmethod
def draw_group(pinfo, panel_scales, coord, ax, **params):
y = pinfo['y']
xmin = np.array(pinfo['xmin'])
xmax = np.array(pinfo['xmax'])
ymin = np.array(pinfo['ymin'])
ymax = np.array(pinfo['ymax'])
notchwidth = pinfo.get('notchwidth')
ynotchupper = pinfo.get('ynotchupper')
ynotchlower = pinfo.get('ynotchlower')

keys = ['alpha', 'color', 'fill', 'size',
'linetype', 'zorder']

def copy_keys(d):
for k in keys:
d[k] = pinfo[k]
def draw_group(data, panel_scales, coord, ax, **params):
y = data['y']
xmin = data['xmin']
xmax = data['xmax']
ymin = data['ymin']
ymax = data['ymax']
group = data['group']

# From violin
notchwidth = data.get('notchwidth')
ynotchupper = data.get('ynotchupper')
ynotchlower = data.get('ynotchlower')

def flat(*args):
"""Flatten list-likes"""
return [i for arg in args for i in arg]
return np.hstack(args)

middle = {'x': xmin,
'y': y,
'xend': xmax,
'yend': y,
'group': pinfo['group']}
copy_keys(middle)
middle['size'] = np.asarray(middle['size'])*params['fatten'],
middle = pd.DataFrame({'x': xmin,
'y': y,
'xend': xmax,
'yend': y,
'group': group})
copy_missing_columns(middle, data)
middle['size'] *= params['fatten']

has_notch = ynotchlower is not None and ynotchupper is not None
if has_notch: # 10 points + 1 closing
ynotchlower = np.array(ynotchlower)
ynotchupper = np.array(ynotchupper)
if (any(ynotchlower < ymin) or any(ynotchupper > ymax)):
msg = ("Notch went outside hinges."
" Try setting notch=False.")
gg_warn(msg)

notchindent = (1 - notchwidth) * (xmax-xmin)/2

middle['x'] = np.array(middle['x']) + notchindent
middle['xend'] = np.array(middle['xend']) - notchindent
box = {
middle['x'] += notchindent
middle['xend'] -= notchindent
box = pd.DataFrame({
'x': flat(xmin, xmin, xmin+notchindent, xmin, xmin,
xmax, xmax, xmax-notchindent, xmax, xmax,
xmin),
'y': flat(ymax, ynotchupper, y, ynotchlower, ymin,
ymin, ynotchlower, y, ynotchupper, ymax,
ymax),
'group': np.tile(np.arange(1, len(pinfo['group'])+1), 11)}
'group': np.tile(np.arange(1, len(group)+1), 11)})
else:
# No notch, 4 points + 1 closing
box = {
box = pd.DataFrame({
'x': flat(xmin, xmin, xmax, xmax, xmin),
'y': flat(ymax, ymax, ymax, ymin, ymin),
'group': np.tile(np.arange(1, len(pinfo['group'])+1), 5)}
copy_keys(box)
'group': np.tile(np.arange(1, len(group)+1), 5)})

copy_missing_columns(box, data)
geom_polygon.draw_group(box, panel_scales, coord, ax, **params)
geom_segment.draw_group(middle, panel_scales, coord, ax, **params)

Expand Down
Loading

0 comments on commit 95c2a29

Please sign in to comment.