Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle invert_axes for all Element types #1919

Merged
merged 12 commits into from
Oct 7, 2017
7 changes: 4 additions & 3 deletions holoviews/plotting/bokeh/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ def get_data(self, element, ranges=None):
data['text'] = [element.text]
return (data, mapping)


def get_batched_data(self, element, ranges=None):
data = defaultdict(list)
for key, el in element.data.items():
Expand All @@ -51,7 +50,6 @@ def get_batched_data(self, element, ranges=None):
data[k].extend(eld)
return data, elmapping


def get_extents(self, element, ranges=None):
return None, None, None, None

Expand Down Expand Up @@ -100,7 +98,10 @@ class SplinePlot(ElementPlot):
_plot_methods = dict(single='bezier')

def get_data(self, element, ranges=None):
data_attrs = ['x0', 'y0', 'cx0', 'cy0', 'cx1', 'cy1', 'x1', 'y1',]
if self.invert_axes:
data_attrs = ['y0', 'x0', 'cy0', 'cx0', 'cy1', 'cx1', 'y1', 'x1']
else:
data_attrs = ['x0', 'y0', 'cx0', 'cy0', 'cx1', 'cy1', 'x1', 'y1']
verts = np.array(element.data[0])
inds = np.where(np.array(element.data[1])==1)[0]
data = {da: [] for da in data_attrs}
Expand Down
14 changes: 8 additions & 6 deletions holoviews/plotting/bokeh/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,12 @@ def get_data(self, element, ranges=None):
input_scale = style.pop('scale', 1.0)

# Get x, y, angle, magnitude and color data
xidx, yidx = (1, 0) if self.invert_axes else (0, 1)
rads = element.dimension_values(2)
if self.invert_axes:
xidx, yidx = (1, 0)
rads = rads+1.5*np.pi
else:
xidx, yidx = (0, 1)
lens = self._get_lengths(element, ranges)/input_scale
cdim = element.get_dimension(self.color_index)
cdata, cmapping = self._get_color_data(element, ranges, style,
Expand Down Expand Up @@ -1001,7 +1005,7 @@ def _get_factors(self, element):
Get factors for categorical axes.
"""
if not element.kdims:
return [element.label], []
xfactors, yfactors = [element.label], []
else:
if bokeh_version < '0.12.7':
factors = [', '.join([d.pprint_value(v).replace(':', ';')
Expand All @@ -1011,10 +1015,8 @@ def _get_factors(self, element):
factors = [tuple(d.pprint_value(v) for d, v in zip(element.kdims, key))
for key in element.groupby(element.kdims).data.keys()]
factors = [f[0] if len(f) == 1 else f for f in factors]
if self.invert_axes:
return [], factors
else:
return factors, []
xfactors, yfactors = factors, []
return (yfactors, xfactors) if self.invert_axes else (xfactors, yfactors)

def get_data(self, element, ranges=None):
if element.kdims:
Expand Down
100 changes: 60 additions & 40 deletions holoviews/plotting/bokeh/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,16 @@ def __init__(self, *args, **kwargs):
self.invert_yaxis = not self.invert_yaxis


def _glyph_properties(self, plot, element, source, ranges):
properties = super(RasterPlot, self)._glyph_properties(plot, element,
source, ranges)
properties = {k: v for k, v in properties.items()}
val_dim = [d for d in element.vdims][0]
properties['color_mapper'] = self._get_colormapper(val_dim, element, ranges,
properties)
return properties


def get_data(self, element, ranges=None):
mapping = dict(image='image', x='x', y='y', dw='dw', dh='dh')
if self.static_source:
Expand All @@ -31,36 +41,30 @@ def get_data(self, element, ranges=None):
if img.dtype.kind == 'b':
img = img.astype(np.int8)

if isinstance(element, Image):
l, b, r, t = element.bounds.lbrt()
else:
img = img.T[::-1] if self.invert_yaxis else img.T
if type(element) is Raster:
l, b, r, t = element.extents
if self.invert_axes:
l, b, r, t = b, l, t, r
else:
img = img.T
else:
l, b, r, t = element.bounds.lbrt()
if self.invert_axes:
img = img.T
l, b, r, t = b, l, t, r

# Ensure axis inversions are handled correctly
if self.invert_xaxis:
l, r = r, l
img = img[:, ::-1]
if self.invert_yaxis:
img = img[::-1]
b, t = t, b
if type(element) is not Raster:
img = img[::-1]
dh, dw = t-b, r-l

data = dict(image=[img], x=[l], y=[b], dw=[dw], dh=[dh])
return (data, mapping)


def _glyph_properties(self, plot, element, source, ranges):
properties = super(RasterPlot, self)._glyph_properties(plot, element,
source, ranges)
properties = {k: v for k, v in properties.items()}
val_dim = [d for d in element.vdims][0]
properties['color_mapper'] = self._get_colormapper(val_dim, element, ranges,
properties)
return properties



class RGBPlot(RasterPlot):

Expand Down Expand Up @@ -88,12 +92,15 @@ def get_data(self, element, ranges=None):

# Ensure axis inversions are handled correctly
l, b, r, t = element.bounds.lbrt()
if self.invert_axes:
img = img.T
l, b, r, t = b, l, t, r
if self.invert_xaxis:
l, r = r, l
img = img[:, ::-1]
if self.invert_yaxis:
b, t = t, b
img = img[::-1]
b, t = t, b
dh, dw = t-b, r-l

data = dict(image=[img], x=[l], y=[b], dw=[dw], dh=[dh])
Expand Down Expand Up @@ -133,19 +140,27 @@ def _get_factors(self, element):

def get_data(self, element, ranges=None):
x, y, z = [dimension_sanitizer(d) for d in element.dimensions(label=True)[:3]]
if self.invert_axes: x, y = y, x
style = self.style[self.cyclic_index]
cmapper = self._get_colormapper(element.vdims[0], element, ranges, style)
if self.static_source:
data = {}
return {}, {'x': x, 'y': y, 'fill_color': {'field': 'zvalues', 'transform': cmapper}}

aggregate = element.gridded
xdim, ydim = aggregate.dimensions()[:2]
xvals, yvals = (aggregate.dimension_values(x),
aggregate.dimension_values(y))
zvals = aggregate.dimension_values(2, flat=False)
if self.invert_axes:
xdim, ydim = ydim, xdim
zvals = zvals.T.flatten()
else:
aggregate = element.gridded
xdim, ydim = aggregate.dimensions()[:2]
xvals, yvals, zvals = (aggregate.dimension_values(i) for i in range(3))
if xvals.dtype.kind not in 'SU':
xvals = [xdim.pprint_value(xv) for xv in xvals]
if yvals.dtype.kind not in 'SU':
yvals = [ydim.pprint_value(yv) for yv in yvals]
data = {x: xvals, y: yvals, 'zvalues': zvals}
zvals = zvals.T.flatten()
if xvals.dtype.kind not in 'SU':
xvals = [xdim.pprint_value(xv) for xv in xvals]
if yvals.dtype.kind not in 'SU':
yvals = [ydim.pprint_value(yv) for yv in yvals]
data = {x: xvals, y: yvals, 'zvalues': zvals}

if any(isinstance(t, HoverTool) for t in self.state.tools) and not self.static_source:
for vdim in element.vdims:
Expand All @@ -166,22 +181,27 @@ class QuadMeshPlot(ColorbarPlot):

def get_data(self, element, ranges=None):
x, y, z = element.dimensions(label=True)
if self.invert_axes: x, y = y, x
style = self.style[self.cyclic_index]
cmapper = self._get_colormapper(element.vdims[0], element, ranges, style)
if self.static_source:
data = {}
return {}, {'x': x, 'y': y, 'fill_color': {'field': z, 'transform': cmapper}}

if len(set(v.shape for v in element.data)) == 1:
raise SkipRendering("Bokeh QuadMeshPlot only supports rectangular meshes")
zdata = element.data[2]
xvals = element.dimension_values(0, False)
yvals = element.dimension_values(1, False)
widths = np.diff(element.data[0])
heights = np.diff(element.data[1])
if self.invert_axes:
zvals = zdata.flatten()
xvals, yvals, widths, heights = yvals, xvals, heights, widths
else:
if len(set(v.shape for v in element.data)) == 1:
raise SkipRendering("Bokeh QuadMeshPlot only supports rectangular meshes")
zvals = element.data[2].T.flatten()
xvals = element.dimension_values(0, False)
yvals = element.dimension_values(1, False)
widths = np.diff(element.data[0])
heights = np.diff(element.data[1])
xs, ys = cartesian_product([xvals, yvals], copy=True)
ws, hs = cartesian_product([widths, heights], copy=True)
data = {x: xs, y: ys, z: zvals, 'widths': ws, 'heights': hs}

zvals = zdata.T.flatten()
xs, ys = cartesian_product([xvals, yvals], copy=True)
ws, hs = cartesian_product([widths, heights], copy=True)
data = {x: xs, y: ys, z: zvals, 'widths': ws, 'heights': hs}
return (data, {'x': x, 'y': y,
'fill_color': {'field': z, 'transform': cmapper},
'height': 'heights', 'width': 'widths'})
12 changes: 10 additions & 2 deletions holoviews/plotting/mpl/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ class VLinePlot(AnnotationPlot):
style_opts = ['alpha', 'color', 'linewidth', 'linestyle', 'visible']

def draw_annotation(self, axis, position, opts):
return [axis.axvline(position, **opts)]
if self.invert_axes:
return [axis.axhline(position, **opts)]
else:
return [axis.axvline(position, **opts)]



Expand All @@ -56,7 +59,10 @@ class HLinePlot(AnnotationPlot):

def draw_annotation(self, axis, position, opts):
"Draw a horizontal line on the axis"
return [axis.axhline(position, **opts)]
if self.invert_axes:
return [axis.axvline(position, **opts)]
else:
return [axis.axhline(position, **opts)]


class TextPlot(AnnotationPlot):
Expand All @@ -67,6 +73,7 @@ class TextPlot(AnnotationPlot):
def draw_annotation(self, axis, data, opts):
(x,y, text, fontsize,
horizontalalignment, verticalalignment, rotation) = data
if self.invert_axes: x, y = y, x
opts['fontsize'] = fontsize
return [axis.text(x,y, text,
horizontalalignment = horizontalalignment,
Expand All @@ -85,6 +92,7 @@ class ArrowPlot(AnnotationPlot):

def draw_annotation(self, axis, data, opts):
x, y, text, direction, points, arrowstyle = data
if self.invert_axes: x, y = y, x
direction = direction.lower()
arrowprops = dict({'arrowstyle':arrowstyle},
**{k: opts[k] for k in self._arrow_style_opts if k in opts})
Expand Down
3 changes: 2 additions & 1 deletion holoviews/plotting/mpl/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,8 @@ def get_data(self, element, ranges, style):
xs = element.dimension_values(xidx) if len(element.data) else []
ys = element.dimension_values(yidx) if len(element.data) else []
radians = element.dimension_values(2) if len(element.data) else []
angles = list(np.rad2deg(radians))
if self.invert_axes: radians = radians+1.5*np.pi
angles = list(np.rad2deg(radians))
if self.rescale_lengths:
input_scale = input_scale / self._min_dist

Expand Down
15 changes: 10 additions & 5 deletions holoviews/plotting/mpl/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,22 +342,27 @@ def _set_axis_limits(self, axis, view, subplots, ranges):
if self.invert_xaxis or any(p.invert_xaxis for p in subplots):
r, l = l, r
if l != r:
lims = {}
if valid_lim(l):
axis.set_xlim(left=l)
lims['left'] = l
scalex = False
if valid_lim(r):
axis.set_xlim(right=r)
lims['right'] = r
scalex = False

if lims:
axis.set_xlim(**lims)
if self.invert_yaxis or any(p.invert_yaxis for p in subplots):
t, b = b, t
if b != t:
lims = {}
if valid_lim(b):
axis.set_ylim(bottom=b)
lims['bottom'] = b
scaley = False
if valid_lim(t):
axis.set_ylim(top=t)
lims['top'] = t
scaley = False
if lims:
axis.set_ylim(**lims)
axis.autoscale_view(scalex=scalex, scaley=scaley)


Expand Down
Loading