Skip to content

Commit

Permalink
WIP: basic uncertainty r/w with wcs1d-fits
Browse files Browse the repository at this point in the history
  • Loading branch information
dhomeier committed Mar 3, 2023
1 parent 232af68 commit 5f79132
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 44 deletions.
155 changes: 120 additions & 35 deletions specutils/io/default_loaders/wcs_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from astropy.io import fits
from astropy.wcs import WCS
from astropy.modeling import models
from astropy.nddata import StdDevUncertainty, InverseVariance, VarianceUncertainty
from astropy.utils.exceptions import AstropyUserWarning

import numpy as np
Expand All @@ -16,6 +17,14 @@

__all__ = ['wcs1d_fits_loader', 'non_linear_wcs1d_fits', 'non_linear_multispec_fits']

UNCERT_REF = {'STD': StdDevUncertainty,
'ERR': StdDevUncertainty,
'UNCERT': StdDevUncertainty,
'VAR': VarianceUncertainty,
'IVAR': InverseVariance}

UNCERT_EXP = {'STD': 1, 'VAR': 2, 'IVAR': -2}


def identify_wcs1d_fits(origin, *args, **kwargs):
"""
Expand All @@ -27,9 +36,8 @@ def identify_wcs1d_fits(origin, *args, **kwargs):
# Default FITS format is BINTABLE in 1st extension HDU, unless IMAGE is
# indicated via naming pattern or (explicitly) selecting primary HDU.
if origin == 'write':
return ((args[0].endswith(('wcs.fits', 'wcs1d.fits', 'wcs.fit')) or
(args[0].endswith(('.fits', '.fit')) and whdu == 0)) and not
hasattr(args[2], 'uncertainty'))
return (args[0].endswith(('wcs.fits', 'wcs1d.fits', 'wcs.fit')) or
(args[0].endswith(('.fits', '.fit')) and whdu == 0))

hdu = kwargs.get('hdu', 0)
# Check if number of axes is one and dimension of WCS is one
Expand All @@ -45,8 +53,8 @@ def identify_wcs1d_fits(origin, *args, **kwargs):

@data_loader("wcs1d-fits", identifier=identify_wcs1d_fits,
dtype=Spectrum1D, extensions=['fits', 'fit'], priority=5)
def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None,
hdu=None, verbose=False, **kwargs):
def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None, hdu=None,
verbose=False, mask_hdu=True, uncertainty_hdu=True, **kwargs):
"""
Loader for single spectrum-per-HDU spectra in FITS files, with the spectral
axis stored in the header as FITS-WCS. The flux unit of the spectrum is
Expand All @@ -68,10 +76,17 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None,
Units of the flux for this spectrum. If not given (or None), the unit
will be inferred from the BUNIT keyword in the header. Note that this
unit will attempt to convert from BUNIT if BUNIT is present.
hdu : int, str or None. optional
The index or name of the HDU to load into this spectrum (default: search 1st).
hdu : int, str or None, optional
The index or name of the HDU to load into this spectrum
(default: find 1st applicable `ImageHDU`).
verbose : bool. optional
Print extra info.
mask_hdu : int, str, bool or None, optional
The index or name of the HDU to read mask from
(default: try to autodetect; `False`|`None`: do not read in).
uncertainy_hdu : int, str, bool or None, optional
The index or name of the HDU to read uncertainy from
(default: try to autodetect; `False`|`None`: do not read in).
**kwargs
Extra keywords for :func:`~specutils.io.parsing_utils.read_fileobj_or_hdulist`.
Expand All @@ -88,15 +103,22 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None,

with read_fileobj_or_hdulist(file_obj, **kwargs) as hdulist:
if hdu is None:
for ext in ('FLUX', 'SCI', 'PRIMARY'):
for ext in ('FLUX', 'SCI', 'DATA', 'PRIMARY'):
# For now rely on extension containing spectral data.
if ext in hdulist and (
isinstance(hdulist[ext], (fits.ImageHDU, fits.PrimaryHDU)) and
hdulist[ext].data is not None):
hdu = ext
break
else:
raise ValueError('No HDU with spectral data found.')
if hdu is None or hdulist[ext] == hdulist[hdu]:
hdu = ext
continue
else:
warnings.warn(f"Found multiple data HDUs '{hdu}' and '{ext}', "
f"will read '{hdu}'! Please use `hdu=<flux_hdu>` "
"to select a specific one.", AstropyUserWarning)
break

if hdu is None:
raise ValueError('No HDU with spectral data found.')

header = hdulist[hdu].header
wcs = WCS(header)
Expand All @@ -108,11 +130,48 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None,
else:
data = u.Quantity(hdulist[hdu].data, unit=flux_unit)

# TODO: add additional HDU data like 'IVAR'
if 'MASK' in hdulist:
mask = hdulist['MASK'].data
else:
mask = None
# Read mask from specified HDU or try to auto-detect it.
mask = None
if mask_hdu is True:
for ext in ('MASK', 'DQ', 'QUALITY'):
if ext in hdulist and (isinstance(hdulist[ext], fits.ImageHDU) and
hdulist[ext].data is not None):
mask = hdulist[ext].data
break
elif isinstance(mask_hdu, (int, str)):
if mask_hdu in hdulist:
mask = hdulist[mask_hdu].data
else:
warnings.warn(f"No HDU '{mask_hdu}' for mask found in file.",
AstropyUserWarning)

uncertainty = None
if uncertainty_hdu is True:
for ext in UNCERT_REF:
if ext in hdulist and (isinstance(hdulist[ext], fits.ImageHDU) and
hdulist[ext].data is not None):
uncertainty = hdulist[ext].data
break
elif isinstance(uncertainty_hdu, (int, str)):
if uncertainty_hdu in hdulist:
ext = uncertainty_hdu
uncertainty = hdulist[ext].data
else:
warnings.warn(f"No HDU '{uncertainty_hdu}' for uncertainty found in file.",
AstropyUserWarning)

if uncertainty is not None:
if hdulist[ext].name in UNCERT_REF:
unc_type = hdulist[ext].name
else:
warnings.warn(f"Could not determine uncertainty type for HDU '{ext}' "
f"('{hdulist[ext].name}'), assuming 'StdDev'.",
AstropyUserWarning)
unc_type = 'STD'
uunit = u.Unit(header.get('BUNIT', flux_unit))
if unc_type != 'STD':
uunit = uunit**UNCERT_EXP[unc_type]
uncertainty = UNCERT_REF[unc_type](u.Quantity(uncertainty, unit=uunit))

if spectral_axis_unit is not None:
wcs.wcs.cunit[0] = str(spectral_axis_unit)
Expand All @@ -134,14 +193,16 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None,

meta = {'header': header}

# Is this restriction still appropriate?
if wcs.naxis > 4:
raise ValueError('FITS file input to wcs1d_fits_loader is > 4D')

return Spectrum1D(flux=data, wcs=wcs, mask=mask, meta=meta)
return Spectrum1D(flux=data, wcs=wcs, mask=mask, uncertainty=uncertainty, meta=meta)


@custom_writer("wcs1d-fits")
def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False, flux_name='FLUX', **kwargs):
def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False,
flux_name='FLUX', mask_name='', uncertainty_name='', **kwargs):
"""
Write spectrum with spectral axis defined by its WCS to (primary)
IMAGE_HDU of a FITS file.
Expand All @@ -157,6 +218,10 @@ def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False, flux_name
Update FITS header with all compatible entries in `spectrum.meta`
flux_name : str, optional
HDU name to store flux spectrum under (default 'FLUX')
mask_name : str or `None`, optional
HDU name to store mask under (default 'MASK'; `None`: do not save)
uncertainty_name : str or `None`, optional
HDU name to store uncertainty under (default set from type; `None`: do not save)
unit : str or :class:`~astropy.units.Unit`, optional
Unit for the flux (and associated uncertainty; defaults to `spectrum.flux.unit`)
dtype : str or :class:`~numpy.dtype`, optional
Expand All @@ -171,16 +236,16 @@ def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False, flux_name
raise ValueError(f'Only Spectrum1D objects with valid WCS can be written as wcs1d: {err}')

# Verify spectral axis constructed from WCS
wl = spectrum.spectral_axis
disp = spectrum.spectral_axis
# Not sure why the extra check is necessary for FITS WCS
if hasattr(wcs, 'celestial') and wcs.celestial.naxis > 0:
dwl = (wcs.spectral.all_pix2world(np.arange(len(wl)), 0) - wl.value) / wl.value
ddisp = (wcs.spectral.all_pix2world(np.arange(len(disp)), 0) - disp.value) / disp.value
else:
dwl = (wcs.all_pix2world(np.arange(len(wl)), 0) - wl.value) / wl.value
if np.abs(dwl).max() > 1.e-10:
m = np.abs(dwl).argmax()
ddisp = (wcs.all_pix2world(np.arange(len(disp)), 0) - disp.value) / disp.value
if np.abs(ddisp).max() > 1.e-10:
m = np.abs(ddisp).argmax()
raise ValueError('Relative difference between WCS spectral axis and'
f'spectral_axis at {m:}: {dwl[m]}')
f'spectral_axis at {m:}: {ddisp[m]}')

if update_header:
hdr_types = (str, int, float, complex, bool,
Expand All @@ -189,26 +254,46 @@ def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False, flux_name
(isinstance(keyword[1], hdr_types) and
keyword[0] not in ('NAXIS', 'NAXIS1', 'NAXIS2'))])

# Cannot include uncertainty in IMAGE_HDU - maybe provide option to
# separately write this to BINARY_TBL extension later.
if spectrum.uncertainty is not None:
warnings.warn("Saving uncertainties in wcs1d format is not yet supported!",
AstropyUserWarning)

# Add flux array and unit
ftype = kwargs.pop('dtype', spectrum.flux.dtype)
funit = u.Unit(kwargs.pop('unit', spectrum.flux.unit))
flux = spectrum.flux.to(funit, equivalencies=u.spectral_density(wl))
flux = spectrum.flux.to(funit, equivalencies=u.spectral_density(disp))
hdulist[0].data = flux.value.astype(ftype)
if flux_name is not None:
hdulist[0].name = flux_name

# Append mask array (duplicate WCS for that extension)?
if spectrum.mask is not None:
if spectrum.mask is not None and mask_name is not None:
hdulist.append(wcs.to_fits()[0])
hdulist[1].data = spectrum.mask
hdulist[1].name = 'MASK'
hdulist[-1].data = spectrum.mask
if mask_name == '':
hdulist[-1].name = 'MASK'
else:
hdulist[-1].name = mask_name
hdu += 1
# Warn if extension name for saving was explicitly chosen.
elif mask_name is not None and mask_name != '':
warnings.warn("No mask found in this Spectrum1D, none saved.", AstropyUserWarning)

# Append uncertainty array (duplicate WCS for that extension)?
if spectrum.uncertainty is not None and uncertainty_name is not None:
hdulist.append(wcs.to_fits()[0])
# uncertainty - units to be inferred from spectrum.flux
if uncertainty_name == '':
uncertainty_name = [n for n in UNCERT_REF if
isinstance(spectrum.uncertainty, UNCERT_REF[n])][0]
if uncertainty_name in ('VAR', 'IVAR'):
uunit = funit**UNCERT_EXP[uncertainty_name]
else:
uunit = funit
sig = spectrum.uncertainty.quantity.to_value(uunit, equivalencies=u.spectral_density(disp))
hdulist[-1].data = sig.astype(ftype)
hdulist[-1].name = uncertainty_name
hdu += 1
# Warn if extension name for saving was explicitly chosen.
elif uncertainty_name is not None and uncertainty_name != '':
warnings.warn("No uncertainty array found in this Spectrum1D, none saved.",
AstropyUserWarning)

if hasattr(funit, 'long_names') and len(funit.long_names) > 0:
comment = f'[{funit.long_names[0]}] {funit.physical_type}'
Expand Down
54 changes: 45 additions & 9 deletions specutils/tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@
from astropy.io.fits.verify import VerifyWarning
from astropy.table import Table
from astropy.units import UnitsWarning
from astropy.wcs import FITSFixedWarning, InconsistentAxisTypesError, WCS
from astropy.wcs import FITSFixedWarning, WCS
from astropy.io.registry import IORegistryError
from astropy.modeling import models
from astropy.tests.helper import quantity_allclose
from astropy.nddata import StdDevUncertainty
from astropy.nddata import StdDevUncertainty, VarianceUncertainty, InverseVariance
from astropy.utils.exceptions import AstropyUserWarning

from numpy.testing import assert_allclose

Expand Down Expand Up @@ -709,7 +710,9 @@ def test_tabular_fits_compressed(compress, tmpdir):
@pytest.mark.parametrize("spectral_axis",
['WAVE', 'FREQ', 'ENER', 'WAVN'])
@pytest.mark.parametrize("with_mask", [False, True])
def test_wcs1d_fits_writer(tmpdir, spectral_axis, with_mask):
@pytest.mark.parametrize("uncertainty",
[None, StdDevUncertainty, VarianceUncertainty, InverseVariance])
def test_wcs1d_fits_writer(tmpdir, spectral_axis, with_mask, uncertainty):
"""Test write/read for Spectrum1D with WCS-constructed spectral_axis."""
wlunits = {'WAVE': 'Angstrom', 'FREQ': 'GHz', 'ENER': 'eV', 'WAVN': 'cm**-1'}
# Header dictionary for constructing WCS
Expand All @@ -721,13 +724,20 @@ def test_wcs1d_fits_writer(tmpdir, spectral_axis, with_mask):
wl0 = hdr['CRVAL1']
dwl = hdr['CDELT1']
disp = np.arange(wl0, wl0 + (len(flux) - 0.5) * dwl, dwl) * wlu
tmpfile = str(tmpdir.join('_tst.fits'))

if with_mask:
mask = np.array([0, 0, 0, 0, 1, 0, 0, 0, 1, 0], dtype=np.uint16)
else:
mask = None

spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask)
tmpfile = str(tmpdir.join('_tst.fits'))
# ToDo: test with explicit (and different from flux) units.
if uncertainty is None:
spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask)
assert spectrum.uncertainty is None
else:
unc = uncertainty(0.1 * np.sqrt(np.abs(flux.value)))
spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask, uncertainty=unc)
spectrum.write(tmpfile, hdu=0)

# Read it in and check against the original
Expand All @@ -738,6 +748,10 @@ def test_wcs1d_fits_writer(tmpdir, spectral_axis, with_mask):
assert quantity_allclose(spec.spectral_axis, disp)
assert quantity_allclose(spec.flux, spectrum.flux)
assert np.all(spec.mask == spectrum.mask)
if uncertainty is None:
assert spec.uncertainty is None
else:
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)

# Read from HDUList
hdulist = fits.open(tmpfile)
Expand All @@ -746,12 +760,18 @@ def test_wcs1d_fits_writer(tmpdir, spectral_axis, with_mask):
assert quantity_allclose(spec.spectral_axis, spectrum.spectral_axis)
assert quantity_allclose(spec.flux, spectrum.flux)
assert np.all(spec.mask == spectrum.mask)
if uncertainty is None:
assert spec.uncertainty is None
else:
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)


@pytest.mark.parametrize("spectral_axis",
['WAVE', 'FREQ', 'ENER', 'WAVN'])
@pytest.mark.parametrize("with_mask", [False, True])
def test_wcs1d_fits_cube(tmpdir, spectral_axis, with_mask):
@pytest.mark.parametrize("uncertainty",
[None, StdDevUncertainty, VarianceUncertainty, InverseVariance])
def test_wcs1d_fits_cube(tmpdir, spectral_axis, with_mask, uncertainty):
"""Test write/read for Spectrum1D spectral cube with WCS spectral_axis."""
wlunits = {'WAVE': 'Angstrom', 'FREQ': 'GHz', 'ENER': 'eV', 'WAVN': 'cm**-1'}
# Header dictionary for constructing WCS
Expand All @@ -770,15 +790,23 @@ def test_wcs1d_fits_cube(tmpdir, spectral_axis, with_mask):
wl0 = hdr['CRVAL1']
dwl = hdr['CDELT1']
disp = np.arange(wl0, wl0 + (flux.shape[2] - 0.5) * dwl, dwl) * wlu
tmpfile = str(tmpdir.join('_tst.fits'))

if with_mask:
die = np.random.Generator(np.random.MT19937(23))
mask = die.choice([0, 0, 0, 0, 0, 1], size=flux.shape).astype(np.uint16)
else:
mask = None

spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask)
tmpfile = str(tmpdir.join('_tst.fits'))
spectrum.write(tmpfile, hdu=0)
if uncertainty is None:
spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask)
assert spectrum.uncertainty is None
with pytest.warns(AstropyUserWarning, match='No uncertainty array found'):
spectrum.write(tmpfile, hdu=0, uncertainty_name='STD')
else:
unc = uncertainty(0.1 * np.sqrt(np.abs(flux.value)))
spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask, uncertainty=unc)
spectrum.write(tmpfile, hdu=0)

# Broken reader!
# Read it in and check against the original
Expand All @@ -790,6 +818,10 @@ def test_wcs1d_fits_cube(tmpdir, spectral_axis, with_mask):
assert quantity_allclose(spec.spectral_axis, disp)
assert quantity_allclose(spec.flux, spectrum.flux)
assert np.all(spec.mask == spectrum.mask)
if uncertainty is None:
assert spec.uncertainty is None
else:
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)

# Read from HDUList
with fits.open(tmpfile) as hdulist:
Expand All @@ -803,6 +835,10 @@ def test_wcs1d_fits_cube(tmpdir, spectral_axis, with_mask):
assert quantity_allclose(spec.spectral_axis, spectrum.spectral_axis)
assert quantity_allclose(spec.flux, spectrum.flux)
assert np.all(spec.mask == spectrum.mask)
if uncertainty is None:
assert spec.uncertainty is None
else:
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)


@pytest.mark.filterwarnings('ignore:Card is too long')
Expand Down

0 comments on commit 5f79132

Please sign in to comment.