Skip to content

Commit

Permalink
Roman L1 ramp cube parser for cubeviz
Browse files Browse the repository at this point in the history
  • Loading branch information
bmorris3 committed Aug 22, 2023
1 parent d067429 commit 7352138
Show file tree
Hide file tree
Showing 10 changed files with 197 additions and 23 deletions.
9 changes: 8 additions & 1 deletion jdaviz/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def to_unit(self, data, cid, values, original_units, target_units):
# should return the converted values. Note that original_units
# gives the units of the values array, which might not be the same
# as the original native units of the component in the data.
if cid.label == "flux":
if cid.label == 'Pixel Axis 0 [z]' and target_units == '':
# handle ramps loaded into Cubeviz by avoiding conversion
# of the groups axis:
return values
elif cid.label == "flux":
spec = data.get_object(cls=Spectrum1D)
if len(values) == 2:
# Need this for setting the y-limits
Expand Down Expand Up @@ -1219,6 +1223,9 @@ def _get_display_unit(self, axis):
elif axis == 'flux':
sv = self.get_viewer(self._jdaviz_helper._default_spectrum_viewer_reference_name)
return sv.data()[0].flux.unit
elif axis == 'data':
sv = self.get_viewer(self._jdaviz_helper._default_spectrum_viewer_reference_name)
return sv.data()[0].unit
else:
raise ValueError(f"could not find units for axis='{axis}'")
try:
Expand Down
111 changes: 111 additions & 0 deletions jdaviz/configs/cubeviz/plugins/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@
from jdaviz.core.registries import data_parser_registry
from jdaviz.utils import standardize_metadata, PRIHDR_KEY

try:
from roman_datamodels import datamodels as rdd
except ImportError:
HAS_ROMAN_DATAMODELS = False
else:
HAS_ROMAN_DATAMODELS = True

__all__ = ['parse_data']

EXT_TYPES = dict(flux=['flux', 'sci', 'data'],
uncert=['ivar', 'err', 'var', 'uncert'],
mask=['mask', 'dq', 'quality'])

cubeviz_ramp_meta_flag = '_roman_ramp'

@data_parser_registry("cubeviz-data-parser")
def parse_data(app, file_obj, data_type=None, data_label=None):
Expand Down Expand Up @@ -64,6 +72,19 @@ def parse_data(app, file_obj, data_type=None, data_label=None):
flux_viewer_reference_name=flux_viewer_reference_name,
spectrum_viewer_reference_name=spectrum_viewer_reference_name)
return
elif file_obj.lower().endswith('.asdf'):
if not HAS_ROMAN_DATAMODELS:
raise ImportError(
"ASDF detected but roman-datamodels is not installed."
)
with rdd.open(file_obj) as pf:
_roman_3d_to_glue_data(
app, pf, data_label,
flux_viewer_reference_name=flux_viewer_reference_name,
spectrum_viewer_reference_name=spectrum_viewer_reference_name,
uncert_viewer_reference_name=uncert_viewer_reference_name
)
return

file_name = os.path.basename(file_obj)

Expand Down Expand Up @@ -126,6 +147,14 @@ def parse_data(app, file_obj, data_type=None, data_label=None):
flux_viewer_reference_name=flux_viewer_reference_name,
spectrum_viewer_reference_name=spectrum_viewer_reference_name,
uncert_viewer_reference_name=uncert_viewer_reference_name)
elif HAS_ROMAN_DATAMODELS and isinstance(file_obj, rdd.DataModel):
with rdd.open(file_obj) as pf:
_roman_3d_to_glue_data(
app, pf, data_label,
flux_viewer_reference_name=flux_viewer_reference_name,
spectrum_viewer_reference_name=spectrum_viewer_reference_name,
uncert_viewer_reference_name=uncert_viewer_reference_name
)
else:
raise NotImplementedError(f'Unsupported data format: {file_obj}')

Expand Down Expand Up @@ -447,3 +476,85 @@ def _get_data_type_by_hdu(hdu):
else:
data_type = ''
return data_type


def _roman_3d_to_glue_data(
app, file_obj, data_label,
flux_viewer_reference_name=None,
spectrum_viewer_reference_name=None,
uncert_viewer_reference_name=None,
):
"""
Parse a Roman 3D ramp cube file (Level 1),
usually with suffix '_uncal.asdf'.
"""
def _swap_axes(x):
# swap axes per the conventions of Roman cubes
# (group axis comes first) and the default in
# Cubeviz (wavelength axis expected last)
return np.swapaxes(x, 0, -1)

# update viewer reference names for Roman ramp cubes:
# app._update_viewer_reference_name()

data = file_obj.data

if data_label is None:
data_label = app.return_data_label(file_obj)

# last axis is the group axis, first two are spatial axes:
diff_data = np.vstack([
# begin with a group of zeros, so
# that `diff_data.ndim == data.ndim`
np.zeros((1, *data[0].shape)),
np.diff(data, axis=0)
])

# load the `data` cube into what's usually the "flux-viewer"
_parse_ndarray(
app,
file_obj=_swap_axes(data),
data_label=f"{data_label}[DATA]",
data_type="flux",
flux_viewer_reference_name=flux_viewer_reference_name,
spectrum_viewer_reference_name=spectrum_viewer_reference_name
)

# load the diff of the data cube
# into what's usually the "uncert-viewer"
_parse_ndarray(
app,
file_obj=_swap_axes(diff_data),
data_type="uncert",
data_label=f"{data_label}[DIFF]",
uncert_viewer_reference_name=uncert_viewer_reference_name
)

# If the default Cubeviz viewers are still their defaults, rename them to
# names that are appropriate for the Roman ramp files that we just parsed:
if 'flux-viewer' in app.get_viewer_reference_names():
app._update_viewer_reference_name('flux-viewer', 'group-viewer')
app._update_viewer_reference_name('uncert-viewer', 'group-diff-viewer')
app._update_viewer_reference_name('spectrum-viewer', 'integration-viewer')

# the default collapse function in the profile viewer is "sum",
# but for ramp files, "median" is more useful:
viewer = app.get_viewer('integration-viewer')
viewer.state.function = 'median'

# some Cubeviz plugins aren't relevant for ramps, so remove them:
remove_tray_items = [
'g-line-list',
'specviz-line-analysis',
'cubeviz-moment-maps',
'g-gaussian-smooth'
]

for item_name in remove_tray_items:
item_names = [
tray_item['name'] for tray_item in app.state.tray_items
]

app.state.tray_items.pop(
item_names.index(item_name)
)
18 changes: 15 additions & 3 deletions jdaviz/configs/cubeviz/plugins/slice/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from jdaviz.core.registries import tray_registry
from jdaviz.core.template_mixin import PluginTemplateMixin
from jdaviz.core.user_api import PluginUserApi
from jdaviz.configs.cubeviz.plugins.parsers import cubeviz_ramp_meta_flag

__all__ = ['Slice']

Expand Down Expand Up @@ -111,9 +112,15 @@ def _watch_viewer(self, viewer, watch=True):
self._viewer_slices_changed)
self._watched_viewers.remove(viewer)
elif isinstance(viewer, BqplotProfileView) and watch:
viewer_data = viewer.data()
if self._x_all is None and len(viewer.data()):
# cache wavelengths so that wavelength <> slice conversion can be done efficiently
self._update_data(viewer.data()[0].spectral_axis)
if hasattr(viewer_data, 'spectral_axis'):
# cache wavelengths so that wavelength <> slice
# conversion can be done efficiently
self._update_data(viewer_data[0].spectral_axis)
else:
sample_index = np.arange(1, 1 + viewer_data[0].shape[-1]) * u.one
self._update_data(sample_index)

if viewer not in self._indicator_viewers:
self._indicator_viewers.append(viewer)
Expand All @@ -134,7 +141,12 @@ def _on_data_added(self, msg):
def _update_reference_data(self, reference_data):
if reference_data is None:
return # pragma: no cover
self._update_data(reference_data.get_object(cls=Spectrum1D).spectral_axis)

if reference_data.meta.get(cubeviz_ramp_meta_flag, False):
sample_index = np.arange(1, 1 + reference_data['data'].shape[-1]) * u.one
self._update_data(sample_index)
else:
self._update_data(reference_data.get_object(cls=Spectrum1D).spectral_axis)

def _update_data(self, x_all):
self._x_all = x_all.value
Expand Down
2 changes: 1 addition & 1 deletion jdaviz/configs/cubeviz/plugins/slice/slice.vue
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
</v-col>
</v-row>

<v-row class="row-no-outside-padding">
<v-row class="row-no-outside-padding" v-if="wavelength_unit !== 'pix'">
<v-col>
<v-text-field
v-model="wavelength"
Expand Down
16 changes: 15 additions & 1 deletion jdaviz/configs/default/plugins/model_fitting/model_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from traitlets import Bool, List, Unicode, observe
from glue.core.data import Data

from jdaviz.core.events import SnackbarMessage, GlobalDisplayUnitChanged
from jdaviz.core.events import (
SnackbarMessage, GlobalDisplayUnitChanged, ViewerRenamedMessage
)
from jdaviz.core.registries import tray_registry
from jdaviz.core.template_mixin import (PluginTemplateMixin,
SelectPluginComponent,
Expand All @@ -25,6 +27,7 @@
from jdaviz.configs.default.plugins.model_fitting.initializers import (MODELS,
initialize,
get_model_parameters)
from jdaviz.configs.cubeviz.plugins.parsers import cubeviz_ramp_meta_flag

__all__ = ['ModelFitting']

Expand Down Expand Up @@ -167,6 +170,9 @@ def __init__(self, *args, **kwargs):
self.hub.subscribe(self, GlobalDisplayUnitChanged,
handler=self._on_global_display_unit_changed)

self.hub.subscribe(self, ViewerRenamedMessage,
handler=self._on_viewer_renamed)

@property
def _default_spectrum_viewer_reference_name(self):
return getattr(
Expand Down Expand Up @@ -333,6 +339,10 @@ def _dataset_selected_changed(self, event=None):
# during initial init, this can trigger before the component is initialized
return

if self.dataset.selected_obj:
if self.dataset.selected_obj.meta.get(cubeviz_ramp_meta_flag, False):
return

selected_spec = self._get_1d_spectrum()
if selected_spec is None:
return
Expand Down Expand Up @@ -999,3 +1009,7 @@ def _apply_subset_masks(self, spectrum, subset_component):

spectrum.mask = subset_mask
return spectrum

def _on_viewer_renamed(self, msg):
if msg.old_viewer_ref in self.dataset._viewers:
self.dataset._viewers = [self._default_spectrum_viewer_reference_name]
5 changes: 4 additions & 1 deletion jdaviz/configs/specviz/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,10 @@ def get_data(self, data_label=None, spectral_subset=None, cls=None,
if self.app.config == 'cubeviz':
# then this is a specviz instance inside cubeviz and we want to default to the
# viewer's collapse function
default_sp_viewer = self.app.get_viewer(self._default_spectrum_viewer_reference_name)
default_sp_viewer = self.app.get_viewer(
# use the viewer reference name of the Cubeviz spectrum viewer:
self.app._jdaviz_helper._default_spectrum_viewer_reference_name
)
if function is True or function is None:
function = getattr(default_sp_viewer.state, 'function', None)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def _on_viewer_data_changed(self, msg):
self.disabled_msg = 'Line Analysis unavailable without spectral data'
return

if viewer_data.spectral_axis.unit == u.pix:
if not hasattr(viewer_data, 'spectral_axis'):
self.disabled_msg = 'Line Analysis plugin unavailable when viewing ramps'
elif viewer_data.spectral_axis.unit == u.pix:
# disable the plugin until we can address this properly (either using the wavelength
# solution to support pixels in line-lists, or properly displaying the extracted
# 1d spectrum in wavelength-space)
Expand Down
46 changes: 34 additions & 12 deletions jdaviz/configs/specviz/plugins/viewers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from jdaviz.core.linelists import load_preset_linelist, get_available_linelists
from jdaviz.core.freezable_state import FreezableProfileViewerState
from jdaviz.configs.default.plugins.viewers import JdavizViewerMixin
from jdaviz.configs.cubeviz.plugins.parsers import cubeviz_ramp_meta_flag
from jdaviz.utils import get_subset_type

__all__ = ['SpecvizProfileView']
Expand Down Expand Up @@ -373,8 +374,16 @@ def add_data(self, data, color=None, alpha=None, **layer_state):
result = super().add_data(data, color, alpha, **layer_state)

if reset_plot_axes:
x_units = data.get_component(self.state.x_att.label).units
y_units = data.get_component("flux").units
component_labels = [comp.label for comp in data.components]
if 'flux' in component_labels:
# interpret this as a spectrum:
x_units = data.get_component(self.state.x_att.label).units
y_units = data.get_component('flux').units
elif 'data' in component_labels:
# interpret this as a ramp:
x_units = ''
y_units = data.get_component('data').units

self.state.x_display_unit = x_units if len(x_units) else None
self.state.y_display_unit = y_units if len(y_units) else None
self.set_plot_axes()
Expand Down Expand Up @@ -529,19 +538,32 @@ def _plot_uncertainties(self):

def set_plot_axes(self):
# Set axes labels for the spectrum viewer
flux_unit_type = "Flux density"

x_disp_unit = self.state.x_display_unit
x_unit = u.Unit(x_disp_unit) if x_disp_unit else u.dimensionless_unscaled
if x_unit.is_equivalent(u.m):
spectral_axis_unit_type = "Wavelength"
elif x_unit.is_equivalent(u.Hz):
spectral_axis_unit_type = "Frequency"
elif x_unit.is_equivalent(u.pixel):
spectral_axis_unit_type = "Pixel"

is_ramp = (
self.state.reference_data and
self.state.reference_data.meta.get(cubeviz_ramp_meta_flag, False)
)

if not is_ramp:
flux_unit_type = "Flux density"
x_unit = u.Unit(x_disp_unit) if x_disp_unit else u.dimensionless_unscaled
if x_unit.is_equivalent(u.m):
spectral_axis_unit_type = "Wavelength"
elif x_unit.is_equivalent(u.Hz):
spectral_axis_unit_type = "Frequency"
elif x_unit.is_equivalent(u.pixel):
spectral_axis_unit_type = "Pixel"
else:
spectral_axis_unit_type = str(x_unit.physical_type).title()
x_disp_unit = f'[{self.state.x_display_unit}]'
else:
spectral_axis_unit_type = str(x_unit.physical_type).title()
flux_unit_type = "Counts"
spectral_axis_unit_type = "Sample"
x_disp_unit = ''

self.figure.axes[0].label = f"{spectral_axis_unit_type} [{self.state.x_display_unit}]"
self.figure.axes[0].label = f"{spectral_axis_unit_type} {x_disp_unit}"
self.figure.axes[1].label = f"{flux_unit_type} [{self.state.y_display_unit}]"

# Make it so y axis label is not covering tick numbers.
Expand Down
6 changes: 4 additions & 2 deletions jdaviz/core/template_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1812,12 +1812,14 @@ def selected_obj(self):

def selected_spectrum_for_spatial_subset(self,
spatial_subset=SPATIAL_DEFAULT_TEXT,
use_display_units=True):
use_display_units=True,
cls=None):
if spatial_subset == SPATIAL_DEFAULT_TEXT:
spatial_subset = None
return self.plugin._specviz_helper.get_data(data_label=self.selected,
spatial_subset=spatial_subset,
use_display_units=use_display_units)
use_display_units=use_display_units,
cls=cls)

def _is_valid_item(self, data):
def not_from_plugin(data):
Expand Down
3 changes: 2 additions & 1 deletion jdaviz/core/user_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def __setattr__(self, attr, value):
return
elif isinstance(exp_obj, PlotOptionsSyncState):
if not len(exp_obj.linked_states):
raise ValueError("there are currently no synced glue states to set")
raise ValueError("There are currently no synced glue states to set. "
"Check the selected viewer and/or layer.")

# this allows setting the value immediately, and unmixing state, if appropriate,
# even if the value matches the current value
Expand Down

0 comments on commit 7352138

Please sign in to comment.