diff --git a/holoviews/plotting/mpl/chart.py b/holoviews/plotting/mpl/chart.py index c762a57b85..65956f1d20 100644 --- a/holoviews/plotting/mpl/chart.py +++ b/holoviews/plotting/mpl/chart.py @@ -1,13 +1,18 @@ from __future__ import unicode_literals from itertools import product +from distutils.version import LooseVersion import numpy as np +import matplotlib as mpl from matplotlib import cm from matplotlib import pyplot as plt from matplotlib.collections import LineCollection +from matplotlib.path import Path as MPLPath from matplotlib.dates import date2num, DateFormatter +mpl_version = LooseVersion(mpl.__version__) + import param from ...core import OrderedDict, Dimension @@ -75,7 +80,8 @@ def get_data(self, element, ranges, style): if xs.dtype.kind == 'M': dt_format = Dimension.type_formatters[np.datetime64] dims[0] = dims[0](value_format=DateFormatter(dt_format)) - return (xs, ys), style, {'dimensions': dims} + coords = (ys, xs) if self.invert_axes else (xs, ys) + return coords, style, {'dimensions': dims} def init_artists(self, ax, plot_args, plot_kwargs): xs, ys = plot_args @@ -111,7 +117,14 @@ class ErrorPlot(ChartPlot): _plot_methods = dict(single='errorbar') def init_artists(self, ax, plot_data, plot_kwargs): - _, (bottoms, tops), verts = ax.errorbar(*plot_data, **plot_kwargs) + handles = ax.errorbar(*plot_data, **plot_kwargs) + bottoms, tops = None, None + if mpl_version >= str('2.0'): + _, caps, verts = handles + if caps: + bottoms, tops = caps + else: + _, (bottoms, tops), verts = handles return {'bottoms': bottoms, 'tops': tops, 'verts': verts[0]} @@ -120,8 +133,15 @@ def get_data(self, element, ranges, style): dims = element.dimensions() xs, ys = (element.dimension_values(i) for i in range(2)) yerr = element.array(dimensions=dims[2:4]) - style['yerr'] = yerr.T if len(dims) > 3 else yerr[:, 0] - return (xs, ys), style, {} + + if self.invert_axes: + coords = (ys, xs) + err_key = 'xerr' + else: + coords = (xs, ys) + err_key = 'yerr' + style[err_key] = yerr.T if len(dims) > 3 else yerr[:, 0] + return coords, style, {} def update_handles(self, key, axis, element, ranges, style): @@ -130,30 +150,37 @@ def update_handles(self, key, axis, element, ranges, style): verts = self.handles['verts'] paths = verts.get_paths() - (xs, ys), style, axis_kwargs = self.get_data(element, ranges, style) - - neg_error = element.dimension_values(2) + _, style, axis_kwargs = self.get_data(element, ranges, style) + xs, ys, neg_error = (element.dimension_values(i) for i in range(3)) + samples = len(xs) + npaths = len(paths) pos_error = element.dimension_values(3) if len(element.dimensions()) > 3 else neg_error if self.invert_axes: - bdata = xs - neg_error - tdata = xs + pos_error - tops.set_xdata(bdata) - tops.set_ydata(ys) - bottoms.set_xdata(tdata) - bottoms.set_ydata(ys) - for i, path in enumerate(paths): - path.vertices = np.array([[bdata[i], ys[i]], - [tdata[i], ys[i]]]) + bxs, bys = ys - neg_error, xs + txs, tys = ys + pos_error, xs + new_arrays = [np.array([[bxs[i], xs[i]], [txs[i], xs[i]]]) + for i in range(samples)] else: - bdata = ys - neg_error - tdata = ys + pos_error - bottoms.set_xdata(xs) - bottoms.set_ydata(bdata) - tops.set_xdata(xs) - tops.set_ydata(tdata) - for i, path in enumerate(paths): - path.vertices = np.array([[xs[i], bdata[i]], - [xs[i], tdata[i]]]) + bxs, bys = xs, ys - neg_error + txs, tys = xs, ys + pos_error + new_arrays = [np.array([[xs[i], bys[i]], [xs[i], tys[i]]]) + for i in range(samples)] + + new_paths = [] + for i, arr in enumerate(new_arrays): + if i < npaths: + paths[i].vertices = arr + new_paths.append(paths[i]) + else: + new_paths.append(MPLPath(arr)) + verts.set_paths(new_paths) + + if bottoms: + bottoms.set_xdata(bxs) + bottoms.set_ydata(bys) + if tops: + tops.set_xdata(txs) + tops.set_ydata(tys) return axis_kwargs