From 3b11689beddd71ed3fa9df09357113c487df53b3 Mon Sep 17 00:00:00 2001 From: Philipp Rudiger Date: Sat, 7 Oct 2017 19:09:09 +0100 Subject: [PATCH] Handle invert_axes for all Element types (#1919) --- holoviews/plotting/bokeh/annotation.py | 7 +- holoviews/plotting/bokeh/chart.py | 14 ++-- holoviews/plotting/bokeh/raster.py | 100 +++++++++++++++---------- holoviews/plotting/mpl/annotation.py | 12 ++- holoviews/plotting/mpl/chart.py | 3 +- holoviews/plotting/mpl/element.py | 15 ++-- holoviews/plotting/mpl/raster.py | 42 +++++++++-- tests/testplotinstantiation.py | 75 ++++++++++++++++++- 8 files changed, 202 insertions(+), 66 deletions(-) diff --git a/holoviews/plotting/bokeh/annotation.py b/holoviews/plotting/bokeh/annotation.py index f9a139243f..f738b3fb26 100644 --- a/holoviews/plotting/bokeh/annotation.py +++ b/holoviews/plotting/bokeh/annotation.py @@ -42,7 +42,6 @@ def get_data(self, element, ranges=None): data['text'] = [element.text] return (data, mapping) - def get_batched_data(self, element, ranges=None): data = defaultdict(list) for key, el in element.data.items(): @@ -51,7 +50,6 @@ def get_batched_data(self, element, ranges=None): data[k].extend(eld) return data, elmapping - def get_extents(self, element, ranges=None): return None, None, None, None @@ -100,7 +98,10 @@ class SplinePlot(ElementPlot): _plot_methods = dict(single='bezier') def get_data(self, element, ranges=None): - data_attrs = ['x0', 'y0', 'cx0', 'cy0', 'cx1', 'cy1', 'x1', 'y1',] + if self.invert_axes: + data_attrs = ['y0', 'x0', 'cy0', 'cx0', 'cy1', 'cx1', 'y1', 'x1'] + else: + data_attrs = ['x0', 'y0', 'cx0', 'cy0', 'cx1', 'cy1', 'x1', 'y1'] verts = np.array(element.data[0]) inds = np.where(np.array(element.data[1])==1)[0] data = {da: [] for da in data_attrs} diff --git a/holoviews/plotting/bokeh/chart.py b/holoviews/plotting/bokeh/chart.py index ea939d70f3..02855aa5f5 100644 --- a/holoviews/plotting/bokeh/chart.py +++ b/holoviews/plotting/bokeh/chart.py @@ -190,8 +190,12 @@ def get_data(self, element, ranges=None): input_scale = style.pop('scale', 1.0) # Get x, y, angle, magnitude and color data - xidx, yidx = (1, 0) if self.invert_axes else (0, 1) rads = element.dimension_values(2) + if self.invert_axes: + xidx, yidx = (1, 0) + rads = rads+1.5*np.pi + else: + xidx, yidx = (0, 1) lens = self._get_lengths(element, ranges)/input_scale cdim = element.get_dimension(self.color_index) cdata, cmapping = self._get_color_data(element, ranges, style, @@ -1008,7 +1012,7 @@ def _get_factors(self, element): Get factors for categorical axes. """ if not element.kdims: - return [element.label], [] + xfactors, yfactors = [element.label], [] else: if bokeh_version < '0.12.7': factors = [', '.join([d.pprint_value(v).replace(':', ';') @@ -1018,10 +1022,8 @@ def _get_factors(self, element): factors = [tuple(d.pprint_value(v) for d, v in zip(element.kdims, key)) for key in element.groupby(element.kdims).data.keys()] factors = [f[0] if len(f) == 1 else f for f in factors] - if self.invert_axes: - return [], factors - else: - return factors, [] + xfactors, yfactors = factors, [] + return (yfactors, xfactors) if self.invert_axes else (xfactors, yfactors) def get_data(self, element, ranges=None): if element.kdims: diff --git a/holoviews/plotting/bokeh/raster.py b/holoviews/plotting/bokeh/raster.py index 1b4f73d6f6..4cb822f7b6 100644 --- a/holoviews/plotting/bokeh/raster.py +++ b/holoviews/plotting/bokeh/raster.py @@ -22,6 +22,16 @@ def __init__(self, *args, **kwargs): self.invert_yaxis = not self.invert_yaxis + def _glyph_properties(self, plot, element, source, ranges): + properties = super(RasterPlot, self)._glyph_properties(plot, element, + source, ranges) + properties = {k: v for k, v in properties.items()} + val_dim = [d for d in element.vdims][0] + properties['color_mapper'] = self._get_colormapper(val_dim, element, ranges, + properties) + return properties + + def get_data(self, element, ranges=None): mapping = dict(image='image', x='x', y='y', dw='dw', dh='dh') if self.static_source: @@ -31,36 +41,30 @@ def get_data(self, element, ranges=None): if img.dtype.kind == 'b': img = img.astype(np.int8) - if isinstance(element, Image): - l, b, r, t = element.bounds.lbrt() - else: - img = img.T[::-1] if self.invert_yaxis else img.T + if type(element) is Raster: l, b, r, t = element.extents + if self.invert_axes: + l, b, r, t = b, l, t, r + else: + img = img.T + else: + l, b, r, t = element.bounds.lbrt() + if self.invert_axes: + img = img.T + l, b, r, t = b, l, t, r - # Ensure axis inversions are handled correctly if self.invert_xaxis: l, r = r, l img = img[:, ::-1] if self.invert_yaxis: + img = img[::-1] b, t = t, b - if type(element) is not Raster: - img = img[::-1] dh, dw = t-b, r-l - + data = dict(image=[img], x=[l], y=[b], dw=[dw], dh=[dh]) return (data, mapping) - def _glyph_properties(self, plot, element, source, ranges): - properties = super(RasterPlot, self)._glyph_properties(plot, element, - source, ranges) - properties = {k: v for k, v in properties.items()} - val_dim = [d for d in element.vdims][0] - properties['color_mapper'] = self._get_colormapper(val_dim, element, ranges, - properties) - return properties - - class RGBPlot(RasterPlot): @@ -88,12 +92,15 @@ def get_data(self, element, ranges=None): # Ensure axis inversions are handled correctly l, b, r, t = element.bounds.lbrt() + if self.invert_axes: + img = img.T + l, b, r, t = b, l, t, r if self.invert_xaxis: l, r = r, l img = img[:, ::-1] if self.invert_yaxis: - b, t = t, b img = img[::-1] + b, t = t, b dh, dw = t-b, r-l data = dict(image=[img], x=[l], y=[b], dw=[dw], dh=[dh]) @@ -133,19 +140,27 @@ def _get_factors(self, element): def get_data(self, element, ranges=None): x, y, z = [dimension_sanitizer(d) for d in element.dimensions(label=True)[:3]] + if self.invert_axes: x, y = y, x style = self.style[self.cyclic_index] cmapper = self._get_colormapper(element.vdims[0], element, ranges, style) if self.static_source: - data = {} + return {}, {'x': x, 'y': y, 'fill_color': {'field': 'zvalues', 'transform': cmapper}} + + aggregate = element.gridded + xdim, ydim = aggregate.dimensions()[:2] + xvals, yvals = (aggregate.dimension_values(x), + aggregate.dimension_values(y)) + zvals = aggregate.dimension_values(2, flat=False) + if self.invert_axes: + xdim, ydim = ydim, xdim + zvals = zvals.T.flatten() else: - aggregate = element.gridded - xdim, ydim = aggregate.dimensions()[:2] - xvals, yvals, zvals = (aggregate.dimension_values(i) for i in range(3)) - if xvals.dtype.kind not in 'SU': - xvals = [xdim.pprint_value(xv) for xv in xvals] - if yvals.dtype.kind not in 'SU': - yvals = [ydim.pprint_value(yv) for yv in yvals] - data = {x: xvals, y: yvals, 'zvalues': zvals} + zvals = zvals.T.flatten() + if xvals.dtype.kind not in 'SU': + xvals = [xdim.pprint_value(xv) for xv in xvals] + if yvals.dtype.kind not in 'SU': + yvals = [ydim.pprint_value(yv) for yv in yvals] + data = {x: xvals, y: yvals, 'zvalues': zvals} if any(isinstance(t, HoverTool) for t in self.state.tools) and not self.static_source: for vdim in element.vdims: @@ -166,22 +181,27 @@ class QuadMeshPlot(ColorbarPlot): def get_data(self, element, ranges=None): x, y, z = element.dimensions(label=True) + if self.invert_axes: x, y = y, x style = self.style[self.cyclic_index] cmapper = self._get_colormapper(element.vdims[0], element, ranges, style) if self.static_source: - data = {} + return {}, {'x': x, 'y': y, 'fill_color': {'field': z, 'transform': cmapper}} + + if len(set(v.shape for v in element.data)) == 1: + raise SkipRendering("Bokeh QuadMeshPlot only supports rectangular meshes") + zdata = element.data[2] + xvals = element.dimension_values(0, False) + yvals = element.dimension_values(1, False) + widths = np.diff(element.data[0]) + heights = np.diff(element.data[1]) + if self.invert_axes: + zvals = zdata.flatten() + xvals, yvals, widths, heights = yvals, xvals, heights, widths else: - if len(set(v.shape for v in element.data)) == 1: - raise SkipRendering("Bokeh QuadMeshPlot only supports rectangular meshes") - zvals = element.data[2].T.flatten() - xvals = element.dimension_values(0, False) - yvals = element.dimension_values(1, False) - widths = np.diff(element.data[0]) - heights = np.diff(element.data[1]) - xs, ys = cartesian_product([xvals, yvals], copy=True) - ws, hs = cartesian_product([widths, heights], copy=True) - data = {x: xs, y: ys, z: zvals, 'widths': ws, 'heights': hs} - + zvals = zdata.T.flatten() + xs, ys = cartesian_product([xvals, yvals], copy=True) + ws, hs = cartesian_product([widths, heights], copy=True) + data = {x: xs, y: ys, z: zvals, 'widths': ws, 'heights': hs} return (data, {'x': x, 'y': y, 'fill_color': {'field': z, 'transform': cmapper}, 'height': 'heights', 'width': 'widths'}) diff --git a/holoviews/plotting/mpl/annotation.py b/holoviews/plotting/mpl/annotation.py index b169ca59f6..92e41256b2 100644 --- a/holoviews/plotting/mpl/annotation.py +++ b/holoviews/plotting/mpl/annotation.py @@ -45,7 +45,10 @@ class VLinePlot(AnnotationPlot): style_opts = ['alpha', 'color', 'linewidth', 'linestyle', 'visible'] def draw_annotation(self, axis, position, opts): - return [axis.axvline(position, **opts)] + if self.invert_axes: + return [axis.axhline(position, **opts)] + else: + return [axis.axvline(position, **opts)] @@ -56,7 +59,10 @@ class HLinePlot(AnnotationPlot): def draw_annotation(self, axis, position, opts): "Draw a horizontal line on the axis" - return [axis.axhline(position, **opts)] + if self.invert_axes: + return [axis.axvline(position, **opts)] + else: + return [axis.axhline(position, **opts)] class TextPlot(AnnotationPlot): @@ -67,6 +73,7 @@ class TextPlot(AnnotationPlot): def draw_annotation(self, axis, data, opts): (x,y, text, fontsize, horizontalalignment, verticalalignment, rotation) = data + if self.invert_axes: x, y = y, x opts['fontsize'] = fontsize return [axis.text(x,y, text, horizontalalignment = horizontalalignment, @@ -85,6 +92,7 @@ class ArrowPlot(AnnotationPlot): def draw_annotation(self, axis, data, opts): x, y, text, direction, points, arrowstyle = data + if self.invert_axes: x, y = y, x direction = direction.lower() arrowprops = dict({'arrowstyle':arrowstyle}, **{k: opts[k] for k in self._arrow_style_opts if k in opts}) diff --git a/holoviews/plotting/mpl/chart.py b/holoviews/plotting/mpl/chart.py index 433d288661..149f3ae184 100644 --- a/holoviews/plotting/mpl/chart.py +++ b/holoviews/plotting/mpl/chart.py @@ -632,7 +632,8 @@ def get_data(self, element, ranges, style): xs = element.dimension_values(xidx) if len(element.data) else [] ys = element.dimension_values(yidx) if len(element.data) else [] radians = element.dimension_values(2) if len(element.data) else [] - angles = list(np.rad2deg(radians)) + if self.invert_axes: radians = radians+1.5*np.pi + angles = list(np.rad2deg(radians)) if self.rescale_lengths: input_scale = input_scale / self._min_dist diff --git a/holoviews/plotting/mpl/element.py b/holoviews/plotting/mpl/element.py index eea93cc3a4..47aa52657b 100644 --- a/holoviews/plotting/mpl/element.py +++ b/holoviews/plotting/mpl/element.py @@ -342,22 +342,27 @@ def _set_axis_limits(self, axis, view, subplots, ranges): if self.invert_xaxis or any(p.invert_xaxis for p in subplots): r, l = l, r if l != r: + lims = {} if valid_lim(l): - axis.set_xlim(left=l) + lims['left'] = l scalex = False if valid_lim(r): - axis.set_xlim(right=r) + lims['right'] = r scalex = False - + if lims: + axis.set_xlim(**lims) if self.invert_yaxis or any(p.invert_yaxis for p in subplots): t, b = b, t if b != t: + lims = {} if valid_lim(b): - axis.set_ylim(bottom=b) + lims['bottom'] = b scaley = False if valid_lim(t): - axis.set_ylim(top=t) + lims['top'] = t scaley = False + if lims: + axis.set_ylim(**lims) axis.autoscale_view(scalex=scalex, scaley=scaley) diff --git a/holoviews/plotting/mpl/raster.py b/holoviews/plotting/mpl/raster.py index fa4997a3f4..adf85e18f8 100644 --- a/holoviews/plotting/mpl/raster.py +++ b/holoviews/plotting/mpl/raster.py @@ -57,14 +57,22 @@ def get_data(self, element, ranges, style): if isinstance(element, RGB): style.pop('cmap', None) + data = get_raster_array(element) if type(element) is Raster: l, b, r, t = element.extents - if self.invert_yaxis: - b, t = t, b + if self.invert_axes: + data = data[:, ::-1] + else: + data = data[::-1] else: l, b, r, t = element.bounds.lbrt() + if self.invert_axes: + data = data[::-1, ::-1] + + if self.invert_axes: + data = data.transpose([1, 0, 2]) if isinstance(element, RGB) else data.T + l, b, r, t = b, l, t, r - data = get_raster_array(element) vdim = element.vdims[0] self._norm_kwargs(element, ranges, style, vdim) style['extent'] = [l, r, b, t] @@ -109,8 +117,15 @@ def _annotate_plot(self, ax, annotations): def _annotate_values(self, element): val_dim = element.vdims[0] - vals = element.dimension_values(2) + vals = element.dimension_values(2, flat=False) d1uniq, d2uniq = [element.dimension_values(i, False) for i in range(2)] + if self.invert_axes: + d1uniq, d2uniq = d2uniq, d1uniq + else: + vals = vals.T + if self.invert_xaxis: vals = vals[::-1] + if self.invert_yaxis: vals = vals[:, ::-1] + vals = vals.flatten() num_x, num_y = len(d1uniq), len(d2uniq) xpos = np.linspace(0.5, num_x-0.5, num_x) ypos = np.linspace(0.5, num_y-0.5, num_y) @@ -127,6 +142,8 @@ def _compute_ticks(self, element, ranges): agg = element.gridded dim1_keys, dim2_keys = [unique_array(agg.dimension_values(i, False)) for i in range(2)] + if self.invert_axes: + dim1_keys, dim2_keys = dim2_keys, dim1_keys num_x, num_y = len(dim1_keys), len(dim2_keys) xpos = np.linspace(.5, num_x-0.5, num_x) ypos = np.linspace(.5, num_y-0.5, num_y) @@ -151,6 +168,9 @@ def get_data(self, element, ranges, style): data = np.flipud(element.gridded.dimension_values(2, flat=False)) data = np.ma.array(data, mask=np.logical_not(np.isfinite(data))) + if self.invert_axes: data = data.T[::-1, ::-1] + if self.invert_xaxis: data = data[:, ::-1] + if self.invert_yaxis: data = data[::-1] shape = data.shape style['aspect'] = shape[0]/shape[1] style['extent'] = (0, shape[1], 0, shape[0]) @@ -164,8 +184,7 @@ def update_handles(self, key, axis, element, ranges, style): im = self.handles['artist'] data, style, axis_kwargs = self.get_data(element, ranges, style) im.set_data(data[0]) - shape = data[0].shape - im.set_extent((0, shape[1], 0, shape[0])) + im.set_extent(style['extent']) im.set_clim((style['vmin'], style['vmax'])) if 'norm' in style: im.norm = style['norm'] @@ -187,9 +206,12 @@ class ImagePlot(RasterPlot): def get_data(self, element, ranges, style): data = np.flipud(element.dimension_values(2, flat=False)) data = np.ma.array(data, mask=np.logical_not(np.isfinite(data))) + l, b, r, t = element.bounds.lbrt() + if self.invert_axes: + data = data[::-1].T + l, b, r, t = b, l, t, r vdim = element.vdims[0] self._norm_kwargs(element, ranges, style, vdim) - l, b, r, t = element.bounds.lbrt() style['extent'] = [l, r, b, t] return (data,), style, {} @@ -211,7 +233,11 @@ class QuadMeshPlot(ColorbarPlot): def get_data(self, element, ranges, style): data = np.ma.array(element.data[2], mask=np.logical_not(np.isfinite(element.data[2]))) - cmesh_data = list(element.data[:2]) + [data] + coords = list(element.data[:2]) + if self.invert_axes: + coords = coords[::-1] + data = data.T + cmesh_data = coords + [data] style['locs'] = np.concatenate(element.data[:2]) vdim = element.vdims[0] self._norm_kwargs(element, ranges, style, vdim) diff --git a/tests/testplotinstantiation.py b/tests/testplotinstantiation.py index 19d36ffb59..ceb4a1b3ac 100644 --- a/tests/testplotinstantiation.py +++ b/tests/testplotinstantiation.py @@ -19,7 +19,7 @@ from holoviews.element import (Curve, Scatter, Image, VLine, Points, HeatMap, QuadMesh, Spikes, ErrorBars, Scatter3D, Path, Polygons, Bars, Text, - BoxWhisker, HLine, RGB) + BoxWhisker, HLine, RGB, Raster) from holoviews.element.comparison import ComparisonTestCase from holoviews.streams import Stream, PointerXY, PointerX from holoviews.operation import gridmatrix @@ -315,6 +315,38 @@ def test_polygons_colored(self): self.assertEqual(artist.get_array(), np.array([j])) self.assertEqual(artist.get_clim(), (0, 4)) + def test_raster_invert_axes(self): + arr = np.array([[0, 1, 2], [3, 4, 5]]) + raster = Raster(arr).opts(plot=dict(invert_axes=True)) + plot = mpl_renderer.get_plot(raster) + artist = plot.handles['artist'] + self.assertEqual(artist.get_array().data, arr.T[::-1]) + self.assertEqual(artist.get_extent(), [0, 2, 0, 3]) + + def test_image_invert_axes(self): + arr = np.array([[0, 1, 2], [3, 4, 5]]) + raster = Image(arr).opts(plot=dict(invert_axes=True)) + plot = mpl_renderer.get_plot(raster) + artist = plot.handles['artist'] + self.assertEqual(artist.get_array().data, arr.T[::-1, ::-1]) + self.assertEqual(artist.get_extent(), [-0.5, 0.5, -0.5, 0.5]) + + def test_quadmesh_invert_axes(self): + arr = np.array([[0, 1, 2], [3, 4, 5]]) + qmesh = QuadMesh(Image(arr)).opts(plot=dict(invert_axes=True)) + plot = mpl_renderer.get_plot(qmesh) + artist = plot.handles['artist'] + self.assertEqual(artist.get_array().data, arr.T[:, ::-1].flatten()) + + def test_heatmap_invert_axes(self): + arr = np.array([[0, 1, 2], [3, 4, 5]]) + hm = HeatMap(Image(arr)).opts(plot=dict(invert_axes=True)) + plot = mpl_renderer.get_plot(hm) + artist = plot.handles['artist'] + self.assertEqual(artist.get_array().data, arr.T[::-1, ::-1]) + self.assertEqual(artist.get_extent(), (0, 2, 0, 3)) + + class TestBokehPlotInstantiation(ComparisonTestCase): @@ -1234,6 +1266,47 @@ def test_vline_plot(self): self.assertEqual(span.dimension, 'height') self.assertEqual(span.location, 1.1) + def test_raster_invert_axes(self): + arr = np.array([[0, 1, 2], [3, 4, 5]]) + raster = Raster(arr).opts(plot=dict(invert_axes=True)) + plot = bokeh_renderer.get_plot(raster) + source = plot.handles['source'] + self.assertEqual(source.data['image'][0], np.rot90(arr)) + self.assertEqual(source.data['x'][0], 0) + self.assertEqual(source.data['y'][0], 3) + self.assertEqual(source.data['dw'][0], 2) + self.assertEqual(source.data['dh'][0], -3) + + def test_image_invert_axes(self): + arr = np.array([[0, 1, 2], [3, 4, 5]]) + raster = Image(arr).opts(plot=dict(invert_axes=True)) + plot = bokeh_renderer.get_plot(raster) + source = plot.handles['source'] + self.assertEqual(source.data['image'][0], np.rot90(arr)[::-1, ::-1]) + self.assertEqual(source.data['x'][0], -.5) + self.assertEqual(source.data['y'][0], -.5) + self.assertEqual(source.data['dw'][0], 1) + self.assertEqual(source.data['dh'][0], 1) + + def test_quadmesh_invert_axes(self): + arr = np.array([[0, 1, 2], [3, 4, 5]]) + qmesh = QuadMesh(Image(arr)).opts(plot=dict(invert_axes=True)) + plot = bokeh_renderer.get_plot(qmesh) + source = plot.handles['source'] + self.assertEqual(source.data['z'], qmesh.dimension_values(2, flat=False).flatten()) + self.assertEqual(source.data['x'], qmesh.dimension_values(0)) + self.assertEqual(source.data['y'], qmesh.dimension_values(1)) + + def test_heatmap_invert_axes(self): + arr = np.array([[0, 1, 2], [3, 4, 5]]) + hm = HeatMap(Image(arr)).opts(plot=dict(invert_axes=True)) + plot = bokeh_renderer.get_plot(hm) + xdim, ydim = hm.kdims + source = plot.handles['source'] + self.assertEqual(source.data['zvalues'], hm.dimension_values(2, flat=False).T.flatten()) + self.assertEqual(source.data['x'], [xdim.pprint_value(v) for v in hm.dimension_values(0)]) + self.assertEqual(source.data['y'], [ydim.pprint_value(v) for v in hm.dimension_values(1)]) + def test_box_whisker_datetime(self): times = np.arange(dt.datetime(2017,1,1), dt.datetime(2017,2,1), dt.timedelta(days=1))