Skip to content

Commit

Permalink
Merge pull request #1254 from ioam/shared_source_fix
Browse files Browse the repository at this point in the history
Correctly sync shared datasources
  • Loading branch information
jlstevens authored Apr 9, 2017
2 parents 193aefc + 380d263 commit d408876
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 5 deletions.
14 changes: 10 additions & 4 deletions holoviews/plotting/bokeh/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..util import get_dynamic_mode, initialize_sampled
from .renderer import BokehRenderer
from .util import (bokeh_version, layout_padding, pad_plots,
filter_toolboxes, make_axis)
filter_toolboxes, make_axis, update_shared_sources)

if bokeh_version >= '0.12':
from bokeh.layouts import gridplot
Expand Down Expand Up @@ -153,6 +153,8 @@ def sync_sources(self):
and 'source' in x.handles)
data_sources = self.traverse(get_sources, [filter_fn])
grouped_sources = groupby(sorted(data_sources, key=lambda x: x[0]), lambda x: x[0])
shared_sources = []
source_cols = {}
for _, group in grouped_sources:
group = list(group)
if len(group) > 1:
Expand All @@ -169,6 +171,10 @@ def sync_sources(self):
else:
renderer.update(source=new_source)
plot.handles['source'] = new_source
shared_sources.append(new_source)
source_cols[id(new_source)] = [c for c in new_source.data]
self.handles['shared_sources'] = shared_sources
self.handles['source_cols'] = source_cols



Expand Down Expand Up @@ -441,7 +447,7 @@ def _make_axes(self, plot):
plot = Column(*models)
return plot


@update_shared_sources
def update_frame(self, key, ranges=None):
"""
Update the internal state of the Plot to represent the given
Expand All @@ -450,7 +456,7 @@ def update_frame(self, key, ranges=None):
"""
ranges = self.compute_ranges(self.layout, key, ranges)
for coord in self.layout.keys(full_grid=True):
subplot = self.subplots.get(coord, None)
subplot = self.subplots.get(wrap_tuple(coord), None)
if subplot is not None:
subplot.update_frame(key, ranges)
title = self._get_title(key)
Expand Down Expand Up @@ -692,7 +698,7 @@ def initialize_plot(self, plots=None, ranges=None):

return self.handles['plot']


@update_shared_sources
def update_frame(self, key, ranges=None):
"""
Update the internal state of the Plot to represent the given
Expand Down
26 changes: 26 additions & 0 deletions holoviews/plotting/bokeh/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,3 +612,29 @@ def filter_batched_data(data, mapping):
del data[v]
except:
pass


def update_shared_sources(f):
"""
Context manager to ensures data sources shared between multiple
plots are cleared and updated appropriately avoiding warnings and
allowing empty frames on subplots. Expects a list of
shared_sources and a mapping of the columns expected columns for
each source in the plots handles.
"""
def wrapper(self, *args, **kwargs):
source_cols = self.handles.get('source_cols', {})
shared_sources = self.handles.get('shared_sources', [])
for source in shared_sources:
source.data.clear()

ret = f(self, *args, **kwargs)

for source in shared_sources:
expected = source_cols[id(source)]
found = [c for c in expected if c in source.data]
empty = np.full_like(source.data[found[0]], np.NaN) if found else []
patch = {c: empty for c in expected if c not in source.data}
source.data.update(patch)
return ret
return wrapper
75 changes: 74 additions & 1 deletion tests/testplotinstantiation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import param
import numpy as np
from holoviews import (Dimension, Overlay, DynamicMap, Store,
from holoviews import (Dimension, Overlay, DynamicMap, Store, Dataset,
NdOverlay, GridSpace, HoloMap, Layout, Cycle)
from holoviews.core.util import pd
from holoviews.element import (Curve, Scatter, Image, VLine, Points,
Expand Down Expand Up @@ -1224,6 +1224,79 @@ def test_shared_axes_disable(self):
self.assertEqual((x_range.start, x_range.end), (-.5, .5))
self.assertEqual((y_range.start, y_range.end), (-.5, .5))

def test_layout_shared_source_synced_update(self):
hmap = HoloMap({i: Dataset({chr(65+j): np.random.rand(i+2)
for j in range(4)}, kdims=['A', 'B', 'C', 'D'])
for i in range(3)})

# Create two holomaps of points sharing the same data source
hmap1= hmap.map(lambda x: Points(x.clone(kdims=['A', 'B'])), Dataset)
hmap2 = hmap.map(lambda x: Points(x.clone(kdims=['D', 'C'])), Dataset)

# Pop key (1,) for one of the HoloMaps and make Layout
hmap2.pop((1,))
layout = (hmap1 + hmap2)(plot=dict(shared_datasource=True))

# Get plot
plot = bokeh_renderer.get_plot(layout)

# Check plot created shared data source and recorded expected columns
sources = plot.handles.get('shared_sources', [])
source_cols = plot.handles.get('source_cols', {})
self.assertEqual(len(sources), 1)
source = sources[0]
data = source.data
cols = source_cols[id(source)]
self.assertEqual(set(cols), {'A', 'B', 'C', 'D'})

# Ensure the source contains the expected columns
self.assertEqual(set(data.keys()), {'A', 'B', 'C', 'D'})

# Update to key (1,) and check the source contains data
# corresponding to hmap1 and filled in NaNs for hmap2,
# which was popped above
plot.update((1,))
self.assertEqual(data['A'], hmap1[1].dimension_values(0))
self.assertEqual(data['B'], hmap1[1].dimension_values(1))
self.assertEqual(data['C'], np.full_like(hmap1[1].dimension_values(0), np.NaN))
self.assertEqual(data['D'], np.full_like(hmap1[1].dimension_values(0), np.NaN))

def test_grid_shared_source_synced_update(self):
hmap = HoloMap({i: Dataset({chr(65+j): np.random.rand(i+2)
for j in range(4)}, kdims=['A', 'B', 'C', 'D'])
for i in range(3)})

# Create two holomaps of points sharing the same data source
hmap1= hmap.map(lambda x: Points(x.clone(kdims=['A', 'B'])), Dataset)
hmap2 = hmap.map(lambda x: Points(x.clone(kdims=['D', 'C'])), Dataset)

# Pop key (1,) for one of the HoloMaps and make GridSpace
hmap2.pop(1)
grid = GridSpace({0: hmap1, 2: hmap2}, kdims=['X'])(plot=dict(shared_datasource=True))

# Get plot
plot = bokeh_renderer.get_plot(grid)

# Check plot created shared data source and recorded expected columns
sources = plot.handles.get('shared_sources', [])
source_cols = plot.handles.get('source_cols', {})
self.assertEqual(len(sources), 1)
source = sources[0]
data = source.data
cols = source_cols[id(source)]
self.assertEqual(set(cols), {'A', 'B', 'C', 'D'})

# Ensure the source contains the expected columns
self.assertEqual(set(data.keys()), {'A', 'B', 'C', 'D'})

# Update to key (1,) and check the source contains data
# corresponding to hmap1 and filled in NaNs for hmap2,
# which was popped above
plot.update((1,))
self.assertEqual(data['A'], hmap1[1].dimension_values(0))
self.assertEqual(data['B'], hmap1[1].dimension_values(1))
self.assertEqual(data['C'], np.full_like(hmap1[1].dimension_values(0), np.NaN))
self.assertEqual(data['D'], np.full_like(hmap1[1].dimension_values(0), np.NaN))


class TestPlotlyPlotInstantiation(ComparisonTestCase):
Expand Down

0 comments on commit d408876

Please sign in to comment.