Skip to content

Commit

Permalink
Merge pull request #362 from ioam/bokeh_improvements
Browse files Browse the repository at this point in the history
Bokeh improvements
  • Loading branch information
jlstevens committed Dec 16, 2015
2 parents b8a7efd + 0659263 commit 0b9e977
Show file tree
Hide file tree
Showing 9 changed files with 184 additions and 108 deletions.
5 changes: 5 additions & 0 deletions holoviews/plotting/bokeh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,8 @@
options.Raster = Options('style', cmap='hot')
options.QuadMesh = Options('style', cmap='hot')
options.HeatMap = Options('style', cmap='RdYlBu_r', line_alpha=0)

# Annotations
options.HLine = Options('style', line_color='black', line_width=3, line_alpha=1)
options.VLine = Options('style', line_color='black', line_width=3, line_alpha=1)

55 changes: 35 additions & 20 deletions holoviews/plotting/bokeh/annotation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from bokeh.models import BoxAnnotation

from ...element import HLine, VLine
from .element import ElementPlot, text_properties, line_properties
Expand All @@ -9,8 +10,10 @@ class TextPlot(ElementPlot):
style_opts = text_properties
_plot_method = 'text'

def get_data(self, element, ranges=None):
def get_data(self, element, ranges=None, empty=False):
mapping = dict(x='x', y='y', text='text')
if empty:
return dict(x=[], y=[], text=[]), mapping
return (dict(x=[element.x], y=[element.y],
text=[element.text]), mapping)

Expand All @@ -21,22 +24,29 @@ def get_extents(self, element, ranges=None):
class LineAnnotationPlot(ElementPlot):

style_opts = line_properties
_plot_method = 'segment'

def get_data(self, element, ranges=None):
def get_data(self, element, ranges=None, empty=False):
plot = self.handles['plot']
data, mapping = {}, {}
if isinstance(element, HLine):
x0 = plot.x_range.start
y0 = element.data
x1 = plot.x_range.end
y1 = element.data
mapping['bottom'] = element.data
mapping['top'] = element.data
elif isinstance(element, VLine):
x0 = element.data
y0 = plot.y_range.start
x1 = element.data
y1 = plot.y_range.end
return (dict(x0=[x0], y0=[y0], x1=[x1], y1=[y1]),
dict(x0='x0', y0='y0', x1='x1', y1='y1'))
mapping['left'] = element.data
mapping['right'] = element.data
return (data, mapping)


def _init_glyph(self, plot, mapping, properties):
"""
Returns a Bokeh glyph object.
"""
properties.pop('source')
properties.pop('legend')
box = BoxAnnotation(plot=plot, level='overlay',
**dict(mapping, **properties))
plot.renderers.append(box)
return box


def get_extents(self, element, ranges=None):
Expand All @@ -53,10 +63,15 @@ class SplinePlot(ElementPlot):
style_opts = line_properties
_plot_method = 'bezier'

def get_data(self, element, ranges=None):
verts = np.array(element.data[0])
xs, ys = verts[:, 0], verts[:, 1]
return (dict(x0=[xs[0]], y0=[ys[0]], x1=[xs[-1]], y1=[ys[-1]],
cx0=[xs[1]], cy0=[ys[1]], cx1=[xs[2]], cy1=[ys[2]]),
dict(x0='x0', y0='y0', x1='x1', y1='y1',
cx0='cx0', cx1='cx1', cy0='cy0', cy1='cy1'))
def get_data(self, element, ranges=None, empty=False):
data_attrs = ['x0', 'y0', 'x1', 'y1',
'cx0', 'cx1', 'cy0', 'cy1']
if empty:
data = {attr: [] for attr in data_attrs}
else:
verts = np.array(element.data[0])
xs, ys = verts[:, 0], verts[:, 1]
data = dict(x0=[xs[0]], y0=[ys[0]], x1=[xs[-1]], y1=[ys[-1]],
cx0=[xs[1]], cy0=[ys[1]], cx1=[xs[2]], cy1=[ys[2]])

return (data, dict(zip(data_attrs, data_attrs)))
42 changes: 29 additions & 13 deletions holoviews/plotting/bokeh/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,22 @@ def serialize(self, objects):



class DownsampleImage(Callback):
class DownsampleCallback(Callback):
"""
DownsampleCallbacks can downsample the data before it is
plotted and can therefore provide major speed optimizations.
"""

apply_on_update = param.Boolean(default=True, doc="""
Callback should always be applied after each update to
downsample the data before it is displayed.""")

reinitialize = param.Boolean(default=True, doc="""
DownsampleColumns should be reinitialized per plot object""")



class DownsampleImage(DownsampleCallback):
"""
Downsamples any Image plot to the specified
max_width and max_height by slicing the
Expand All @@ -165,10 +180,6 @@ class DownsampleImage(Callback):
constraints.
"""

apply_on_update = param.Boolean(default=True, doc="""
Callback should always be applied after each update to
downsample the data before it is displayed.""")

max_width = param.Integer(default=250, doc="""
Maximum plot width in pixels after slicing and downsampling.""")

Expand Down Expand Up @@ -210,26 +221,22 @@ def __call__(self, data):



class DownsampleColumns(Callback):
class DownsampleColumns(DownsampleCallback):
"""
Downsamples any column based Element by randomizing
the rows and updating the ColumnDataSource with
up to max_samples.
"""

apply_on_update = param.Boolean(default=True, doc="""
Callback should always be applied after each update to
downsample the data before it is displayed.""")
compute_ranges = param.Boolean(default=False, doc="""
Whether the ranges are recomputed for the sliced region""")

max_samples = param.Integer(default=800, doc="""
Maximum number of samples to display at the same time.""")

random_seed = param.Integer(default=42, doc="""
Seed used to initialize randomization.""")

reinitialize = param.Boolean(default=True, doc="""
DownsampleColumns should be reinitialized per plot object""")

plot_attributes = param.Dict(default={'x_range': ['start', 'end'],
'y_range': ['start', 'end']})

Expand All @@ -248,13 +255,17 @@ def __call__(self, data):
element = plot.current_frame
if element.interface is not ArrayColumns:
element = plot.current_frame.clone(datatype=['array'])
ranges = plot.current_ranges

# Slice element to current ranges
xdim, ydim = element.dimensions(label=True)[0:2]
sliced = element.select(**{xdim: (xstart, xend),
ydim: (ystart, yend)})

if self.compute_ranges:
ranges = {d: element.range(d) for d in element.dimensions()}
else:
ranges = plot.current_ranges

# Avoid randomizing if possible (expensive)
if len(sliced) > self.max_samples:
# Randomize element samples and slice to region
Expand Down Expand Up @@ -381,6 +392,11 @@ def _chain_callbacks(self, plot, cb_obj, callbacks):
else:
cb_obj.callback = callback

@property
def downsample(self):
return any(isinstance(v, DownsampleCallback)
for _ , v in self.get_param_values())


def __call__(self, plot):
"""
Expand Down
99 changes: 61 additions & 38 deletions holoviews/plotting/bokeh/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class PointPlot(ElementPlot):
_plot_method = 'scatter'


def get_data(self, element, ranges=None):
def get_data(self, element, ranges=None, empty=False):
style = self.style[self.cyclic_index]
dims = element.dimensions(label=True)

Expand All @@ -52,22 +52,29 @@ def get_data(self, element, ranges=None):
if self.color_index < len(dims) and cmap:
map_key = 'color_' + dims[self.color_index]
mapping['color'] = map_key
cmap = get_cmap(cmap)
colors = element.dimension_values(self.color_index)
crange = ranges.get(dims[self.color_index], None)
data[map_key] = map_colors(colors, crange, cmap)
if empty:
data[map_key] = []
else:
cmap = get_cmap(cmap)
colors = element.dimension_values(self.color_index)
crange = ranges.get(dims[self.color_index], None)
data[map_key] = map_colors(colors, crange, cmap)
if self.size_index < len(dims):
map_key = 'size_' + dims[self.size_index]
mapping['size'] = map_key
ms = style.get('size', 1)
sizes = element.dimension_values(self.size_index)
data[map_key] = compute_sizes(sizes, self.size_fn,
self.scaling_factor, ms)
data[dims[0]] = element.dimension_values(0)
data[dims[1]] = element.dimension_values(1)
if empty:
data[map_key] = []
else:
ms = style.get('size', 1)
sizes = element.dimension_values(self.size_index)
data[map_key] = compute_sizes(sizes, self.size_fn,
self.scaling_factor, ms)

data[dims[0]] = [] if empty else element.dimension_values(0)
data[dims[1]] = [] if empty else element.dimension_values(1)
if 'hover' in self.tools:
for d in dims[2:]:
data[d] = element.dimension_values(d)
data[d] = [] if empty else element.dimension_values(d)
return data, mapping


Expand All @@ -84,24 +91,24 @@ def _init_glyph(self, plot, mapping, properties):
color = mapping.pop('color', color)
properties.pop('legend', None)
unselected = Circle(**dict(properties, fill_color=unselect_color, **mapping))
selected = Circle(**dict(properties, fill_color=color, **mapping))
plot.add_glyph(source, selected, selection_glyph=selected,
glyph = Circle(**dict(properties, fill_color=color, **mapping))
plot.add_glyph(source, selected, selection_glyph=glyph,
nonselection_glyph=unselected)
else:
getattr(plot, self._plot_method)(**dict(properties, **mapping))

glyph = getattr(plot, self._plot_method)(**dict(properties, **mapping))
return glyph


class CurvePlot(ElementPlot):

style_opts = ['color'] + line_properties
_plot_method = 'line'

def get_data(self, element, ranges=None):
def get_data(self, element, ranges=None, empty=False):
x = element.get_dimension(0).name
y = element.get_dimension(1).name
return ({x: element.dimension_values(0),
y: element.dimension_values(1)},
return ({x: [] if empty else element.dimension_values(0),
y: [] if empty else element.dimension_values(1)},
dict(x=x, y=y))


Expand All @@ -112,7 +119,9 @@ class SpreadPlot(PolygonPlot):
def __init__(self, *args, **kwargs):
super(SpreadPlot, self).__init__(*args, **kwargs)

def get_data(self, element, ranges=None):
def get_data(self, element, ranges=None, empty=None):
if empty:
return dict(xs=[], ys=[]), self._mapping

xvals = element.dimension_values(0)
mean = element.dimension_values(1)
Expand All @@ -132,13 +141,16 @@ class HistogramPlot(ElementPlot):
style_opts = ['color'] + line_properties + fill_properties
_plot_method = 'quad'

def get_data(self, element, ranges=None):
def get_data(self, element, ranges=None, empty=None):
mapping = dict(top='top', bottom=0, left='left', right='right')
data = dict(top=element.values, left=element.edges[:-1],
right=element.edges[1:])
if empty:
data = dict(top=[], left=[], right=[])
else:
data = dict(top=element.values, left=element.edges[:-1],
right=element.edges[1:])

if 'hover' in self.default_tools + self.tools:
data.update({d: element.dimension_values(d)
data.update({d: [] if empty else element.dimension_values(d)
for d in element.dimensions(label=True)})
return (data, mapping)

Expand All @@ -154,14 +166,17 @@ class SideHistogramPlot(HistogramPlot):
show_title = param.Boolean(default=False, doc="""
Whether to display the plot title.""")

def get_data(self, element, ranges=None):
def get_data(self, element, ranges=None, empty=None):
if self.invert_axes:
mapping = dict(top='left', bottom='right', left=0, right='top')
else:
mapping = dict(top='top', bottom=0, left='left', right='right')

data = dict(top=element.values, left=element.edges[:-1],
right=element.edges[1:])
if empty:
data = dict(top=[], left=[], right=[])
else:
data = dict(top=element.values, left=element.edges[:-1],
right=element.edges[1:])

dim = element.get_dimension(0).name
main = self.adjoined.main
Expand All @@ -174,12 +189,11 @@ def get_data(self, element, ranges=None):

if 'cmap' in style or 'palette' in style:
cmap = get_cmap(style.get('cmap', style.get('palette', None)))
colors = map_colors(vals, main_range, cmap)
data['color'] = colors
data['color'] = [] if empty else map_colors(vals, main_range, cmap)
mapping['fill_color'] = 'color'

if 'hover' in self.default_tools + self.tools:
data.update({d: element.dimension_values(d)
data.update({d: [] if empty else element.dimension_values(d)
for d in element.dimensions(label=True)})
return (data, mapping)

Expand All @@ -191,7 +205,10 @@ class ErrorPlot(PathPlot):

style_opts = ['color'] + line_properties

def get_data(self, element, ranges=None):
def get_data(self, element, ranges=None, empty=False):
if empty:
return dict(xs=[], ys=[]), self._mapping

data = element.array(dimensions=element.dimensions()[0:4])
err_xs = []
err_ys = []
Expand Down Expand Up @@ -231,12 +248,14 @@ def get_extents(self, element, ranges):
return l, b, r, t


def get_data(self, element, ranges=None):
def get_data(self, element, ranges=None, empty=False):
style = self.style[self.cyclic_index]
dims = element.dimensions(label=True)

pos = self.position
if len(dims) > 1:
if empty:
xs, ys, keys = [], [], []
elif len(dims) > 1:
xs, ys = zip(*(((x, x), (pos, pos+y))
for x, y in element.array()))
mapping = dict(xs=dims[0], ys=dims[1])
Expand All @@ -248,18 +267,22 @@ def get_data(self, element, ranges=None):
mapping = dict(xs=dims[0], ys='heights')
keys = (dims[0], 'heights')

if self.invert_axes: keys = keys[::-1]
if not empty and self.invert_axes: keys = keys[::-1]
data = dict(zip(keys, (xs, ys)))

cmap = style.get('palette', style.get('cmap', None))
if self.color_index < len(dims) and cmap:
cdim = dims[self.color_index]
map_key = 'color_' + cdim
mapping['color'] = map_key
cmap = get_cmap(cmap)
colors = element.dimension_values(cdim)
crange = ranges.get(cdim, None)
data[map_key] = map_colors(colors, crange, cmap)
if empty:
colors = []
else:
cmap = get_cmap(cmap)
cvals = element.dimension_values(cdim)
crange = ranges.get(cdim, None)
colors = map_colors(cvals, crange, cmap)
data[map_key] = colors

return data, mapping

Expand Down
Loading

0 comments on commit 0b9e977

Please sign in to comment.