Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate VegaFusion into JupyterChart #3281

Merged
merged 14 commits into from
Dec 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 140 additions & 7 deletions altair/jupyter/js/index.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import embed from "https://esm.sh/vega-embed@6?deps=vega@5&deps=vega-lite@5.16.3";
import debounce from "https://esm.sh/lodash-es@4.17.21/debounce";
import vegaEmbed from "https://esm.sh/vega-embed@6?deps=vega@5&deps=vega-lite@5.16.3";
import lodashDebounce from "https://esm.sh/lodash-es@4.17.21/debounce";

export async function render({ model, el }) {
let finalize;
Expand All @@ -19,10 +19,21 @@ export async function render({ model, el }) {
finalize();
}

let spec = model.get("spec");
model.set("local_tz", Intl.DateTimeFormat().resolvedOptions().timeZone);
mattijn marked this conversation as resolved.
Show resolved Hide resolved

let spec = structuredClone(model.get("spec"));
if (spec == null) {
// Remove any existing chart and return
while (el.firstChild) {
el.removeChild(el.lastChild);
}
model.save_changes();
return;
}

let api;
try {
api = await embed(el, spec);
api = await vegaEmbed(el, spec);
} catch (error) {
showError(error)
return;
Expand All @@ -32,7 +43,10 @@ export async function render({ model, el }) {

// Debounce config
const wait = model.get("debounce_wait") ?? 10;
const maxWait = wait;
const debounceOpts = {leading: false, trailing: true};
if (model.get("max_wait") ?? true) {
debounceOpts["maxWait"] = wait;
}

const initialSelections = {};
for (const selectionName of Object.keys(model.get("_vl_selections"))) {
Expand All @@ -45,7 +59,7 @@ export async function render({ model, el }) {
model.set("_vl_selections", newSelections);
model.save_changes();
};
api.view.addSignalListener(selectionName, debounce(selectionHandler, wait, {maxWait}));
api.view.addSignalListener(selectionName, lodashDebounce(selectionHandler, wait, debounceOpts));

initialSelections[selectionName] = {
value: cleanJson(api.view.signal(selectionName) ?? {}),
Expand All @@ -62,7 +76,7 @@ export async function render({ model, el }) {
model.set("_params", newParams);
model.save_changes();
};
api.view.addSignalListener(paramName, debounce(paramHandler, wait, {maxWait}));
api.view.addSignalListener(paramName, lodashDebounce(paramHandler, wait, debounceOpts));

initialParams[paramName] = api.view.signal(paramName) ?? null
}
Expand All @@ -76,13 +90,132 @@ export async function render({ model, el }) {
}
await api.view.runAsync();
});

// Add signal/data listeners
for (const watch of model.get("_js_watch_plan") ?? []) {
if (watch.namespace === "data") {
const dataHandler = (_, value) => {
model.set("_js_to_py_updates", [{
namespace: "data",
name: watch.name,
scope: watch.scope,
value: cleanJson(value)
}]);
model.save_changes();
};
addDataListener(api.view, watch.name, watch.scope, lodashDebounce(dataHandler, wait, debounceOpts))

} else if (watch.namespace === "signal") {
const signalHandler = (_, value) => {
model.set("_js_to_py_updates", [{
namespace: "signal",
name: watch.name,
scope: watch.scope,
value: cleanJson(value)
}]);
model.save_changes();
};

addSignalListener(api.view, watch.name, watch.scope, lodashDebounce(signalHandler, wait, debounceOpts))
}
}

// Add signal/data updaters
model.on('change:_py_to_js_updates', async (updates) => {
for (const update of updates.changed._py_to_js_updates ?? []) {
if (update.namespace === "signal") {
setSignalValue(api.view, update.name, update.scope, update.value);
} else if (update.namespace === "data") {
setDataValue(api.view, update.name, update.scope, update.value);
}
}
await api.view.runAsync();
});
}

model.on('change:spec', reembed);
model.on('change:debounce_wait', reembed);
model.on('change:max_wait', reembed);
await reembed();
}

function cleanJson(data) {
return JSON.parse(JSON.stringify(data))
}

function getNestedRuntime(view, scope) {
var runtime = view._runtime;
for (const index of scope) {
runtime = runtime.subcontext[index];
}
return runtime
}

function lookupSignalOp(view, name, scope) {
let parent_runtime = getNestedRuntime(view, scope);
return parent_runtime.signals[name] ?? null;
}

function dataRef(view, name, scope) {
let parent_runtime = getNestedRuntime(view, scope);
return parent_runtime.data[name];
}

export function setSignalValue(view, name, scope, value) {
let signal_op = lookupSignalOp(view, name, scope);
view.update(signal_op, value);
}

export function setDataValue(view, name, scope, value) {
let dataset = dataRef(view, name, scope);
let changeset = view.changeset().remove(() => true).insert(value)
dataset.modified = true;
view.pulse(dataset.input, changeset);
}

export function addSignalListener(view, name, scope, handler) {
let signal_op = lookupSignalOp(view, name, scope);
return addOperatorListener(
view,
name,
signal_op,
handler,
);
}

export function addDataListener(view, name, scope, handler) {
let dataset = dataRef(view, name, scope).values;
return addOperatorListener(
view,
name,
dataset,
handler,
);
}

// Private helpers from Vega for dealing with nested signals/data
function findOperatorHandler(op, handler) {
const h = (op._targets || [])
.filter(op => op._update && op._update.handler === handler);
return h.length ? h[0] : null;
}

function addOperatorListener(view, name, op, handler) {
let h = findOperatorHandler(op, handler);
if (!h) {
h = trap(view, () => handler(name, op.value));
h.handler = handler;
view.on(op, null, h);
}
return view;
}

function trap(view, fn) {
return !fn ? null : function() {
try {
fn.apply(this, arguments);
} catch (error) {
view.error(error);
}
};
}
100 changes: 88 additions & 12 deletions altair/jupyter/jupyter_chart.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import json
import anywidget
import traitlets
import pathlib
from typing import Any, Set

import altair as alt
from altair.utils._vegafusion_data import using_vegafusion
from altair.utils._vegafusion_data import (
using_vegafusion,
compile_to_vegafusion_chart_state,
)
from altair import TopLevelSpec
from altair.utils.selection import IndexSelection, PointSelection, IntervalSelection

Expand All @@ -20,9 +24,7 @@ def __init__(self, trait_values):
super().__init__()

for key, value in trait_values.items():
if isinstance(value, int):
traitlet_type = traitlets.Int()
elif isinstance(value, float):
if isinstance(value, (int, float)):
traitlet_type = traitlets.Float()
elif isinstance(value, str):
traitlet_type = traitlets.Unicode()
Expand Down Expand Up @@ -101,9 +103,12 @@ class JupyterChart(anywidget.AnyWidget):
"""

# Public traitlets
chart = traitlets.Instance(TopLevelSpec)
spec = traitlets.Dict().tag(sync=True)
chart = traitlets.Instance(TopLevelSpec, allow_none=True)
spec = traitlets.Dict(allow_none=True).tag(sync=True)
debounce_wait = traitlets.Float(default_value=10).tag(sync=True)
max_wait = traitlets.Bool(default_value=True).tag(sync=True)
local_tz = traitlets.Unicode(default_value=None, allow_none=True).tag(sync=True)
debug = traitlets.Bool(default_value=False)

# Internal selection traitlets
_selection_types = traitlets.Dict()
Expand All @@ -112,7 +117,20 @@ class JupyterChart(anywidget.AnyWidget):
# Internal param traitlets
_params = traitlets.Dict().tag(sync=True)

def __init__(self, chart: TopLevelSpec, debounce_wait: int = 10, **kwargs: Any):
# Internal comm traitlets for VegaFusion support
_chart_state = traitlets.Any(allow_none=True)
_js_watch_plan = traitlets.Any(allow_none=True).tag(sync=True)
_js_to_py_updates = traitlets.Any(allow_none=True).tag(sync=True)
_py_to_js_updates = traitlets.Any(allow_none=True).tag(sync=True)

def __init__(
self,
chart: TopLevelSpec,
debounce_wait: int = 10,
max_wait: bool = True,
debug: bool = False,
**kwargs: Any,
):
"""
Jupyter Widget for displaying and updating Altair Charts, and
retrieving selection and parameter values
Expand All @@ -122,11 +140,24 @@ def __init__(self, chart: TopLevelSpec, debounce_wait: int = 10, **kwargs: Any):
chart: Chart
Altair Chart instance
debounce_wait: int
Debouncing wait time in milliseconds
Debouncing wait time in milliseconds. Updates will be sent from the client to the kernel
after debounce_wait milliseconds of no chart interactions.
max_wait: bool
If True (default), updates will be sent from the client to the kernel every debounce_wait
milliseconds even if there are ongoing chart interactions. If False, updates will not be
sent until chart interactions have completed.
debug: bool
If True, debug messages will be printed
"""
self.params = Params({})
self.selections = Selections({})
super().__init__(chart=chart, debounce_wait=debounce_wait, **kwargs)
super().__init__(
chart=chart,
debounce_wait=debounce_wait,
max_wait=max_wait,
debug=debug,
**kwargs,
)

@traitlets.observe("chart")
def _on_change_chart(self, change):
Expand All @@ -135,14 +166,22 @@ def _on_change_chart(self, change):
state when the wrapped Chart instance changes
"""
new_chart = change.new

params = getattr(new_chart, "params", [])
selection_watches = []
selection_types = {}
initial_params = {}
initial_vl_selections = {}
empty_selections = {}

if new_chart is None:
with self.hold_sync():
self.spec = None
self._selection_types = selection_types
self._vl_selections = initial_vl_selections
self._params = initial_params
return

params = getattr(new_chart, "params", [])

if params is not alt.Undefined:
for param in new_chart.params:
if isinstance(param.name, alt.ParameterName):
Expand Down Expand Up @@ -205,13 +244,50 @@ def on_param_traitlet_changed(param_change):
# Update properties all together
with self.hold_sync():
if using_vegafusion():
self.spec = new_chart.to_dict(format="vega")
if self.local_tz is None:
self.spec = None

def on_local_tz_change(change):
self._init_with_vegafusion(change["new"])

self.observe(on_local_tz_change, ["local_tz"])
else:
self._init_with_vegafusion(self.local_tz)
else:
self.spec = new_chart.to_dict()
self._selection_types = selection_types
self._vl_selections = initial_vl_selections
self._params = initial_params

def _init_with_vegafusion(self, local_tz: str):
if self.chart is not None:
vegalite_spec = self.chart.to_dict(context={"pre_transform": False})
with self.hold_sync():
self._chart_state = compile_to_vegafusion_chart_state(
vegalite_spec, local_tz
)
self._js_watch_plan = self._chart_state.get_watch_plan()[
"client_to_server"
]
self.spec = self._chart_state.get_transformed_spec()

# Callback to update chart state and send updates back to client
def on_js_to_py_updates(change):
if self.debug:
updates_str = json.dumps(change["new"], indent=2)
print(
f"JavaScript to Python VegaFusion updates:\n {updates_str}"
)
updates = self._chart_state.update(change["new"])
if self.debug:
updates_str = json.dumps(updates, indent=2)
print(
f"Python to JavaScript VegaFusion updates:\n {updates_str}"
)
self._py_to_js_updates = updates

self.observe(on_js_to_py_updates, ["_js_to_py_updates"])

@traitlets.observe("_params")
def _on_change_params(self, change):
for param_name, value in change.new.items():
Expand Down
2 changes: 1 addition & 1 deletion altair/utils/_importers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def import_vegafusion() -> ModuleType:
min_version = "1.4.0"
min_version = "1.5.0"
try:
version = importlib_version("vegafusion")
if Version(version) < Version(min_version):
Expand Down
Loading