Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JWST L1 ramp parser for Rampviz #3148

Merged
merged 3 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ New Features

- The standalone version of jdaviz now uses solara instead of voila, resulting in faster load times. [#2909]

- New configuration for ramp/Level 1 data products from Roman WFI and JWST [#3120, #3148]

Cubeviz
^^^^^^^

Expand Down
156 changes: 82 additions & 74 deletions jdaviz/configs/rampviz/plugins/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import astropy.units as u
from astropy.io import fits
from astropy.nddata import NDData, NDDataArray
from astropy.time import Time
from stdatamodels.jwst.datamodels import Level1bModel

from jdaviz.core.registries import data_parser_registry
from jdaviz.configs.cubeviz.plugins.parsers import _get_data_type_by_hdu
from jdaviz.utils import (
standardize_metadata, download_uri_to_path,
PRIHDR_KEY, standardize_roman_metadata
Expand All @@ -25,7 +24,8 @@

@data_parser_registry("ramp-data-parser")
def parse_data(app, file_obj, data_type=None, data_label=None,
parent=None, cache=None, local_path=None, timeout=None):
parent=None, cache=None, local_path=None, timeout=None,
integration=0):
"""
Attempts to parse a data file and auto-populate available viewers in
rampviz.
Expand Down Expand Up @@ -53,6 +53,10 @@
remote requests in seconds (passed to
`~astropy.utils.data.download_file` or
`~astroquery.mast.Conf.timeout`).
integration : int, optional
JWST Level 1b products bundle multiple integrations in a time-series into the
same ramp file. If this keyword is specified and the observations
are JWST Level 1b products, this integration in the time series will be selected.
"""

group_viewer_reference_name = app._jdaviz_helper._default_group_viewer_reference_name
Expand Down Expand Up @@ -101,6 +105,7 @@
with fits.open(file_obj) as hdulist:
_parse_hdulist(
app, hdulist, file_name=data_label or file_name,
integration=integration,
group_viewer_reference_name=group_viewer_reference_name,
diff_viewer_reference_name=diff_viewer_reference_name,
)
Expand All @@ -121,6 +126,20 @@
meta=getattr(file_obj, 'meta')
)

elif isinstance(file_obj, Level1bModel):
metadata = standardize_metadata({
key: value for key, value in file_obj.to_flat_dict().items()
if key.startswith('meta')
})

_parse_ramp_cube(
app, file_obj.data[integration], u.DN,
data_label or file_obj.__class__.__name__,
group_viewer_reference_name,
diff_viewer_reference_name,
meta=metadata
)

elif HAS_ROMAN_DATAMODELS and isinstance(file_obj, rdd.DataModel):
with rdd.open(file_obj) as pf:
_roman_3d_to_glue_data(
Expand All @@ -133,6 +152,13 @@
raise NotImplementedError(f'Unsupported data format: {file_obj}')


def _swap_axes(x):
# swap axes per the conventions of ramp cubes
# (group axis comes first) and the default in
# rampviz (group axis expected last)
return np.swapaxes(x, 0, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you ever plan to support beyond Roman vs JWST? Maybe instead of swapping, should hardcode the index by telescope name?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not in the plans for now. If they become part of the plan, we'd need to build a parser for those "non-standard" cubes anyway, so I'll keep it like this for now.



def _roman_3d_to_glue_data(
app, file_obj, data_label,
group_viewer_reference_name=None,
Expand All @@ -143,12 +169,6 @@
Parse a Roman 3D ramp cube file (Level 1),
usually with suffix '_uncal.asdf'.
"""
def _swap_axes(x):
# swap axes per the conventions of Roman cubes
# (group axis comes first) and the default in
# Cubeviz (wavelength axis expected last)
return np.swapaxes(x, 0, -1)

# update viewer reference names for Roman ramp cubes:
# app._update_viewer_reference_name()

Expand Down Expand Up @@ -199,88 +219,76 @@

def _parse_hdulist(
app, hdulist, file_name=None,
viewer_reference_name=None
integration=None,
group_viewer_reference_name=None,
diff_viewer_reference_name=None,
):
if file_name is None and hasattr(hdulist, 'file_name'):
file_name = hdulist.file_name
else:
file_name = file_name or "Unknown HDU object"

is_loaded = []

# TODO: This needs refactoring to be more robust.
# Current logic fails if there are multiple EXTVER.
for hdu in hdulist:
if hdu.data is None or not hdu.is_image or hdu.data.ndim != 3:
continue

data_type = _get_data_type_by_hdu(hdu)
if not data_type:
continue

# Only load each type once.
if data_type in is_loaded:
continue

is_loaded.append(data_type)
data_label = app.return_data_label(file_name, hdu.name)
hdu = hdulist[1] # extension containing the ramp
if hdu.header['NAXIS'] != 4:
raise ValueError(f"Expected a ramp with NAXIS=4 (with axes:"

Check warning on line 233 in jdaviz/configs/rampviz/plugins/parsers.py

View check run for this annotation

Codecov / codecov/patch

jdaviz/configs/rampviz/plugins/parsers.py#L231-L233

Added lines #L231 - L233 were not covered by tests
f"integrations, groups, x, y), but got "
f"NAXIS={hdu.header['NAXIS']}.")

if 'BUNIT' in hdu.header:
try:
flux_unit = u.Unit(hdu.header['BUNIT'])
except Exception:
logging.warning("Invalid BUNIT, using DN as data unit")
flux_unit = u.DN
else:
if 'BUNIT' in hdu.header:
try:
flux_unit = u.Unit(hdu.header['BUNIT'])
except Exception:

Check warning on line 240 in jdaviz/configs/rampviz/plugins/parsers.py

View check run for this annotation

Codecov / codecov/patch

jdaviz/configs/rampviz/plugins/parsers.py#L237-L240

Added lines #L237 - L240 were not covered by tests
logging.warning("Invalid BUNIT, using DN as data unit")
flux_unit = u.DN
else:
logging.warning("Invalid BUNIT, using DN as data unit")
flux_unit = u.DN

Check warning on line 245 in jdaviz/configs/rampviz/plugins/parsers.py

View check run for this annotation

Codecov / codecov/patch

jdaviz/configs/rampviz/plugins/parsers.py#L244-L245

Added lines #L244 - L245 were not covered by tests

flux = hdu.data << flux_unit
metadata = standardize_metadata(hdu.header)
if hdu.name != 'PRIMARY' and 'PRIMARY' in hdulist:
metadata[PRIHDR_KEY] = standardize_metadata(hdulist['PRIMARY'].header)
# index the ramp array by the integration to load. returns all groups and pixels.
# cast from uint16 to integers:
ramp_cube = hdu.data[integration].astype(int)

Check warning on line 249 in jdaviz/configs/rampviz/plugins/parsers.py

View check run for this annotation

Codecov / codecov/patch

jdaviz/configs/rampviz/plugins/parsers.py#L249

Added line #L249 was not covered by tests

app.add_data(flux, data_label)
app.data_collection[data_label].get_component("data").units = flux_unit
app.add_data_to_viewer(viewer_reference_name, data_label)
app._jdaviz_helper._loaded_flux_cube = app.data_collection[data_label]
metadata = standardize_metadata(hdu.header)
if hdu.name != 'PRIMARY' and 'PRIMARY' in hdulist:
metadata[PRIHDR_KEY] = standardize_metadata(hdulist['PRIMARY'].header)

Check warning on line 253 in jdaviz/configs/rampviz/plugins/parsers.py

View check run for this annotation

Codecov / codecov/patch

jdaviz/configs/rampviz/plugins/parsers.py#L251-L253

Added lines #L251 - L253 were not covered by tests

_parse_ramp_cube(

Check warning on line 255 in jdaviz/configs/rampviz/plugins/parsers.py

View check run for this annotation

Codecov / codecov/patch

jdaviz/configs/rampviz/plugins/parsers.py#L255

Added line #L255 was not covered by tests
app, ramp_cube, flux_unit, file_name,
group_viewer_reference_name,
diff_viewer_reference_name,
meta=metadata
)


def _parse_jwst_level1(
app, hdulist, data_label, ext='SCI',
viewer_name=None,
):
hdu = hdulist[ext]
data_type = _get_data_type_by_hdu(hdu)

# Manually inject MJD-OBS until we can support GWCS, see
# https://github.com/spacetelescope/jdaviz/issues/690 and
# https://github.com/glue-viz/glue-astronomy/issues/59
if ext == 'SCI' and 'MJD-OBS' not in hdu.header:
for key in ('MJD-BEG', 'DATE-OBS'): # Possible alternatives
if key in hdu.header:
if key.startswith('MJD'):
hdu.header['MJD-OBS'] = hdu.header[key]
break
else:
t = Time(hdu.header[key])
hdu.header['MJD-OBS'] = t.mjd
break

unit = u.Unit(hdu.header.get('BUNIT', 'count'))
flux = hdu.data << unit
def _parse_ramp_cube(app, ramp_cube_data, flux_unit, file_name,
group_viewer_reference_name, diff_viewer_reference_name,
meta=None):
# last axis is the group axis, first two are spatial axes:
diff_data = np.vstack([
# begin with a group of zeros, so
# that `diff_data.ndim == data.ndim`
np.zeros((1, *ramp_cube_data[0].shape)),
np.diff(ramp_cube_data, axis=0)
])

metadata = standardize_metadata(hdu.header)
app.data_collection[data_label] = NDData(data=flux, meta=metadata)
ramp_cube = NDDataArray(_swap_axes(ramp_cube_data), unit=flux_unit, meta=meta)
diff_cube = NDDataArray(_swap_axes(diff_data), unit=flux_unit, meta=meta)

group_data_label = app.return_data_label(file_name, ext="DATA")
diff_data_label = app.return_data_label(file_name, ext="DIFF")

if data_type == 'flux':
app.data_collection[-1].get_component("data").units = flux.unit
for data_entry, data_label, viewer_ref in zip(
(ramp_cube, diff_cube),
(group_data_label, diff_data_label),
(group_viewer_reference_name, diff_viewer_reference_name)
):
app.add_data(data_entry, data_label)
app.add_data_to_viewer(viewer_ref, data_label)

if viewer_name is not None:
app.add_data_to_viewer(viewer_name, data_label)
# load these cubes into the cache:
app._jdaviz_helper.cube_cache[data_label] = data_entry

if data_type == 'flux':
app._jdaviz_helper._loaded_flux_cube = app.data_collection[data_label]
app._jdaviz_helper._loaded_flux_cube = app.data_collection[group_data_label]


def _parse_ndarray(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ def _on_subset_update(self, msg={}):
return

# glue region has transposed coords relative to cached cube:
region_mask = region.to_mask().to_image(self.cube.shape[:-1]).astype(bool).T
region_mask = region.to_mask().to_image(
self.cube.shape[:-1][::-1]
).astype(bool).T
cube_subset = self.cube[region_mask] # shape: (N pixels extracted, M groups)

n_pixels_in_extraction = cube_subset.shape[0]
Expand Down Expand Up @@ -292,7 +294,7 @@ def _update_aperture_method_on_function_change(self, *args):

@property
def cube(self):
return self.app._jdaviz_helper.cube_cache[self.dataset.selected]
return self.app._jdaviz_helper.cube_cache.get(self.dataset.selected)

@property
def slice_display_unit(self):
Expand Down
18 changes: 17 additions & 1 deletion jdaviz/configs/rampviz/tests/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@


@pytest.mark.skipif(not HAS_ROMAN_DATAMODELS, reason="roman_datamodels is not installed")
def test_load_data(rampviz_helper, roman_level_1_ramp):
def test_load_data_roman(rampviz_helper, roman_level_1_ramp):
rampviz_helper.load_data(roman_level_1_ramp)

# on ramp cube load (1), the parser loads a diff cube (2) and
Expand All @@ -17,3 +17,19 @@ def test_load_data(rampviz_helper, roman_level_1_ramp):

assert viewer.axis_x.label == 'Group'
assert viewer.axis_y.label == 'DN'


def test_load_data_jwst(rampviz_helper, jwst_level_1b_ramp):
rampviz_helper.load_data(jwst_level_1b_ramp)

# on ramp cube load (1), the parser loads a diff cube (2) and
# the ramp extraction plugin produces a default extraction (3):
assert len(rampviz_helper.app.data_collection) == 3

# each viewer should have one loaded data entry:
for refname in 'group-viewer, diff-viewer, integration-viewer'.split(', '):
viewer = rampviz_helper.app.get_viewer(refname)
assert len(viewer.state.layers) == 1

assert viewer.axis_x.label == 'Group'
assert viewer.axis_y.label == 'DN'
22 changes: 15 additions & 7 deletions jdaviz/configs/rampviz/tests/test_ramp_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,30 @@


@pytest.mark.skipif(not HAS_ROMAN_DATAMODELS, reason="roman_datamodels is not installed")
def test_previews(rampviz_helper, roman_level_1_ramp):
rampviz_helper.load_data(roman_level_1_ramp)
def test_previews_roman(rampviz_helper, roman_level_1_ramp):
_ramp_extraction_previews(rampviz_helper, roman_level_1_ramp)


def test_previews_jwst(rampviz_helper, jwst_level_1b_ramp):
_ramp_extraction_previews(rampviz_helper, jwst_level_1b_ramp)


def _ramp_extraction_previews(_rampviz_helper, _ramp_file):
_rampviz_helper.load_data(_ramp_file)

# add subset:
region = CirclePixelRegion(center=PixCoord(12.5, 15.5), radius=2)
rampviz_helper.load_regions(region)
ramp_extr = rampviz_helper.plugins['Ramp Extraction']._obj
_rampviz_helper.load_regions(region)
ramp_extr = _rampviz_helper.plugins['Ramp Extraction']._obj

subsets = rampviz_helper.app.get_subsets()
ramp_cube = rampviz_helper.app.data_collection[0]
subsets = _rampviz_helper.app.get_subsets()
ramp_cube = _rampviz_helper.app.data_collection[0]
n_groups = ramp_cube.shape[-1]

assert len(subsets) == 1
assert 'Subset 1' in subsets

integration_viewer = rampviz_helper.app.get_viewer('integration-viewer')
integration_viewer = _rampviz_helper.app.get_viewer('integration-viewer')

# contains a layer for the default ramp extraction and the subset:
assert len(integration_viewer.layers) == 2
Expand Down
Loading