From 7097d8afdda973ad9b61e662c99d360b33f9bc67 Mon Sep 17 00:00:00 2001 From: "P. L. Lim" <2090236+pllim@users.noreply.github.com> Date: Fri, 11 Oct 2024 18:15:01 -0400 Subject: [PATCH] FEAT: with_spectral_axis_and_flux_units --- CHANGES.rst | 3 + specutils/spectra/spectrum_mixin.py | 71 ++++++++------ specutils/tests/test_spectrum1d.py | 137 +++++++++++++++------------- 3 files changed, 121 insertions(+), 90 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 7b0dd3146..d2a0ebb60 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,6 +4,9 @@ New Features ^^^^^^^^^^^^ +- New ``Spectrum1D.with_spectral_axis_and_flux_units`` method to convert both + spectral axis and flux units at the same time. [#1184] + Bug Fixes ^^^^^^^^^ diff --git a/specutils/spectra/spectrum_mixin.py b/specutils/spectra/spectrum_mixin.py index 6479c0df5..06fa32123 100644 --- a/specutils/spectra/spectrum_mixin.py +++ b/specutils/spectra/spectrum_mixin.py @@ -86,6 +86,26 @@ def new_flux_unit(self, unit, equivalencies=None, suppress_conversion=False): return self.with_flux_unit(unit, equivalencies=equivalencies, suppress_conversion=suppress_conversion) + def _convert_flux(self, unit, equivalencies=None, suppress_conversion=False): + """This is always done in-place. + Also see :meth:`with_flux_unit`.""" + + if not suppress_conversion: + if equivalencies is None: + equivalencies = eq.spectral_density(self.spectral_axis) + + new_data = self.flux.to(unit, equivalencies=equivalencies) + + self._data = new_data.value + self._unit = new_data.unit + else: + self._unit = u.Unit(unit) + + if self.uncertainty is not None: + self.uncertainty = StdDevUncertainty( + self.uncertainty.represent_as(StdDevUncertainty).quantity.to( + unit, equivalencies=equivalencies)) + def with_flux_unit(self, unit, equivalencies=None, suppress_conversion=False): """Returns a new spectrum with a different flux unit. If uncertainty is defined, it will be converted to @@ -107,30 +127,14 @@ def with_flux_unit(self, unit, equivalencies=None, suppress_conversion=False): Returns ------- - `~specutils.Spectrum1D` + new_spec : `~specutils.Spectrum1D` A new spectrum with the converted flux array (and uncertainty, if applicable). """ new_spec = deepcopy(self) - - if not suppress_conversion: - if equivalencies is None: - equivalencies = eq.spectral_density(self.spectral_axis) - - new_data = self.flux.to( - unit, equivalencies=equivalencies) - - new_spec._data = new_data.value - new_spec._unit = new_data.unit - else: - new_spec._unit = u.Unit(unit) - - if self.uncertainty is not None: - new_spec.uncertainty = StdDevUncertainty( - self.uncertainty.represent_as(StdDevUncertainty).quantity.to( - unit, equivalencies=equivalencies)) - + new_spec._convert_flux( + unit, equivalencies=equivalencies, suppress_conversion=suppress_conversion) return new_spec @property @@ -175,7 +179,7 @@ def velocity(self): Returns ------- - ~`astropy.units.Quantity` + new_data : `~astropy.units.Quantity` The converted dispersion array in the new dispersion space. """ if self.rest_value is None: @@ -202,8 +206,7 @@ def with_spectral_unit(self, unit, velocity_convention=None, self.with_spectral_axis_unit(unit, velocity_convention=velocity_convention, rest_value=rest_value) - def with_spectral_axis_unit(self, unit, velocity_convention=None, - rest_value=None): + def with_spectral_axis_unit(self, unit, velocity_convention=None, rest_value=None): """ Returns a new spectrum with a different spectral axis unit. Note that this creates a new object using the converted spectral axis and thus drops the original WCS, if it existed, @@ -230,11 +233,9 @@ def with_spectral_axis_unit(self, unit, velocity_convention=None, even if your spectrum has air wavelength units """ - - velocity_convention = velocity_convention if velocity_convention is not None else self.velocity_convention # noqa + velocity_convention = velocity_convention if velocity_convention is not None else self.velocity_convention # noqa rest_value = rest_value if rest_value is not None else self.rest_value - unit = self._new_wcs_argument_validation(unit, velocity_convention, - rest_value) + unit = self._new_wcs_argument_validation(unit, velocity_convention, rest_value) # Store the original unit information and WCS for posterity meta = deepcopy(self._meta) @@ -252,6 +253,24 @@ def with_spectral_axis_unit(self, unit, velocity_convention=None, return self.__class__(flux=self.flux, spectral_axis=new_spectral_axis, meta=meta, uncertainty=self.uncertainty, mask=self.mask) + def with_spectral_axis_and_flux_units(self, spectral_axis_unit, flux_unit, + velocity_convention=None, rest_value=None, + flux_equivalencies=None, suppress_flux_conversion=False): + """Perform :meth:`with_spectral_axis_unit` and :meth:`with_flux_unit` together. + See the respective methods for input and output definitions. + + Returns + ------- + new_spec : `~specutils.Spectrum1D` + Spectrum in requested units. + + """ + new_spec = self.with_spectral_axis_unit( + spectral_axis_unit, velocity_convention=velocity_convention, rest_value=rest_value) + new_spec._convert_flux( + flux_unit, equivalencies=flux_equivalencies, suppress_conversion=suppress_flux_conversion) + return new_spec + def _new_wcs_argument_validation(self, unit, velocity_convention, rest_value): # Allow string specification of units, for example diff --git a/specutils/tests/test_spectrum1d.py b/specutils/tests/test_spectrum1d.py index 7b8e1aac4..39ea23278 100644 --- a/specutils/tests/test_spectrum1d.py +++ b/specutils/tests/test_spectrum1d.py @@ -25,7 +25,7 @@ def test_empty_spectrum(): def test_create_from_arrays(): spec = Spectrum1D(spectral_axis=np.arange(50) * u.AA, - flux=np.random.randn(50) * u.Jy) + flux=np.ones(50) * u.Jy) assert isinstance(spec.spectral_axis, SpectralCoord) assert spec.spectral_axis.size == 50 @@ -36,7 +36,7 @@ def test_create_from_arrays(): # Test creating spectrum with unknown arguments with pytest.raises(ValueError): spec = Spectrum1D(wavelength=np.arange(1, 50) * u.nm, - flux=np.random.randn(48) * u.Jy) + flux=np.ones(48) * u.Jy) def test_create_from_multidimensional_arrays(): @@ -47,7 +47,7 @@ def test_create_from_multidimensional_arrays(): """ freqs = np.arange(50) * u.GHz - flux = np.random.random((5, len(freqs))) * u.Jy + flux = np.ones((5, len(freqs))) * u.Jy spec = Spectrum1D(spectral_axis=freqs, flux=flux) assert (spec.frequency == freqs).all() @@ -55,15 +55,15 @@ def test_create_from_multidimensional_arrays(): # Mis-matched lengths should raise an exception (unless freqs is one longer # than flux, in which case it's interpreted as bin edges) - freqs = np.arange(50) * u.GHz - flux = np.random.random((5, len(freqs)-10)) * u.Jy + flux = np.ones((5, len(freqs) - 10)) * u.Jy with pytest.raises(ValueError): spec = Spectrum1D(spectral_axis=freqs, flux=flux) def test_create_from_quantities(): - spec = Spectrum1D(spectral_axis=np.arange(1, 50) * u.nm, - flux=np.random.randn(49) * u.Jy) + wav = np.arange(1, 50) * u.nm + flux = np.ones(49) * u.Jy + spec = Spectrum1D(spectral_axis=wav, flux=flux) assert isinstance(spec.spectral_axis, SpectralCoord) assert spec.spectral_axis.unit == u.nm @@ -72,13 +72,12 @@ def test_create_from_quantities(): # Mis-matched lengths should raise an exception (unless freqs is one longer # than flux, in which case it's interpreted as bin edges) with pytest.raises(ValueError): - spec = Spectrum1D(spectral_axis=np.arange(1, 50) * u.nm, - flux=np.random.randn(47) * u.Jy) + spec = Spectrum1D(spectral_axis=wav, flux=np.ones(47) * u.Jy) def test_create_implicit_wcs(): spec = Spectrum1D(spectral_axis=np.arange(50) * u.AA, - flux=np.random.randn(50) * u.Jy) + flux=np.ones(50) * u.Jy) assert isinstance(spec.wcs, gwcs.wcs.WCS) @@ -90,7 +89,7 @@ def test_create_implicit_wcs(): def test_create_implicit_wcs_with_spectral_unit(): spec = Spectrum1D(spectral_axis=np.arange(1, 50) * u.nm, - flux=np.random.randn(49) * u.Jy) + flux=np.ones(49) * u.Jy) assert isinstance(spec.wcs, gwcs.wcs.WCS) @@ -102,8 +101,8 @@ def test_create_implicit_wcs_with_spectral_unit(): def test_create_with_spectral_coord(): - spectral_coord = SpectralCoord(np.arange(5100, 5150)*u.AA, radial_velocity=u.Quantity(1000.0, "km/s")) - flux = np.random.randn(50)*u.Jy + spectral_coord = SpectralCoord(np.arange(5100, 5150) * u.AA, radial_velocity=u.Quantity(1000.0, "km/s")) + flux = np.ones(50) * u.Jy spec = Spectrum1D(spectral_axis=spectral_coord, flux=flux) assert spec.radial_velocity == u.Quantity(1000.0, "km/s") @@ -137,26 +136,24 @@ def test_spectral_axis_conversions(): assert np.all(spec.spectral_axis == np.array([400, 500]) * u.angstrom) assert spec.spectral_axis.unit == u.angstrom - spec = Spectrum1D(spectral_axis=np.arange(50) * u.AA, - flux=np.random.randn(50) * u.Jy) + flux = np.ones(49) * u.Jy + spec = Spectrum1D(spectral_axis=np.arange(flux.size) * u.AA, flux=flux) assert spec.wavelength.unit == u.AA - spec = Spectrum1D(spectral_axis=np.arange(1, 50) * u.nm, - flux=np.random.randn(49) * u.Jy) + spec = Spectrum1D(spectral_axis=np.arange(1, 50) * u.nm, flux=flux) assert spec.frequency.unit == u.GHz with pytest.raises(ValueError): spec.velocity - spec = Spectrum1D(spectral_axis=np.arange(100, 150) * u.nm, - flux=np.random.randn(49) * u.Jy) + spec = Spectrum1D(spectral_axis=np.arange(100, 150) * u.nm, flux=flux) - new_spec = spec.with_spectral_axis_unit(u.km/u.s, rest_value=125*u.um, - velocity_convention="relativistic") + new_spec = spec.with_spectral_axis_unit(u.km / u.s, rest_value=125 * u.um, + velocity_convention="relativistic") - assert new_spec.spectral_axis.unit == u.km/u.s + assert new_spec.spectral_axis.unit == u.km / u.s assert new_spec.wcs.world_axis_units[0] == "km.s**-1" # Make sure meta stored the old WCS correctly assert new_spec.meta["original_wcs"].world_axis_units[0] == "nm" @@ -164,10 +161,10 @@ def test_spectral_axis_conversions(): wcs_dict = {"CTYPE1": "WAVE", "CRVAL1": 3.622e3, "CDELT1": 8e-2, "CRPIX1": 0, "CUNIT1": "Angstrom"} - wcs_spec = Spectrum1D(flux=np.random.randn(49) * u.Jy, wcs=WCS(wcs_dict), + wcs_spec = Spectrum1D(flux=flux, wcs=WCS(wcs_dict), meta={'header': wcs_dict.copy()}) - new_spec = wcs_spec.with_spectral_axis_unit(u.km/u.s, rest_value=125*u.um, - velocity_convention="relativistic") + new_spec = wcs_spec.with_spectral_axis_unit(u.km / u.s, rest_value=125 * u.um, + velocity_convention="relativistic") new_spec.meta['original_wcs'].wcs.crval = [3.777e-7] new_spec.meta['header']['CRVAL1'] = 3777.0 @@ -175,9 +172,26 @@ def test_spectral_axis_conversions(): assert wcs_spec.meta['header']['CRVAL1'] == 3622. +def test_spectral_axis_and_flux_conversions(): + """A little bit from both sets of tests.""" + spec = Spectrum1D(spectral_axis=np.arange(100, 150) * u.nm, + flux=np.ones(49) * u.Jy) + + new_spec = spec.with_spectral_axis_and_flux_units( + u.km / u.s, u.uJy, rest_value=125 * u.um, velocity_convention="relativistic") + + assert new_spec.spectral_axis.unit == u.km/u.s + assert new_spec.wcs.world_axis_units[0] == "km.s**-1" + # Make sure meta stored the old WCS correctly + assert new_spec.meta["original_wcs"].world_axis_units[0] == "nm" + assert new_spec.meta["original_spectral_axis_unit"] == "nm" + assert new_spec.flux.unit == u.uJy + assert_allclose(new_spec.flux.value, 1000000) + + def test_spectral_slice(): spec = Spectrum1D(spectral_axis=np.linspace(100, 1000, 10) * u.nm, - flux=np.random.random(10) * u.Jy) + flux=np.ones(10) * u.Jy) sliced_spec = spec[300*u.nm:600*u.nm] assert np.all(sliced_spec.spectral_axis == [300, 400, 500] * u.nm) @@ -192,7 +206,7 @@ def test_spectral_slice(): # Test higher dimensional slicing spec = Spectrum1D(spectral_axis=np.linspace(100, 1000, 10) * u.nm, - flux=np.random.random((10, 10)) * u.Jy) + flux=np.ones((10, 10)) * u.Jy) sliced_spec = spec[300*u.nm:600*u.nm] assert np.all(sliced_spec.spectral_axis == [300, 400, 500] * u.nm) @@ -302,7 +316,7 @@ def test_flux_unit_conversion(): def test_wcs_transformations(): # Test with a GWCS spec = Spectrum1D(spectral_axis=np.arange(1, 50) * u.nm, - flux=np.random.randn(49) * u.Jy) + flux=np.ones(49) * u.Jy) pix_axis = spec.wcs.world_to_pixel(np.arange(20, 30) * u.nm) disp_axis = spec.wcs.pixel_to_world(np.arange(20, 30)) @@ -359,24 +373,19 @@ def test_create_explicit_fitswcs(): def test_create_with_uncertainty(): spec = Spectrum1D(spectral_axis=np.arange(1, 50) * u.nm, - flux=np.random.sample(49) * u.Jy, - uncertainty=StdDevUncertainty(np.random.sample(49) * 0.1)) + flux=np.ones(49) * u.Jy, + uncertainty=StdDevUncertainty(np.ones(49) * 0.1)) assert isinstance(spec.uncertainty, StdDevUncertainty) - - spec = Spectrum1D(spectral_axis=np.arange(1, 50) * u.nm, - flux=np.random.sample(49) * u.Jy, - uncertainty=StdDevUncertainty(np.random.sample(49) * 0.1)) - assert spec.flux.unit == spec.uncertainty.unit # If flux and uncertainty are different sizes then raise exception - wavelengths = np.arange(0, 10) - flux=100*np.abs(np.random.randn(3, 4, 10))*u.Jy - uncertainty = StdDevUncertainty(np.abs(np.random.randn(3, 2, 10))*u.Jy) + wavelengths = np.arange(10) * u.um + flux= np.ones((3, 4, 10)) * u.Jy + uncertainty = StdDevUncertainty(np.ones((3, 2, 10)) * u.Jy) with pytest.raises(ValueError): - Spectrum1D(spectral_axis=wavelengths*u.um, flux=flux, uncertainty=uncertainty) + Spectrum1D(spectral_axis=wavelengths, flux=flux, uncertainty=uncertainty) @pytest.mark.parametrize("flux_unit", ["adu", "ct/s", "count"]) @@ -407,7 +416,7 @@ def test_read_linear_solution(remote_data_path): def test_energy_photon_flux(): spec = Spectrum1D(spectral_axis=np.linspace(100, 1000, 10) * u.nm, - flux=np.random.randn(10)*u.Jy) + flux=np.ones(10) * u.Jy) assert spec.energy.size == 10 assert spec.photon_flux.size == 10 assert spec.photon_flux.unit == u.photon * u.cm**-2 * u.s**-1 * u.nm**-1 @@ -415,7 +424,7 @@ def test_energy_photon_flux(): def test_flux_nans_propagate_to_mask(): """Check that indices in input flux with NaNs get propagated to the mask""" - flux = np.random.randn(10) + flux = np.ones(10) nan_idx = [0, 3, 5] flux[nan_idx] = np.nan spec = Spectrum1D(spectral_axis=np.linspace(100, 1000, 10) * u.nm, @@ -424,16 +433,15 @@ def test_flux_nans_propagate_to_mask(): def test_repr(): - spec_with_wcs = Spectrum1D(spectral_axis=np.linspace(100, 1000, 10) * u.nm, - flux=np.random.random(10) * u.Jy) + wav = np.linspace(100, 1000, 10) * u.nm + flux = np.ones(10) * u.Jy + spec_with_wcs = Spectrum1D(spectral_axis=wav, flux=flux) result = repr(spec_with_wcs) assert result.startswith('