From 282f19300ac5806360a2fa146483b1decd5f3dca Mon Sep 17 00:00:00 2001 From: "Brett M. Morris" Date: Wed, 7 Aug 2024 16:01:28 -0400 Subject: [PATCH] adding support for visualizing all pixels in subsets, and per pixel on hover --- jdaviz/configs/cubeviz/plugins/__init__.py | 2 +- jdaviz/configs/cubeviz/plugins/tools.py | 52 ++---------- jdaviz/configs/default/plugins/tools.py | 47 +++++++++++ jdaviz/configs/default/plugins/viewers.py | 2 +- jdaviz/configs/rampviz/helper.py | 39 +-------- jdaviz/configs/rampviz/plugins/__init__.py | 1 + jdaviz/configs/rampviz/plugins/parsers.py | 15 ++-- .../ramp_extraction/ramp_extraction.py | 82 +++++++++++++------ jdaviz/configs/rampviz/plugins/tools.py | 43 ++++++++++ jdaviz/configs/rampviz/plugins/viewers.py | 50 ++++++++++- jdaviz/core/marks.py | 1 + 11 files changed, 216 insertions(+), 118 deletions(-) create mode 100644 jdaviz/configs/default/plugins/tools.py create mode 100644 jdaviz/configs/rampviz/plugins/tools.py diff --git a/jdaviz/configs/cubeviz/plugins/__init__.py b/jdaviz/configs/cubeviz/plugins/__init__.py index 61390844ae..4cc5c65a3a 100644 --- a/jdaviz/configs/cubeviz/plugins/__init__.py +++ b/jdaviz/configs/cubeviz/plugins/__init__.py @@ -1,7 +1,7 @@ -from .tools import * # noqa from .mixins import * # noqa from .viewers import * # noqa from .parsers import * # noqa from .moment_maps.moment_maps import * # noqa from .slice.slice import * # noqa from .spectral_extraction.spectral_extraction import * # noqa +from .tools import * # noqa diff --git a/jdaviz/configs/cubeviz/plugins/tools.py b/jdaviz/configs/cubeviz/plugins/tools.py index 0974b03299..5914151dcc 100644 --- a/jdaviz/configs/cubeviz/plugins/tools.py +++ b/jdaviz/configs/cubeviz/plugins/tools.py @@ -3,14 +3,13 @@ from glue.config import viewer_tool from glue_jupyter.bqplot.image import BqplotImageView -from glue_jupyter.bqplot.profile import BqplotProfileView from glue.viewers.common.tool import CheckableTool import numpy as np from specutils import Spectrum1D from jdaviz.core.events import SliceToolStateMessage, SliceSelectSliceMessage -from jdaviz.core.tools import PanZoom, BoxZoom, SinglePixelRegion, _MatchedZoomMixin -from jdaviz.core.marks import PluginLine +from jdaviz.core.tools import PanZoom, BoxZoom, _MatchedZoomMixin +from jdaviz.configs.default.plugins.tools import ProfileFromCube __all__ = [] @@ -81,52 +80,17 @@ def on_mouse_event(self, data): @viewer_tool -class SpectrumPerSpaxel(SinglePixelRegion): +class SpectrumPerSpaxel(ProfileFromCube): icon = os.path.join(ICON_DIR, 'pixelspectra.svg') tool_id = 'jdaviz:spectrumperspaxel' action_text = 'See spectrum at a single spaxel' tool_tip = 'Click on the viewer and see the spectrum at that spaxel in the spectrum viewer' - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._spectrum_viewer = None - self._previous_bounds = None - self._mark = None - self._data = None - - def _reset_spectrum_viewer_bounds(self): - sv_state = self._spectrum_viewer.state - sv_state.x_min = self._previous_bounds[0] - sv_state.x_max = self._previous_bounds[1] - sv_state.y_min = self._previous_bounds[2] - sv_state.y_max = self._previous_bounds[3] - - def activate(self): - self.viewer.add_event_callback(self.on_mouse_move, events=['mousemove', 'mouseleave']) - if self._spectrum_viewer is None: - # Get first profile viewer - for _, viewer in self.viewer.jdaviz_helper.app._viewer_store.items(): - if isinstance(viewer, BqplotProfileView): - self._spectrum_viewer = viewer - break - if self._mark is None: - self._mark = PluginLine(self._spectrum_viewer, visible=False) - self._spectrum_viewer.figure.marks = self._spectrum_viewer.figure.marks + [self._mark,] - # Store these so we can revert to previous user-set zoom after preview view - sv_state = self._spectrum_viewer.state - self._previous_bounds = [sv_state.x_min, sv_state.x_max, sv_state.y_min, sv_state.y_max] - super().activate() - - def deactivate(self): - self.viewer.remove_event_callback(self.on_mouse_move) - self._reset_spectrum_viewer_bounds() - super().deactivate() - def on_mouse_move(self, data): if data['event'] == 'mouseleave': self._mark.visible = False - self._reset_spectrum_viewer_bounds() + self._reset_profile_viewer_bounds() return x = int(np.round(data['domain']['x'])) @@ -157,13 +121,13 @@ def on_mouse_move(self, data): else: spectrum = cube_data.get_object(statistic=None) # Note: change this when Spectrum1D.with_spectral_axis is fixed. - x_unit = self._spectrum_viewer.state.x_display_unit + x_unit = self._profile_viewer.state.x_display_unit if spectrum.spectral_axis.unit != x_unit: new_spectral_axis = spectrum.spectral_axis.to(x_unit) spectrum = Spectrum1D(spectrum.flux, new_spectral_axis) if x >= spectrum.flux.shape[0] or x < 0 or y >= spectrum.flux.shape[1] or y < 0: - self._reset_spectrum_viewer_bounds() + self._reset_profile_viewer_bounds() self._mark.visible = False else: y_values = spectrum.flux[x, y, :] @@ -172,5 +136,5 @@ def on_mouse_move(self, data): return self._mark.update_xy(spectrum.spectral_axis.value, y_values) self._mark.visible = True - self._spectrum_viewer.state.y_max = np.nanmax(y_values.value) * 1.2 - self._spectrum_viewer.state.y_min = np.nanmin(y_values.value) * 0.8 + self._profile_viewer.state.y_max = np.nanmax(y_values.value) * 1.2 + self._profile_viewer.state.y_min = np.nanmin(y_values.value) * 0.8 diff --git a/jdaviz/configs/default/plugins/tools.py b/jdaviz/configs/default/plugins/tools.py new file mode 100644 index 0000000000..d7a158365f --- /dev/null +++ b/jdaviz/configs/default/plugins/tools.py @@ -0,0 +1,47 @@ +from glue_jupyter.bqplot.profile import BqplotProfileView +from jdaviz.core.tools import SinglePixelRegion +from jdaviz.core.marks import PluginLine + + +__all__ = ['ProfileFromCube'] + + +class ProfileFromCube(SinglePixelRegion): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._profile_viewer = None + self._previous_bounds = None + self._mark = None + self._data = None + + def _reset_profile_viewer_bounds(self): + pv_state = self._profile_viewer.state + pv_state.x_min = self._previous_bounds[0] + pv_state.x_max = self._previous_bounds[1] + pv_state.y_min = self._previous_bounds[2] + pv_state.y_max = self._previous_bounds[3] + + def activate(self): + self.viewer.add_event_callback(self.on_mouse_move, events=['mousemove', 'mouseleave']) + if self._profile_viewer is None: + # Get first profile viewer + for _, viewer in self.viewer.jdaviz_helper.app._viewer_store.items(): + if isinstance(viewer, BqplotProfileView): + self._profile_viewer = viewer + break + if self._mark is None: + self._mark = PluginLine(self._profile_viewer, visible=False) + self._profile_viewer.figure.marks = self._profile_viewer.figure.marks + [self._mark, ] + # Store these so we can revert to previous user-set zoom after preview view + pv_state = self._profile_viewer.state + self._previous_bounds = [pv_state.x_min, pv_state.x_max, pv_state.y_min, pv_state.y_max] + super().activate() + + def deactivate(self): + self.viewer.remove_event_callback(self.on_mouse_move) + self._reset_profile_viewer_bounds() + super().deactivate() + + def on_mouse_move(self, data): + raise NotImplementedError("must be implemented by sublcasses") diff --git a/jdaviz/configs/default/plugins/viewers.py b/jdaviz/configs/default/plugins/viewers.py index c484660e14..4f54a247f3 100644 --- a/jdaviz/configs/default/plugins/viewers.py +++ b/jdaviz/configs/default/plugins/viewers.py @@ -191,7 +191,7 @@ def _apply_layer_defaults(self, layer_state): layer_state.add_callback('as_steps', self._show_uncertainty_changed) def _expected_subset_layer_default(self, layer_state): - if self.__class__.__name__ == 'CubevizImageView': + if self.__class__.__name__ in ('CubevizImageView', 'RampvizImageView'): # Do not override default for subsets as for some reason # this isn't getting called when they're first added, but rather when # the next state change is made (for example: manually changing the visibility) diff --git a/jdaviz/configs/rampviz/helper.py b/jdaviz/configs/rampviz/helper.py index 4fb68c717a..67f6ebeddd 100644 --- a/jdaviz/configs/rampviz/helper.py +++ b/jdaviz/configs/rampviz/helper.py @@ -1,5 +1,5 @@ from jdaviz.core.events import SliceSelectSliceMessage -from jdaviz.core.events import AddDataMessage, SnackbarMessage +from jdaviz.core.events import AddDataMessage from jdaviz.core.helpers import CubeConfigHelper from jdaviz.configs.rampviz.plugins.viewers import RampvizImageView @@ -49,42 +49,11 @@ def load_data(self, data, data_label=None, **kwargs): self.app.hub.subscribe(self, AddDataMessage, handler=self._set_x_axis) - if 'Ramp Extraction' not in self.plugins: # pragma: no cover - msg = SnackbarMessage( - "Automatic ramp extraction requires the Ramp Extraction plugin to be enabled", # noqa - color='error', sender=self, timeout=10000) - self.app.hub.broadcast(msg) - else: - try: - self.plugins['Ramp Extraction']._obj._extract_in_new_instance(auto_update=False, add_data=True) # noqa - except Exception as err: - msg = SnackbarMessage( - "Automatic ramp extraction for the entire cube failed." - f" See the ramp extraction plugin to perform a custom extraction: {err}", - color='error', sender=self, timeout=10000) - else: - msg = SnackbarMessage( - "The extracted ramp profile was generated automatically for the entire cube." - " See the ramp extraction plugin for details or to" - " perform a custom extraction.", - color='warning', sender=self, timeout=10000) - self.app.hub.broadcast(msg) - def _set_x_axis(self, msg): - viewer = self.app.get_viewer(self._default_integration_viewer_reference_name) - if msg.viewer_id != viewer.reference_id: - return + viewer = self.app.get_viewer(self._default_group_viewer_reference_name) ref_data = viewer.state.reference_data - if ref_data and ref_data.ndim == 3: - for att_name in _temporal_axis_names: - if att_name in ref_data.component_ids(): - if viewer.state.x_att != ref_data.id[att_name]: - viewer.state.x_att = ref_data.id[att_name] - viewer.state.reset_limits() - break - else: - viewer.state.x_att = ref_data.id["Pixel Axis 2 [x]"] - viewer.state.reset_limits() + viewer.state.x_att = ref_data.id["Pixel Axis 2 [x]"] + viewer.state.reset_limits() def select_group(self, group_index): """ diff --git a/jdaviz/configs/rampviz/plugins/__init__.py b/jdaviz/configs/rampviz/plugins/__init__.py index f3bea6d9d8..e00c3cbac2 100644 --- a/jdaviz/configs/rampviz/plugins/__init__.py +++ b/jdaviz/configs/rampviz/plugins/__init__.py @@ -1,3 +1,4 @@ from .viewers import * # noqa from .parsers import * # noqa from .ramp_extraction import * # noqa +from .tools import * # noqa diff --git a/jdaviz/configs/rampviz/plugins/parsers.py b/jdaviz/configs/rampviz/plugins/parsers.py index 67888a88dc..d6b844c18c 100644 --- a/jdaviz/configs/rampviz/plugins/parsers.py +++ b/jdaviz/configs/rampviz/plugins/parsers.py @@ -112,8 +112,6 @@ def parse_data(app, file_obj, data_type=None, data_label=None, ) elif isinstance(file_obj, (np.ndarray, NDData)) and file_obj.ndim in (1, 2): - if file_obj.ndim == 2: - app.get_viewer(integration_viewer_reference_name).is2d = True # load 1D profile(s) to integration_viewer _parse_ndarray( app, file_obj, data_label=data_label, @@ -166,24 +164,25 @@ def _swap_axes(x): ]) ramp_cube_data_label = f"{data_label}[DATA]" + ramp_diff_data_label = f"{data_label}[DIFF]" + + # load these cubes into the cache: + app._jdaviz_helper.cube_cache[ramp_cube_data_label] = NDDataArray(_swap_axes(data)) + app._jdaviz_helper.cube_cache[ramp_diff_data_label] = NDDataArray(_swap_axes(diff_data)) + + # load these cubes into the app: _parse_ndarray( app, file_obj=_swap_axes(data), data_label=ramp_cube_data_label, viewer_reference_name=group_viewer_reference_name, ) - - app._jdaviz_helper.cube_cache[ramp_cube_data_label] = NDDataArray(_swap_axes(data)) - - # load the diff of the data cube - ramp_diff_data_label = f"{data_label}[DIFF]" _parse_ndarray( app, file_obj=_swap_axes(diff_data), data_label=ramp_diff_data_label, viewer_reference_name=diff_viewer_reference_name, ) - app._jdaviz_helper.cube_cache[ramp_diff_data_label] = NDDataArray(_swap_axes(diff_data)) # the default collapse function in the profile viewer is "sum", # but for ramp files, "median" is more useful: diff --git a/jdaviz/configs/rampviz/plugins/ramp_extraction/ramp_extraction.py b/jdaviz/configs/rampviz/plugins/ramp_extraction/ramp_extraction.py index 84aba31a0b..8344af58d1 100644 --- a/jdaviz/configs/rampviz/plugins/ramp_extraction/ramp_extraction.py +++ b/jdaviz/configs/rampviz/plugins/ramp_extraction/ramp_extraction.py @@ -4,6 +4,7 @@ from functools import cached_property from traitlets import Bool, Float, List, Unicode, observe +from glue.core.message import DataCollectionAddMessage, SubsetUpdateMessage from jdaviz.core.events import SnackbarMessage, SliceValueUpdatedMessage from jdaviz.core.marks import PluginLine @@ -102,6 +103,12 @@ def __init__(self, *args, **kwargs): self.session.hub.subscribe(self, SliceValueUpdatedMessage, handler=self._on_slice_changed) + self.session.hub.subscribe(self, DataCollectionAddMessage, + handler=self._on_data_added) + + self.session.hub.subscribe(self, SubsetUpdateMessage, + handler=self._on_subset_update) + self._update_disabled_msg() if self.app.state.settings.get('server_is_remote', False): @@ -109,6 +116,48 @@ def __init__(self, *args, **kwargs): # on the user's machine, so export support in cubeviz should be disabled self.export_enabled = False + @property + def integration_viewer(self): + viewer = self.app.get_viewer( + self.app._jdaviz_helper._default_integration_viewer_reference_name + ) + return viewer + + def _on_data_added(self, msg={}): + if msg.data.label.endswith('[DATA]'): + self.extract(add_data=True) + self.integration_viewer._initialize_x_axis() + + def _on_subset_update(self, msg={}): + + if not hasattr(self, 'aperture') or not hasattr(self.app._jdaviz_helper, 'cube_cache'): + return + + cube_cache = self.app._jdaviz_helper.cube_cache + cube = cube_cache[list(cube_cache.keys())[0]] + + subset_lbl = msg.subset.label + color = msg.subset.style.color + + subset = self.app.get_subsets(subset_lbl)[0] + region = subset['region'] + # glue region has transposed coords relative to cached cube: + region_mask = region.to_mask().to_image(cube.shape[:-1]).astype(bool).T + cube_subset = cube[region_mask] + + mark = [ + PluginLine(self.integration_viewer, x=np.arange(cube_subset.shape[1]), y=y, + stroke_width=1.5, colors=[color], opacities=[0.3], label=subset_lbl) + for y in cube_subset + ] + + self.integration_viewer.figure.marks = [ + mark for mark in self.integration_viewer.figure.marks + if getattr(mark, 'label', None) != subset_lbl + ] + mark + + self.integration_viewer.reset_limits() + @property def user_api(self): expose = ['dataset', 'function', 'aperture', @@ -158,31 +207,6 @@ def _active_step_changed(self, *args): def slice_plugin(self): return self.app._jdaviz_helper.plugins['Slice'] - @observe('aperture_items') - @skip_if_not_tray_instance() - def _aperture_items_changed(self, msg): - if not self.do_auto_extraction: - return - if not hasattr(self, 'aperture'): - return - for item in msg['new']: - if item not in msg['old']: - if item.get('type') != 'spatial': - continue - subset_lbl = item.get('label') - try: - self._extract_in_new_instance(subset_lbl=subset_lbl, - auto_update=True, add_data=True) - except Exception as err: - msg = SnackbarMessage( - f"Automatic {self.resulting_product_name} extraction for {subset_lbl} failed: {err}", # noqa - color='error', sender=self, timeout=10000) - else: - msg = SnackbarMessage( - f"Automatic {self.resulting_product_name} extraction for {subset_lbl} successful", # noqa - color='success', sender=self) - self.app.hub.broadcast(msg) - def _extract_in_new_instance(self, dataset=None, function='Mean', subset_lbl=None, auto_update=False, add_data=False): # create a new instance of the Ramp Extraction plugin (to not affect the instance in @@ -261,9 +285,13 @@ def _extract_from_aperture(self, **kwargs): collapsed = getattr(np, selected_func)( nddata.data, **collapse_kwargs ) << nddata.unit + + def expand(x): + return np.expand_dims(x, axis=(0, 1)) + return NDDataArray( - data=collapsed, - mask=mask.all(axis=self.spatial_axes), + data=expand(collapsed), + mask=expand(mask.all(axis=self.spatial_axes)), meta=nddata.meta ) diff --git a/jdaviz/configs/rampviz/plugins/tools.py b/jdaviz/configs/rampviz/plugins/tools.py new file mode 100644 index 0000000000..a50fe7ffbe --- /dev/null +++ b/jdaviz/configs/rampviz/plugins/tools.py @@ -0,0 +1,43 @@ +import os +import numpy as np +from glue.config import viewer_tool +from jdaviz.configs.default.plugins.tools import ProfileFromCube + +__all__ = ['RampPerPixel'] + +ICON_DIR = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'data', 'icons') + + +@viewer_tool +class RampPerPixel(ProfileFromCube): + + # TODO: replace "pixelspectra" graphic with a "pixelramp" equivalent + icon = os.path.join(ICON_DIR, 'pixelspectra.svg') + tool_id = 'jdaviz:rampperpixel' + action_text = 'See ramp at a single pixel' + tool_tip = ( + 'Click on the viewer and see the ramp profile ' + 'at that pixel in the integration viewer' + ) + + def on_mouse_move(self, data): + if data['event'] == 'mouseleave': + self._mark.visible = False + self._reset_profile_viewer_bounds() + return + + x = int(np.round(data['domain']['x'])) + y = int(np.round(data['domain']['y'])) + + cube_cache = self.viewer.jdaviz_app._jdaviz_helper.cube_cache + spectrum = cube_cache[list(cube_cache.keys())[0]].data + + if x >= spectrum.shape[0] or x < 0 or y >= spectrum.shape[1] or y < 0: + self._mark.visible = False + else: + y_values = spectrum[x, y, :] + if np.all(np.isnan(y_values)): + self._mark.visible = False + return + self._mark.update_xy(np.arange(y_values.size), y_values) + self._mark.visible = True diff --git a/jdaviz/configs/rampviz/plugins/viewers.py b/jdaviz/configs/rampviz/plugins/viewers.py index 43551fb238..109ab8ae67 100644 --- a/jdaviz/configs/rampviz/plugins/viewers.py +++ b/jdaviz/configs/rampviz/plugins/viewers.py @@ -1,3 +1,4 @@ +import numpy as np from astropy.nddata import NDDataArray from glue.core import BaseData from glue_jupyter.bqplot.image import BqplotImageView @@ -5,7 +6,7 @@ from jdaviz.configs.default.plugins.viewers import JdavizViewerMixin, JdavizProfileView from jdaviz.configs.cubeviz.plugins.mixins import WithSliceSelection, WithSliceIndicator from jdaviz.core.registries import viewer_registry -from jdaviz.core.freezable_state import FreezableBqplotImageViewerState, FreezableProfileViewerState +from jdaviz.core.freezable_state import FreezableBqplotImageViewerState __all__ = ['RampvizProfileView', 'RampvizImageView'] @@ -23,13 +24,57 @@ class RampvizProfileView(JdavizProfileView, WithSliceIndicator): ] default_class = NDDataArray - _state_cls = FreezableProfileViewerState _default_profile_subset_type = 'temporal' def __init__(self, *args, **kwargs): kwargs.setdefault('default_tool_priority', ['jdaviz:selectslice']) super().__init__(*args, **kwargs) + def _initialize_x_axis(self): + self.state.x_att = self.state.x_att_helper.choices[-1] + self.set_plot_axes() + self.reset_limits() + + def reset_limits(self): + super().reset_limits() + + # override to reset to the global y limits including marks: + global_y_min = float(self.state.y_min) + global_y_max = float(self.state.y_max) + for mark in self.figure.marks: + if len(mark.y): + global_y_min = min(global_y_min, np.nanmin(mark.y)) + global_y_max = max(global_y_max, np.nanmax(mark.y)) + + if global_y_min != self.state.y_min or global_y_max != self.state.y_max: + self.set_limits( + y_min=global_y_min * 0.9, + y_max=global_y_max * 1.1 + ) + + def set_plot_axes(self): + + with self.figure.hold_sync(): + self.figure.axes[0].label = "Group" + self.figure.axes[1].label = self.state.y_display_unit + + # Make it so axis labels are not covering tick numbers. + self.figure.fig_margin["left"] = 95 + self.figure.fig_margin["bottom"] = 60 + self.figure.send_state('fig_margin') # Force update + self.figure.axes[0].label_offset = "40" + self.figure.axes[1].label_offset = "-70" + # NOTE: with tick_style changed below, the default responsive ticks in bqplot result + # in overlapping tick labels. For now we'll hardcode at 8, but this could be removed + # (default to None) if/when bqplot auto ticks react to styling options. + self.figure.axes[1].num_ticks = 8 + + # Set Y-axis to scientific notation + self.figure.axes[1].tick_format = '0.1e' + + for i in (0, 1): + self.figure.axes[i].tick_style = {'font-size': 15, 'font-weight': 600} + @viewer_registry("rampviz-image-viewer", label="Image 2D (Rampviz)") class RampvizImageView(JdavizViewerMixin, WithSliceSelection, BqplotImageView): @@ -42,6 +87,7 @@ class RampvizImageView(JdavizViewerMixin, WithSliceSelection, BqplotImageView): ['jdaviz:pixelpanzoommatch', 'jdaviz:panzoom'], ['bqplot:truecircle', 'bqplot:rectangle', 'bqplot:ellipse', 'bqplot:circannulus'], + ['jdaviz:rampperpixel'], ['jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] ] diff --git a/jdaviz/core/marks.py b/jdaviz/core/marks.py index 6ba08b63af..4906143ef3 100644 --- a/jdaviz/core/marks.py +++ b/jdaviz/core/marks.py @@ -542,6 +542,7 @@ def __init__(self, viewer, x=[], y=[], **kwargs): self.viewer = viewer # color is same blue as import button kwargs.setdefault('colors', [accent_color]) + self.label = kwargs.get('label') super().__init__(x=x, y=y, scales=kwargs.pop('scales', viewer.scales), **kwargs)