Skip to content

Commit

Permalink
Handle invert_axes for all Element types (#1919)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored and jlstevens committed Oct 7, 2017
1 parent 4e02ccc commit 3b11689
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 66 deletions.
7 changes: 4 additions & 3 deletions holoviews/plotting/bokeh/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def get_data(self, element, ranges=None):
data['text'] = [element.text]
return (data, mapping)


def get_batched_data(self, element, ranges=None):
data = defaultdict(list)
for key, el in element.data.items():
Expand All @@ -51,7 +50,6 @@ def get_batched_data(self, element, ranges=None):
data[k].extend(eld)
return data, elmapping


def get_extents(self, element, ranges=None):
return None, None, None, None

Expand Down Expand Up @@ -100,7 +98,10 @@ class SplinePlot(ElementPlot):
_plot_methods = dict(single='bezier')

def get_data(self, element, ranges=None):
data_attrs = ['x0', 'y0', 'cx0', 'cy0', 'cx1', 'cy1', 'x1', 'y1',]
if self.invert_axes:
data_attrs = ['y0', 'x0', 'cy0', 'cx0', 'cy1', 'cx1', 'y1', 'x1']
else:
data_attrs = ['x0', 'y0', 'cx0', 'cy0', 'cx1', 'cy1', 'x1', 'y1']
verts = np.array(element.data[0])
inds = np.where(np.array(element.data[1])==1)[0]
data = {da: [] for da in data_attrs}
Expand Down
14 changes: 8 additions & 6 deletions holoviews/plotting/bokeh/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,12 @@ def get_data(self, element, ranges=None):
input_scale = style.pop('scale', 1.0)

# Get x, y, angle, magnitude and color data
xidx, yidx = (1, 0) if self.invert_axes else (0, 1)
rads = element.dimension_values(2)
if self.invert_axes:
xidx, yidx = (1, 0)
rads = rads+1.5*np.pi
else:
xidx, yidx = (0, 1)
lens = self._get_lengths(element, ranges)/input_scale
cdim = element.get_dimension(self.color_index)
cdata, cmapping = self._get_color_data(element, ranges, style,
Expand Down Expand Up @@ -1008,7 +1012,7 @@ def _get_factors(self, element):
Get factors for categorical axes.
"""
if not element.kdims:
return [element.label], []
xfactors, yfactors = [element.label], []
else:
if bokeh_version < '0.12.7':
factors = [', '.join([d.pprint_value(v).replace(':', ';')
Expand All @@ -1018,10 +1022,8 @@ def _get_factors(self, element):
factors = [tuple(d.pprint_value(v) for d, v in zip(element.kdims, key))
for key in element.groupby(element.kdims).data.keys()]
factors = [f[0] if len(f) == 1 else f for f in factors]
if self.invert_axes:
return [], factors
else:
return factors, []
xfactors, yfactors = factors, []
return (yfactors, xfactors) if self.invert_axes else (xfactors, yfactors)

def get_data(self, element, ranges=None):
if element.kdims:
Expand Down
100 changes: 60 additions & 40 deletions holoviews/plotting/bokeh/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ def __init__(self, *args, **kwargs):
self.invert_yaxis = not self.invert_yaxis


def _glyph_properties(self, plot, element, source, ranges):
properties = super(RasterPlot, self)._glyph_properties(plot, element,
source, ranges)
properties = {k: v for k, v in properties.items()}
val_dim = [d for d in element.vdims][0]
properties['color_mapper'] = self._get_colormapper(val_dim, element, ranges,
properties)
return properties


def get_data(self, element, ranges=None):
mapping = dict(image='image', x='x', y='y', dw='dw', dh='dh')
if self.static_source:
Expand All @@ -31,36 +41,30 @@ def get_data(self, element, ranges=None):
if img.dtype.kind == 'b':
img = img.astype(np.int8)

if isinstance(element, Image):
l, b, r, t = element.bounds.lbrt()
else:
img = img.T[::-1] if self.invert_yaxis else img.T
if type(element) is Raster:
l, b, r, t = element.extents
if self.invert_axes:
l, b, r, t = b, l, t, r
else:
img = img.T
else:
l, b, r, t = element.bounds.lbrt()
if self.invert_axes:
img = img.T
l, b, r, t = b, l, t, r

# Ensure axis inversions are handled correctly
if self.invert_xaxis:
l, r = r, l
img = img[:, ::-1]
if self.invert_yaxis:
img = img[::-1]
b, t = t, b
if type(element) is not Raster:
img = img[::-1]
dh, dw = t-b, r-l

data = dict(image=[img], x=[l], y=[b], dw=[dw], dh=[dh])
return (data, mapping)


def _glyph_properties(self, plot, element, source, ranges):
properties = super(RasterPlot, self)._glyph_properties(plot, element,
source, ranges)
properties = {k: v for k, v in properties.items()}
val_dim = [d for d in element.vdims][0]
properties['color_mapper'] = self._get_colormapper(val_dim, element, ranges,
properties)
return properties



class RGBPlot(RasterPlot):

Expand Down Expand Up @@ -88,12 +92,15 @@ def get_data(self, element, ranges=None):

# Ensure axis inversions are handled correctly
l, b, r, t = element.bounds.lbrt()
if self.invert_axes:
img = img.T
l, b, r, t = b, l, t, r
if self.invert_xaxis:
l, r = r, l
img = img[:, ::-1]
if self.invert_yaxis:
b, t = t, b
img = img[::-1]
b, t = t, b
dh, dw = t-b, r-l

data = dict(image=[img], x=[l], y=[b], dw=[dw], dh=[dh])
Expand Down Expand Up @@ -133,19 +140,27 @@ def _get_factors(self, element):

def get_data(self, element, ranges=None):
x, y, z = [dimension_sanitizer(d) for d in element.dimensions(label=True)[:3]]
if self.invert_axes: x, y = y, x
style = self.style[self.cyclic_index]
cmapper = self._get_colormapper(element.vdims[0], element, ranges, style)
if self.static_source:
data = {}
return {}, {'x': x, 'y': y, 'fill_color': {'field': 'zvalues', 'transform': cmapper}}

aggregate = element.gridded
xdim, ydim = aggregate.dimensions()[:2]
xvals, yvals = (aggregate.dimension_values(x),
aggregate.dimension_values(y))
zvals = aggregate.dimension_values(2, flat=False)
if self.invert_axes:
xdim, ydim = ydim, xdim
zvals = zvals.T.flatten()
else:
aggregate = element.gridded
xdim, ydim = aggregate.dimensions()[:2]
xvals, yvals, zvals = (aggregate.dimension_values(i) for i in range(3))
if xvals.dtype.kind not in 'SU':
xvals = [xdim.pprint_value(xv) for xv in xvals]
if yvals.dtype.kind not in 'SU':
yvals = [ydim.pprint_value(yv) for yv in yvals]
data = {x: xvals, y: yvals, 'zvalues': zvals}
zvals = zvals.T.flatten()
if xvals.dtype.kind not in 'SU':
xvals = [xdim.pprint_value(xv) for xv in xvals]
if yvals.dtype.kind not in 'SU':
yvals = [ydim.pprint_value(yv) for yv in yvals]
data = {x: xvals, y: yvals, 'zvalues': zvals}

if any(isinstance(t, HoverTool) for t in self.state.tools) and not self.static_source:
for vdim in element.vdims:
Expand All @@ -166,22 +181,27 @@ class QuadMeshPlot(ColorbarPlot):

def get_data(self, element, ranges=None):
x, y, z = element.dimensions(label=True)
if self.invert_axes: x, y = y, x
style = self.style[self.cyclic_index]
cmapper = self._get_colormapper(element.vdims[0], element, ranges, style)
if self.static_source:
data = {}
return {}, {'x': x, 'y': y, 'fill_color': {'field': z, 'transform': cmapper}}

if len(set(v.shape for v in element.data)) == 1:
raise SkipRendering("Bokeh QuadMeshPlot only supports rectangular meshes")
zdata = element.data[2]
xvals = element.dimension_values(0, False)
yvals = element.dimension_values(1, False)
widths = np.diff(element.data[0])
heights = np.diff(element.data[1])
if self.invert_axes:
zvals = zdata.flatten()
xvals, yvals, widths, heights = yvals, xvals, heights, widths
else:
if len(set(v.shape for v in element.data)) == 1:
raise SkipRendering("Bokeh QuadMeshPlot only supports rectangular meshes")
zvals = element.data[2].T.flatten()
xvals = element.dimension_values(0, False)
yvals = element.dimension_values(1, False)
widths = np.diff(element.data[0])
heights = np.diff(element.data[1])
xs, ys = cartesian_product([xvals, yvals], copy=True)
ws, hs = cartesian_product([widths, heights], copy=True)
data = {x: xs, y: ys, z: zvals, 'widths': ws, 'heights': hs}

zvals = zdata.T.flatten()
xs, ys = cartesian_product([xvals, yvals], copy=True)
ws, hs = cartesian_product([widths, heights], copy=True)
data = {x: xs, y: ys, z: zvals, 'widths': ws, 'heights': hs}
return (data, {'x': x, 'y': y,
'fill_color': {'field': z, 'transform': cmapper},
'height': 'heights', 'width': 'widths'})
12 changes: 10 additions & 2 deletions holoviews/plotting/mpl/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ class VLinePlot(AnnotationPlot):
style_opts = ['alpha', 'color', 'linewidth', 'linestyle', 'visible']

def draw_annotation(self, axis, position, opts):
return [axis.axvline(position, **opts)]
if self.invert_axes:
return [axis.axhline(position, **opts)]
else:
return [axis.axvline(position, **opts)]



Expand All @@ -56,7 +59,10 @@ class HLinePlot(AnnotationPlot):

def draw_annotation(self, axis, position, opts):
"Draw a horizontal line on the axis"
return [axis.axhline(position, **opts)]
if self.invert_axes:
return [axis.axvline(position, **opts)]
else:
return [axis.axhline(position, **opts)]


class TextPlot(AnnotationPlot):
Expand All @@ -67,6 +73,7 @@ class TextPlot(AnnotationPlot):
def draw_annotation(self, axis, data, opts):
(x,y, text, fontsize,
horizontalalignment, verticalalignment, rotation) = data
if self.invert_axes: x, y = y, x
opts['fontsize'] = fontsize
return [axis.text(x,y, text,
horizontalalignment = horizontalalignment,
Expand All @@ -85,6 +92,7 @@ class ArrowPlot(AnnotationPlot):

def draw_annotation(self, axis, data, opts):
x, y, text, direction, points, arrowstyle = data
if self.invert_axes: x, y = y, x
direction = direction.lower()
arrowprops = dict({'arrowstyle':arrowstyle},
**{k: opts[k] for k in self._arrow_style_opts if k in opts})
Expand Down
3 changes: 2 additions & 1 deletion holoviews/plotting/mpl/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,8 @@ def get_data(self, element, ranges, style):
xs = element.dimension_values(xidx) if len(element.data) else []
ys = element.dimension_values(yidx) if len(element.data) else []
radians = element.dimension_values(2) if len(element.data) else []
angles = list(np.rad2deg(radians))
if self.invert_axes: radians = radians+1.5*np.pi
angles = list(np.rad2deg(radians))
if self.rescale_lengths:
input_scale = input_scale / self._min_dist

Expand Down
15 changes: 10 additions & 5 deletions holoviews/plotting/mpl/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,22 +342,27 @@ def _set_axis_limits(self, axis, view, subplots, ranges):
if self.invert_xaxis or any(p.invert_xaxis for p in subplots):
r, l = l, r
if l != r:
lims = {}
if valid_lim(l):
axis.set_xlim(left=l)
lims['left'] = l
scalex = False
if valid_lim(r):
axis.set_xlim(right=r)
lims['right'] = r
scalex = False

if lims:
axis.set_xlim(**lims)
if self.invert_yaxis or any(p.invert_yaxis for p in subplots):
t, b = b, t
if b != t:
lims = {}
if valid_lim(b):
axis.set_ylim(bottom=b)
lims['bottom'] = b
scaley = False
if valid_lim(t):
axis.set_ylim(top=t)
lims['top'] = t
scaley = False
if lims:
axis.set_ylim(**lims)
axis.autoscale_view(scalex=scalex, scaley=scaley)


Expand Down
Loading

0 comments on commit 3b11689

Please sign in to comment.