From 5f79132f0f7ced4b7a023f9acf7db03c7ce9b9b2 Mon Sep 17 00:00:00 2001 From: Derek Homeier Date: Sun, 19 Feb 2023 00:08:24 +0100 Subject: [PATCH] WIP: basic uncertainty r/w with wcs1d-fits --- specutils/io/default_loaders/wcs_fits.py | 155 ++++++++++++++++++----- specutils/tests/test_loaders.py | 54 ++++++-- 2 files changed, 165 insertions(+), 44 deletions(-) diff --git a/specutils/io/default_loaders/wcs_fits.py b/specutils/io/default_loaders/wcs_fits.py index 6dbf32397..068049de9 100644 --- a/specutils/io/default_loaders/wcs_fits.py +++ b/specutils/io/default_loaders/wcs_fits.py @@ -5,6 +5,7 @@ from astropy.io import fits from astropy.wcs import WCS from astropy.modeling import models +from astropy.nddata import StdDevUncertainty, InverseVariance, VarianceUncertainty from astropy.utils.exceptions import AstropyUserWarning import numpy as np @@ -16,6 +17,14 @@ __all__ = ['wcs1d_fits_loader', 'non_linear_wcs1d_fits', 'non_linear_multispec_fits'] +UNCERT_REF = {'STD': StdDevUncertainty, + 'ERR': StdDevUncertainty, + 'UNCERT': StdDevUncertainty, + 'VAR': VarianceUncertainty, + 'IVAR': InverseVariance} + +UNCERT_EXP = {'STD': 1, 'VAR': 2, 'IVAR': -2} + def identify_wcs1d_fits(origin, *args, **kwargs): """ @@ -27,9 +36,8 @@ def identify_wcs1d_fits(origin, *args, **kwargs): # Default FITS format is BINTABLE in 1st extension HDU, unless IMAGE is # indicated via naming pattern or (explicitly) selecting primary HDU. if origin == 'write': - return ((args[0].endswith(('wcs.fits', 'wcs1d.fits', 'wcs.fit')) or - (args[0].endswith(('.fits', '.fit')) and whdu == 0)) and not - hasattr(args[2], 'uncertainty')) + return (args[0].endswith(('wcs.fits', 'wcs1d.fits', 'wcs.fit')) or + (args[0].endswith(('.fits', '.fit')) and whdu == 0)) hdu = kwargs.get('hdu', 0) # Check if number of axes is one and dimension of WCS is one @@ -45,8 +53,8 @@ def identify_wcs1d_fits(origin, *args, **kwargs): @data_loader("wcs1d-fits", identifier=identify_wcs1d_fits, dtype=Spectrum1D, extensions=['fits', 'fit'], priority=5) -def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None, - hdu=None, verbose=False, **kwargs): +def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None, hdu=None, + verbose=False, mask_hdu=True, uncertainty_hdu=True, **kwargs): """ Loader for single spectrum-per-HDU spectra in FITS files, with the spectral axis stored in the header as FITS-WCS. The flux unit of the spectrum is @@ -68,10 +76,17 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None, Units of the flux for this spectrum. If not given (or None), the unit will be inferred from the BUNIT keyword in the header. Note that this unit will attempt to convert from BUNIT if BUNIT is present. - hdu : int, str or None. optional - The index or name of the HDU to load into this spectrum (default: search 1st). + hdu : int, str or None, optional + The index or name of the HDU to load into this spectrum + (default: find 1st applicable `ImageHDU`). verbose : bool. optional Print extra info. + mask_hdu : int, str, bool or None, optional + The index or name of the HDU to read mask from + (default: try to autodetect; `False`|`None`: do not read in). + uncertainy_hdu : int, str, bool or None, optional + The index or name of the HDU to read uncertainy from + (default: try to autodetect; `False`|`None`: do not read in). **kwargs Extra keywords for :func:`~specutils.io.parsing_utils.read_fileobj_or_hdulist`. @@ -88,15 +103,22 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None, with read_fileobj_or_hdulist(file_obj, **kwargs) as hdulist: if hdu is None: - for ext in ('FLUX', 'SCI', 'PRIMARY'): + for ext in ('FLUX', 'SCI', 'DATA', 'PRIMARY'): # For now rely on extension containing spectral data. if ext in hdulist and ( isinstance(hdulist[ext], (fits.ImageHDU, fits.PrimaryHDU)) and hdulist[ext].data is not None): - hdu = ext - break - else: - raise ValueError('No HDU with spectral data found.') + if hdu is None or hdulist[ext] == hdulist[hdu]: + hdu = ext + continue + else: + warnings.warn(f"Found multiple data HDUs '{hdu}' and '{ext}', " + f"will read '{hdu}'! Please use `hdu=` " + "to select a specific one.", AstropyUserWarning) + break + + if hdu is None: + raise ValueError('No HDU with spectral data found.') header = hdulist[hdu].header wcs = WCS(header) @@ -108,11 +130,48 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None, else: data = u.Quantity(hdulist[hdu].data, unit=flux_unit) - # TODO: add additional HDU data like 'IVAR' - if 'MASK' in hdulist: - mask = hdulist['MASK'].data - else: - mask = None + # Read mask from specified HDU or try to auto-detect it. + mask = None + if mask_hdu is True: + for ext in ('MASK', 'DQ', 'QUALITY'): + if ext in hdulist and (isinstance(hdulist[ext], fits.ImageHDU) and + hdulist[ext].data is not None): + mask = hdulist[ext].data + break + elif isinstance(mask_hdu, (int, str)): + if mask_hdu in hdulist: + mask = hdulist[mask_hdu].data + else: + warnings.warn(f"No HDU '{mask_hdu}' for mask found in file.", + AstropyUserWarning) + + uncertainty = None + if uncertainty_hdu is True: + for ext in UNCERT_REF: + if ext in hdulist and (isinstance(hdulist[ext], fits.ImageHDU) and + hdulist[ext].data is not None): + uncertainty = hdulist[ext].data + break + elif isinstance(uncertainty_hdu, (int, str)): + if uncertainty_hdu in hdulist: + ext = uncertainty_hdu + uncertainty = hdulist[ext].data + else: + warnings.warn(f"No HDU '{uncertainty_hdu}' for uncertainty found in file.", + AstropyUserWarning) + + if uncertainty is not None: + if hdulist[ext].name in UNCERT_REF: + unc_type = hdulist[ext].name + else: + warnings.warn(f"Could not determine uncertainty type for HDU '{ext}' " + f"('{hdulist[ext].name}'), assuming 'StdDev'.", + AstropyUserWarning) + unc_type = 'STD' + uunit = u.Unit(header.get('BUNIT', flux_unit)) + if unc_type != 'STD': + uunit = uunit**UNCERT_EXP[unc_type] + uncertainty = UNCERT_REF[unc_type](u.Quantity(uncertainty, unit=uunit)) if spectral_axis_unit is not None: wcs.wcs.cunit[0] = str(spectral_axis_unit) @@ -134,14 +193,16 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None, meta = {'header': header} + # Is this restriction still appropriate? if wcs.naxis > 4: raise ValueError('FITS file input to wcs1d_fits_loader is > 4D') - return Spectrum1D(flux=data, wcs=wcs, mask=mask, meta=meta) + return Spectrum1D(flux=data, wcs=wcs, mask=mask, uncertainty=uncertainty, meta=meta) @custom_writer("wcs1d-fits") -def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False, flux_name='FLUX', **kwargs): +def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False, + flux_name='FLUX', mask_name='', uncertainty_name='', **kwargs): """ Write spectrum with spectral axis defined by its WCS to (primary) IMAGE_HDU of a FITS file. @@ -157,6 +218,10 @@ def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False, flux_name Update FITS header with all compatible entries in `spectrum.meta` flux_name : str, optional HDU name to store flux spectrum under (default 'FLUX') + mask_name : str or `None`, optional + HDU name to store mask under (default 'MASK'; `None`: do not save) + uncertainty_name : str or `None`, optional + HDU name to store uncertainty under (default set from type; `None`: do not save) unit : str or :class:`~astropy.units.Unit`, optional Unit for the flux (and associated uncertainty; defaults to `spectrum.flux.unit`) dtype : str or :class:`~numpy.dtype`, optional @@ -171,16 +236,16 @@ def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False, flux_name raise ValueError(f'Only Spectrum1D objects with valid WCS can be written as wcs1d: {err}') # Verify spectral axis constructed from WCS - wl = spectrum.spectral_axis + disp = spectrum.spectral_axis # Not sure why the extra check is necessary for FITS WCS if hasattr(wcs, 'celestial') and wcs.celestial.naxis > 0: - dwl = (wcs.spectral.all_pix2world(np.arange(len(wl)), 0) - wl.value) / wl.value + ddisp = (wcs.spectral.all_pix2world(np.arange(len(disp)), 0) - disp.value) / disp.value else: - dwl = (wcs.all_pix2world(np.arange(len(wl)), 0) - wl.value) / wl.value - if np.abs(dwl).max() > 1.e-10: - m = np.abs(dwl).argmax() + ddisp = (wcs.all_pix2world(np.arange(len(disp)), 0) - disp.value) / disp.value + if np.abs(ddisp).max() > 1.e-10: + m = np.abs(ddisp).argmax() raise ValueError('Relative difference between WCS spectral axis and' - f'spectral_axis at {m:}: {dwl[m]}') + f'spectral_axis at {m:}: {ddisp[m]}') if update_header: hdr_types = (str, int, float, complex, bool, @@ -189,26 +254,46 @@ def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False, flux_name (isinstance(keyword[1], hdr_types) and keyword[0] not in ('NAXIS', 'NAXIS1', 'NAXIS2'))]) - # Cannot include uncertainty in IMAGE_HDU - maybe provide option to - # separately write this to BINARY_TBL extension later. - if spectrum.uncertainty is not None: - warnings.warn("Saving uncertainties in wcs1d format is not yet supported!", - AstropyUserWarning) - # Add flux array and unit ftype = kwargs.pop('dtype', spectrum.flux.dtype) funit = u.Unit(kwargs.pop('unit', spectrum.flux.unit)) - flux = spectrum.flux.to(funit, equivalencies=u.spectral_density(wl)) + flux = spectrum.flux.to(funit, equivalencies=u.spectral_density(disp)) hdulist[0].data = flux.value.astype(ftype) if flux_name is not None: hdulist[0].name = flux_name # Append mask array (duplicate WCS for that extension)? - if spectrum.mask is not None: + if spectrum.mask is not None and mask_name is not None: hdulist.append(wcs.to_fits()[0]) - hdulist[1].data = spectrum.mask - hdulist[1].name = 'MASK' + hdulist[-1].data = spectrum.mask + if mask_name == '': + hdulist[-1].name = 'MASK' + else: + hdulist[-1].name = mask_name + hdu += 1 + # Warn if extension name for saving was explicitly chosen. + elif mask_name is not None and mask_name != '': + warnings.warn("No mask found in this Spectrum1D, none saved.", AstropyUserWarning) + + # Append uncertainty array (duplicate WCS for that extension)? + if spectrum.uncertainty is not None and uncertainty_name is not None: + hdulist.append(wcs.to_fits()[0]) + # uncertainty - units to be inferred from spectrum.flux + if uncertainty_name == '': + uncertainty_name = [n for n in UNCERT_REF if + isinstance(spectrum.uncertainty, UNCERT_REF[n])][0] + if uncertainty_name in ('VAR', 'IVAR'): + uunit = funit**UNCERT_EXP[uncertainty_name] + else: + uunit = funit + sig = spectrum.uncertainty.quantity.to_value(uunit, equivalencies=u.spectral_density(disp)) + hdulist[-1].data = sig.astype(ftype) + hdulist[-1].name = uncertainty_name hdu += 1 + # Warn if extension name for saving was explicitly chosen. + elif uncertainty_name is not None and uncertainty_name != '': + warnings.warn("No uncertainty array found in this Spectrum1D, none saved.", + AstropyUserWarning) if hasattr(funit, 'long_names') and len(funit.long_names) > 0: comment = f'[{funit.long_names[0]}] {funit.physical_type}' diff --git a/specutils/tests/test_loaders.py b/specutils/tests/test_loaders.py index bd1dacf9d..eff4db904 100644 --- a/specutils/tests/test_loaders.py +++ b/specutils/tests/test_loaders.py @@ -11,11 +11,12 @@ from astropy.io.fits.verify import VerifyWarning from astropy.table import Table from astropy.units import UnitsWarning -from astropy.wcs import FITSFixedWarning, InconsistentAxisTypesError, WCS +from astropy.wcs import FITSFixedWarning, WCS from astropy.io.registry import IORegistryError from astropy.modeling import models from astropy.tests.helper import quantity_allclose -from astropy.nddata import StdDevUncertainty +from astropy.nddata import StdDevUncertainty, VarianceUncertainty, InverseVariance +from astropy.utils.exceptions import AstropyUserWarning from numpy.testing import assert_allclose @@ -709,7 +710,9 @@ def test_tabular_fits_compressed(compress, tmpdir): @pytest.mark.parametrize("spectral_axis", ['WAVE', 'FREQ', 'ENER', 'WAVN']) @pytest.mark.parametrize("with_mask", [False, True]) -def test_wcs1d_fits_writer(tmpdir, spectral_axis, with_mask): +@pytest.mark.parametrize("uncertainty", + [None, StdDevUncertainty, VarianceUncertainty, InverseVariance]) +def test_wcs1d_fits_writer(tmpdir, spectral_axis, with_mask, uncertainty): """Test write/read for Spectrum1D with WCS-constructed spectral_axis.""" wlunits = {'WAVE': 'Angstrom', 'FREQ': 'GHz', 'ENER': 'eV', 'WAVN': 'cm**-1'} # Header dictionary for constructing WCS @@ -721,13 +724,20 @@ def test_wcs1d_fits_writer(tmpdir, spectral_axis, with_mask): wl0 = hdr['CRVAL1'] dwl = hdr['CDELT1'] disp = np.arange(wl0, wl0 + (len(flux) - 0.5) * dwl, dwl) * wlu + tmpfile = str(tmpdir.join('_tst.fits')) + if with_mask: mask = np.array([0, 0, 0, 0, 1, 0, 0, 0, 1, 0], dtype=np.uint16) else: mask = None - spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask) - tmpfile = str(tmpdir.join('_tst.fits')) + # ToDo: test with explicit (and different from flux) units. + if uncertainty is None: + spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask) + assert spectrum.uncertainty is None + else: + unc = uncertainty(0.1 * np.sqrt(np.abs(flux.value))) + spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask, uncertainty=unc) spectrum.write(tmpfile, hdu=0) # Read it in and check against the original @@ -738,6 +748,10 @@ def test_wcs1d_fits_writer(tmpdir, spectral_axis, with_mask): assert quantity_allclose(spec.spectral_axis, disp) assert quantity_allclose(spec.flux, spectrum.flux) assert np.all(spec.mask == spectrum.mask) + if uncertainty is None: + assert spec.uncertainty is None + else: + assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity) # Read from HDUList hdulist = fits.open(tmpfile) @@ -746,12 +760,18 @@ def test_wcs1d_fits_writer(tmpdir, spectral_axis, with_mask): assert quantity_allclose(spec.spectral_axis, spectrum.spectral_axis) assert quantity_allclose(spec.flux, spectrum.flux) assert np.all(spec.mask == spectrum.mask) + if uncertainty is None: + assert spec.uncertainty is None + else: + assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity) @pytest.mark.parametrize("spectral_axis", ['WAVE', 'FREQ', 'ENER', 'WAVN']) @pytest.mark.parametrize("with_mask", [False, True]) -def test_wcs1d_fits_cube(tmpdir, spectral_axis, with_mask): +@pytest.mark.parametrize("uncertainty", + [None, StdDevUncertainty, VarianceUncertainty, InverseVariance]) +def test_wcs1d_fits_cube(tmpdir, spectral_axis, with_mask, uncertainty): """Test write/read for Spectrum1D spectral cube with WCS spectral_axis.""" wlunits = {'WAVE': 'Angstrom', 'FREQ': 'GHz', 'ENER': 'eV', 'WAVN': 'cm**-1'} # Header dictionary for constructing WCS @@ -770,15 +790,23 @@ def test_wcs1d_fits_cube(tmpdir, spectral_axis, with_mask): wl0 = hdr['CRVAL1'] dwl = hdr['CDELT1'] disp = np.arange(wl0, wl0 + (flux.shape[2] - 0.5) * dwl, dwl) * wlu + tmpfile = str(tmpdir.join('_tst.fits')) + if with_mask: die = np.random.Generator(np.random.MT19937(23)) mask = die.choice([0, 0, 0, 0, 0, 1], size=flux.shape).astype(np.uint16) else: mask = None - spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask) - tmpfile = str(tmpdir.join('_tst.fits')) - spectrum.write(tmpfile, hdu=0) + if uncertainty is None: + spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask) + assert spectrum.uncertainty is None + with pytest.warns(AstropyUserWarning, match='No uncertainty array found'): + spectrum.write(tmpfile, hdu=0, uncertainty_name='STD') + else: + unc = uncertainty(0.1 * np.sqrt(np.abs(flux.value))) + spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask, uncertainty=unc) + spectrum.write(tmpfile, hdu=0) # Broken reader! # Read it in and check against the original @@ -790,6 +818,10 @@ def test_wcs1d_fits_cube(tmpdir, spectral_axis, with_mask): assert quantity_allclose(spec.spectral_axis, disp) assert quantity_allclose(spec.flux, spectrum.flux) assert np.all(spec.mask == spectrum.mask) + if uncertainty is None: + assert spec.uncertainty is None + else: + assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity) # Read from HDUList with fits.open(tmpfile) as hdulist: @@ -803,6 +835,10 @@ def test_wcs1d_fits_cube(tmpdir, spectral_axis, with_mask): assert quantity_allclose(spec.spectral_axis, spectrum.spectral_axis) assert quantity_allclose(spec.flux, spectrum.flux) assert np.all(spec.mask == spectrum.mask) + if uncertainty is None: + assert spec.uncertainty is None + else: + assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity) @pytest.mark.filterwarnings('ignore:Card is too long')