diff --git a/holoviews/plotting/mpl/__init__.py b/holoviews/plotting/mpl/__init__.py index 4c8dd541a1..74683a6385 100644 --- a/holoviews/plotting/mpl/__init__.py +++ b/holoviews/plotting/mpl/__init__.py @@ -138,7 +138,7 @@ def grid_selector(grid): # Raster plots QuadMesh: QuadMeshPlot, Raster: RasterPlot, - HeatMap: RasterPlot, + HeatMap: HeatMapPlot, Image: RasterPlot, RGB: RasterPlot, HSV: RasterPlot, @@ -177,7 +177,7 @@ def grid_selector(grid): options.Bars = Options('style', ec='k', color=Cycle()) options.Histogram = Options('style', ec='k', facecolor=Cycle()) options.Points = Options('style', color=Cycle(), marker='o') -options.Scatter3D = Options('style', facecolors=Cycle(), marker='o') +options.Scatter3D = Options('style', c=Cycle(), marker='o') options.Scatter3D = Options('plot', fig_size=150) options.Surface = Options('plot', fig_size=150) options.Spikes = Options('style', color='black') diff --git a/holoviews/plotting/mpl/annotation.py b/holoviews/plotting/mpl/annotation.py index 6e9205ee95..f162e357f2 100644 --- a/holoviews/plotting/mpl/annotation.py +++ b/holoviews/plotting/mpl/annotation.py @@ -26,15 +26,12 @@ def initialize_plot(self, ranges=None): self.handles['annotations'] = handles return self._finalize_axis(key, ranges=ranges) - - def update_handles(self, axis, annotation, key, ranges=None): + def update_handles(self, key, axis, annotation, ranges, style): # Clear all existing annotations for element in self.handles['annotations']: element.remove() - self.handles['annotations']=[] - opts = self.style[self.cyclic_index] - self.handles['annotations'] = self.draw_annotation(axis, annotation.data, opts) + self.handles['annotations'] = self.draw_annotation(axis, annotation.data, style) class VLinePlot(AnnotationPlot): @@ -42,9 +39,6 @@ class VLinePlot(AnnotationPlot): style_opts = ['alpha', 'color', 'linewidth', 'linestyle', 'visible'] - def __init__(self, annotation, **params): - super(VLinePlot, self).__init__(annotation, **params) - def draw_annotation(self, axis, position, opts): return [axis.axvline(position, **opts)] @@ -55,9 +49,6 @@ class HLinePlot(AnnotationPlot): style_opts = ['alpha', 'color', 'linewidth', 'linestyle', 'visible'] - def __init__(self, annotation, **params): - super(HLinePlot, self).__init__(annotation, **params) - def draw_annotation(self, axis, position, opts): "Draw a horizontal line on the axis" return [axis.axhline(position, **opts)] @@ -68,9 +59,6 @@ class TextPlot(AnnotationPlot): style_opts = ['alpha', 'color', 'family', 'weight', 'rotation', 'fontsize', 'visible'] - def __init__(self, annotation, **params): - super(TextPlot, self).__init__(annotation, **params) - def draw_annotation(self, axis, data, opts): (x,y, text, fontsize, horizontalalignment, verticalalignment, rotation) = data @@ -90,9 +78,6 @@ class ArrowPlot(AnnotationPlot): style_opts = sorted(set(_arrow_style_opts + _text_style_opts)) - def __init__(self, annotation, **params): - super(ArrowPlot, self).__init__(annotation, **params) - def draw_annotation(self, axis, data, opts): direction, text, xy, points, arrowstyle = data arrowprops = dict({'arrowstyle':arrowstyle}, @@ -113,9 +98,6 @@ class SplinePlot(AnnotationPlot): style_opts = ['alpha', 'edgecolor', 'linewidth', 'linestyle', 'visible'] - def __init__(self, annotation, **params): - super(SplinePlot, self).__init__(annotation, **params) - def draw_annotation(self, axis, data, opts): verts, codes = data patch = patches.PathPatch(matplotlib.path.Path(verts, codes), diff --git a/holoviews/plotting/mpl/chart.py b/holoviews/plotting/mpl/chart.py index 40b8a27255..8e5c6a3bb6 100644 --- a/holoviews/plotting/mpl/chart.py +++ b/holoviews/plotting/mpl/chart.py @@ -115,26 +115,10 @@ class CurvePlot(ChartPlot): style_opts = ['alpha', 'color', 'visible', 'linewidth', 'linestyle', 'marker'] - def initialize_plot(self, ranges=None): - element = self.hmap.last - axis = self.handles['axis'] - key = self.keys[-1] - - ranges = self.compute_ranges(self.hmap, key, ranges) - ranges = match_spec(element, ranges) + def init_artists(self, ax, plot_data, plot_kwargs): + return {'artist': ax.plot(*plot_data, **plot_kwargs)[0]} - xs, ys, xticks = self.get_data(element) - # Create line segments and apply style - style = self.style[self.cyclic_index] - legend = element.label if self.show_legend else '' - line_segment = axis.plot(xs, ys, label=legend, - zorder=self.zorder, **style)[0] - - self.handles['artist'] = line_segment - return self._finalize_axis(self.keys[-1], ranges=ranges, xticks=xticks) - - - def get_data(self, element): + def get_data(self, element, ranges, style): # Create xticks and reorder data if cyclic xticks = None if self.cyclic_range and all(v is not None for v in self.cyclic_range): @@ -146,17 +130,15 @@ def get_data(self, element): else: xs = element.dimension_values(0) ys = element.dimension_values(1) - return xs, ys, xticks + return (xs, ys), style, {'xticks': xticks} - def update_handles(self, axis, element, key, ranges=None): + def update_handles(self, key, axis, element, ranges, style): artist = self.handles['artist'] - xs, ys, xticks = self.get_data(element) - + (xs, ys), style, axis_kwargs = self.get_data(element, ranges, style) artist.set_xdata(xs) artist.set_ydata(ys) - return {'xticks': xticks} - + return axis_kwargs @@ -174,63 +156,51 @@ class ErrorPlot(ChartPlot): 'markerfacecolor', 'markersize', 'solid_capstyle', 'solid_joinstyle', 'dashes', 'color'] + def init_artists(self, ax, plot_data, plot_kwargs): + _, (bottoms, tops), verts = ax.errorbar(*plot_data, **plot_kwargs) + return {'bottoms': bottoms, 'tops': tops, 'verts': verts[0]} - def initialize_plot(self, ranges=None): - element = self.hmap.last - axis = self.handles['axis'] - key = self.keys[-1] - - ranges = self.compute_ranges(self.hmap, key, ranges) - ranges = match_spec(element, ranges) + def get_data(self, element, ranges, style): + style['fmt'] = 'none' dims = element.dimensions() - error_kwargs = dict(self.style[self.cyclic_index], fmt='none', - zorder=self.zorder) + xs, ys = (element.dimension_values(i) for i in range(2)) yerr = element.array(dimensions=dims[2:4]) - error_kwargs['yerr'] = yerr.T if len(dims) > 3 else yerr - _, (bottoms, tops), verts = axis.errorbar(element.dimension_values(0), - element.dimension_values(1), - **error_kwargs) - self.handles['bottoms'] = bottoms - self.handles['tops'] = tops - self.handles['verts'] = verts[0] + style['yerr'] = yerr.T if len(dims) > 3 else yerr + return (xs, ys), style, {} - return self._finalize_axis(self.keys[-1], ranges=ranges) - - def update_handles(self, axis, element, key, ranges=None): + def update_handles(self, key, axis, element, ranges, style): bottoms = self.handles['bottoms'] tops = self.handles['tops'] verts = self.handles['verts'] paths = verts.get_paths() - xvals = element.dimension_values(0) - mean = element.dimension_values(1) - neg_error = element.dimension_values(2) - pos_idx = 3 if len(element.dimensions()) > 3 else 2 - pos_error = element.dimension_values(pos_idx) + (xs, ys), style, axis_kwargs = self.get_data(element, ranges, style) + neg_error = element.dimension_values(2) + pos_error = element.dimension_values(3) if len(element.dimensions()) > 3 else neg_error if self.invert_axes: - bdata = xvals - neg_error - tdata = xvals + pos_error + bdata = xs - neg_error + tdata = xs + pos_error tops.set_xdata(bdata) - tops.set_ydata(mean) + tops.set_ydata(ys) bottoms.set_xdata(tdata) - bottoms.set_ydata(mean) + bottoms.set_ydata(ys) for i, path in enumerate(paths): - path.vertices = np.array([[bdata[i], mean[i]], - [tdata[i], mean[i]]]) + path.vertices = np.array([[bdata[i], ys[i]], + [tdata[i], ys[i]]]) else: - bdata = mean - neg_error - tdata = mean + pos_error - bottoms.set_xdata(xvals) + bdata = ys - neg_error + tdata = ys + pos_error + bottoms.set_xdata(xs) bottoms.set_ydata(bdata) - tops.set_xdata(xvals) + tops.set_xdata(xs) tops.set_ydata(tdata) for i, path in enumerate(paths): - path.vertices = np.array([[xvals[i], bdata[i]], - [xvals[i], tdata[i]]]) - + path.vertices = np.array([[xs[i], bdata[i]], + [xs[i], tdata[i]]]) + return axis_kwargs class AreaPlot(ChartPlot): @@ -242,23 +212,15 @@ class AreaPlot(ChartPlot): 'hatch', 'linestyle', 'joinstyle', 'fill', 'capstyle', 'interpolate'] - def initialize_plot(self, ranges=None): - element = self.hmap.last - axis = self.handles['axis'] - key = self.keys[-1] - - ranges = self.compute_ranges(self.hmap, key, ranges) - ranges = match_spec(element, ranges) - - self.update_handles(axis, element, key, ranges) - - ylabel = str(element.vdims[0]) - return self._finalize_axis(self.keys[-1], ranges=ranges, ylabel=ylabel) - - def get_data(self, element): + def get_data(self, element, ranges, style): xs = element.dimension_values(0) ys = [element.dimension_values(vdim) for vdim in element.vdims] - return tuple([xs]+ys) + return tuple([xs]+ys), style, {} + + def init_artists(self, ax, plot_data, plot_kwargs): + fill_fn = ax.fill_betweenx if self.invert_axes else ax.fill_between + stack = fill_fn(*plot_data, **plot_kwargs) + return {'artist': stack} def get_extents(self, element, ranges): vdims = element.vdims @@ -266,16 +228,6 @@ def get_extents(self, element, ranges): ranges[vdim] = max_range([ranges[vd.name] for vd in vdims]) return super(AreaPlot, self).get_extents(element, ranges) - def update_handles(self, axis, element, key, ranges=None): - if 'artist' in self.handles: - self.handles['artist'].remove() - - # Create line segments and apply style - style = self.style[self.cyclic_index] - data = self.get_data(element) - fill_fn = axis.fill_betweenx if self.invert_axes else axis.fill_between - stack = fill_fn(*data, zorder=self.zorder, **style) - self.handles['artist'] = stack @@ -288,17 +240,16 @@ class SpreadPlot(AreaPlot): Whether to show legend for the plot.""") def __init__(self, element, **params): - self.table = element.table() super(SpreadPlot, self).__init__(element, **params) self._extents = None - def get_data(self, element): + def get_data(self, element, ranges, style): xs = element.dimension_values(0) mean = element.dimension_values(1) neg_error = element.dimension_values(2) pos_idx = 3 if len(element.dimensions()) > 3 else 2 pos_error = element.dimension_values(pos_idx) - return xs, mean-neg_error, mean+pos_error + return (xs, mean-neg_error, mean+pos_error), style, {} class HistogramPlot(ChartPlot): @@ -432,11 +383,7 @@ def _update_artists(self, key, hist, edges, hvals, widths, lims, ranges): bar.set_width(width) - def update_handles(self, axis, element, key, ranges=None): - """ - Update the plot for an animation. - :param axis: - """ + def update_handles(self, key, axis, element, ranges, style): # Process values, axes and style edges, hvals, widths, lims = self._process_hist(element) @@ -578,59 +525,49 @@ class PointPlot(ChartPlot, ColorbarPlot): _disabled_opts = ['size'] - def initialize_plot(self, ranges=None): - points = self.hmap.last - axis = self.handles['axis'] + def init_artists(self, ax, plot_args, plot_kwargs): + return {'artist': ax.scatter(*plot_args, **plot_kwargs)} - ranges = self.compute_ranges(self.hmap, self.keys[-1], ranges) - ranges = match_spec(points, ranges) - xs = points.dimension_values(0) if len(points.data) else [] - ys = points.dimension_values(1) if len(points.data) else [] + def get_data(self, element, ranges, style): + xs, ys = (element.dimension_values(i) for i in range(2)) + self._compute_styles(element, ranges, style) + return (xs, ys), style, {} - style = self.style[self.cyclic_index] - cdim = points.get_dimension(self.color_index) + + def _compute_styles(self, element, ranges, style): + cdim = element.get_dimension(self.color_index) color = style.pop('color', None) if cdim: - cs = points.dimension_values(self.color_index) + cs = element.dimension_values(self.color_index) style['c'] = cs if 'clim' not in style: clims = ranges[cdim.name] style.update(vmin=clims[0], vmax=clims[1]) - else: + elif color: style['c'] = color - edgecolor = style.pop('edgecolors', style.pop('edgecolor', 'none')) - - if points.get_dimension(self.size_index): - style['s'] = self._compute_size(points, style) + style['edgecolors'] = style.pop('edgecolors', style.pop('edgecolor', 'none')) - legend = points.label if self.show_legend else '' - scatterplot = axis.scatter(xs, ys, zorder=self.zorder, label=legend, - edgecolors=edgecolor, **style) - self.handles['artist'] = scatterplot + if element.get_dimension(self.size_index): + sizes = element.dimension_values(self.size_index) + ms = style.pop('s') if 's' in style else plt.rcParams['lines.markersize'] + style['s'] = compute_sizes(sizes, self.size_fn, self.scaling_factor, + self.scaling_method, ms) + style['edgecolors'] = style.pop('edgecolors', 'none') - return self._finalize_axis(self.keys[-1], ranges=ranges) - - def _compute_size(self, element, opts): - sizes = element.dimension_values(self.size_index) - ms = opts.pop('s') if 's' in opts else plt.rcParams['lines.markersize'] - return compute_sizes(sizes, self.size_fn, self.scaling_factor, self.scaling_method, ms) - - - def update_handles(self, axis, element, key, ranges=None): + def update_handles(self, key, axis, element, ranges, style): paths = self.handles['artist'] - paths.set_offsets(element.array(dimensions=[0, 1])) + (xs, ys), style, _ = self.get_data(element, ranges, style) + paths.set_offsets(np.column_stack([xs, ys])) sdim = element.get_dimension(self.size_index) if sdim: - opts = self.style[self.cyclic_index] - paths.set_sizes(self._compute_size(element, opts)) + paths.set_sizes(style['s']) cdim = element.get_dimension(self.color_index) if cdim: - cs = element.dimension_values(self.color_index) - paths.set_clim(ranges[cdim.name]) - paths.set_array(cs) + paths.set_clim(style['vmin'], style['vmax']) + paths.set_array(style['c']) @@ -679,104 +616,73 @@ def _get_map_info(self, vmap): """ Get the minimum sample distance and maximum magnitude """ - dists = [] - for vfield in vmap: - dists.append(self._get_min_dist(vfield)) - return min(dists) - - - def _get_info(self, vfield, input_scale, ranges): - ndims = len(vfield.dimensions()) - xs = vfield.dimension_values(0) if len(vfield.data) else [] - ys = vfield.dimension_values(1) if len(vfield.data) else [] - radians = vfield.dimension_values(2) if len(vfield.data) else [] - magnitudes = vfield.dimension_values(3) if ndims>=4 else np.array([1.0] * len(xs)) - colors = magnitudes if self.color_dim == 'magnitude' else radians - - if ndims >= 4: - magnitude_dim = vfield.get_dimension(3).name - _, max_magnitude = ranges[magnitude_dim] - else: - max_magnitude = 1.0 - - min_dist = self._min_dist if self._min_dist else self._get_min_dist(vfield) - - if self.normalize_lengths and max_magnitude != 0: - magnitudes = magnitudes / max_magnitude - - return (xs, ys, list((radians / np.pi) * 180), - magnitudes, colors, input_scale / min_dist) + return np.min([self._get_min_dist(vfield) for vfield in vmap]) def _get_min_dist(self, vfield): "Get the minimum sampling distance." - xys = np.array([complex(x,y) for x,y in zip(vfield.dimension_values(0), - vfield.dimension_values(1))]) + xys = vfield.array([0, 1]).view(dtype=np.complex128) m, n = np.meshgrid(xys, xys) - distances = abs(m-n) + distances = np.abs(m-n) np.fill_diagonal(distances, np.inf) - return distances.min() - - - def initialize_plot(self, ranges=None): - vfield = self.hmap.last - axis = self.handles['axis'] - - colorized = self.color_dim is not None - kwargs = self.style[self.cyclic_index] - input_scale = kwargs.pop('scale', 1.0) - ranges = self.compute_ranges(self.hmap, self.keys[-1], ranges) - ranges = match_spec(vfield, ranges) - xs, ys, angles, lens, colors, scale = self._get_info(vfield, input_scale, ranges) - - args = (xs, ys, lens, [0.0] * len(vfield)) - args = args + (colors,) if colorized else args - + return distances.min() + + + def get_data(self, element, ranges, style): + input_scale = style.pop('scale', 1.0) + mag_dim = element.get_dimension(3) + xs = element.dimension_values(0) if len(element.data) else [] + ys = element.dimension_values(1) if len(element.data) else [] + radians = element.dimension_values(2) if len(element.data) else [] + angles = list(np.rad2deg(radians)) + scale = input_scale / self._min_dist + + if mag_dim: + magnitudes = element.dimension_values(3) + _, max_magnitude = ranges[mag_dim.name] + if self.normalize_lengths and max_magnitude != 0: + magnitudes = magnitudes / max_magnitude + else: + magnitudes = np.ones(len(xs)) + + args = (xs, ys, magnitudes, [0.0] * len(element)) + if self.color_dim: + colors = magnitudes if self.color_dim == 'magnitude' else radians + args = args + (colors,) + if self.color_dim == 'angle': + style['clim'] = element.get_dimension(2).range + elif self.color_dim == 'magnitude': + magnitude_dim = element.get_dimension(3).name + style['clim'] = ranges[magnitude_dim] + style.pop('color', None) + + if 'pivot' not in style: style['pivot'] = 'mid' if not self.arrow_heads: - kwargs['headaxislength'] = 0 - - if 'pivot' not in kwargs: kwargs['pivot'] = 'mid' + style['headaxislength'] = 0 + style.update(dict(scale=scale, angles=angles)) - legend = vfield.label if self.show_legend else '' - quiver = axis.quiver(*args, zorder=self.zorder, units='x', label=legend, - scale_units='x', scale = scale, angles = angles , - **({k:v for k,v in kwargs.items() if k!='color'} - if colorized else kwargs)) + return args, style, {} - if self.color_dim == 'angle': - clims = vfield.get_dimension(2).range - quiver.set_clim(clims) - elif self.color_dim == 'magnitude': - magnitude_dim = vfield.get_dimension(3).name - quiver.set_clim(ranges[magnitude_dim]) + def init_artists(self, ax, plot_args, plot_kwargs): + quiver = ax.quiver(*plot_args, units='x', scale_units='x', **plot_kwargs) + return {'artist': quiver} - self.handles['axis'].add_collection(quiver) - self.handles['artist'] = quiver - self.handles['input_scale'] = input_scale - return self._finalize_axis(self.keys[-1], ranges=ranges) - - - def update_handles(self, axis, element, key, ranges=None): - artist = self.handles['artist'] - artist.set_offsets(element.array()[:,0:2]) - input_scale = self.handles['input_scale'] - ranges = self.compute_ranges(self.hmap, key, ranges) - ranges = match_spec(element, ranges) - - xs, ys, angles, lens, colors, scale = self._get_info(element, input_scale, ranges) + def update_handles(self, key, axis, element, ranges, style): + args, style, axis_kwargs = self.get_data(element, ranges, style) # Set magnitudes, angles and colors if supplied. quiver = self.handles['artist'] - quiver.U = lens - quiver.angles = angles - if self.color_dim is not None: - quiver.set_array(colors) + quiver.set_offsets(np.column_stack(args[:2])) + quiver.U = args[2] + quiver.angles = style['angles'] + if self.color_dim: + quiver.set_array(args[-1]) if self.color_dim == 'magnitude': - magnitude_dim = element.get_dimension(3).name - quiver.set_clim(ranges[magnitude_dim]) + quiver.set_clim(style['clim']) + return axis_kwargs class BarPlot(LegendPlot): @@ -984,7 +890,7 @@ def _create_bars(self, axis, element): return bars, xticks, xlabel - def update_handles(self, axis, element, key, ranges=None): + def update_handles(self, key, axis, element, ranges, style): dims = element.dimensions('key', label=True) ndims = len(dims) ci, gi, si = self.category_index, self.group_index, self.stack_index @@ -1024,29 +930,12 @@ class SpikesPlot(PathPlot): style_opts = PathPlot.style_opts + ['cmap'] - def initialize_plot(self, ranges=None): - lines = self.hmap.last - key = self.keys[-1] + def init_artists(self, ax, plot_args, plot_kwargs): + line_segments = LineCollection(*plot_args, **plot_kwargs) + ax.add_collection(line_segments) + return {'artist': line_segments} - ranges = self.compute_ranges(self.hmap, key, ranges) - ranges = match_spec(lines, ranges) - style = self.style[self.cyclic_index] - label = lines.label if self.show_legend else '' - - data, array, clim = self.get_data(lines, ranges) - if array is not None: - style['array'] = array - style['clim'] = clim - - line_segments = LineCollection(data, label=label, - zorder=self.zorder, **style) - self.handles['artist'] = line_segments - self.handles['axis'].add_collection(line_segments) - - return self._finalize_axis(key, ranges=ranges) - - - def get_data(self, element, ranges): + def get_data(self, element, ranges, style): dimensions = element.dimensions(label=True) ndims = len(dimensions) @@ -1065,18 +954,20 @@ def get_data(self, element, ranges): if cdim: array = element.dimension_values(cdim) clim = ranges[cdim.name] - return data, array, clim + style['array'] = array + style['clim'] = clim + return (np.array(data),), style, {} - def update_handles(self, axis, element, key, ranges=None): + def update_handles(self, key, axis, element, ranges, style): artist = self.handles['artist'] - data, array, clim = self.get_data(element, ranges) + (data,), kwargs, axis_kwargs = self.get_data(element, ranges, style) artist.set_paths(data) - visible = self.style[self.cyclic_index].get('visible', True) - artist.set_visible(visible) - if array is not None: - artist.set_clim(clim) - artist.set_array(array) + artist.set_visible(style.get('visible', True)) + if 'array' not in kwargs: + artist.set_clim(kwargs['clim']) + artist.set_array(kwargs['array']) + return axis_kwargs class SideSpikesPlot(AdjoinedPlot, SpikesPlot): @@ -1125,23 +1016,8 @@ class BoxPlot(ChartPlot): def get_extents(self, element, ranges): return (np.NaN,)*4 - def initialize_plot(self, ranges=None): - element = self.hmap.last - axis = self.handles['axis'] - key = self.keys[-1] - - ranges = self.compute_ranges(self.hmap, key, ranges) - ranges = match_spec(element, ranges) - xlabel = ','.join([str(d) for d in element.kdims]) - ylabel = str(element.vdims[0]) - - self.handles['artist'] = self.get_artist(element, axis) - - return self._finalize_axis(self.keys[-1], ranges=ranges, xlabel=xlabel, - ylabel=ylabel) - - def get_artist(self, element, axis): + def get_data(self, element, ranges, style): groups = element.groupby(element.kdims) data, labels = [], [] @@ -1156,16 +1032,26 @@ def get_artist(self, element, axis): label = key data.append(group[group.vdims[0]]) labels.append(label) - return axis.boxplot(data, labels=labels, vert=not self.invert_axes, - **self.style[self.cyclic_index]) + style['labels'] = labels + style.pop('zorder') + style.pop('label') + style['vert'] = not self.invert_axes + + xlabel = ','.join([str(d) for d in element.kdims]) + ylabel = str(element.vdims[0]) + + return (data,), style, {'xlabel': xlabel, 'ylabel': ylabel} - def update_handles(self, axis, element, key, ranges=None): + def init_artists(self, ax, plot_args, plot_kwargs): + boxplot = ax.boxplot(*plot_args, **plot_kwargs) + return {'artist': boxplot} + + + def teardown_handles(self): for k, group in self.handles['artist'].items(): for v in group: v.remove() - self.handles['artist'] = self.get_artist(element, axis) - class SideBoxPlot(AdjoinedPlot, BoxPlot): diff --git a/holoviews/plotting/mpl/chart3d.py b/holoviews/plotting/mpl/chart3d.py index 3824c114c6..2c05b1d72e 100644 --- a/holoviews/plotting/mpl/chart3d.py +++ b/holoviews/plotting/mpl/chart3d.py @@ -109,45 +109,29 @@ class Scatter3DPlot(Plot3D, PointPlot): allow_None=True, doc=""" Index of the dimension from which the sizes will the drawn.""") - def initialize_plot(self, ranges=None): - axis = self.handles['axis'] - points = self.hmap.last - ranges = self.compute_ranges(self.hmap, self.keys[-1], ranges) - ranges = match_spec(points, ranges) - key = self.keys[-1] - xs, ys, zs = (points.dimension_values(i) for i in range(3)) - - style = self.style[self.cyclic_index] - cdim = points.get_dimension(self.color_index) - if cdim and 'cmap' in style: - cs = points.dimension_values(self.color_index) - style['c'] = cs - if 'clim' not in style: - clims = ranges[cdim.name] - style.update(vmin=clims[0], vmax=clims[1]) - if points.get_dimension(self.size_index): - style['s'] = self._compute_size(points, style) - - scatterplot = axis.scatter(xs, ys, zs, zorder=self.zorder, **style) - - self.handles['axis'].add_collection(scatterplot) - self.handles['artist'] = scatterplot - - return self._finalize_axis(key, ranges=ranges) - - def update_handles(self, axis, points, key, ranges=None): + def get_data(self, element, ranges, style): + xs, ys, zs = (element.dimension_values(i) for i in range(3)) + self._compute_styles(element, ranges, style) + # Temporary fix until color handling is deterministic in mpl+py3 + if not element.get_dimension(self.color_index) and 'c' in style: + style['color'] = style['c'] + return (xs, ys, zs), style, {} + + def init_artists(self, ax, plot_data, plot_kwargs): + scatterplot = ax.scatter(*plot_data, **plot_kwargs) + ax.add_collection(scatterplot) + return {'artist': scatterplot} + + def update_handles(self, key, axis, element, ranges, style): artist = self.handles['artist'] - artist._offsets3d = tuple(points[d] for d in points.dimensions()) - cdim = points.get_dimension(self.color_index) - style = self.style[self.cyclic_index] + artist._offsets3d, style, _ = self.get_data(element, ranges, style) + cdim = element.get_dimension(self.color_index) if cdim and 'cmap' in style: - cs = points.dimension_values(self.color_index) clim = style['clim'] if 'clim' in style else ranges[cdim.name] cmap = cm.get_cmap(style['cmap']) - artist._facecolor3d = map_colors(cs, clim, cmap, False) - if points.get_dimension(self.size_index): - artist.set_sizes(self._compute_size(points, style)) - + artist._facecolor3d = map_colors(style['c'], clim, cmap, hex=False) + if element.get_dimension(self.size_index): + artist.set_sizes(style['s']) @@ -171,36 +155,24 @@ class SurfacePlot(Plot3D): style_opts = ['antialiased', 'cmap', 'color', 'shade', 'linewidth', 'facecolors', 'rstride', 'cstride'] - def initialize_plot(self, ranges=None): - element = self.hmap.last - key = self.keys[-1] - - ranges = self.compute_ranges(self.hmap, self.keys[-1], ranges) - ranges = match_spec(element, ranges) - - self.update_handles(self.handles['axis'], element, key, ranges) - return self._finalize_axis(key, ranges=ranges) - + def init_artists(self, ax, plot_data, plot_kwargs): + if self.plot_type == "wireframe": + artist = ax.plot_wireframe(*plot_data, **plot_kwargs) + elif self.plot_type == "surface": + artist = ax.plot_surface(*plot_data, **plot_kwargs) + elif self.plot_type == "contour": + artist = ax.contour3D(*plot_data, **plot_kwargs) + return {'artist': artist} - def update_handles(self, axis, element, key, ranges=None): - if 'artist' in self.handles: - self.handles['axis'].collections.remove(self.handles['artist']) + def get_data(self, element, ranges, style): mat = element.data rn, cn = mat.shape l, b, zmin, r, t, zmax = self.get_extents(element, ranges) r, c = np.mgrid[l:r:(r-l)/float(rn), b:t:(t-b)/float(cn)] - - style_opts = self.style[self.cyclic_index] - - if self.plot_type == "wireframe": - self.handles['artist'] = self.handles['axis'].plot_wireframe(r, c, mat, **style_opts) - elif self.plot_type == "surface": - style_opts['vmin'] = zmin - style_opts['vmax'] = zmax - self.handles['artist'] = self.handles['axis'].plot_surface(r, c, mat, **style_opts) - elif self.plot_type == "contour": - self.handles['artist'] = self.handles['axis'].contour3D(r, c, mat, **style_opts) - + style['vmin'] = zmin + style['vmax'] = zmax + return (r, c, mat), style, {} + class TrisurfacePlot(Plot3D): @@ -214,24 +186,13 @@ class TrisurfacePlot(Plot3D): style_opts = ['cmap', 'color', 'shade', 'linewidth', 'edgecolor'] - def initialize_plot(self, ranges=None): - element = self.hmap.last - key = self.keys[-1] - - ranges = self.compute_ranges(self.hmap, self.keys[-1], ranges) - ranges = match_spec(element, ranges) - - self.update_handles(self.handles['axis'], element, key, ranges) - return self._finalize_axis(key, ranges=ranges) - - - def update_handles(self, axis, element, key, ranges=None): - if 'artist' in self.handles: - self.handles['axis'].collections.remove(self.handles['artist']) - style_opts = self.style[self.cyclic_index] + def get_data(self, element, ranges, style): dims = element.dimensions(label=True) vrange = ranges[dims[2]] + style['vmin'] = vrange[0] + style['vmax'] = vrange[1] x, y, z = [element.dimension_values(d) for d in dims] - artist = axis.plot_trisurf(x, y, z, vmax=vrange[1], - vmin=vrange[0], **style_opts) - self.handles['artist'] = artist + return (x, y, z), style, {} + + def init_artists(self, ax, plot_data, plot_kwargs): + return {'artist': ax.plot_trisurf(*plot_data, **plot_kwargs)} diff --git a/holoviews/plotting/mpl/element.py b/holoviews/plotting/mpl/element.py index 55ecef9c71..cb58953f9c 100644 --- a/holoviews/plotting/mpl/element.py +++ b/holoviews/plotting/mpl/element.py @@ -424,19 +424,57 @@ def update_frame(self, key, ranges=None, element=None): handle.set_visible(element is not None) if element is None: return + ranges = self.compute_ranges(self.hmap, key, ranges) if not self.adjoined: ranges = util.match_spec(element, ranges) - axis_kwargs = self.update_handles(axis, element, key if element is not None else {}, ranges) + + label = element.label if self.show_legend else '' + style = dict(label=label, zorder=self.zorder, **self.style[self.cyclic_index]) + axis_kwargs = self.update_handles(key, axis, element, ranges, style) self._finalize_axis(key, ranges=ranges, **(axis_kwargs if axis_kwargs else {})) - def update_handles(self, axis, view, key, ranges=None): + def initialize_plot(self, ranges=None): + element = self.hmap.last + ax = self.handles['axis'] + key = list(self.hmap.data.keys())[-1] + dim_map = dict(zip((d.name for d in self.hmap.kdims), key)) + key = tuple(dim_map.get(d.name, None) for d in self.dimensions) + + ranges = self.compute_ranges(self.hmap, key, ranges) + ranges = util.match_spec(element, ranges) + + style = dict(zorder=self.zorder, **self.style[self.cyclic_index]) + if self.show_legend: + style['label'] = element.label + + plot_data, plot_kwargs, axis_kwargs = self.get_data(element, ranges, style) + handles = self.init_artists(ax, plot_data, plot_kwargs) + self.handles.update(handles) + + return self._finalize_axis(self.keys[-1], ranges=ranges, **axis_kwargs) + + + def update_handles(self, key, axis, element, ranges, style): """ Update the elements of the plot. - :param axis: """ - raise NotImplementedError + self.teardown_handles() + plot_data, plot_kwargs, axis_kwargs = self.get_data(element, ranges, style) + handles = self.init_artists(axis, plot_data, plot_kwargs) + self.handles.update(handles) + return axis_kwargs + + def teardown_handles(self): + """ + If no custom update_handles method is supplied this method + is called to tear down any previous handles before replacing + them. + """ + if 'artist' in self.handles: + self.handles['artist'].remove() + @@ -548,7 +586,8 @@ def _norm_kwargs(self, element, ranges, opts): linthresh=clim[1]/np.e) else: norm = colors.LogNorm(vmin=clim[0], vmax=clim[1]) - return clim, norm, opts + opts['norm'] = norm + opts['clim'] = clim @@ -711,40 +750,3 @@ def update_frame(self, key, ranges=None, element=None): self._adjust_legend(element, axis) self._finalize_axis(key, ranges=ranges) - - - -class DrawPlot(ElementPlot): - """ - A DrawPlot is an ElementPlot that uses a draw method for - rendering. The draw method is also called per update such that a - full redraw is triggered per frame. - - Although not optimized for HoloMaps (due to the full redraw), - DrawPlot is very easy to subclass to interface HoloViews with any - third-party libraries offering matplotlib plotting functionality. - """ - - _abstract = True - - def draw(self, axis, element, ranges=None): - """ - The only method that needs to be overridden in subclasses. - - The current axis and element are supplied as arguments. The - job of this function is to apply the appropriate matplotlib - commands to render the element to the supplied axis. - """ - raise NotImplementedError - - def initialize_plot(self, ranges=None): - element = self.hmap.last - key = self.keys[-1] - ranges = self.compute_ranges(self.hmap, key, ranges) - ranges = util.match_spec(element, ranges) - self.draw(self.handles['axis'], self.hmap.last, ranges) - return self._finalize_axis(self.keys[-1], ranges=ranges) - - def update_handles(self, axis, element, key, ranges=None): - if self.zorder == 0 and axis: axis.cla() - self.draw(axis, element, ranges) diff --git a/holoviews/plotting/mpl/pandas.py b/holoviews/plotting/mpl/pandas.py index ea79e617ef..2c5635ae79 100644 --- a/holoviews/plotting/mpl/pandas.py +++ b/holoviews/plotting/mpl/pandas.py @@ -123,8 +123,8 @@ def _validate(self, dfview): raise Exception("Multiple %s plots cannot be composed." % self.plot_type) - def _update_plot(self, axis, view): - style = self._process_style(self.style[self.cyclic_index]) + def _update_plot(self, axis, view, style): + style = self._process_style(style) if self.plot_type == 'scatter_matrix': pd.scatter_matrix(view.data, ax=axis, **style) elif self.plot_type == 'autocorrelation_plot': @@ -136,14 +136,14 @@ def _update_plot(self, axis, view): getattr(view.data, self.plot_type)(ax=axis, **style) - def update_handles(self, axis, view, key, ranges=None): + def update_handles(self, key, axis, view, ranges, style): """ Update the plot for an animation. """ if not self.plot_type in ['hist', 'scatter_matrix']: if self.zorder == 0 and axis: axis.cla() - self._update_plot(axis, view) + self._update_plot(axis, view, style) Store.register({DataFrameView: DFrameViewPlot, diff --git a/holoviews/plotting/mpl/path.py b/holoviews/plotting/mpl/path.py index 6100bdf2cb..1baed8bc81 100644 --- a/holoviews/plotting/mpl/path.py +++ b/holoviews/plotting/mpl/path.py @@ -15,29 +15,19 @@ class PathPlot(ElementPlot): style_opts = ['alpha', 'color', 'linestyle', 'linewidth', 'visible'] - def __init__(self, *args, **params): - super(PathPlot, self).__init__(*args, **params) + def get_data(self, element, ranges, style): + return (element.data,), style, {} - def initialize_plot(self, ranges=None): - lines = self.hmap.last - key = self.keys[-1] - ranges = self.compute_ranges(self.hmap, key, ranges) - ranges = match_spec(lines, ranges) - style = self.style[self.cyclic_index] - label = lines.label if self.show_legend else '' - line_segments = LineCollection(lines.data, label=label, - zorder=self.zorder, **style) - self.handles['artist'] = line_segments - self.handles['axis'].add_collection(line_segments) + def init_artists(self, ax, plot_args, plot_kwargs): + line_segments = LineCollection(*plot_args, **plot_kwargs) + ax.add_collection(line_segments) + return {'artist': line_segments} - return self._finalize_axis(key, ranges=ranges) - - - def update_handles(self, axis, element, key, ranges=None): + def update_handles(self, key, axis, element, ranges, style): artist = self.handles['artist'] - artist.set_paths(element.data) - visible = self.style[self.cyclic_index].get('visible', True) - artist.set_visible(visible) + data, style, axis_kwargs = self.get_data(element, ranges, style) + artist.set_paths(data[0]) + artist.set_visible(style.get('visible', True)) @@ -56,53 +46,32 @@ class PolygonPlot(ColorbarPlot): style_opts = ['alpha', 'cmap', 'facecolor', 'edgecolor', 'linewidth', 'hatch', 'linestyle', 'joinstyle', 'fill', 'capstyle'] - def initialize_plot(self, ranges=None): - element = self.hmap.last - key = self.keys[-1] - axis = self.handles['axis'] - ranges = self.compute_ranges(self.hmap, key, ranges) - ranges = match_spec(element, ranges) - collection, polys = self._create_polygons(element, ranges) - axis.add_collection(collection) - self.handles['polys'] = polys - - if self.colorbar: - self._draw_colorbar(collection, element) - - self.handles['artist'] = collection - - return self._finalize_axis(self.keys[-1], ranges=ranges) - - def _create_polygons(self, element, ranges): + def get_data(self, element, ranges, style): value = element.level vdim = element.vdims[0] - - style = self.style[self.cyclic_index] polys = [] for segments in element.data: if segments.shape[0]: polys.append(Polygon(segments)) - legend = element.label if self.show_legend else '' - collection = PatchCollection(polys, clim=ranges[vdim.name], - zorder=self.zorder, label=legend, **style) + style['clim'] = ranges[vdim.name] if value is not None and np.isfinite(value): - collection.set_array(np.array([value]*len(polys))) - return collection, polys + style['array'] = np.array([value]*len(polys)) + return (polys,), style, {} + def init_artists(self, ax, plot_args, plot_kwargs): + collection = PatchCollection(*plot_args, **plot_kwargs) + ax.add_collection(collection) + if self.colorbar: + self._draw_colorbar(collection, element) + return {'artist': collection, 'polys': plot_args[0]} - def update_handles(self, axis, element, key, ranges=None): - vdim = element.vdims[0] - collection = self.handles['artist'] - value = element.level + def update_handles(self, key, axis, element, ranges, style): + collection = self.handles['artist'] if any(not np.array_equal(data, poly.get_xy()) for data, poly in zip(element.data, self.handles['polys'])): - collection.remove() - collection, polys = self._create_polygons(element, ranges) - self.handles['polys'] = polys - self.handles['artist'] = collection - axis.add_collection(collection) + return super(PolygonPlot, self).update_handles(key, axis, element, ranges, style) elif value is not None and np.isfinite(value): collection.set_array(np.array([value]*len(element.data))) collection.set_clim(ranges[vdim.name]) diff --git a/holoviews/plotting/mpl/plot.py b/holoviews/plotting/mpl/plot.py index 4babee9613..88693d9fbe 100644 --- a/holoviews/plotting/mpl/plot.py +++ b/holoviews/plotting/mpl/plot.py @@ -220,8 +220,13 @@ def update_frame(self, key, ranges=None): ranges = self.compute_ranges(self.layout, key, ranges) for subplot in self.subplots.values(): subplot.update_frame(key, ranges=ranges) - axis = self.handles['axis'] - self.update_handles(axis, self.layout, key, ranges) + + title = self._format_title(key) if self.show_title else '' + if 'title' in self.handles: + self.handles['title'].set_text(title) + else: + title = axis.set_title(title, **self._fontsize('title')) + self.handles['title'] = title @@ -441,17 +446,6 @@ def _readjust_axes(self, axis): self._adjust_subplots(self.handles['axis'], self.subaxes) - def update_handles(self, axis, view, key, ranges=None): - """ - Should be called by the update_frame class to update - any handles on the plot. - """ - if self.show_title: - title = axis.set_title(self._format_title(key), - **self._fontsize('title')) - self.handles['title'] = title - - def _layout_axis(self, layout, axis): fig = self.handles['fig'] axkwargs = {'gid': str(self.position)} if axis else {} @@ -911,11 +905,6 @@ def _compute_gridspec(self, layout): padding = dict(w_pad=self.tight_padding, h_pad=self.tight_padding) self.gs.tight_layout(self.handles['fig'], rect=self.fig_bounds, **padding) - # Create title handle - if self.show_title and len(self.coords) > 1: - title = self.handles['fig'].suptitle('', **self._fontsize('title')) - self.handles['title'] = title - return layout_subplots, layout_axes, collapsed_layout @@ -1033,23 +1022,19 @@ def _create_subplots(self, layout, positions, layout_dimensions, ranges, axes={} return subplots, adjoint_clone, projections - def update_handles(self, axis, view, key, ranges=None): - """ - Should be called by the update_frame class to update - any handles on the plot. - """ - if self.show_title and 'title' in self.handles and len(self.coords) > 1: - self.handles['title'].set_text(self._format_title(key)) - - def initialize_plot(self): axis = self.handles['axis'] - self.update_handles(axis, None, self.keys[-1]) - - ranges = self.compute_ranges(self.layout, self.keys[-1], None) + key = self.keys[-1] + ranges = self.compute_ranges(self.layout, key, None) for subplot in self.subplots.values(): subplot.initialize_plot(ranges=ranges) + # Create title handle + if self.show_title and len(self.coords) > 1: + title = self._format_title(key) + title = self.handles['fig'].suptitle(title, **self._fontsize('title')) + self.handles['title'] = title + return self._finalize_axis(None) diff --git a/holoviews/plotting/mpl/raster.py b/holoviews/plotting/mpl/raster.py index 1d205f396f..f8879b648b 100644 --- a/holoviews/plotting/mpl/raster.py +++ b/holoviews/plotting/mpl/raster.py @@ -27,9 +27,6 @@ class RasterPlot(ColorbarPlot): situate_axes = param.Boolean(default=False, doc=""" Whether to situate the image relative to other plots. """) - show_values = param.Boolean(default=False, doc=""" - Whether to annotate each pixel with its value.""") - symmetric = param.Boolean(default=False, doc=""" Whether to make the colormap symmetric around zero.""") @@ -54,18 +51,15 @@ def get_extents(self, element, ranges): return element.extents - def initialize_plot(self, ranges=None): - element = self.hmap.last - axis = self.handles['axis'] + def _compute_ticks(self, element, ranges): + return None, None - ranges = self.compute_ranges(self.hmap, self.keys[-1], ranges) - ranges = match_spec(element, ranges) + def get_data(self, element, ranges, style): xticks, yticks = self._compute_ticks(element, ranges) - opts = self.style[self.cyclic_index] if element.depth != 1: - opts.pop('cmap', None) + style.pop('cmap', None) data = element.data if isinstance(element, Image): @@ -77,49 +71,46 @@ def initialize_plot(self, ranges=None): if isinstance(element, RGB): data = element.rgb.data - elif isinstance(element, HeatMap): - data = element.raster - data = np.ma.array(data, mask=np.logical_not(np.isfinite(data))) - cmap_name = opts.pop('cmap', None) - cmap = copy.copy(plt.cm.get_cmap('gray' if cmap_name is None else cmap_name)) - cmap.set_bad('w', 1.) - opts['cmap'] = cmap + self._norm_kwargs(element, ranges, style) + style['extent'] = [l, r, b, t] - clim, norm, opts = self._norm_kwargs(element, ranges, opts) - im = axis.imshow(data, extent=[l, r, b, t], zorder=self.zorder, - clim=clim, norm=norm, **opts) - self.handles['artist'] = im + return [data], style, {'xticks': xticks, 'yticks': yticks} - if isinstance(element, HeatMap): - self.handles['axis'].set_aspect(float(r - l)/(t-b)) - self.handles['annotations'] = {} - if self.show_values: - self._annotate_values(element) + def init_artists(self, ax, plot_args, plot_kwargs): + im = ax.imshow(*plot_args, **plot_kwargs) + return {'artist': im} - return self._finalize_axis(self.keys[-1], ranges=ranges, - xticks=xticks, yticks=yticks) + def update_handles(self, key, axis, element, ranges, style): + im = self.handles['artist'] + data, style, axis_kwargs = self.get_data(element, ranges, style) + l, r, b, t = style['extent'] + im.set_data(data[0]) + im.set_extent((l, r, b, t)) + im.set_clim(style['clim']) + if 'norm' in style: + im.norm = style['norm'] - def _compute_ticks(self, element, ranges): - if isinstance(element, HeatMap): - xdim, ydim = element.kdims - dim1_keys, dim2_keys = [element.dimension_values(i, True) - for i in range(2)] - num_x, num_y = len(dim1_keys), len(dim2_keys) - x0, y0, x1, y1 = element.extents - xstep, ystep = ((x1-x0)/num_x, (y1-y0)/num_y) - xpos = np.linspace(x0+xstep/2., x1-xstep/2., num_x) - ypos = np.linspace(y0+ystep/2., y1-ystep/2., num_y) - xlabels = [xdim.pprint_value(k) for k in dim1_keys] - ylabels = [ydim.pprint_value(k) for k in dim2_keys] - return (xpos, xlabels), (ypos, ylabels) - else: - return None, None + return axis_kwargs + + +class HeatMapPlot(RasterPlot): + + show_values = param.Boolean(default=False, doc=""" + Whether to annotate each pixel with its value.""") + + def _annotate_plot(self, ax, annotations): + handles = {} + for plot_coord, text in annotations.items(): + handles[plot_coord] = ax.annotate(text, xy=plot_coord, + xycoords='axes fraction', + horizontalalignment='center', + verticalalignment='center') + return handles def _annotate_values(self, element): - axis = self.handles['axis'] val_dim = element.vdims[0] vals = np.rot90(element.raster, 3).flatten() d1uniq, d2uniq = [np.unique(element.dimension_values(i)) for i in range(2)] @@ -128,48 +119,73 @@ def _annotate_values(self, element): xpos = np.linspace(xstep/2., 1.0-xstep/2., num_x) ypos = np.linspace(ystep/2., 1.0-ystep/2., num_y) plot_coords = product(xpos, ypos) + annotations = {} for plot_coord, v in zip(plot_coords, vals): text = val_dim.pprint_value(v) text = '' if v is np.nan else text - if plot_coord not in self.handles['annotations']: - annotation = axis.annotate(text, xy=plot_coord, - xycoords='axes fraction', - horizontalalignment='center', - verticalalignment='center') - self.handles['annotations'][plot_coord] = annotation - else: - self.handles['annotations'][plot_coord].set_text(text) - old_coords = set(self.handles['annotations'].keys()) - set(product(xpos, ypos)) - for plot_coord in old_coords: - annotation = self.handles['annotations'].pop(plot_coord) - annotation.remove() - + annotations[plot_coord] = text + return annotations - def update_handles(self, axis, element, key, ranges=None): - im = self.handles.get('artist', None) - data = np.ma.array(element.data, - mask=np.logical_not(np.isfinite(element.data))) - im.set_data(data) - if isinstance(element, HeatMap) and self.show_values: - self._annotate_values(element) + def _compute_ticks(self, element, ranges): + xdim, ydim = element.kdims + dim1_keys, dim2_keys = [element.dimension_values(i, True) + for i in range(2)] + num_x, num_y = len(dim1_keys), len(dim2_keys) + x0, y0, x1, y1 = element.extents + xstep, ystep = ((x1-x0)/num_x, (y1-y0)/num_y) + xpos = np.linspace(x0+xstep/2., x1-xstep/2., num_x) + ypos = np.linspace(y0+ystep/2., y1-ystep/2., num_y) + xlabels = [xdim.pprint_value(k) for k in dim1_keys] + ylabels = [ydim.pprint_value(k) for k in dim2_keys] + return (xpos, xlabels), (ypos, ylabels) + + + def init_artists(self, ax, plot_args, plot_kwargs): + l, r, b, t = plot_kwargs['extent'] + ax.set_aspect(float(r - l)/(t-b)) + + handles = {} + annotations = plot_kwargs.pop('annotations', None) + handles['artist'] = ax.imshow(*plot_args, **plot_kwargs) + if self.show_values and annotations: + handles['annotations'] = self._annotate_plot(ax, annotations) + return handles + + + def get_data(self, element, ranges, style): + _, style, axis_kwargs = super(HeatMapPlot, self).get_data(element, ranges, style) + data = element.raster + data = np.ma.array(data, mask=np.logical_not(np.isfinite(data))) + cmap_name = style.pop('cmap', None) + cmap = copy.copy(plt.cm.get_cmap('gray' if cmap_name is None else cmap_name)) + cmap.set_bad('w', 1.) + style['cmap'] = cmap + style['annotations'] = self._annotate_values(element) + return [data], style, axis_kwargs + + + def update_handles(self, key, axis, element, ranges, style): + im = self.handles['artist'] + data, style, axis_kwargs = self.get_data(element, ranges, style) + l, r, b, t = style['extent'] + im.set_data(data[0]) + im.set_extent((l, r, b, t)) + im.set_clim(style['clim']) + if 'norm' in style: + im.norm = style['norm'] - if isinstance(element, Image): - l, b, r, t = element.bounds.lbrt() - else: - l, b, r, t = element.extents - if type(element) == Raster: - b, t = t, b + if self.show_values: + annotations = self.handles['annotations'] + for annotation in annotations.values(): + try: + annotation.remove() + except: + pass + self._annotate_plot(axis, style['annotations']) + return axis_kwargs - opts = self.style[self.cyclic_index] - clim, norm, opts = self._norm_kwargs(element, ranges, opts) - im.set_clim(clim) - if norm: - im.norm = norm - im.set_extent((l, r, b, t)) - xticks, yticks = self._compute_ticks(element, ranges) - return {'xticks': xticks, 'yticks': yticks} class QuadMeshPlot(ColorbarPlot): @@ -180,45 +196,35 @@ class QuadMeshPlot(ColorbarPlot): style_opts = ['alpha', 'cmap', 'clim', 'edgecolors', 'norm', 'shading', 'linestyles', 'linewidths', 'hatch', 'visible'] - def initialize_plot(self, ranges=None): - key = self.hmap.keys()[-1] - element = self.hmap.last - axis = self.handles['axis'] - - ranges = self.compute_ranges(self.hmap, self.keys[-1], ranges) - ranges = match_spec(element, ranges) - self._init_cmesh(axis, element, ranges) - - return self._finalize_axis(key, ranges) - - def _init_cmesh(self, axis, element, ranges): - opts = self.style[self.cyclic_index] - if 'cmesh' in self.handles: - self.handles['cmesh'].remove() + def get_data(self, element, ranges, style): data = np.ma.array(element.data[2], mask=np.logical_not(np.isfinite(element.data[2]))) cmesh_data = list(element.data[:2]) + [data] - clim, norm, opts = self._norm_kwargs(element, ranges, opts) - self.handles['artist'] = axis.pcolormesh(*cmesh_data, zorder=self.zorder, - vmin=clim[0], vmax=clim[1], norm=norm, - **opts) - self.handles['locs'] = np.concatenate(element.data[:2]) + style['locs'] = np.concatenate(element.data[:2]) + self._norm_kwargs(element, ranges, style) + return tuple(cmesh_data), style, {} - def update_handles(self, axis, element, key, ranges=None): + def init_artists(self, ax, plot_args, plot_kwargs): + locs = plot_kwargs.pop('locs') + artist = ax.pcolormesh(*plot_args, **plot_kwargs) + return {'artist': artist, 'locs': locs} + + + def update_handles(self, key, axis, element, ranges, style): cmesh = self.handles['artist'] - opts = self.style[self.cyclic_index] locs = np.concatenate(element.data[:2]) + if (locs != self.handles['locs']).any(): - self._init_cmesh(axis, element, ranges) + return super(QuadMeshPlot, self).update_handles(key, axis, element, + ranges, style) else: - mask_array = np.logical_not(np.isfinite(element.data[2])) - data = np.ma.array(element.data[2], mask=mask_array) - cmesh.set_array(data.ravel()) - clim, norm, opts = self._norm_kwargs(element, ranges, opts) - cmesh.set_clim(clim) - if norm: - cmesh.norm = norm + data, style, axis_kwargs = self.get_data(element, ranges, style) + cmesh.set_array(data[-1]) + cmesh.set_clim(style['clim']) + if 'norm' in style: + cmesh.norm = style['norm'] + return axis_kwargs class RasterGridPlot(GridPlot, OverlayPlot): diff --git a/holoviews/plotting/mpl/seaborn.py b/holoviews/plotting/mpl/seaborn.py index 203003968c..5acbaf99d9 100644 --- a/holoviews/plotting/mpl/seaborn.py +++ b/holoviews/plotting/mpl/seaborn.py @@ -18,9 +18,9 @@ from .plot import MPLPlot, AdjoinedPlot -class FullRedrawPlot(ElementPlot): +class SeabornPlot(ElementPlot): """ - FullRedrawPlot provides an abstract baseclass, defining an + SeabornPlot provides an abstract baseclass, defining an update_frame method, which completely wipes the axis and redraws the plot. """ @@ -39,14 +39,12 @@ class FullRedrawPlot(ElementPlot): _abstract = True - def update_handles(self, axis, view, key, ranges=None): - if self.zorder == 0 and axis: - axis.cla() - self._update_plot(axis, view) + def teardown_handles(self): + if self.zorder == 0: + self.handles['axis'].cla() - -class RegressionPlot(FullRedrawPlot): +class RegressionPlot(SeabornPlot): """ RegressionPlot visualizes Regression Views using the Seaborn regplot interface, allowing the user to perform and plot @@ -60,22 +58,15 @@ class RegressionPlot(FullRedrawPlot): 'scatter_kws', 'line_kws', 'ci', 'dropna', 'x_jitter', 'y_jitter', 'x_partial', 'y_partial'] - def initialize_plot(self, ranges=None): - self._update_plot(self.handles['axis'], self.hmap.last) - return self._finalize_axis(self.keys[-1]) + def init_artists(self, ax, plot_data, plot_kwargs): + return {'axis': sns.regplot(*plot_data, ax=ax, **plot_kwargs)} + def get_data(self, element, ranges, style): + xs, ys = (element[d] for d in self.dimensions()[:1]) + return (xs, ys), style, {} - def _update_plot(self, axis, view): - kwargs = self.style[self.cyclic_index] - label = view.label if self.overlaid >= 1 else '' - if label: - kwargs['label'] = label - sns.regplot(view.dimension_values(0), view.dimension_values(1), - ax=axis, **kwargs) - - -class BivariatePlot(FullRedrawPlot): +class BivariatePlot(SeabornPlot): """ Bivariate plot visualizes two-dimensional kernel density estimates using the Seaborn kdeplot function. Additionally, @@ -93,33 +84,24 @@ class BivariatePlot(FullRedrawPlot): 'ci', 'kind', 'bw', 'kernel', 'cumulative', 'shade', 'vertical', 'cmap'] - def initialize_plot(self, ranges=None): - kdeview = self.hmap.last - axis = self.handles['axis'] - if self.joint and self.subplot: - raise Exception("Joint plots can't be animated or laid out in a grid.") - self._update_plot(axis, kdeview) - - return self._finalize_axis(self.keys[-1]) - - - def _update_plot(self, axis, view): - kwargs = self.style[self.cyclic_index] + def init_artists(self, ax, plot_data, plot_kwargs): if self.joint: - kwargs.pop('cmap', None) - self.handles['fig'] = sns.jointplot(view.data[:,0], - view.data[:,1], - **kwargs).fig + if self.joint and self.subplot: + raise Exception("Joint plots can't be animated or laid out in a grid.") + return {'fig': sns.jointplot(*plot_data, **plot_kwargs).fig} else: - kwargs = self.style[self.cyclic_index] - label = view.label if self.overlaid >= 1 else '' - if label: - kwargs['label'] = label - sns.kdeplot(view.data, ax=axis, zorder=self.zorder, **kwargs) + return {'axis': sns.kdeplot(*plot_data, ax=ax, **plot_kwargs)} + def get_data(self, element, ranges, style): + xs, ys = (element[d] for d in element.dimensions()[:2]) + if self.joint: + style.pop('cmap', None) + style.pop('zorder', None) + return (xs, ys), style, {} -class TimeSeriesPlot(FullRedrawPlot): + +class TimeSeriesPlot(SeabornPlot): """ TimeSeries visualizes sets of curves using the Seaborn tsplot function. This provides functionality to plot @@ -137,27 +119,20 @@ class TimeSeriesPlot(FullRedrawPlot): 'ci', 'n_boot', 'err_kws', 'err_palette', 'estimator', 'kwargs'] - def initialize_plot(self, ranges=None): - element = self.hmap.last - axis = self.handles['axis'] - self._update_plot(axis, element) + def get_data(self, element, ranges, style): + style.pop('zorder', None) + if 'label' in style: + style['condition'] = style.pop('label') + axis_kwargs = {'xlabel': str(element.kdims[0]), + 'ylabel': str(element.vdims[0])} + return (element.data, element.xdata), style, axis_kwargs - return self._finalize_axis(self.keys[-1]) + def init_artists(self, ax, plot_data, plot_kwargs): + return {'axis': sns.tsplot(*plot_data, ax=ax, **plot_kwargs)} - def _update_plot(self, axis, view): - kwargs = self.style[self.cyclic_index] - label = view.label if self.overlaid >= 1 else '' - if label: - kwargs['condition'] = label - sns.tsplot(view.data, view.xdata, ax=axis, zorder=self.zorder, **kwargs) - - def _axis_labels(self, view, subplots, xlabel=None, ylabel=None, zlabel=None): - xlabel = xlabel if xlabel else str(view.kdims[0]) - ylabel = ylabel if ylabel else str(view.vdims[0]) - return xlabel, ylabel, zlabel -class DistributionPlot(FullRedrawPlot): +class DistributionPlot(SeabornPlot): """ DistributionPlot visualizes Distribution Views using the Seaborn distplot function. This allows visualizing a 1D @@ -173,23 +148,16 @@ class DistributionPlot(FullRedrawPlot): style_opts = ['bins', 'hist', 'kde', 'rug', 'fit', 'hist_kws', 'kde_kws', 'rug_kws', 'fit_kws', 'color'] - def initialize_plot(self, ranges=None): - element = self.hmap.last - axis = self.handles['axis'] - self._update_plot(axis, element) - dim = element.get_dimension(0) - - return self._finalize_axis(self.keys[-1], xlabel='', ylabel=str(dim)) + def get_data(self, element, ranges, style): + style.pop('zorder', None) + if self.invert_axes: + style['vertical'] = True + axis_kwargs = dict(xlabel='', ylabel=str(element.get_dimension(0))) + return (element.dimension_values(0),), style, axis_kwargs + def init_artists(self, ax, plot_data, plot_kwargs): + return {'axis': sns.distplot(*plot_data, ax=ax, **plot_kwargs)} - def _update_plot(self, axis, view): - kwargs = self.style[self.cyclic_index] - label = view.label if self.overlaid >= 1 else '' - if label: - kwargs['label'] = label - if self.invert_axes: - kwargs['vertical'] = True - sns.distplot(view.dimension_values(0), ax=axis, **kwargs) class SideDistributionPlot(AdjoinedPlot, DistributionPlot): @@ -296,7 +264,9 @@ def update_frame(self, key, ranges=None): axis = self.handles['axis'] if axis: axis.set_visible(view is not None) - axis_kwargs = self.update_handles(axis, view, key, ranges) + + style = dict(label=label, zorder=self.zorder, **self.style[self.cyclic_index]) + axis_kwargs = self.update_handles(key, axis, view, key, ranges, style) if axis: self._finalize_axis(key, **(axis_kwargs if axis_kwargs else {})) diff --git a/holoviews/plotting/mpl/tabular.py b/holoviews/plotting/mpl/tabular.py index feaa5960cb..4cf11d7713 100644 --- a/holoviews/plotting/mpl/tabular.py +++ b/holoviews/plotting/mpl/tabular.py @@ -132,7 +132,7 @@ def initialize_plot(self, ranges=None): return self._finalize_axis(self.keys[-1]) - def update_handles(self, axis, view, key, ranges=None): + def update_handles(self, key, axis, view, ranges, style): table = self.handles['artist'] for coords, cell in table.get_celld().items():