Skip to content

Commit

Permalink
Support write and read of boolean masks in wcs1d-fits (#1051)
Browse files Browse the repository at this point in the history
* Test warning on unspecified uncertainty type

* Store and retrieve boolean masks in wcs1d-fits

* Move uncertainty before mask in hdulist
  • Loading branch information
dhomeier authored Apr 18, 2023
1 parent 2f44758 commit b332b87
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 37 deletions.
48 changes: 30 additions & 18 deletions specutils/io/default_loaders/wcs_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,19 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None,
# Read mask from specified HDU or try to auto-detect it.
mask = None
if mask_hdu is True:
mask_hdu = None
for ext in ('MASK', 'DQ', 'QUALITY'):
if ext in hdulist and (isinstance(hdulist[ext], fits.ImageHDU) and
hdulist[ext].data is not None):
mask = hdulist[ext].data
mask_hdu = ext
break
elif isinstance(mask_hdu, (int, str)):
if isinstance(mask_hdu, (int, str)):
if mask_hdu in hdulist:
mask = hdulist[mask_hdu].data
# Do we have a mask originally converted from bool/bit?
if hdulist[mask_hdu].header.get('BFORM', 'B') in ('L', 'X'):
# ToDo: check for overflow and warn on values != (0, 1)
mask = mask.astype(bool)
else:
warnings.warn(f"No HDU '{mask_hdu}' for mask found in file.",
AstropyUserWarning)
Expand Down Expand Up @@ -272,19 +277,6 @@ def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False,
if flux_name is not None:
hdulist[0].name = flux_name

# Append mask array
if spectrum.mask is not None and mask_name is not None:
hdulist.append(fits.ImageHDU())
hdulist[-1].data = spectrum.mask
if mask_name == '':
hdulist[-1].name = 'MASK'
else:
hdulist[-1].name = mask_name
hdu += 1
# Warn if saving was requested (per explicitly choosing extension name).
elif mask_name is not None and mask_name != '':
warnings.warn("No mask found in this Spectrum1D, none saved.", AstropyUserWarning)

# Append uncertainty array
if spectrum.uncertainty is not None and uncertainty_name is not None:
hdulist.append(fits.ImageHDU())
Expand All @@ -310,6 +302,25 @@ def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False,
warnings.warn("No uncertainty array found in this Spectrum1D, none saved.",
AstropyUserWarning)

# Append mask array
if spectrum.mask is not None and mask_name is not None:
hdulist.append(fits.ImageHDU())
# No standard representation of bool in FITS;
# introducing 'BFORM' here in analogy to BINTABLE 'TFORMn'.
if spectrum.mask.dtype == bool:
hdulist[-1].data = spectrum.mask.astype(np.uint8)
hdulist[-1].header['BFORM'] = 'L'
else:
hdulist[-1].data = spectrum.mask
if mask_name == '':
hdulist[-1].name = 'MASK'
else:
hdulist[-1].name = mask_name
hdu += 1
# Warn if saving was requested (per explicitly choosing extension name).
elif mask_name is not None and mask_name != '':
warnings.warn("No mask found in this Spectrum1D, none saved.", AstropyUserWarning)

if hasattr(funit, 'long_names') and len(funit.long_names) > 0:
comment = f'[{funit.long_names[0]}] {funit.physical_type}'
else:
Expand Down Expand Up @@ -451,8 +462,8 @@ def non_linear_multispec_fits(file_obj, **kwargs):
return SpectrumCollection(flux=flux, spectral_axis=spectral_axis, meta=meta)


def _read_non_linear_iraf_fits(file_obj, spectral_axis_unit=None, flux_unit=None, verbose=False,
**kwargs):
def _read_non_linear_iraf_fits(file_obj, spectral_axis_unit=None, flux_unit=None,
verbose=False, **kwargs):
"""Read spectrum data with WCS spectral axis from FITS files written by IRAF
IRAF does not strictly follow the fits standard especially for non-linear
Expand Down Expand Up @@ -483,7 +494,8 @@ def _read_non_linear_iraf_fits(file_obj, spectral_axis_unit=None, flux_unit=None
Returns
-------
Tuple of data to pass to `~specutils.SpectrumCollection` or `~specutils.Spectrum1D`:
Tuple of data to pass to :class:`~specutils.SpectrumCollection` or
`:class:`~specutils.Spectrum1D` on initialization:
spectral_axis : :class:`~astropy.units.Quantity`
The spectral axis or axes as constructed from WCS(hdulist[0].header).
Expand Down
110 changes: 91 additions & 19 deletions specutils/tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,12 +710,10 @@ def test_tabular_fits_compressed(compress, tmp_path):
assert quantity_allclose(spec.flux, spectrum.flux)


@pytest.mark.parametrize("spectral_axis",
['WAVE', 'FREQ', 'ENER', 'WAVN'])
@pytest.mark.parametrize("with_mask", [False, True])
@pytest.mark.parametrize("spectral_axis", ['WAVE', 'FREQ', 'ENER', 'WAVN'])
@pytest.mark.parametrize("uncertainty",
[None, StdDevUncertainty, VarianceUncertainty, InverseVariance])
def test_wcs1d_fits_writer(tmp_path, spectral_axis, with_mask, uncertainty):
def test_wcs1d_fits_writer(tmp_path, spectral_axis, uncertainty):
"""Test write/read for Spectrum1D with WCS-constructed spectral_axis."""
wlunits = {'WAVE': 'Angstrom', 'FREQ': 'GHz', 'ENER': 'eV', 'WAVN': 'cm**-1'}
# Header dictionary for constructing WCS
Expand All @@ -729,10 +727,7 @@ def test_wcs1d_fits_writer(tmp_path, spectral_axis, with_mask, uncertainty):
disp = np.arange(wl0, wl0 + (len(flux) - 0.5) * dwl, dwl) * wlu
tmpfile = tmp_path / 'wcs_tst.fits'

if with_mask:
mask = np.array([0, 0, 0, 0, 1, 0, 0, 0, 1, 0], dtype=np.uint16)
else:
mask = None
mask = np.array([0, 0, 0, 0, 1, 0, 0, 0, 1, 0], dtype=np.uint8)

# ToDo: test with explicit (and different from flux) units.
if uncertainty is None:
Expand Down Expand Up @@ -770,6 +765,65 @@ def test_wcs1d_fits_writer(tmp_path, spectral_axis, with_mask, uncertainty):
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)


@pytest.mark.parametrize("spectral_axis", ['WAVE', 'FREQ'])
@pytest.mark.parametrize("mask_type", [None, bool, np.uint8, np.int8, np.uint16, np.int16, '>i2'])
@pytest.mark.parametrize("uncertainty", [StdDevUncertainty, InverseVariance])
def test_wcs1d_fits_masks(tmp_path, spectral_axis, mask_type, uncertainty):
"""Test write/read for Spectrum1D with WCS-constructed spectral_axis."""
wlunits = {'WAVE': 'nm', 'FREQ': 'GHz', 'ENER': 'eV', 'WAVN': 'cm**-1'}
# Header dictionary for constructing WCS
hdr = {'CTYPE1': spectral_axis, 'CUNIT1': wlunits[spectral_axis],
'CRPIX1': 1, 'CRVAL1': 1, 'CDELT1': 0.01}
# Create a small data set
flux = np.arange(1, 11)**2 * 1.e-14 * u.Jy
wlu = u.Unit(hdr['CUNIT1'])
wl0 = hdr['CRVAL1']
dwl = hdr['CDELT1']
disp = np.arange(wl0, wl0 + (len(flux) - 0.5) * dwl, dwl) * wlu
unc = uncertainty(0.1 * np.sqrt(np.abs(flux.value)))
tmpfile = tmp_path / 'wcs_tst.fits'

if mask_type is None:
mask = None
spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), uncertainty=unc)
assert spectrum.mask is None
else:
mask = np.array([0, 0, 1, 0, 3, 0, 0, -99, -199, 0]).astype(mask_type)
spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask, uncertainty=unc)
assert spectrum.mask.dtype == mask.dtype

spectrum.write(tmpfile, hdu=0)

# Read it in and check against the original
spec = Spectrum1D.read(tmpfile)
assert quantity_allclose(spec.spectral_axis, spectrum.spectral_axis)
assert quantity_allclose(spec.spectral_axis, disp)
assert quantity_allclose(spec.flux, spectrum.flux)
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)
assert np.all(spec.mask == spectrum.mask)
# int16 is returned as FITS-native '>i2'
if mask_type == np.int16:
assert spec.mask.dtype.kind == spectrum.mask.dtype.kind
assert spec.mask.dtype.itemsize == spectrum.mask.dtype.itemsize
else:
assert np.array(spec.mask).dtype == np.array(spectrum.mask).dtype

# Read from HDUList
with fits.open(tmpfile) as hdulist:
spec = Spectrum1D.read(hdulist, format='wcs1d-fits')

assert isinstance(spec, Spectrum1D)
assert quantity_allclose(spec.spectral_axis, spectrum.spectral_axis)
assert quantity_allclose(spec.flux, spectrum.flux)
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)
assert np.all(spec.mask == spectrum.mask)
if mask_type == np.int16:
assert spec.mask.dtype.kind == spectrum.mask.dtype.kind
assert spec.mask.dtype.itemsize == spectrum.mask.dtype.itemsize
else:
assert np.array(spec.mask).dtype == np.array(spectrum.mask).dtype


@pytest.mark.parametrize("spectral_axis",
['WAVE', 'FREQ', 'ENER', 'WAVN'])
@pytest.mark.parametrize("with_mask", [False, True])
Expand Down Expand Up @@ -846,7 +900,8 @@ def test_wcs1d_fits_cube(tmp_path, spectral_axis, with_mask, uncertainty):


@pytest.mark.parametrize("uncertainty_rsv", ['STD', 'ERR', 'UNCERT', 'VAR', 'IVAR'])
def test_wcs1d_fits_uncertainty(tmp_path, uncertainty_rsv):
@pytest.mark.parametrize("hdu", [None, 0, 1])
def test_wcs1d_fits_uncertainty(tmp_path, uncertainty_rsv, hdu):
"""
Test Spectrum1D.write with custom `uncertainty` names,
ensure it raises on illegal (reserved) names.
Expand Down Expand Up @@ -877,21 +932,38 @@ def test_wcs1d_fits_uncertainty(tmp_path, uncertainty_rsv):
# Set permitted custom name
uncertainty_type = spectrum.uncertainty.uncertainty_type
uncertainty_alt = UNCERT_ALT[uncertainty_type]
spectrum.write(tmpfile, format='wcs1d-fits', uncertainty_name=uncertainty_alt)
if hdu is None:
spectrum.write(tmpfile, format='wcs1d-fits', uncertainty_name=uncertainty_alt)
hdu = 0
else:
spectrum.write(tmpfile, format='wcs1d-fits', uncertainty_name=uncertainty_alt, hdu=hdu)

# Auto-detect only works with flux in default (primary) HDU.
if hdu == 0:
kwargs = {'uncertainty_hdu': hdu+1}
else:
kwargs = {'uncertainty_hdu': hdu+1, 'format': 'wcs1d-fits'}

# Check EXTNAME (uncertainty is in last HDU)
# Check EXTNAME (uncertainty is in first HDU following flux spectrum)
with fits.open(tmpfile) as hdulist:
assert hdulist[-1].name == uncertainty_alt.upper()
assert hdulist[hdu+1].name == uncertainty_alt.upper()

# Read it in and check against the original
with pytest.raises(ValueError, match=f"Invalid uncertainty type: '{uncertainty_alt}'; should"):
spec = Spectrum1D.read(tmpfile, uncertainty_hdu=2, uncertainty_type=uncertainty_alt)
spec = Spectrum1D.read(tmpfile, uncertainty_hdu=2, uncertainty_type=uncertainty_type)
spec = Spectrum1D.read(tmpfile, uncertainty_type=uncertainty_alt, **kwargs)
# Need to specify type if not default
with pytest.warns(AstropyUserWarning, match="Could not determine uncertainty type for HDU "
rf"'{hdu+1}' .'{uncertainty_alt.upper()}'., assuming 'StdDev'"):
spec = Spectrum1D.read(tmpfile, **kwargs)
if uncertainty_type != 'std':
assert spec.uncertainty.uncertainty_type != uncertainty_type
spec = Spectrum1D.read(tmpfile, uncertainty_type=uncertainty_type, **kwargs)

assert spec.flux.unit == spectrum.flux.unit
assert spec.spectral_axis.unit == spectrum.spectral_axis.unit
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)
spec = Spectrum1D.read(tmpfile, uncertainty_hdu=uncertainty_alt,
uncertainty_type=uncertainty_type)
kwargs['uncertainty_hdu'] = uncertainty_alt
spec = Spectrum1D.read(tmpfile, uncertainty_type=uncertainty_type, **kwargs)
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)


Expand Down Expand Up @@ -996,7 +1068,7 @@ def test_wcs1d_fits_non1d(tmp_path, spectral_axis):
def test_wcs1d_fits_compressed(compress, tmp_path):
"""Test automatic recognition of supported compression formats for IMAGE/WCS.
"""
ext = {'gzip': 'gz', 'bzip2': 'bz2', 'xz': 'xz'}
ext = {'gzip': '.gz', 'bzip2': '.bz2', 'xz': '.xz'}
if compress == 'bzip2' and not HAS_BZ2:
pytest.xfail("Python installation has no bzip2 support")
if compress == 'xz' and not HAS_LZMA:
Expand All @@ -1020,7 +1092,7 @@ def test_wcs1d_fits_compressed(compress, tmp_path):
with warnings.catch_warnings():
warnings.simplefilter('ignore', FITSFixedWarning)
os.system(f'{compress} {tmpfile}')
spec = Spectrum1D.read(tmpfile.with_suffix(f'.fits.{ext[compress]}'))
spec = Spectrum1D.read(tmpfile.with_suffix(f'{tmpfile.suffix}{ext[compress]}'))

assert isinstance(spec, Spectrum1D)
assert quantity_allclose(spec.spectral_axis, disp)
Expand All @@ -1029,7 +1101,7 @@ def test_wcs1d_fits_compressed(compress, tmp_path):
# Try again without compression suffix:
with warnings.catch_warnings():
warnings.simplefilter('ignore', FITSFixedWarning)
os.system(f'mv {tmpfile}.{ext[compress]} {tmpfile}')
shutil.move(tmpfile.with_suffix(f'{tmpfile.suffix}{ext[compress]}'), tmpfile)
spec = Spectrum1D.read(tmpfile)

assert isinstance(spec, Spectrum1D)
Expand Down

0 comments on commit b332b87

Please sign in to comment.