Skip to content

Commit

Permalink
Preserve plot options when replacing data from a plugin (#2288)
Browse files Browse the repository at this point in the history
* Preserve plot options when replacing data from a plugin

* Add changelog, fix codestyle

* Add test for refitting model

* Codestyle

* Update test

Update test so that it actually fails on main

* Switch to attribute ignore list

* Ignore other attributes that we can't set
  • Loading branch information
rosteen authored Jul 13, 2023
1 parent b18ed7f commit cc81032
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ Other Changes and Additions

- Added direct launchers for each config (e.g. ``specviz``) [#1960]

- Replacing existing data from a plugin (e.g., refitting a model with the same label)
now preserves the plot options of the data as previously displayed. [#2288]

3.5.1 (unreleased)
==================

Expand Down
29 changes: 29 additions & 0 deletions jdaviz/configs/default/plugins/model_fitting/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,35 @@ def test_register_cube_model(cubeviz_helper, spectrum1d_cube):
assert test_label in cubeviz_helper.app.data_collection


def test_refit_plot_options(specviz_helper, spectrum1d):
specviz_helper.load_data(spectrum1d)
modelfit_plugin = specviz_helper.plugins['Model Fitting']

modelfit_plugin.model_comp_selected = 'Const1D'
modelfit_plugin._obj.comp_label = "C"
modelfit_plugin._obj.vue_add_model({})

with pytest.warns(AstropyUserWarning):
modelfit_plugin.calculate_fit(add_data=True)

sv = specviz_helper.app.get_viewer('spectrum-viewer')
atts = {"color": "red", "linewidth": 2, "alpha": 0.8}
layer_state = [layer.state for layer in sv.layers if layer.layer.label == "model"][0]
for att in atts:
setattr(layer_state, att, atts[att])

# Refit using the same name, which will replace the data by default.
modelfit_plugin.create_model_component('Linear1D', 'L')

with pytest.warns(AstropyUserWarning):
modelfit_plugin.calculate_fit(add_data=True)

layer_state = [layer.state for layer in sv.layers if layer.layer.label == "model"][0]

for att in atts:
assert atts[att] == getattr(layer_state, att)


def test_user_api(specviz_helper, spectrum1d):
with warnings.catch_warnings():
warnings.simplefilter('ignore')
Expand Down
34 changes: 28 additions & 6 deletions jdaviz/core/template_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1904,6 +1904,10 @@ def add_results_from_plugin(self, data_item, replace=None, label=None):
Add ``data_item`` to the app's data_collection according to the default or user-provided
label and adds to any requested viewers.
"""

# Note that we can only preserve one of percentile or vmin+vmax
ignore_attributes = ("layer", "attribute", "percentile")

if self.label_invalid_msg:
raise ValueError(self.label_invalid_msg)

Expand All @@ -1915,22 +1919,33 @@ def add_results_from_plugin(self, data_item, replace=None, label=None):
# entry should be the same as the original entry (to avoid deleting reference data)
add_to_viewer_refs = []
add_to_viewer_vis = []
preserved_attributes = []
for viewer_select_item in self.add_to_viewer_items[1:]:
# index 0 is for "None"
viewer_ref = viewer_select_item['reference']
viewer_item = self.app._viewer_item_by_reference(viewer_ref)
viewer = self.app.get_viewer(viewer_ref)
viewer_loaded_labels = [layer.layer.label for layer in viewer.layers]
if label in viewer_loaded_labels:
add_to_viewer_refs.append(viewer_ref)
add_to_viewer_vis.append(label in viewer_item['visible_layers'])
for layer in viewer.layers:
if layer.layer.label != label:
continue
else:
add_to_viewer_refs.append(viewer_ref)
add_to_viewer_vis.append(label in viewer_item['visible_layers'])
preserve_these = {}
for att in layer.state.as_dict():
# Can't set cmap_att, size_att, etc
if att not in ignore_attributes and "_att" not in att:
preserve_these[att] = getattr(layer.state, att)
preserved_attributes.append(preserve_these)
else:
if self.add_to_viewer_selected == 'None':
add_to_viewer_refs = []
add_to_viewer_vis = []
preserved_attributes = []
else:
add_to_viewer_refs = [self.add_to_viewer_selected]
add_to_viewer_vis = [True]
preserved_attributes = [{}]

if label in self.app.data_collection:
for viewer_ref in add_to_viewer_refs:
Expand All @@ -1944,18 +1959,25 @@ def add_results_from_plugin(self, data_item, replace=None, label=None):
data_item.meta['mosviz_row'] = self.app.state.settings['mosviz_row']
self.app.add_data(data_item, label)

for viewer_ref, visible in zip(add_to_viewer_refs, add_to_viewer_vis):
for viewer_ref, visible, preserved in zip(add_to_viewer_refs, add_to_viewer_vis,
preserved_attributes):
# replace the contents in the selected viewer with the results from this plugin
this_viewer = self.app.get_viewer(viewer_ref)
if replace is not None:
this_replace = replace
else:
this_viewer = self.app.get_viewer(viewer_ref)
this_replace = isinstance(this_viewer, BqplotImageView)

self.app.add_data_to_viewer(viewer_ref,
label,
visible=visible, clear_other_data=this_replace)

if preserved != {}:
layer_state = [layer.state for layer in this_viewer.layers if
layer.layer.label == label][0]
for att in preserved:
setattr(layer_state, att, preserved[att])

# update overwrite warnings, etc
self._on_label_changed()

Expand Down

0 comments on commit cc81032

Please sign in to comment.