Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed handling of shared_axes #1187

Merged
merged 2 commits into from
Mar 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 39 additions & 14 deletions holoviews/plotting/bokeh/element.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def __init__(self, element, plot=None, **params):
self.callbacks = self._construct_callbacks()
self.static_source = False

# Whether axes are shared between plots
self._shared = {'x': False, 'y': False}


def _construct_callbacks(self):
"""
Expand Down Expand Up @@ -260,6 +263,28 @@ def _get_hover_data(self, data, element, empty=False):
data[dim] = [v for _ in range(len(list(data.values())[0]))]


def _merge_ranges(self, plots, xlabel, ylabel):
"""
Given a list of other plots return axes that are shared
with another plot by matching the axes labels
"""
plot_ranges = {}
for plot in plots:
if plot is None:
continue
if hasattr(plot, 'xaxis'):
if plot.xaxis[0].axis_label == xlabel:
plot_ranges['x_range'] = plot.x_range
if plot.xaxis[0].axis_label == ylabel:
plot_ranges['y_range'] = plot.x_range
if hasattr(plot, 'yaxis'):
if plot.yaxis[0].axis_label == ylabel:
plot_ranges['y_range'] = plot.y_range
if plot.yaxis[0].axis_label == xlabel:
plot_ranges['x_range'] = plot.y_range
return plot_ranges
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little confused here...there are only two keys 'x_range' and 'y_range' so if you iterate over a bunch of plots can't you end up clobbering these keys with different values? Won't you end up with whatever was found at the end of the loop?

Copy link
Member Author

@philippjfr philippjfr Mar 10, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct, I could have two separate loops for the x- and y-axis and break when a matching axis is found. However in practice if shared_axes are enabled all the plots will share the same axes anyway so it doesn't much matter whether I end up using the axis from the first or last plot, because they are all the same anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, good that you've thought about this. It is also good to have this pulled out as a method.



def _axes_props(self, plots, subplots, element, ranges):
# Get the bottom layer and range element
el = element.traverse(lambda x: x, [Element])
Expand All @@ -273,16 +298,7 @@ def _axes_props(self, plots, subplots, element, ranges):
plot_ranges = {}
# Try finding shared ranges in other plots in the same Layout
if plots and self.shared_axes:
for plot in plots:
if plot is None or not hasattr(plot, 'xaxis'): continue
if plot.xaxis[0].axis_label == xlabel:
plot_ranges['x_range'] = plot.x_range
if plot.xaxis[0].axis_label == ylabel:
plot_ranges['y_range'] = plot.x_range
if plot.yaxis[0].axis_label == ylabel:
plot_ranges['y_range'] = plot.y_range
if plot.yaxis[0].axis_label == xlabel:
plot_ranges['x_range'] = plot.y_range
plot_ranges = self._merge_ranges(plots, xlabel, ylabel)

if el.get_dimension_type(0) in util.datetime_types:
x_axis_type = 'datetime'
Expand All @@ -300,6 +316,12 @@ def _axes_props(self, plots, subplots, element, ranges):
if self.invert_axes:
l, b, r, t = b, l, t, r

# Declare shared axes
if 'x_range' in plot_ranges:
self._shared['x'] = True
if 'y_range' in plot_ranges:
self._shared['y'] = True

categorical = any(self.traverse(lambda x: x._categorical))
categorical_x = any(isinstance(x, util.basestring) for x in (l, r))
categorical_y = any(isinstance(y, util.basestring) for y in (b, t))
Expand Down Expand Up @@ -418,7 +440,7 @@ def _axis_properties(self, axis, key, plot, dimension=None,
axis_props = {}
if ((axis == 'x' and self.xaxis in ['bottom-bare', 'top-bare']) or
(axis == 'y' and self.yaxis in ['left-bare', 'right-bare'])):
axis_props['axis_label'] = ''
axis_props['axis_label_text_font_size'] = value('0pt')
axis_props['major_label_text_font_size'] = value('0pt')
axis_props['major_tick_line_color'] = None
axis_props['minor_tick_line_color'] = None
Expand Down Expand Up @@ -495,18 +517,21 @@ def _update_ranges(self, element, ranges):
xfactors, yfactors = None, None
if any(isinstance(ax_range, FactorRange) for ax_range in [x_range, y_range]):
xfactors, yfactors = self._get_factors(element)
self._update_range(x_range, l, r, xfactors, self.invert_xaxis)
self._update_range(y_range, b, t, yfactors, self.invert_yaxis)
self._update_range(x_range, l, r, xfactors, self.invert_xaxis, self._shared['x'])
self._update_range(y_range, b, t, yfactors, self.invert_yaxis, self._shared['y'])


def _update_range(self, axis_range, low, high, factors, invert):
def _update_range(self, axis_range, low, high, factors, invert, shared):
if isinstance(axis_range, Range1d):
if (low == high and low is not None and
not isinstance(high, util.datetime_types)):
offset = abs(low*0.1 if low else 0.5)
low -= offset
high += offset
if invert: low, high = high, low
if shared:
shared = (axis_range.start, axis_range.end)
low, high = util.max_range([(low, high), shared])
if low is not None and (isinstance(low, util.datetime_types)
or np.isfinite(low)):
axis_range.start = low
Expand Down
19 changes: 19 additions & 0 deletions tests/testplotinstantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,25 @@ def formatter(x):
plot = bokeh_renderer.get_plot(curve).state
self.assertIsInstance(plot.yaxis[0].formatter, FuncTickFormatter)

def test_shared_axes(self):
curve = Curve(range(10))
img = Image(np.random.rand(10,10))
plot = bokeh_renderer.get_plot(curve+img)
plot = plot.subplots[(0, 1)].subplots['main']
x_range, y_range = plot.handles['x_range'], plot.handles['y_range']
self.assertEqual((x_range.start, x_range.end), (-.5, 9))
self.assertEqual((y_range.start, y_range.end), (-.5, 9))

def test_shared_axes_disable(self):
curve = Curve(range(10))
img = Image(np.random.rand(10,10))(plot=dict(shared_axes=False))
plot = bokeh_renderer.get_plot(curve+img)
plot = plot.subplots[(0, 1)].subplots['main']
x_range, y_range = plot.handles['x_range'], plot.handles['y_range']
self.assertEqual((x_range.start, x_range.end), (-.5, .5))
self.assertEqual((y_range.start, y_range.end), (-.5, .5))



class TestPlotlyPlotInstantiation(ComparisonTestCase):

Expand Down