Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support write and read of boolean masks in wcs1d-fits #1051

Merged
merged 3 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions specutils/io/default_loaders/wcs_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,19 @@ def wcs1d_fits_loader(file_obj, spectral_axis_unit=None, flux_unit=None,
# Read mask from specified HDU or try to auto-detect it.
mask = None
if mask_hdu is True:
mask_hdu = None
for ext in ('MASK', 'DQ', 'QUALITY'):
if ext in hdulist and (isinstance(hdulist[ext], fits.ImageHDU) and
hdulist[ext].data is not None):
mask = hdulist[ext].data
mask_hdu = ext
break
elif isinstance(mask_hdu, (int, str)):
if isinstance(mask_hdu, (int, str)):
if mask_hdu in hdulist:
mask = hdulist[mask_hdu].data
# Do we have a mask originally converted from bool/bit?
if hdulist[mask_hdu].header.get('BFORM', 'B') in ('L', 'X'):
# ToDo: check for overflow and warn on values != (0, 1)
mask = mask.astype(bool)
else:
warnings.warn(f"No HDU '{mask_hdu}' for mask found in file.",
AstropyUserWarning)
Expand Down Expand Up @@ -272,19 +277,6 @@ def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False,
if flux_name is not None:
hdulist[0].name = flux_name

# Append mask array
if spectrum.mask is not None and mask_name is not None:
hdulist.append(fits.ImageHDU())
hdulist[-1].data = spectrum.mask
if mask_name == '':
hdulist[-1].name = 'MASK'
else:
hdulist[-1].name = mask_name
hdu += 1
# Warn if saving was requested (per explicitly choosing extension name).
elif mask_name is not None and mask_name != '':
warnings.warn("No mask found in this Spectrum1D, none saved.", AstropyUserWarning)

# Append uncertainty array
if spectrum.uncertainty is not None and uncertainty_name is not None:
hdulist.append(fits.ImageHDU())
Expand All @@ -310,6 +302,25 @@ def wcs1d_fits_writer(spectrum, file_name, hdu=0, update_header=False,
warnings.warn("No uncertainty array found in this Spectrum1D, none saved.",
AstropyUserWarning)

# Append mask array
if spectrum.mask is not None and mask_name is not None:
hdulist.append(fits.ImageHDU())
# No standard representation of bool in FITS;
# introducing 'BFORM' here in analogy to BINTABLE 'TFORMn'.
if spectrum.mask.dtype == bool:
hdulist[-1].data = spectrum.mask.astype(np.uint8)
hdulist[-1].header['BFORM'] = 'L'
else:
hdulist[-1].data = spectrum.mask
if mask_name == '':
hdulist[-1].name = 'MASK'
else:
hdulist[-1].name = mask_name
hdu += 1
# Warn if saving was requested (per explicitly choosing extension name).
elif mask_name is not None and mask_name != '':
warnings.warn("No mask found in this Spectrum1D, none saved.", AstropyUserWarning)

if hasattr(funit, 'long_names') and len(funit.long_names) > 0:
comment = f'[{funit.long_names[0]}] {funit.physical_type}'
else:
Expand Down Expand Up @@ -451,8 +462,8 @@ def non_linear_multispec_fits(file_obj, **kwargs):
return SpectrumCollection(flux=flux, spectral_axis=spectral_axis, meta=meta)


def _read_non_linear_iraf_fits(file_obj, spectral_axis_unit=None, flux_unit=None, verbose=False,
**kwargs):
def _read_non_linear_iraf_fits(file_obj, spectral_axis_unit=None, flux_unit=None,
verbose=False, **kwargs):
"""Read spectrum data with WCS spectral axis from FITS files written by IRAF

IRAF does not strictly follow the fits standard especially for non-linear
Expand Down Expand Up @@ -483,7 +494,8 @@ def _read_non_linear_iraf_fits(file_obj, spectral_axis_unit=None, flux_unit=None

Returns
-------
Tuple of data to pass to `~specutils.SpectrumCollection` or `~specutils.Spectrum1D`:
Tuple of data to pass to :class:`~specutils.SpectrumCollection` or
`:class:`~specutils.Spectrum1D` on initialization:

spectral_axis : :class:`~astropy.units.Quantity`
The spectral axis or axes as constructed from WCS(hdulist[0].header).
Expand Down
110 changes: 91 additions & 19 deletions specutils/tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,12 +710,10 @@ def test_tabular_fits_compressed(compress, tmp_path):
assert quantity_allclose(spec.flux, spectrum.flux)


@pytest.mark.parametrize("spectral_axis",
['WAVE', 'FREQ', 'ENER', 'WAVN'])
@pytest.mark.parametrize("with_mask", [False, True])
@pytest.mark.parametrize("spectral_axis", ['WAVE', 'FREQ', 'ENER', 'WAVN'])
@pytest.mark.parametrize("uncertainty",
[None, StdDevUncertainty, VarianceUncertainty, InverseVariance])
def test_wcs1d_fits_writer(tmp_path, spectral_axis, with_mask, uncertainty):
def test_wcs1d_fits_writer(tmp_path, spectral_axis, uncertainty):
"""Test write/read for Spectrum1D with WCS-constructed spectral_axis."""
wlunits = {'WAVE': 'Angstrom', 'FREQ': 'GHz', 'ENER': 'eV', 'WAVN': 'cm**-1'}
# Header dictionary for constructing WCS
Expand All @@ -729,10 +727,7 @@ def test_wcs1d_fits_writer(tmp_path, spectral_axis, with_mask, uncertainty):
disp = np.arange(wl0, wl0 + (len(flux) - 0.5) * dwl, dwl) * wlu
tmpfile = tmp_path / 'wcs_tst.fits'

if with_mask:
mask = np.array([0, 0, 0, 0, 1, 0, 0, 0, 1, 0], dtype=np.uint16)
else:
mask = None
mask = np.array([0, 0, 0, 0, 1, 0, 0, 0, 1, 0], dtype=np.uint8)

# ToDo: test with explicit (and different from flux) units.
if uncertainty is None:
Expand Down Expand Up @@ -770,6 +765,65 @@ def test_wcs1d_fits_writer(tmp_path, spectral_axis, with_mask, uncertainty):
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)


@pytest.mark.parametrize("spectral_axis", ['WAVE', 'FREQ'])
@pytest.mark.parametrize("mask_type", [None, bool, np.uint8, np.int8, np.uint16, np.int16, '>i2'])
@pytest.mark.parametrize("uncertainty", [StdDevUncertainty, InverseVariance])
def test_wcs1d_fits_masks(tmp_path, spectral_axis, mask_type, uncertainty):
"""Test write/read for Spectrum1D with WCS-constructed spectral_axis."""
wlunits = {'WAVE': 'nm', 'FREQ': 'GHz', 'ENER': 'eV', 'WAVN': 'cm**-1'}
# Header dictionary for constructing WCS
hdr = {'CTYPE1': spectral_axis, 'CUNIT1': wlunits[spectral_axis],
'CRPIX1': 1, 'CRVAL1': 1, 'CDELT1': 0.01}
# Create a small data set
flux = np.arange(1, 11)**2 * 1.e-14 * u.Jy
wlu = u.Unit(hdr['CUNIT1'])
wl0 = hdr['CRVAL1']
dwl = hdr['CDELT1']
disp = np.arange(wl0, wl0 + (len(flux) - 0.5) * dwl, dwl) * wlu
unc = uncertainty(0.1 * np.sqrt(np.abs(flux.value)))
tmpfile = tmp_path / 'wcs_tst.fits'

if mask_type is None:
mask = None
spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), uncertainty=unc)
assert spectrum.mask is None
else:
mask = np.array([0, 0, 1, 0, 3, 0, 0, -99, -199, 0]).astype(mask_type)
spectrum = Spectrum1D(flux=flux, wcs=WCS(hdr), mask=mask, uncertainty=unc)
assert spectrum.mask.dtype == mask.dtype

spectrum.write(tmpfile, hdu=0)

# Read it in and check against the original
spec = Spectrum1D.read(tmpfile)
assert quantity_allclose(spec.spectral_axis, spectrum.spectral_axis)
assert quantity_allclose(spec.spectral_axis, disp)
assert quantity_allclose(spec.flux, spectrum.flux)
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)
assert np.all(spec.mask == spectrum.mask)
# int16 is returned as FITS-native '>i2'
if mask_type == np.int16:
assert spec.mask.dtype.kind == spectrum.mask.dtype.kind
assert spec.mask.dtype.itemsize == spectrum.mask.dtype.itemsize
else:
assert np.array(spec.mask).dtype == np.array(spectrum.mask).dtype

# Read from HDUList
with fits.open(tmpfile) as hdulist:
spec = Spectrum1D.read(hdulist, format='wcs1d-fits')

assert isinstance(spec, Spectrum1D)
assert quantity_allclose(spec.spectral_axis, spectrum.spectral_axis)
assert quantity_allclose(spec.flux, spectrum.flux)
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)
assert np.all(spec.mask == spectrum.mask)
if mask_type == np.int16:
assert spec.mask.dtype.kind == spectrum.mask.dtype.kind
assert spec.mask.dtype.itemsize == spectrum.mask.dtype.itemsize
else:
assert np.array(spec.mask).dtype == np.array(spectrum.mask).dtype


@pytest.mark.parametrize("spectral_axis",
['WAVE', 'FREQ', 'ENER', 'WAVN'])
@pytest.mark.parametrize("with_mask", [False, True])
Expand Down Expand Up @@ -846,7 +900,8 @@ def test_wcs1d_fits_cube(tmp_path, spectral_axis, with_mask, uncertainty):


@pytest.mark.parametrize("uncertainty_rsv", ['STD', 'ERR', 'UNCERT', 'VAR', 'IVAR'])
def test_wcs1d_fits_uncertainty(tmp_path, uncertainty_rsv):
@pytest.mark.parametrize("hdu", [None, 0, 1])
def test_wcs1d_fits_uncertainty(tmp_path, uncertainty_rsv, hdu):
"""
Test Spectrum1D.write with custom `uncertainty` names,
ensure it raises on illegal (reserved) names.
Expand Down Expand Up @@ -877,21 +932,38 @@ def test_wcs1d_fits_uncertainty(tmp_path, uncertainty_rsv):
# Set permitted custom name
uncertainty_type = spectrum.uncertainty.uncertainty_type
uncertainty_alt = UNCERT_ALT[uncertainty_type]
spectrum.write(tmpfile, format='wcs1d-fits', uncertainty_name=uncertainty_alt)
if hdu is None:
spectrum.write(tmpfile, format='wcs1d-fits', uncertainty_name=uncertainty_alt)
hdu = 0
else:
spectrum.write(tmpfile, format='wcs1d-fits', uncertainty_name=uncertainty_alt, hdu=hdu)

# Auto-detect only works with flux in default (primary) HDU.
if hdu == 0:
kwargs = {'uncertainty_hdu': hdu+1}
else:
kwargs = {'uncertainty_hdu': hdu+1, 'format': 'wcs1d-fits'}

# Check EXTNAME (uncertainty is in last HDU)
# Check EXTNAME (uncertainty is in first HDU following flux spectrum)
with fits.open(tmpfile) as hdulist:
assert hdulist[-1].name == uncertainty_alt.upper()
assert hdulist[hdu+1].name == uncertainty_alt.upper()

# Read it in and check against the original
with pytest.raises(ValueError, match=f"Invalid uncertainty type: '{uncertainty_alt}'; should"):
spec = Spectrum1D.read(tmpfile, uncertainty_hdu=2, uncertainty_type=uncertainty_alt)
spec = Spectrum1D.read(tmpfile, uncertainty_hdu=2, uncertainty_type=uncertainty_type)
spec = Spectrum1D.read(tmpfile, uncertainty_type=uncertainty_alt, **kwargs)
# Need to specify type if not default
with pytest.warns(AstropyUserWarning, match="Could not determine uncertainty type for HDU "
rf"'{hdu+1}' .'{uncertainty_alt.upper()}'., assuming 'StdDev'"):
spec = Spectrum1D.read(tmpfile, **kwargs)
if uncertainty_type != 'std':
assert spec.uncertainty.uncertainty_type != uncertainty_type
spec = Spectrum1D.read(tmpfile, uncertainty_type=uncertainty_type, **kwargs)

assert spec.flux.unit == spectrum.flux.unit
assert spec.spectral_axis.unit == spectrum.spectral_axis.unit
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)
spec = Spectrum1D.read(tmpfile, uncertainty_hdu=uncertainty_alt,
uncertainty_type=uncertainty_type)
kwargs['uncertainty_hdu'] = uncertainty_alt
spec = Spectrum1D.read(tmpfile, uncertainty_type=uncertainty_type, **kwargs)
assert quantity_allclose(spec.uncertainty.quantity, spectrum.uncertainty.quantity)


Expand Down Expand Up @@ -996,7 +1068,7 @@ def test_wcs1d_fits_non1d(tmp_path, spectral_axis):
def test_wcs1d_fits_compressed(compress, tmp_path):
"""Test automatic recognition of supported compression formats for IMAGE/WCS.
"""
ext = {'gzip': 'gz', 'bzip2': 'bz2', 'xz': 'xz'}
ext = {'gzip': '.gz', 'bzip2': '.bz2', 'xz': '.xz'}
if compress == 'bzip2' and not HAS_BZ2:
pytest.xfail("Python installation has no bzip2 support")
if compress == 'xz' and not HAS_LZMA:
Expand All @@ -1020,7 +1092,7 @@ def test_wcs1d_fits_compressed(compress, tmp_path):
with warnings.catch_warnings():
warnings.simplefilter('ignore', FITSFixedWarning)
os.system(f'{compress} {tmpfile}')
spec = Spectrum1D.read(tmpfile.with_suffix(f'.fits.{ext[compress]}'))
spec = Spectrum1D.read(tmpfile.with_suffix(f'{tmpfile.suffix}{ext[compress]}'))

assert isinstance(spec, Spectrum1D)
assert quantity_allclose(spec.spectral_axis, disp)
Expand All @@ -1029,7 +1101,7 @@ def test_wcs1d_fits_compressed(compress, tmp_path):
# Try again without compression suffix:
with warnings.catch_warnings():
warnings.simplefilter('ignore', FITSFixedWarning)
os.system(f'mv {tmpfile}.{ext[compress]} {tmpfile}')
shutil.move(tmpfile.with_suffix(f'{tmpfile.suffix}{ext[compress]}'), tmpfile)
spec = Spectrum1D.read(tmpfile)

assert isinstance(spec, Spectrum1D)
Expand Down