diff --git a/CHANGES.rst b/CHANGES.rst index 623f02b176..73d39e2ad2 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -121,6 +121,9 @@ Other Changes and Additions - Added direct launchers for each config (e.g. ``specviz``) [#1960] +- Replacing existing data from a plugin (e.g., refitting a model with the same label) + now preserves the plot options of the data as previously displayed. [#2288] + 3.5.1 (unreleased) ================== diff --git a/jdaviz/configs/default/plugins/model_fitting/tests/test_plugin.py b/jdaviz/configs/default/plugins/model_fitting/tests/test_plugin.py index 4c9d89e072..2b287b1335 100644 --- a/jdaviz/configs/default/plugins/model_fitting/tests/test_plugin.py +++ b/jdaviz/configs/default/plugins/model_fitting/tests/test_plugin.py @@ -154,6 +154,35 @@ def test_register_cube_model(cubeviz_helper, spectrum1d_cube): assert test_label in cubeviz_helper.app.data_collection +def test_refit_plot_options(specviz_helper, spectrum1d): + specviz_helper.load_data(spectrum1d) + modelfit_plugin = specviz_helper.plugins['Model Fitting'] + + modelfit_plugin.model_comp_selected = 'Const1D' + modelfit_plugin._obj.comp_label = "C" + modelfit_plugin._obj.vue_add_model({}) + + with pytest.warns(AstropyUserWarning): + modelfit_plugin.calculate_fit(add_data=True) + + sv = specviz_helper.app.get_viewer('spectrum-viewer') + atts = {"color": "red", "linewidth": 2, "alpha": 0.8} + layer_state = [layer.state for layer in sv.layers if layer.layer.label == "model"][0] + for att in atts: + setattr(layer_state, att, atts[att]) + + # Refit using the same name, which will replace the data by default. + modelfit_plugin.create_model_component('Linear1D', 'L') + + with pytest.warns(AstropyUserWarning): + modelfit_plugin.calculate_fit(add_data=True) + + layer_state = [layer.state for layer in sv.layers if layer.layer.label == "model"][0] + + for att in atts: + assert atts[att] == getattr(layer_state, att) + + def test_user_api(specviz_helper, spectrum1d): with warnings.catch_warnings(): warnings.simplefilter('ignore') diff --git a/jdaviz/core/template_mixin.py b/jdaviz/core/template_mixin.py index 4b27787724..da8e6673de 100644 --- a/jdaviz/core/template_mixin.py +++ b/jdaviz/core/template_mixin.py @@ -1904,6 +1904,10 @@ def add_results_from_plugin(self, data_item, replace=None, label=None): Add ``data_item`` to the app's data_collection according to the default or user-provided label and adds to any requested viewers. """ + + # Note that we can only preserve one of percentile or vmin+vmax + ignore_attributes = ("layer", "attribute", "percentile") + if self.label_invalid_msg: raise ValueError(self.label_invalid_msg) @@ -1915,22 +1919,33 @@ def add_results_from_plugin(self, data_item, replace=None, label=None): # entry should be the same as the original entry (to avoid deleting reference data) add_to_viewer_refs = [] add_to_viewer_vis = [] + preserved_attributes = [] for viewer_select_item in self.add_to_viewer_items[1:]: # index 0 is for "None" viewer_ref = viewer_select_item['reference'] viewer_item = self.app._viewer_item_by_reference(viewer_ref) viewer = self.app.get_viewer(viewer_ref) - viewer_loaded_labels = [layer.layer.label for layer in viewer.layers] - if label in viewer_loaded_labels: - add_to_viewer_refs.append(viewer_ref) - add_to_viewer_vis.append(label in viewer_item['visible_layers']) + for layer in viewer.layers: + if layer.layer.label != label: + continue + else: + add_to_viewer_refs.append(viewer_ref) + add_to_viewer_vis.append(label in viewer_item['visible_layers']) + preserve_these = {} + for att in layer.state.as_dict(): + # Can't set cmap_att, size_att, etc + if att not in ignore_attributes and "_att" not in att: + preserve_these[att] = getattr(layer.state, att) + preserved_attributes.append(preserve_these) else: if self.add_to_viewer_selected == 'None': add_to_viewer_refs = [] add_to_viewer_vis = [] + preserved_attributes = [] else: add_to_viewer_refs = [self.add_to_viewer_selected] add_to_viewer_vis = [True] + preserved_attributes = [{}] if label in self.app.data_collection: for viewer_ref in add_to_viewer_refs: @@ -1944,18 +1959,25 @@ def add_results_from_plugin(self, data_item, replace=None, label=None): data_item.meta['mosviz_row'] = self.app.state.settings['mosviz_row'] self.app.add_data(data_item, label) - for viewer_ref, visible in zip(add_to_viewer_refs, add_to_viewer_vis): + for viewer_ref, visible, preserved in zip(add_to_viewer_refs, add_to_viewer_vis, + preserved_attributes): # replace the contents in the selected viewer with the results from this plugin + this_viewer = self.app.get_viewer(viewer_ref) if replace is not None: this_replace = replace else: - this_viewer = self.app.get_viewer(viewer_ref) this_replace = isinstance(this_viewer, BqplotImageView) self.app.add_data_to_viewer(viewer_ref, label, visible=visible, clear_other_data=this_replace) + if preserved != {}: + layer_state = [layer.state for layer in this_viewer.layers if + layer.layer.label == label][0] + for att in preserved: + setattr(layer_state, att, preserved[att]) + # update overwrite warnings, etc self._on_label_changed()