From 7352138545543604256cac603814f965ac59037f Mon Sep 17 00:00:00 2001 From: "Brett M. Morris" Date: Fri, 4 Aug 2023 10:12:09 -0400 Subject: [PATCH] Roman L1 ramp cube parser for cubeviz --- jdaviz/app.py | 9 +- jdaviz/configs/cubeviz/plugins/parsers.py | 111 ++++++++++++++++++ jdaviz/configs/cubeviz/plugins/slice/slice.py | 18 ++- .../configs/cubeviz/plugins/slice/slice.vue | 2 +- .../plugins/model_fitting/model_fitting.py | 16 ++- jdaviz/configs/specviz/helper.py | 5 +- .../plugins/line_analysis/line_analysis.py | 4 +- jdaviz/configs/specviz/plugins/viewers.py | 46 ++++++-- jdaviz/core/template_mixin.py | 6 +- jdaviz/core/user_api.py | 3 +- 10 files changed, 197 insertions(+), 23 deletions(-) diff --git a/jdaviz/app.py b/jdaviz/app.py index 83e635f93d..f530efacdd 100644 --- a/jdaviz/app.py +++ b/jdaviz/app.py @@ -99,7 +99,11 @@ def to_unit(self, data, cid, values, original_units, target_units): # should return the converted values. Note that original_units # gives the units of the values array, which might not be the same # as the original native units of the component in the data. - if cid.label == "flux": + if cid.label == 'Pixel Axis 0 [z]' and target_units == '': + # handle ramps loaded into Cubeviz by avoiding conversion + # of the groups axis: + return values + elif cid.label == "flux": spec = data.get_object(cls=Spectrum1D) if len(values) == 2: # Need this for setting the y-limits @@ -1219,6 +1223,9 @@ def _get_display_unit(self, axis): elif axis == 'flux': sv = self.get_viewer(self._jdaviz_helper._default_spectrum_viewer_reference_name) return sv.data()[0].flux.unit + elif axis == 'data': + sv = self.get_viewer(self._jdaviz_helper._default_spectrum_viewer_reference_name) + return sv.data()[0].unit else: raise ValueError(f"could not find units for axis='{axis}'") try: diff --git a/jdaviz/configs/cubeviz/plugins/parsers.py b/jdaviz/configs/cubeviz/plugins/parsers.py index 54f8463074..25688269f5 100644 --- a/jdaviz/configs/cubeviz/plugins/parsers.py +++ b/jdaviz/configs/cubeviz/plugins/parsers.py @@ -13,12 +13,20 @@ from jdaviz.core.registries import data_parser_registry from jdaviz.utils import standardize_metadata, PRIHDR_KEY +try: + from roman_datamodels import datamodels as rdd +except ImportError: + HAS_ROMAN_DATAMODELS = False +else: + HAS_ROMAN_DATAMODELS = True + __all__ = ['parse_data'] EXT_TYPES = dict(flux=['flux', 'sci', 'data'], uncert=['ivar', 'err', 'var', 'uncert'], mask=['mask', 'dq', 'quality']) +cubeviz_ramp_meta_flag = '_roman_ramp' @data_parser_registry("cubeviz-data-parser") def parse_data(app, file_obj, data_type=None, data_label=None): @@ -64,6 +72,19 @@ def parse_data(app, file_obj, data_type=None, data_label=None): flux_viewer_reference_name=flux_viewer_reference_name, spectrum_viewer_reference_name=spectrum_viewer_reference_name) return + elif file_obj.lower().endswith('.asdf'): + if not HAS_ROMAN_DATAMODELS: + raise ImportError( + "ASDF detected but roman-datamodels is not installed." + ) + with rdd.open(file_obj) as pf: + _roman_3d_to_glue_data( + app, pf, data_label, + flux_viewer_reference_name=flux_viewer_reference_name, + spectrum_viewer_reference_name=spectrum_viewer_reference_name, + uncert_viewer_reference_name=uncert_viewer_reference_name + ) + return file_name = os.path.basename(file_obj) @@ -126,6 +147,14 @@ def parse_data(app, file_obj, data_type=None, data_label=None): flux_viewer_reference_name=flux_viewer_reference_name, spectrum_viewer_reference_name=spectrum_viewer_reference_name, uncert_viewer_reference_name=uncert_viewer_reference_name) + elif HAS_ROMAN_DATAMODELS and isinstance(file_obj, rdd.DataModel): + with rdd.open(file_obj) as pf: + _roman_3d_to_glue_data( + app, pf, data_label, + flux_viewer_reference_name=flux_viewer_reference_name, + spectrum_viewer_reference_name=spectrum_viewer_reference_name, + uncert_viewer_reference_name=uncert_viewer_reference_name + ) else: raise NotImplementedError(f'Unsupported data format: {file_obj}') @@ -447,3 +476,85 @@ def _get_data_type_by_hdu(hdu): else: data_type = '' return data_type + + +def _roman_3d_to_glue_data( + app, file_obj, data_label, + flux_viewer_reference_name=None, + spectrum_viewer_reference_name=None, + uncert_viewer_reference_name=None, +): + """ + Parse a Roman 3D ramp cube file (Level 1), + usually with suffix '_uncal.asdf'. + """ + def _swap_axes(x): + # swap axes per the conventions of Roman cubes + # (group axis comes first) and the default in + # Cubeviz (wavelength axis expected last) + return np.swapaxes(x, 0, -1) + + # update viewer reference names for Roman ramp cubes: + # app._update_viewer_reference_name() + + data = file_obj.data + + if data_label is None: + data_label = app.return_data_label(file_obj) + + # last axis is the group axis, first two are spatial axes: + diff_data = np.vstack([ + # begin with a group of zeros, so + # that `diff_data.ndim == data.ndim` + np.zeros((1, *data[0].shape)), + np.diff(data, axis=0) + ]) + + # load the `data` cube into what's usually the "flux-viewer" + _parse_ndarray( + app, + file_obj=_swap_axes(data), + data_label=f"{data_label}[DATA]", + data_type="flux", + flux_viewer_reference_name=flux_viewer_reference_name, + spectrum_viewer_reference_name=spectrum_viewer_reference_name + ) + + # load the diff of the data cube + # into what's usually the "uncert-viewer" + _parse_ndarray( + app, + file_obj=_swap_axes(diff_data), + data_type="uncert", + data_label=f"{data_label}[DIFF]", + uncert_viewer_reference_name=uncert_viewer_reference_name + ) + + # If the default Cubeviz viewers are still their defaults, rename them to + # names that are appropriate for the Roman ramp files that we just parsed: + if 'flux-viewer' in app.get_viewer_reference_names(): + app._update_viewer_reference_name('flux-viewer', 'group-viewer') + app._update_viewer_reference_name('uncert-viewer', 'group-diff-viewer') + app._update_viewer_reference_name('spectrum-viewer', 'integration-viewer') + + # the default collapse function in the profile viewer is "sum", + # but for ramp files, "median" is more useful: + viewer = app.get_viewer('integration-viewer') + viewer.state.function = 'median' + + # some Cubeviz plugins aren't relevant for ramps, so remove them: + remove_tray_items = [ + 'g-line-list', + 'specviz-line-analysis', + 'cubeviz-moment-maps', + 'g-gaussian-smooth' + ] + + for item_name in remove_tray_items: + item_names = [ + tray_item['name'] for tray_item in app.state.tray_items + ] + + app.state.tray_items.pop( + item_names.index(item_name) + ) diff --git a/jdaviz/configs/cubeviz/plugins/slice/slice.py b/jdaviz/configs/cubeviz/plugins/slice/slice.py index fd98a42d1b..964a3bfa63 100644 --- a/jdaviz/configs/cubeviz/plugins/slice/slice.py +++ b/jdaviz/configs/cubeviz/plugins/slice/slice.py @@ -15,6 +15,7 @@ from jdaviz.core.registries import tray_registry from jdaviz.core.template_mixin import PluginTemplateMixin from jdaviz.core.user_api import PluginUserApi +from jdaviz.configs.cubeviz.plugins.parsers import cubeviz_ramp_meta_flag __all__ = ['Slice'] @@ -111,9 +112,15 @@ def _watch_viewer(self, viewer, watch=True): self._viewer_slices_changed) self._watched_viewers.remove(viewer) elif isinstance(viewer, BqplotProfileView) and watch: + viewer_data = viewer.data() if self._x_all is None and len(viewer.data()): - # cache wavelengths so that wavelength <> slice conversion can be done efficiently - self._update_data(viewer.data()[0].spectral_axis) + if hasattr(viewer_data, 'spectral_axis'): + # cache wavelengths so that wavelength <> slice + # conversion can be done efficiently + self._update_data(viewer_data[0].spectral_axis) + else: + sample_index = np.arange(1, 1 + viewer_data[0].shape[-1]) * u.one + self._update_data(sample_index) if viewer not in self._indicator_viewers: self._indicator_viewers.append(viewer) @@ -134,7 +141,12 @@ def _on_data_added(self, msg): def _update_reference_data(self, reference_data): if reference_data is None: return # pragma: no cover - self._update_data(reference_data.get_object(cls=Spectrum1D).spectral_axis) + + if reference_data.meta.get(cubeviz_ramp_meta_flag, False): + sample_index = np.arange(1, 1 + reference_data['data'].shape[-1]) * u.one + self._update_data(sample_index) + else: + self._update_data(reference_data.get_object(cls=Spectrum1D).spectral_axis) def _update_data(self, x_all): self._x_all = x_all.value diff --git a/jdaviz/configs/cubeviz/plugins/slice/slice.vue b/jdaviz/configs/cubeviz/plugins/slice/slice.vue index e8bf78d4d9..3138c9272e 100644 --- a/jdaviz/configs/cubeviz/plugins/slice/slice.vue +++ b/jdaviz/configs/cubeviz/plugins/slice/slice.vue @@ -58,7 +58,7 @@ - +