From 125aa0913fc4ae38421c788878c384222c55e0ea Mon Sep 17 00:00:00 2001 From: "Brett M. Morris" Date: Mon, 29 Jul 2024 16:48:20 -0400 Subject: [PATCH] add base profile viewer class --- jdaviz/__init__.py | 9 +- jdaviz/configs/__init__.py | 7 +- jdaviz/configs/default/plugins/viewers.py | 423 ++++++++++++++++++++- jdaviz/configs/rampviz/__init__.py | 2 + jdaviz/configs/rampviz/plugins/__init__.py | 0 jdaviz/configs/rampviz/plugins/viewers.py | 23 ++ jdaviz/configs/specviz/plugins/viewers.py | 393 +------------------ jdaviz/core/helpers.py | 2 +- jdaviz/utils.py | 38 +- 9 files changed, 498 insertions(+), 399 deletions(-) create mode 100644 jdaviz/configs/rampviz/__init__.py create mode 100644 jdaviz/configs/rampviz/plugins/__init__.py create mode 100644 jdaviz/configs/rampviz/plugins/viewers.py diff --git a/jdaviz/__init__.py b/jdaviz/__init__.py index 8553c6671e..1da9ab765b 100644 --- a/jdaviz/__init__.py +++ b/jdaviz/__init__.py @@ -13,11 +13,14 @@ # Top-level API as exposed to users. from jdaviz.app import * # noqa: F401, F403 -from jdaviz.configs.specviz import Specviz # noqa: F401 -from jdaviz.configs.specviz2d import Specviz2d # noqa: F401 -from jdaviz.configs.mosviz import Mosviz # noqa: F401 + from jdaviz.configs.cubeviz import Cubeviz # noqa: F401 from jdaviz.configs.imviz import Imviz # noqa: F401 +from jdaviz.configs.mosviz import Mosviz # noqa: F401 +from jdaviz.configs.rampviz import Rampviz # noqa: F401 +from jdaviz.configs.specviz import Specviz # noqa: F401 +from jdaviz.configs.specviz2d import Specviz2d # noqa: F401 + from jdaviz.utils import enable_hot_reloading # noqa: F401 from jdaviz.core.launcher import open # noqa: F401 diff --git a/jdaviz/configs/__init__.py b/jdaviz/configs/__init__.py index eb078d8c5c..fca0324a7a 100644 --- a/jdaviz/configs/__init__.py +++ b/jdaviz/configs/__init__.py @@ -1,6 +1,7 @@ from .cubeviz import * # noqa -from .specviz import * # noqa -from .specviz2d import * # noqa from .default import * # noqa -from .mosviz import * # noqa from .imviz import * # noqa +from .mosviz import * # noqa +from .rampviz import * # noqa +from .specviz import * # noqa +from .specviz2d import * # noqa diff --git a/jdaviz/configs/default/plugins/viewers.py b/jdaviz/configs/default/plugins/viewers.py index 4f11fd511f..8d8298c9a6 100644 --- a/jdaviz/configs/default/plugins/viewers.py +++ b/jdaviz/configs/default/plugins/viewers.py @@ -1,20 +1,47 @@ -import numpy as np from echo import delay_callback +import numpy as np + +from glue.config import data_translator +from glue.core import BaseData +from glue.core.exceptions import IncompatibleAttribute +from glue.core.units import UnitConverter +from glue.core.subset import Subset +from glue.core.subset_group import GroupedSubset from glue.viewers.scatter.state import ScatterLayerState as BqplotScatterLayerState + +from glue_astronomy.spectral_coordinates import SpectralCoordinates from glue_jupyter.bqplot.profile import BqplotProfileView from glue_jupyter.bqplot.image import BqplotImageView from glue_jupyter.table import TableViewer +from astropy import units as u +from astropy.nddata import ( + NDDataArray, StdDevUncertainty, VarianceUncertainty, InverseVariance +) +from specutils import Spectrum1D + from jdaviz.components.toolbar_nested import NestedJupyterToolbar from jdaviz.core.astrowidgets_api import AstrowidgetsImageViewerMixin +from jdaviz.core.events import SnackbarMessage +from jdaviz.core.freezable_state import FreezableProfileViewerState +from jdaviz.core.marks import LineUncertainties, ScatterMask, OffscreenLinesMarks from jdaviz.core.registries import viewer_registry from jdaviz.core.template_mixin import WithCache from jdaviz.core.user_api import ViewerUserApi from jdaviz.utils import (ColorCycler, get_subset_type, _wcs_only_label, layer_is_image_data, layer_is_not_dq) -__all__ = ['JdavizViewerMixin'] +uc = UnitConverter() + +uncertainty_str_to_cls_mapping = { + "std": StdDevUncertainty, + "var": VarianceUncertainty, + "ivar": InverseVariance +} + + +__all__ = ['JdavizViewerMixin', 'JdavizProfileView'] viewer_registry.add("g-profile-viewer", label="Profile 1D", cls=BqplotProfileView) viewer_registry.add("g-image-viewer", label="Image 2D", cls=BqplotImageView) @@ -362,3 +389,395 @@ def _ref_or_id(self): def set_plot_axes(self): # individual viewers can override to set custom axes labels/ticks/styling return + + +@viewer_registry("jdaviz-profile-viewer", label="Profile 1D") +class JdavizProfileView(JdavizViewerMixin, BqplotProfileView): + # categories: zoom resets, zoom, pan, subset, select tools, shortcuts + tools_nested = [ + ['jdaviz:homezoom', 'jdaviz:prevzoom'], + ['jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'], + ['jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'], + ['bqplot:xrange'], + ['jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] + ] + + default_class = NDDataArray + _state_cls = FreezableProfileViewerState + _default_profile_subset_type = None + + def __init__(self, *args, **kwargs): + default_tool_priority = kwargs.pop('default_tool_priority', []) + super().__init__(*args, **kwargs) + + self._subscribe_to_layers_update() + self.initialize_toolbar(default_tool_priority=default_tool_priority) + self._offscreen_lines_marks = OffscreenLinesMarks(self) + self.figure.marks = self.figure.marks + self._offscreen_lines_marks.marks + + self.state.add_callback('show_uncertainty', self._show_uncertainty_changed) + + self.display_mask = False + + # Change collapse function to sum + default_collapse_function = kwargs.pop('default_collapse_function', 'sum') + + self.state.function = default_collapse_function + + def _expected_subset_layer_default(self, layer_state): + super()._expected_subset_layer_default(layer_state) + + layer_state.linewidth = 3 + + def data(self, cls=None): + # Grab the user's chosen statistic for collapsing data + statistic = getattr(self.state, 'function', None) + data = [] + + for layer_state in self.state.layers: + if hasattr(layer_state, 'layer'): + lyr = layer_state.layer + + # For raw data, just include the data itself + if isinstance(lyr, BaseData): + _class = cls or self.default_class + + if _class is not None: + cache_key = (lyr.label, statistic) + if cache_key in self.jdaviz_app._get_object_cache: + layer_data = self.jdaviz_app._get_object_cache[cache_key] + else: + # If spectrum, collapse via the defined statistic + if _class == Spectrum1D: + layer_data = lyr.get_object(cls=_class, statistic=statistic) + else: + layer_data = lyr.get_object(cls=_class) + self.jdaviz_app._get_object_cache[cache_key] = layer_data + + data.append(layer_data) + + # For subsets, make sure to apply the subset mask to the layer data first + elif isinstance(lyr, Subset): + layer_data = lyr + + if _class is not None: + handler, _ = data_translator.get_handler_for(_class) + try: + layer_data = handler.to_object(layer_data, statistic=statistic) + except IncompatibleAttribute: + continue + data.append(layer_data) + + return data + + def get_scales(self): + fig = self.figure + # Deselect any pan/zoom or subsetting tools so they don't interfere + # with the scale retrieval + if self.toolbar.active_tool is not None: + self.toolbar.active_tool = None + return {'x': fig.interaction.x_scale, 'y': fig.interaction.y_scale} + + def _show_uncertainty_changed(self, msg=None): + # this is subscribed in init to watch for changes to the state + # object since uncertainty handling is in jdaviz instead of glue/glue-jupyter + if self.state.show_uncertainty: + self._plot_uncertainties() + else: + self._clean_error() + + def show_mask(self): + self.display_mask = True + self._plot_mask() + + def clean(self): + # Remove extra traces, in case they exist. + self.display_mask = False + self._clean_mask() + + # this will automatically call _clean_error via _show_uncertainty_changed + self.state.show_uncertainty = False + + def _clean_mask(self): + fig = self.figure + fig.marks = [x for x in fig.marks if not isinstance(x, ScatterMask)] + + def _clean_error(self): + fig = self.figure + fig.marks = [x for x in fig.marks if not isinstance(x, LineUncertainties)] + + def add_data(self, data, color=None, alpha=None, **layer_state): + """ + Overrides the base class to add markers for plotting + uncertainties and data quality flags. + + Parameters + ---------- + spectrum : :class:`glue.core.data.Data` + Data object with the spectrum. + color : obj + Color value for plotting. + alpha : float + Alpha value for plotting. + + Returns + ------- + result : bool + `True` if successful, `False` otherwise. + """ + # If this is the first loaded data, set things up for unit conversion. + if len(self.layers) == 0: + reset_plot_axes = True + else: + # Check if the new data flux unit is actually compatible since flux not linked. + try: + uc.to_unit(data, data.find_component_id("flux"), [1, 1], + u.Unit(self.state.y_display_unit)) # Error if incompatible + except Exception as err: + # Raising exception here introduces a dirty state that messes up next load_data + # but not raising exception also causes weird behavior unless we remove the data + # completely. + self.session.hub.broadcast(SnackbarMessage( + f"Failed to load {data.label}, so removed it: {repr(err)}", + sender=self, color='error')) + self.jdaviz_app.data_collection.remove(data) + return False + reset_plot_axes = False + + # The base class handles the plotting of the main + # trace representing the profile itself. + result = super().add_data(data, color, alpha, **layer_state) + + if reset_plot_axes: + x_units = data.get_component(self.state.x_att.label).units + y_units = data.get_component("flux").units + with delay_callback(self.state, "x_display_unit", "y_display_unit"): + self.state.x_display_unit = x_units if len(x_units) else None + self.state.y_display_unit = y_units if len(y_units) else None + self.set_plot_axes() + + self._plot_uncertainties() + + self._plot_mask() + + # Set default linewidth on any created spectral subset layers + # NOTE: this logic will need updating if we add support for multiple cubes as this assumes + # that new data entries (from model fitting or gaussian smooth, etc) will only be spectra + # and all subsets affected will be spectral + for layer in self.state.layers: + if (isinstance(layer.layer, GroupedSubset) + and get_subset_type(layer.layer) == self._default_profile_subset_type + and layer.layer.data.label == data.label): + layer.linewidth = 3 + + return result + + def _plot_mask(self): + if not self.display_mask: + return + + # Remove existing mask marks + self._clean_mask() + + # Loop through all active data in the viewer + for index, layer_state in enumerate(self.state.layers): + lyr = layer_state.layer + comps = [str(component) for component in lyr.components] + + # Skip subsets + if hasattr(lyr, "subset_state"): + continue + + # Ignore data that does not have a mask component + if "mask" in comps: + mask = np.array(lyr['mask'].data) + + data_obj = lyr.data.get_object(cls=self.default_class) + + if self.default_class == Spectrum1D: + data_x = data_obj.spectral_axis.value + data_y = data_obj.flux.value + else: + data_x = np.arange(data_obj.shape[-1]) + data_y = data_obj.data.value + + # For plotting markers only for the masked data + # points, erase un-masked data from trace. + y = np.where(np.asarray(mask) == 0, np.nan, data_y) + + # A subclass of the bqplot Scatter object, ScatterMask places + # 'X' marks where there is masked data in the viewer. + color = layer_state.color + alpha_shade = layer_state.alpha / 3 + mask_line_mark = ScatterMask(scales=self.scales, + marker='cross', + x=data_x, + y=y, + stroke_width=0.5, + colors=[color], + default_size=25, + default_opacities=[alpha_shade] + ) + # Add mask marks to viewer + self.figure.marks = list(self.figure.marks) + [mask_line_mark] + + def _plot_uncertainties(self): + if not self.state.show_uncertainty: + return + + # Remove existing error bars + self._clean_error() + + # Loop through all active data in the viewer + for index, layer_state in enumerate(self.state.layers): + lyr = layer_state.layer + + # Skip subsets + if hasattr(lyr, "subset_state"): + continue + + comps = [str(component) for component in lyr.components] + + # Ignore data that does not have an uncertainty component + if "uncertainty" in comps: # noqa + error = np.array(lyr['uncertainty'].data) + + # ensure that the uncertainties are represented as stddev: + uncertainty_type_str = lyr.meta.get('uncertainty_type', 'stddev') + uncert_cls = uncertainty_str_to_cls_mapping[uncertainty_type_str] + error = uncert_cls(error).represent_as(StdDevUncertainty).array + + # Then we assume that last axis is always wavelength. + # This may need adjustment after the following + # specutils PR is merged: https://github.com/astropy/specutils/pull/1033 + spectral_axis = -1 + data_obj = lyr.data.get_object(cls=self.default_class, statistic=None) + + if isinstance(lyr.data.coords, SpectralCoordinates): + spectral_wcs = lyr.data.coords + data_x = spectral_wcs.pixel_to_world_values( + np.arange(lyr.data.shape[spectral_axis]) + ) + if isinstance(data_x, tuple): + data_x = data_x[0] + else: + if hasattr(lyr.data.coords, 'spectral_wcs'): + spectral_wcs = lyr.data.coords.spectral_wcs + elif hasattr(lyr.data.coords, 'spectral'): + spectral_wcs = lyr.data.coords.spectral + data_x = spectral_wcs.pixel_to_world( + np.arange(lyr.data.shape[spectral_axis]) + ) + + data_y = data_obj.data + + # The shaded band around the spectrum trace is bounded by + # two lines, above and below the spectrum trace itself. + data_x_list = np.ndarray.tolist(data_x) + x = [data_x_list, data_x_list] + y = [np.ndarray.tolist(data_y - error), + np.ndarray.tolist(data_y + error)] + + if layer_state.as_steps: + for i in (0, 1): + a = np.insert(x[i], 0, 2*x[i][0] - x[i][1]) + b = np.append(x[i], 2*x[i][-1] - x[i][-2]) + edges = (a + b) / 2 + x[i] = np.concatenate((edges[:1], np.repeat(edges[1:-1], 2), edges[-1:])) + y[i] = np.repeat(y[i], 2) + x, y = np.asarray(x), np.asarray(y) + + # A subclass of the bqplot Lines object, LineUncertainties keeps + # track of uncertainties plotted in the viewer. LineUncertainties + # appear with two lines and shaded area in between. + color = layer_state.color + alpha_shade = layer_state.alpha / 3 + error_line_mark = LineUncertainties(viewer=self, + x=[x], + y=[y], + scales=self.scales, + stroke_width=1, + colors=[color, color], + fill_colors=[color, color], + opacities=[0.0, 0.0], + fill_opacities=[alpha_shade, + alpha_shade], + fill='between', + close_path=False + ) + + # Add error lines to viewer + self.figure.marks = list(self.figure.marks) + [error_line_mark] + + def set_plot_axes(self): + # Set y axes labels for the spectrum viewer + y_display_unit = self.state.y_display_unit + y_unit = u.Unit(y_display_unit) if y_display_unit else u.dimensionless_unscaled + + # Get local units. + locally_defined_flux_units = [ + u.Jy, u.mJy, u.uJy, u.MJy, + u.W / (u.m**2 * u.Hz), + u.eV / (u.s * u.m**2 * u.Hz), + u.erg / (u.s * u.cm**2), + u.erg / (u.s * u.cm**2 * u.Angstrom), + u.erg / (u.s * u.cm**2 * u.Hz), + u.ph / (u.s * u.cm**2 * u.Angstrom), + u.ph / (u.s * u.cm**2 * u.Hz), + u.bol, u.AB, u.ST + ] + + locally_defined_sb_units = [ + unit / u.sr for unit in locally_defined_flux_units + ] + + if any(y_unit.is_equivalent(unit) for unit in locally_defined_sb_units): + flux_unit_type = "Surface Brightness" + elif any(y_unit.is_equivalent(unit) for unit in locally_defined_flux_units): + flux_unit_type = 'Flux' + elif y_unit.is_equivalent(u.electron / u.s) or y_unit.physical_type == 'dimensionless': + # electron / s or 'dimensionless_unscaled' should be labeled counts + flux_unit_type = "Counts" + elif y_unit.is_equivalent(u.W): + flux_unit_type = "Luminosity" + else: + # default to Flux Density for flux density or uncaught types + flux_unit_type = "Flux density" + + # Set x axes labels for the spectrum viewer + x_disp_unit = self.state.x_display_unit + x_unit = u.Unit(x_disp_unit) if x_disp_unit else u.dimensionless_unscaled + + if self._state_cls == NDDataArray: + # enter this case for ramps: + spectral_axis_unit_type = "Sample" + self.state.x_display_unit = '' + elif x_unit.is_equivalent(u.m): + spectral_axis_unit_type = "Wavelength" + elif x_unit.is_equivalent(u.Hz): + spectral_axis_unit_type = "Frequency" + elif x_unit.is_equivalent(u.pixel): + spectral_axis_unit_type = "Pixel" + else: + spectral_axis_unit_type = str(x_unit.physical_type).title() + + with self.figure.hold_sync(): + self.figure.axes[0].label = f"{spectral_axis_unit_type} [{self.state.x_display_unit}]" + self.figure.axes[1].label = f"{flux_unit_type} [{self.state.y_display_unit}]" + + # Make it so axis labels are not covering tick numbers. + self.figure.fig_margin["left"] = 95 + self.figure.fig_margin["bottom"] = 60 + self.figure.send_state('fig_margin') # Force update + self.figure.axes[0].label_offset = "40" + self.figure.axes[1].label_offset = "-70" + # NOTE: with tick_style changed below, the default responsive ticks in bqplot result + # in overlapping tick labels. For now we'll hardcode at 8, but this could be removed + # (default to None) if/when bqplot auto ticks react to styling options. + self.figure.axes[1].num_ticks = 8 + + # Set Y-axis to scientific notation + self.figure.axes[1].tick_format = '0.1e' + + for i in (0, 1): + self.figure.axes[i].tick_style = {'font-size': 15, 'font-weight': 600} diff --git a/jdaviz/configs/rampviz/__init__.py b/jdaviz/configs/rampviz/__init__.py new file mode 100644 index 0000000000..af257ac8cc --- /dev/null +++ b/jdaviz/configs/rampviz/__init__.py @@ -0,0 +1,2 @@ +from .plugins import * # noqa +from .helper import Rampviz # noqa diff --git a/jdaviz/configs/rampviz/plugins/__init__.py b/jdaviz/configs/rampviz/plugins/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jdaviz/configs/rampviz/plugins/viewers.py b/jdaviz/configs/rampviz/plugins/viewers.py new file mode 100644 index 0000000000..659ef36727 --- /dev/null +++ b/jdaviz/configs/rampviz/plugins/viewers.py @@ -0,0 +1,23 @@ +from astropy.nddata import NDDataArray + +from jdaviz.core.registries import viewer_registry +from jdaviz.core.freezable_state import FreezableProfileViewerState +from jdaviz.configs.default.plugins.viewers import JdavizProfileView + +__all__ = ['RampvizProfileView'] + + +@viewer_registry("rampviz-profile-viewer", label="Profile 1D (Rampviz)") +class RampvizProfileView(JdavizProfileView): + # categories: zoom resets, zoom, pan, subset, select tools, shortcuts + tools_nested = [ + ['jdaviz:homezoom', 'jdaviz:prevzoom'], + ['jdaviz:boxzoom', 'jdaviz:xrangezoom', 'jdaviz:yrangezoom'], + ['jdaviz:panzoom', 'jdaviz:panzoom_x', 'jdaviz:panzoom_y'], + ['bqplot:xrange'], + ['jdaviz:sidebar_plot', 'jdaviz:sidebar_export'] + ] + + default_class = NDDataArray + _state_cls = FreezableProfileViewerState + _default_profile_subset_type = 'temporal' diff --git a/jdaviz/configs/specviz/plugins/viewers.py b/jdaviz/configs/specviz/plugins/viewers.py index a228c2b026..bde414f458 100644 --- a/jdaviz/configs/specviz/plugins/viewers.py +++ b/jdaviz/configs/specviz/plugins/viewers.py @@ -2,41 +2,21 @@ import numpy as np from astropy import table -from astropy import units as u -from astropy.nddata import StdDevUncertainty, VarianceUncertainty, InverseVariance -from echo import delay_callback -from glue.config import data_translator -from glue.core import BaseData -from glue.core.exceptions import IncompatibleAttribute -from glue.core.units import UnitConverter -from glue.core.subset import Subset -from glue.core.subset_group import GroupedSubset -from glue_astronomy.spectral_coordinates import SpectralCoordinates -from glue_jupyter.bqplot.profile import BqplotProfileView from matplotlib.colors import cnames from specutils import Spectrum1D -from jdaviz.core.events import SpectralMarksChangedMessage, LineIdentifyMessage, SnackbarMessage +from jdaviz.core.events import SpectralMarksChangedMessage, LineIdentifyMessage from jdaviz.core.registries import viewer_registry -from jdaviz.core.marks import SpectralLine, LineUncertainties, ScatterMask, OffscreenLinesMarks +from jdaviz.core.marks import SpectralLine from jdaviz.core.linelists import load_preset_linelist, get_available_linelists from jdaviz.core.freezable_state import FreezableProfileViewerState -from jdaviz.configs.default.plugins.viewers import JdavizViewerMixin -from jdaviz.utils import get_subset_type +from jdaviz.configs.default.plugins.viewers import JdavizProfileView __all__ = ['SpecvizProfileView'] -uc = UnitConverter() - -uncertainty_str_to_cls_mapping = { - "std": StdDevUncertainty, - "var": VarianceUncertainty, - "ivar": InverseVariance -} - @viewer_registry("specviz-profile-viewer", label="Profile 1D (Specviz)") -class SpecvizProfileView(JdavizViewerMixin, BqplotProfileView): +class SpecvizProfileView(JdavizProfileView): # categories: zoom resets, zoom, pan, subset, select tools, shortcuts tools_nested = [ ['jdaviz:homezoom', 'jdaviz:prevzoom'], @@ -50,68 +30,7 @@ class SpecvizProfileView(JdavizViewerMixin, BqplotProfileView): default_class = Spectrum1D spectral_lines = None _state_cls = FreezableProfileViewerState - - def __init__(self, *args, **kwargs): - default_tool_priority = kwargs.pop('default_tool_priority', []) - super().__init__(*args, **kwargs) - - self._subscribe_to_layers_update() - self.initialize_toolbar(default_tool_priority=default_tool_priority) - self._offscreen_lines_marks = OffscreenLinesMarks(self) - self.figure.marks = self.figure.marks + self._offscreen_lines_marks.marks - - self.state.add_callback('show_uncertainty', self._show_uncertainty_changed) - - self.display_mask = False - - # Change collapse function to sum - self.state.function = 'sum' - - def _expected_subset_layer_default(self, layer_state): - super()._expected_subset_layer_default(layer_state) - - layer_state.linewidth = 3 - - def data(self, cls=None): - # Grab the user's chosen statistic for collapsing data - statistic = getattr(self.state, 'function', None) - data = [] - - for layer_state in self.state.layers: - if hasattr(layer_state, 'layer'): - lyr = layer_state.layer - - # For raw data, just include the data itself - if isinstance(lyr, BaseData): - _class = cls or self.default_class - - if _class is not None: - cache_key = (lyr.label, statistic) - if cache_key in self.jdaviz_app._get_object_cache: - layer_data = self.jdaviz_app._get_object_cache[cache_key] - else: - # If spectrum, collapse via the defined statistic - if _class == Spectrum1D: - layer_data = lyr.get_object(cls=_class, statistic=statistic) - else: - layer_data = lyr.get_object(cls=_class) - self.jdaviz_app._get_object_cache[cache_key] = layer_data - - data.append(layer_data) - - # For subsets, make sure to apply the subset mask to the layer data first - elif isinstance(lyr, Subset): - layer_data = lyr - - if _class is not None: - handler, _ = data_translator.get_handler_for(_class) - try: - layer_data = handler.to_object(layer_data, statistic=statistic) - except IncompatibleAttribute: - continue - data.append(layer_data) - - return data + _default_profile_subset_type = 'spectral' @property def redshift(self): @@ -253,14 +172,6 @@ def erase_spectral_lines(self, name=None, name_rest=None, show_none=True): fig.marks = temp_marks self._broadcast_plotted_lines() - def get_scales(self): - fig = self.figure - # Deselect any pan/zoom or subsetting tools so they don't interfere - # with the scale retrieval - if self.toolbar.active_tool is not None: - self.toolbar.active_tool = None - return {'x': fig.interaction.x_scale, 'y': fig.interaction.y_scale} - def plot_spectral_line(self, line, global_redshift=None, plot_units=None, **kwargs): if isinstance(line, str): # Try the full index first (for backend calls), otherwise name only @@ -327,297 +238,3 @@ def plot_spectral_lines(self, colors=["blue"], global_redshift=None, **kwargs): def available_linelists(self): return get_available_linelists() - - def _show_uncertainty_changed(self, msg=None): - # this is subscribed in init to watch for changes to the state - # object since uncertainty handling is in jdaviz instead of glue/glue-jupyter - if self.state.show_uncertainty: - self._plot_uncertainties() - else: - self._clean_error() - - def show_mask(self): - self.display_mask = True - self._plot_mask() - - def clean(self): - # Remove extra traces, in case they exist. - self.display_mask = False - self._clean_mask() - - # this will automatically call _clean_error via _show_uncertainty_changed - self.state.show_uncertainty = False - - def _clean_mask(self): - fig = self.figure - fig.marks = [x for x in fig.marks if not isinstance(x, ScatterMask)] - - def _clean_error(self): - fig = self.figure - fig.marks = [x for x in fig.marks if not isinstance(x, LineUncertainties)] - - def add_data(self, data, color=None, alpha=None, **layer_state): - """ - Overrides the base class to add markers for plotting - uncertainties and data quality flags. - - Parameters - ---------- - spectrum : :class:`glue.core.data.Data` - Data object with the spectrum. - color : obj - Color value for plotting. - alpha : float - Alpha value for plotting. - - Returns - ------- - result : bool - `True` if successful, `False` otherwise. - """ - # If this is the first loaded data, set things up for unit conversion. - if len(self.layers) == 0: - reset_plot_axes = True - else: - # Check if the new data flux unit is actually compatible since flux not linked. - try: - uc.to_unit(data, data.find_component_id("flux"), [1, 1], - u.Unit(self.state.y_display_unit)) # Error if incompatible - except Exception as err: - # Raising exception here introduces a dirty state that messes up next load_data - # but not raising exception also causes weird behavior unless we remove the data - # completely. - self.session.hub.broadcast(SnackbarMessage( - f"Failed to load {data.label}, so removed it: {repr(err)}", - sender=self, color='error')) - self.jdaviz_app.data_collection.remove(data) - return False - reset_plot_axes = False - - # The base class handles the plotting of the main - # trace representing the spectrum itself. - result = super().add_data(data, color, alpha, **layer_state) - - if reset_plot_axes: - x_units = data.get_component(self.state.x_att.label).units - y_units = data.get_component("flux").units - with delay_callback(self.state, "x_display_unit", "y_display_unit"): - self.state.x_display_unit = x_units if len(x_units) else None - self.state.y_display_unit = y_units if len(y_units) else None - self.set_plot_axes() - - self._plot_uncertainties() - - self._plot_mask() - - # Set default linewidth on any created spectral subset layers - # NOTE: this logic will need updating if we add support for multiple cubes as this assumes - # that new data entries (from model fitting or gaussian smooth, etc) will only be spectra - # and all subsets affected will be spectral - for layer in self.state.layers: - if (isinstance(layer.layer, GroupedSubset) - and get_subset_type(layer.layer) == 'spectral' - and layer.layer.data.label == data.label): - layer.linewidth = 3 - - return result - - def _plot_mask(self): - if not self.display_mask: - return - - # Remove existing mask marks - self._clean_mask() - - # Loop through all active data in the viewer - for index, layer_state in enumerate(self.state.layers): - lyr = layer_state.layer - comps = [str(component) for component in lyr.components] - - # Skip subsets - if hasattr(lyr, "subset_state"): - continue - - # Ignore data that does not have a mask component - if "mask" in comps: - mask = np.array(lyr['mask'].data) - - data_obj = lyr.data.get_object() - data_x = data_obj.spectral_axis.value - data_y = data_obj.flux.value - - # For plotting markers only for the masked data - # points, erase un-masked data from trace. - y = np.where(np.asarray(mask) == 0, np.nan, data_y) - - # A subclass of the bqplot Scatter object, ScatterMask places - # 'X' marks where there is masked data in the viewer. - color = layer_state.color - alpha_shade = layer_state.alpha / 3 - mask_line_mark = ScatterMask(scales=self.scales, - marker='cross', - x=data_x, - y=y, - stroke_width=0.5, - colors=[color], - default_size=25, - default_opacities=[alpha_shade] - ) - # Add mask marks to viewer - self.figure.marks = list(self.figure.marks) + [mask_line_mark] - - def _plot_uncertainties(self): - if not self.state.show_uncertainty: - return - - # Remove existing error bars - self._clean_error() - - # Loop through all active data in the viewer - for index, layer_state in enumerate(self.state.layers): - lyr = layer_state.layer - - # Skip subsets - if hasattr(lyr, "subset_state"): - continue - - comps = [str(component) for component in lyr.components] - - # Ignore data that does not have an uncertainty component - if "uncertainty" in comps: # noqa - error = np.array(lyr['uncertainty'].data) - - # ensure that the uncertainties are represented as stddev: - uncertainty_type_str = lyr.meta.get('uncertainty_type', 'stddev') - uncert_cls = uncertainty_str_to_cls_mapping[uncertainty_type_str] - error = uncert_cls(error).represent_as(StdDevUncertainty).array - - # Then we assume that last axis is always wavelength. - # This may need adjustment after the following - # specutils PR is merged: https://github.com/astropy/specutils/pull/1033 - spectral_axis = -1 - data_obj = lyr.data.get_object(cls=Spectrum1D, statistic=None) - - if isinstance(lyr.data.coords, SpectralCoordinates): - spectral_wcs = lyr.data.coords - data_x = spectral_wcs.pixel_to_world_values( - np.arange(lyr.data.shape[spectral_axis]) - ) - if isinstance(data_x, tuple): - data_x = data_x[0] - else: - if hasattr(lyr.data.coords, 'spectral_wcs'): - spectral_wcs = lyr.data.coords.spectral_wcs - elif hasattr(lyr.data.coords, 'spectral'): - spectral_wcs = lyr.data.coords.spectral - data_x = spectral_wcs.pixel_to_world( - np.arange(lyr.data.shape[spectral_axis]) - ) - - data_y = data_obj.data - - # The shaded band around the spectrum trace is bounded by - # two lines, above and below the spectrum trace itself. - data_x_list = np.ndarray.tolist(data_x) - x = [data_x_list, data_x_list] - y = [np.ndarray.tolist(data_y - error), - np.ndarray.tolist(data_y + error)] - - if layer_state.as_steps: - for i in (0, 1): - a = np.insert(x[i], 0, 2*x[i][0] - x[i][1]) - b = np.append(x[i], 2*x[i][-1] - x[i][-2]) - edges = (a + b) / 2 - x[i] = np.concatenate((edges[:1], np.repeat(edges[1:-1], 2), edges[-1:])) - y[i] = np.repeat(y[i], 2) - x, y = np.asarray(x), np.asarray(y) - - # A subclass of the bqplot Lines object, LineUncertainties keeps - # track of uncertainties plotted in the viewer. LineUncertainties - # appear with two lines and shaded area in between. - color = layer_state.color - alpha_shade = layer_state.alpha / 3 - error_line_mark = LineUncertainties(viewer=self, - x=[x], - y=[y], - scales=self.scales, - stroke_width=1, - colors=[color, color], - fill_colors=[color, color], - opacities=[0.0, 0.0], - fill_opacities=[alpha_shade, - alpha_shade], - fill='between', - close_path=False - ) - - # Add error lines to viewer - self.figure.marks = list(self.figure.marks) + [error_line_mark] - - def set_plot_axes(self): - # Set y axes labels for the spectrum viewer - y_display_unit = self.state.y_display_unit - y_unit = u.Unit(y_display_unit) if y_display_unit else u.dimensionless_unscaled - - # Get local units. - locally_defined_flux_units = [ - u.Jy, u.mJy, u.uJy, u.MJy, - u.W / (u.m**2 * u.Hz), - u.eV / (u.s * u.m**2 * u.Hz), - u.erg / (u.s * u.cm**2), - u.erg / (u.s * u.cm**2 * u.Angstrom), - u.erg / (u.s * u.cm**2 * u.Hz), - u.ph / (u.s * u.cm**2 * u.Angstrom), - u.ph / (u.s * u.cm**2 * u.Hz), - u.bol, u.AB, u.ST - ] - - locally_defined_sb_units = [ - unit / u.sr for unit in locally_defined_flux_units - ] - - if any(y_unit.is_equivalent(unit) for unit in locally_defined_sb_units): - flux_unit_type = "Surface Brightness" - elif any(y_unit.is_equivalent(unit) for unit in locally_defined_flux_units): - flux_unit_type = 'Flux' - elif y_unit.is_equivalent(u.electron / u.s) or y_unit.physical_type == 'dimensionless': - # electron / s or 'dimensionless_unscaled' should be labeled counts - flux_unit_type = "Counts" - elif y_unit.is_equivalent(u.W): - flux_unit_type = "Luminosity" - else: - # default to Flux Density for flux density or uncaught types - flux_unit_type = "Flux density" - - # Set x axes labels for the spectrum viewer - x_disp_unit = self.state.x_display_unit - x_unit = u.Unit(x_disp_unit) if x_disp_unit else u.dimensionless_unscaled - if x_unit.is_equivalent(u.m): - spectral_axis_unit_type = "Wavelength" - elif x_unit.is_equivalent(u.Hz): - spectral_axis_unit_type = "Frequency" - elif x_unit.is_equivalent(u.pixel): - spectral_axis_unit_type = "Pixel" - else: - spectral_axis_unit_type = str(x_unit.physical_type).title() - - with self.figure.hold_sync(): - self.figure.axes[0].label = f"{spectral_axis_unit_type} [{self.state.x_display_unit}]" - self.figure.axes[1].label = f"{flux_unit_type} [{self.state.y_display_unit}]" - - # Make it so axis labels are not covering tick numbers. - self.figure.fig_margin["left"] = 95 - self.figure.fig_margin["bottom"] = 60 - self.figure.send_state('fig_margin') # Force update - self.figure.axes[0].label_offset = "40" - self.figure.axes[1].label_offset = "-70" - # NOTE: with tick_style changed below, the default responsive ticks in bqplot result - # in overlapping tick labels. For now we'll hardcode at 8, but this could be removed - # (default to None) if/when bqplot auto ticks react to styling options. - self.figure.axes[1].num_ticks = 8 - - # Set Y-axis to scientific notation - self.figure.axes[1].tick_format = '0.1e' - - for i in (0, 1): - self.figure.axes[i].tick_style = {'font-size': 15, 'font-weight': 600} diff --git a/jdaviz/core/helpers.py b/jdaviz/core/helpers.py index cebc6762cc..21584fcca0 100644 --- a/jdaviz/core/helpers.py +++ b/jdaviz/core/helpers.py @@ -32,7 +32,7 @@ from jdaviz.utils import data_has_valid_wcs, flux_conversion, spectral_axis_conversion -__all__ = ['ConfigHelper', 'ImageConfigHelper'] +__all__ = ['ConfigHelper', 'ImageConfigHelper', 'CubeConfigHelper'] class ConfigHelper(HubListener): diff --git a/jdaviz/utils.py b/jdaviz/utils.py index e381aea407..14d2fc0fc7 100644 --- a/jdaviz/utils.py +++ b/jdaviz/utils.py @@ -19,6 +19,7 @@ from glue.core import BaseData from glue.core.exceptions import IncompatibleAttribute from glue.core.subset import SubsetState, RangeSubsetState, RoiSubsetState +from glue_astronomy.spectral_coordinates import SpectralCoordinates from ipyvue import watch __all__ = ['SnackbarQueue', 'enable_hot_reloading', 'bqplot_clear_figure', @@ -422,7 +423,7 @@ def get_subset_type(subset): Returns ------- subset_type : str or None - 'spatial', 'spectral', or None + 'spatial', 'spectral', 'temporal', or None """ if not hasattr(subset, 'subset_state'): return None @@ -435,7 +436,40 @@ def get_subset_type(subset): if isinstance(subset.subset_state, RoiSubsetState): return 'spatial' elif isinstance(subset.subset_state, RangeSubsetState): - return 'spectral' + # look within a SubsetGroup, or a single Subset + subset_list = getattr(subset, 'subsets', [subset]) + + for ss in subset_list: + if hasattr(ss, 'data'): + ss_data = ss.data + elif hasattr(ss.att, 'parent'): + # if `ss` is a subset state, it won't have a `data` attr, + # check the world coordinate's parent data: + ss_data = ss.att.parent + else: + # if we reach this `else`, continue searching + # through other subsets in the group to identify the + # subset type: + continue + + # check for a spectral coordinate in FITS WCS: + wcs_coords = ( + ss_data.coords.wcs.ctype if hasattr(ss_data.coords, 'wcs') + else [] + ) + + has_spectral_coords = ( + any(str(coord).startswith('WAVE') for coord in wcs_coords) or + + # also check for a spectral coordinate from the glue_astronomy translator: + isinstance(ss_data.coords, SpectralCoordinates) + ) + + if has_spectral_coords: + return 'spectral' + + # otherwise, assume temporal: + return 'temporal' else: return None