diff --git a/jdaviz/app.py b/jdaviz/app.py index 7c97807cf9..78b642619e 100644 --- a/jdaviz/app.py +++ b/jdaviz/app.py @@ -1979,7 +1979,8 @@ def set_data_visibility(self, viewer_reference, data_label, visible=True, replac layer_is_wcs_only = getattr(layer.layer, 'meta', {}).get(self._wcs_only_label, False) if layer.layer.data.label == data_label and layer_is_wcs_only: layer.visible = False - viewer.state.wcs_only_layers.append(data_label) + if data_label not in viewer.state.wcs_only_layers: + viewer.state.wcs_only_layers.append(data_label) selected_items.pop(data_id) # Sets the plot axes labels to be the units of the most recently @@ -2254,7 +2255,7 @@ def _create_viewer_item(self, viewer, vid=None, name=None, reference=None): 'config': self.config, # give viewer access to app config/layout 'data_open': False, 'collapse': True, - 'reference': reference, + 'reference': reference or name or vid, 'linked_by_wcs': linked_by_wcs, } @@ -2282,6 +2283,7 @@ def _on_new_viewer(self, msg, vid=None, name=None, add_layers_to_viewer=False): viewer : `~glue_jupyter.bqplot.common.BqplotBaseView` The new viewer instance. """ + viewer = self._application_handler.new_data_viewer( msg.cls, data=msg.data, show=False) viewer.figure_widget.layout.height = '100%' @@ -2319,6 +2321,9 @@ def _on_new_viewer(self, msg, vid=None, name=None, add_layers_to_viewer=False): ref_data = self._jdaviz_helper.default_viewer._obj.state.reference_data new_viewer_item['reference_data_label'] = getattr(ref_data, 'label', None) + if hasattr(viewer, 'reference'): + viewer.state.reference_data = ref_data + new_stack_item = self._create_stack_item( container='gl-stack', viewers=[new_viewer_item]) @@ -2333,14 +2338,13 @@ def _on_new_viewer(self, msg, vid=None, name=None, add_layers_to_viewer=False): self.session.application.viewers.append(viewer) - # Send out a toast message - self.hub.broadcast(ViewerAddedMessage(vid, sender=self)) - if add_layers_to_viewer: for layer_label in add_layers_to_viewer: - self.add_data_to_viewer(viewer.reference, layer_label) + if hasattr(viewer, 'reference'): + self.add_data_to_viewer(viewer.reference, layer_label) - viewer.state.reference_data = ref_data + # Send out a toast message + self.hub.broadcast(ViewerAddedMessage(vid, sender=self)) return viewer diff --git a/jdaviz/configs/default/plugins/plot_options/plot_options.py b/jdaviz/configs/default/plugins/plot_options/plot_options.py index 391748f737..f5ab09951a 100644 --- a/jdaviz/configs/default/plugins/plot_options/plot_options.py +++ b/jdaviz/configs/default/plugins/plot_options/plot_options.py @@ -357,11 +357,7 @@ def supports_line(state): return is_profile(state) or is_scatter(state) def is_image(state): - wcs_only = ( - hasattr(state, 'wcs_only_layers') and - self.layer.selected in state.wcs_only_layers - ) - return isinstance(state, BqplotImageLayerState) and not wcs_only + return isinstance(state, BqplotImageLayerState) def not_image(state): return not is_image(state) diff --git a/jdaviz/configs/imviz/plugins/orientation/orientation.py b/jdaviz/configs/imviz/plugins/orientation/orientation.py index 786f8906d4..d0eb034352 100644 --- a/jdaviz/configs/imviz/plugins/orientation/orientation.py +++ b/jdaviz/configs/imviz/plugins/orientation/orientation.py @@ -14,7 +14,7 @@ from jdaviz.core.events import ( LinkUpdatedMessage, ExitBatchLoadMessage, ChangeRefDataMessage, AstrowidgetMarkersChangedMessage, MarkersPluginUpdate, - SnackbarMessage, AddDataToViewerMessage, ViewerAddedMessage + SnackbarMessage, ViewerAddedMessage, AddDataMessage ) from jdaviz.core.custom_traitlets import FloatHandleEmpty from jdaviz.core.registries import tray_registry @@ -130,7 +130,7 @@ def __init__(self, *args, **kwargs): self.hub.subscribe(self, ViewerAddedMessage, handler=self._on_viewer_added) - self.hub.subscribe(self, AddDataToViewerMessage, + self.hub.subscribe(self, AddDataMessage, handler=self._on_data_add_to_viewer) self._update_layer_label_default() @@ -224,7 +224,7 @@ def _update_link(self, msg={}): # load data into the viewer that are now compatible with the # new link type, remove data from the viewer that are now # incompatible: - wcs_linked = self.link_type.selected.lower() == 'wcs' + wcs_linked = self.link_type.selected == 'WCS' viewer_selected = self.app.get_viewer(self.viewer.selected) data_in_viewer = self.app.get_viewer(viewer_selected.reference).data() @@ -240,6 +240,9 @@ def _update_link(self, msg={}): f"Data '{data.label}' does not have a valid WCS - removing from viewer.", sender=self, color="warning")) + if wcs_linked: + self._send_wcs_layers_to_all_viewers() + self.linking_in_progress = False self._update_layer_label_default() @@ -351,36 +354,53 @@ def add_orientation(self, rotation_angle=None, east_left=None, label=None, ) # add orientation layer to all viewers: - self._add_data_to_all_viewers(label) + for viewer_ref in self.app._viewer_store: + self._add_data_to_viewer(label, viewer_ref) if set_on_create: - self.orientation._update_layer_items() self.orientation.selected = label - def _add_data_to_all_viewers(self, data_label): - for viewer_ref in self.app.get_viewer_reference_names(): - layers = [ - layer.label for layer in - self.app.get_viewer(viewer_ref).layers - ] - if data_label not in layers: - self.app.add_data_to_viewer(viewer_ref, data_label) + def _add_data_to_viewer(self, data_label, viewer_id): + viewer = self.app.get_viewer_by_id(viewer_id) + + wcs_only_layers = viewer.state.wcs_only_layers + if data_label not in wcs_only_layers: + self.app.add_data_to_viewer(viewer_id, data_label) def _on_viewer_added(self, msg): - for data_label in self.orientation.choices: - self._add_data_to_all_viewers(data_label) + self._send_wcs_layers_to_all_viewers(viewers_to_update=[msg._viewer_id]) + + @observe('viewer_items') + def _send_wcs_layers_to_all_viewers(self, *args, **kwargs): + if not hasattr(self, 'viewer'): + return - if hasattr(self, 'orientation') and len(self.orientation.choices): - self.viewer.selected = msg._viewer_id - self.orientation.selected = self.orientation.choices[0] + wcs_only_layers = self.app._jdaviz_helper.default_viewer._obj.state.wcs_only_layers + + viewers_to_update = kwargs.get( + 'viewers_to_update', self.app._viewer_store.keys() + ) + for viewer_ref in viewers_to_update: + self.viewer.selected = viewer_ref + self.orientation.update_wcs_only_filter(wcs_only=self.link_type_selected == 'WCS') + for wcs_layer in wcs_only_layers: + if wcs_layer not in self.viewer.selected_obj.layers: + self.app.add_data_to_viewer(viewer_ref, wcs_layer) + if ( + self.orientation.selected not in + self.viewer.selected_obj.state.wcs_only_layers and + self.link_type_selected == 'WCS' + ): + self.orientation.selected = base_wcs_layer_label def _on_data_add_to_viewer(self, msg): - if ( - msg._viewer_reference != 'imviz-0' and - self.app.get_viewer_by_id(msg._viewer_reference).state.reference_data is None - ): - self.viewer.selected = msg._viewer_reference - self.orientation.selected = base_wcs_layer_label + all_wcs_only_layers = all( + layer.layer.meta.get(self.app._wcs_only_label) + for layer in self.viewer.selected_obj.layers + ) + if all_wcs_only_layers and msg.data.meta.get(self.app._wcs_only_label, False): + # on adding first data layer, reset the limits: + self.viewer.selected_obj.state.reset_limits() def vue_add_orientation(self, *args, **kwargs): self.add_orientation(set_on_create=True) @@ -391,12 +411,15 @@ def _change_reference_data(self, *args, **kwargs): self.app._change_reference_data( self.orientation.selected, viewer_id=self.viewer.selected ) + viewer_item = self.app._viewer_item_by_id(self.viewer.selected) + if viewer_item != self.orientation.selected: + viewer_item['reference_data_label'] = self.orientation.selected + + def _on_refdata_change(self, msg): - def _on_refdata_change(self, msg={}): - self.orientation.only_wcs_layers = msg.data.meta.get('_WCS_ONLY', False) if hasattr(self, 'viewer'): ref_data = self.ref_data - viewer = self.app.get_viewer(self.viewer.selected) + viewer = self.viewer.selected_obj # don't select until reference data are available: if ref_data is not None: @@ -406,6 +429,11 @@ def _on_refdata_change(self, msg={}): elif not len(viewer.data()): self.link_type_selected = link_type_msg_to_trait['pixels'] + if msg.data.label not in self.orientation.choices: + return + + self.orientation.selected = msg.data.label + # we never want to highlight subsets of pixels within WCS-only layers, # so if this layer is an ImageSubsetLayerState on a WCS-only layer, # ensure that it is never visible: @@ -426,7 +454,6 @@ def ref_data(self): @property def _refdata_change_available(self): viewer = self.app.get_viewer(self.viewer.selected) - ref_data = self.ref_data selected_layer = [lyr.layer for lyr in viewer.layers if lyr.layer.label == self.orientation.selected] if len(selected_layer): @@ -434,8 +461,8 @@ def _refdata_change_available(self): else: is_subset = False return ( - ref_data is not None and - len(self.orientation.selected) and len(self.viewer.selected) and + len(self.orientation.selected) and + len(self.viewer.selected) and not is_subset ) @@ -444,7 +471,6 @@ def _on_viewer_change(self, msg={}): # don't update choices until viewer is available: ref_data = self.ref_data if hasattr(self, 'viewer') and ref_data is not None: - self.orientation._update_layer_items() if ref_data.label in self.orientation.choices: self.orientation.selected = ref_data.label diff --git a/jdaviz/configs/imviz/plugins/viewers.py b/jdaviz/configs/imviz/plugins/viewers.py index 41ab9a9f1e..1433cf40a5 100644 --- a/jdaviz/configs/imviz/plugins/viewers.py +++ b/jdaviz/configs/imviz/plugins/viewers.py @@ -299,7 +299,7 @@ def get_link_type(self, data_label): if len(self.session.application.data_collection) == 0: raise ValueError('No reference data for link look-up') - ref_label = self.state.reference_data.label + ref_label = getattr(self.state.reference_data, 'label', None) if data_label == ref_label: return 'self' diff --git a/jdaviz/configs/imviz/wcs_utils.py b/jdaviz/configs/imviz/wcs_utils.py index 3e914243c2..315b268aab 100644 --- a/jdaviz/configs/imviz/wcs_utils.py +++ b/jdaviz/configs/imviz/wcs_utils.py @@ -440,12 +440,10 @@ def _prepare_rotated_nddata(real_image_shape, wcs, rotation_angle, refdata_shape # create a fake NDData (we use arange so data boundaries show up in Imviz # if it ever is accidentally exposed) with the rotated GWCS: - sequential_data = np.arange( - np.prod(refdata_shape), dtype=np.int8 - ).reshape(refdata_shape) + placeholder_data = np.nan * np.ones(refdata_shape) ndd = NDData( - data=sequential_data, + data=placeholder_data, wcs=new_rotated_gwcs, meta={wcs_only_key: True, '_pixel_scales': pixel_scales} ) diff --git a/jdaviz/core/astrowidgets_api.py b/jdaviz/core/astrowidgets_api.py index c598fa4a31..752342a99e 100644 --- a/jdaviz/core/astrowidgets_api.py +++ b/jdaviz/core/astrowidgets_api.py @@ -9,7 +9,7 @@ from glue.config import colormaps from glue.core import Data -from jdaviz.configs.imviz.helper import get_top_layer_index +from jdaviz.configs.imviz.helper import get_top_layer_index, get_reference_image_data from jdaviz.core.events import SnackbarMessage, AstrowidgetMarkersChangedMessage from jdaviz.core.helpers import data_has_valid_wcs @@ -166,7 +166,11 @@ def offset_by(self, dx, dy): @property def zoom_level(self): - """Zoom level: + """ + The zoom level for an image viewer (not linked by WCS). + + .. warning:: when a viewer is linked by WCS, the result corresponds + to the ``zoom_level`` of the reference data. * 1 means real-pixel-size. * 2 means zoomed in by a factor of 2. @@ -178,8 +182,11 @@ def zoom_level(self): raise ValueError('Viewer is still loading, try again later') if hasattr(self, '_get_real_xy'): - i_top = get_top_layer_index(self) - image = self.layers[i_top].layer + if self.state.reference_data is not None: + image, i_ref = get_reference_image_data(self.jdaviz_app, self.reference) + else: + i_top = get_top_layer_index(self) + image = self.layers[i_top].layer real_min = self._get_real_xy(image, self.state.x_min, self.state.y_min) real_max = self._get_real_xy(image, self.state.x_max, self.state.y_max) else: @@ -210,8 +217,11 @@ def zoom_level(self, val): new_dx = self.shape[1] * 0.5 / val if hasattr(self, '_get_real_xy'): - i_top = get_top_layer_index(self) - image = self.layers[i_top].layer + if self.state.reference_data is not None: + image, i_ref = get_reference_image_data(self.jdaviz_app, self.reference) + else: + i_top = get_top_layer_index(self) + image = self.layers[i_top].layer real_min = self._get_real_xy(image, self.state.x_min, self.state.y_min) real_max = self._get_real_xy(image, self.state.x_max, self.state.y_max) cur_xcen = (real_min[0] + real_max[0]) * 0.5 diff --git a/jdaviz/core/template_mixin.py b/jdaviz/core/template_mixin.py index ea72f856a3..5f09d1e10b 100644 --- a/jdaviz/core/template_mixin.py +++ b/jdaviz/core/template_mixin.py @@ -1226,7 +1226,7 @@ class LayerSelect(SelectPluginComponent): * register with all the automatic logic in the plugin's init by passing the string names of the respective traitlets. * use component in plugin template (see below) - * refer to properties above based on the interally stored reference to the + * refer to properties above based on the internally stored reference to the instantiated object of this component * observe the traitlets created and defined in the plugin, as necessary @@ -1265,7 +1265,6 @@ def __init__(self, plugin, items, selected, viewer, ``default`` text is provided but not in ``manual_options`` it will still be included as the first item in the list. """ - super().__init__(plugin, items=items, selected=selected, @@ -1275,8 +1274,6 @@ def __init__(self, plugin, items, selected, viewer, manual_options=manual_options, default_mode=default_mode) - self.only_wcs_layers = only_wcs_layers - self.hub.subscribe(self, AddDataMessage, handler=self._on_data_added) self.hub.subscribe(self, RemoveDataMessage, @@ -1295,6 +1292,7 @@ def __init__(self, plugin, items, selected, viewer, self.add_observe(viewer, self._on_viewer_selected_changed) self.add_observe(selected, self._update_layer_items) self._update_layer_items() + self.update_wcs_only_filter(only_wcs_layers) def _get_viewer(self, viewer): # newer will likely be the viewer name in most cases, but viewer id in the case @@ -1424,31 +1422,19 @@ def _update_layer_items(self, msg={}): manual_items = [{'label': label} for label in self.manual_options] # use getattr so the super() call above doesn't try to access the attr before # it is initialized: - if not getattr(self, 'only_wcs_layers', False): - all_layers = [ - layer for viewer in self.viewer_objs - for layer in getattr(viewer, 'layers', []) - # don't include WCS-only layers unless asked: - if not hasattr(layer.layer, 'meta') or ( - not layer.layer.meta.get('_WCS_ONLY', False) - ) - ] - else: - all_layers = [ - layer for viewer in self.viewer_objs - for layer in getattr(viewer, 'layers', []) - # only include WCS-only layers: - if ( - hasattr(layer.layer, 'meta') and - layer.layer.meta.get('_WCS_ONLY', False) - ) - ] + + all_layers = [ + layer for viewer in self.viewer_objs + for layer in getattr(viewer, 'layers', []) + if self._is_valid_item(layer) + ] + # remove duplicates - we'll loop back through all selected viewers to get a list of colors # and visibilities later within _layer_to_dict layer_labels = [ layer.layer.label for layer in all_layers if self.app.state.layer_icons.get(layer.layer.label) or - getattr(self, 'only_wcs_layers', False) + self.only_wcs_layers ] unique_layer_labels = list(set(layer_labels)) layer_items = [self._layer_to_dict(layer_label) for layer_label in unique_layer_labels] @@ -1463,6 +1449,35 @@ def _sort_by_icon(items_dict): self._apply_default_selection() + def update_wcs_only_filter(self, wcs_only): + """ + The layers that are populated in LayerSelect.choices + will be either WCS-only layers (for setting viewer orientation) + or non-WCS-only layers (for "real data"). This method toggles + the layer choices by adjusting the layer filters on this + LayerSelect instance. + + Parameters + ---------- + wcs_only : bool + `True` will filter only the WCS-only layers, `False` will + give the non-WCS-only layers. + """ + def is_wcs_only(layer): + return getattr(layer.layer, 'meta', {}).get(self.app._wcs_only_label, False) + + filter_names = [getattr(filt, '__name__', '') for filt in self.filters] + + if not wcs_only and 'is_wcs_only' in filter_names: + self.filters.remove(*[filt for filt in self.filters + if getattr(filt, '__name__', '') == 'is_wcs_only']) + elif wcs_only and 'is_wcs_only' not in filter_names: + self.add_filter(is_wcs_only) + + @property + def only_wcs_layers(self): + return 'is_wcs_only' in [getattr(filt, '__name__', '') for filt in self.filters] + @cached_property def selected_obj(self): viewer_names = self.viewer diff --git a/jdaviz/core/tools.py b/jdaviz/core/tools.py index cbb7765a89..21588991ae 100644 --- a/jdaviz/core/tools.py +++ b/jdaviz/core/tools.py @@ -2,7 +2,6 @@ import time import numpy as np -from astropy import units as u from echo import delay_callback from glue.config import viewer_tool from glue.core import HubListener @@ -92,7 +91,7 @@ def on_limits_change(self, *args): from_lims = {k: getattr(self.viewer.state, k) for k in self.match_keys} orig_refdata = self.viewer.state.reference_data if hasattr(self.viewer, '_get_fov') and orig_refdata and orig_refdata.coords: - orig_fov_sky = self.viewer._get_fov() + orig_fov_sky = self.viewer._get_fov(wcs=orig_refdata.coords) sky_cen = self.viewer._get_center_skycoord() else: orig_fov_sky = sky_cen = None @@ -102,45 +101,37 @@ def on_limits_change(self, *args): # to_lims: proposed new limits for this "matched" viewer orig_lims = {k: getattr(viewer.state, k) for k in self.match_keys} to_lims = self._map_limits(self.viewer, viewer, from_lims) - - to_refdata = viewer.state.reference_data - - if (to_refdata and to_refdata.coords and (orig_refdata != to_refdata) - and (orig_fov_sky is not None)): - # if the viewers have different reference data, - # rescale the zoom and center allowing for different - # viewer rotations: - to_fov_sky = viewer._get_fov(wcs=orig_refdata.coords) - - viewer_center = viewer._get_center_skycoord(orig_refdata) - if sky_cen.separation(viewer_center) > 0.1 * u.arcsec: - # avoid recentering if the viewer is already nearly centered - viewer.center_on(sky_cen) - - viewer.zoom( - float(to_fov_sky / orig_fov_sky) - ) - continue - - # if the viewers have the same reference data, - # make their limits match as usual: - with delay_callback(viewer.state, *self.match_keys): - for ax in self.match_axes: - # to avoid recursion we'll only update the state if there is a change - # outside a tolerance set by some fraction of the limits range - if None in orig_lims.values(): - orig_range = np.inf - else: - orig_range = abs(orig_lims.get(f'{ax}_max') - orig_lims.get(f'{ax}_min')) - to_range = abs(to_lims.get(f'{ax}_max') - to_lims.get(f'{ax}_min')) - tol = 1e-6 * min(orig_range, to_range) - - for k in (f'{ax}_min', f'{ax}_max'): - value = to_lims.get(k) - orig_value = orig_lims.get(k) - if not np.isnan(value) and (orig_value is None or - abs(value-orig_lims.get(k, np.inf)) > tol): - setattr(viewer.state, k, value) + matched_refdata = viewer.state.reference_data + + if hasattr(viewer, '_get_fov'): + to_fov_sky = viewer._get_fov(wcs=matched_refdata.coords) + else: + to_fov_sky = None + + if to_fov_sky is not None and orig_fov_sky is not None: + old_level = viewer.zoom_level + viewer.zoom_level = old_level * float(to_fov_sky / orig_fov_sky) + viewer.center_on(sky_cen) + + else: + with delay_callback(viewer.state, *self.match_keys): + for ax in self.match_axes: + if None in orig_lims.values(): + orig_range = np.inf + else: + orig_range = abs(orig_lims.get(f'{ax}_max') - + orig_lims.get(f'{ax}_min')) + to_range = abs(to_lims.get(f'{ax}_max') - + to_lims.get(f'{ax}_min')) + tol = 1e-6 * min(orig_range, to_range) + + for k in (f'{ax}_min', f'{ax}_max'): + value = to_lims.get(k) + orig_value = orig_lims.get(k) + + if not np.isnan(value) and (orig_value is None or + abs(value-orig_lims.get(k, np.inf)) > tol): + setattr(viewer.state, k, value) def is_visible(self): return len(self.viewer.jdaviz_app._viewer_store) > 1