Skip to content

Commit

Permalink
Fix embed code
Browse files Browse the repository at this point in the history
  • Loading branch information
vidartf committed Jun 19, 2017
1 parent 3ed678d commit 8b1ae9c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 21 deletions.
22 changes: 12 additions & 10 deletions ipywidgets/widgets/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,12 @@ def dependency_state(widgets, drop_defaults=True):
# collect the state of all relevant widgets
if widgets is None:
# Get state of all widgets, no smart resolution needed.
widgets = Widget.widgets.values()
state = Widget.get_manager_state(drop_defaults=drop_defaults, widgets=widgets)['state']
state = Widget.get_manager_state(drop_defaults=drop_defaults, widgets=None)['state']
else:
try:
widgets[0]
except (IndexError, TypeError):
widgets = [widgets]
state = {}
for widget in widgets:
_get_recursive_state(widget, state, drop_defaults)
Expand Down Expand Up @@ -160,24 +163,23 @@ def embed_data(views, drop_defaults=True, state=None):
manager_state: dict of the widget manager state data
view_specs: a list of widget view specs
"""
if views is not None:
if views is None:
views = [w for w in Widget.widgets.values() if isinstance(w, DOMWidget)]
else:
try:
views[0]
except (IndexError, TypeError):
views = [views]
if include_all:
state = Widget.get_manager_state(drop_defaults=drop_defaults, widgets=None)['state']
else:
state = dependency_state(views, drop_defaults)

if state is None:
# Get state of all known widgets
state = state = Widget.get_manager_state(drop_defaults=drop_defaults, widgets=None)['state']

# Rely on ipywidget to get the default values
json_data = Widget.get_manager_state(widgets=[])
# but plug in our own state
json_data['state'] = state

if views is None:
views = [w for w in Widget.widgets.values() if isinstance(w, DOMWidget)]

view_specs = [w.get_view_spec() for w in views]

return dict(manager_state=json_data, view_specs=view_specs)
Expand Down
31 changes: 20 additions & 11 deletions ipywidgets/widgets/tests/test_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import traitlets

from .. import IntSlider, IntText, Widget, jslink, HBox, widget_serialization
from ..embed import embed_data, embed_snippet, embed_minimal_html
from ..embed import embed_data, embed_snippet, embed_minimal_html, dependency_state

try:
from io import StringIO
Expand Down Expand Up @@ -37,7 +37,8 @@ def teardown(self):

def test_embed_data_simple(self):
w = IntText(4)
data = embed_data(views=w, include_all=False, drop_defaults=True)
state = dependency_state(w, drop_defaults=True)
data = embed_data(views=w, drop_defaults=True, state=state)

state = data['manager_state']['state']
views = data['view_specs']
Expand All @@ -52,7 +53,8 @@ def test_embed_data_two_widgets(self):
w1 = IntText(4)
w2 = IntSlider(min=0, max=100)
jslink((w1, 'value'), (w2, 'value'))
data = embed_data(views=[w1, w2], include_all=False, drop_defaults=True)
state = dependency_state([w1, w2], drop_defaults=True)
data = embed_data(views=[w1, w2], drop_defaults=True, state=state)

state = data['manager_state']['state']
views = data['view_specs']
Expand Down Expand Up @@ -82,13 +84,9 @@ def test_embed_data_complex(self):
# Put it in an HBox
HBox(children=[w4])

data = embed_data(views=w4, include_all=False, drop_defaults=True)

state = data['manager_state']['state']
views = data['view_specs']
state = dependency_state(w3)

assert len(state) == 9
assert len(views) == 1

model_names = [s['model_name'] for s in state.values()]
assert 'IntTextModel' in model_names
Expand All @@ -99,9 +97,18 @@ def test_embed_data_complex(self):
# Check that HBox is not collected
assert 'HBoxModel' not in model_names

# Check that views make sense:

data = embed_data(views=w3, drop_defaults=True, state=state)
assert state is data['manager_state']['state']
views = data['view_specs']
assert len(views) == 1


def test_snippet(self):
w = IntText(4)
snippet = embed_snippet(views=w, include_all=False, drop_defaults=True)
state = dependency_state(w, drop_defaults=True)
snippet = embed_snippet(views=w, drop_defaults=True, state=state)

lines = snippet.splitlines()

Expand Down Expand Up @@ -132,7 +139,8 @@ def test_minimal_html_filename(self):

try:
output = os.path.join(tmpd, 'test.html')
embed_minimal_html(output, views=w, include_all=False, drop_defaults=True)
state = dependency_state(w, drop_defaults=True)
embed_minimal_html(output, views=w, drop_defaults=True, state=state)
# Check that the file is written to the intended destination:
with open(output, 'r') as f:
content = f.read()
Expand All @@ -143,6 +151,7 @@ def test_minimal_html_filename(self):
def test_minimal_html_filehandle(self):
w = IntText(4)
output = StringIO()
embed_minimal_html(output, views=w, include_all=False, drop_defaults=True)
state = dependency_state(w, drop_defaults=True)
embed_minimal_html(output, views=w, drop_defaults=True, state=state)
content = output.getvalue()
assert content.splitlines()[0] == '<!DOCTYPE html>'

0 comments on commit 8b1ae9c

Please sign in to comment.