From 7b0d8c7fd16989c0fae7f5e353eede49e0817589 Mon Sep 17 00:00:00 2001 From: Ricky O'Steen Date: Thu, 8 Aug 2024 11:23:28 -0400 Subject: [PATCH] Ensure arithmetic results preserve spectral axis shifts --- specutils/spectra/spectrum1d.py | 14 +++++++++----- specutils/tests/test_arithmetic.py | 9 +++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/specutils/spectra/spectrum1d.py b/specutils/spectra/spectrum1d.py index 7702b58f5..5942a139f 100644 --- a/specutils/spectra/spectrum1d.py +++ b/specutils/spectra/spectrum1d.py @@ -680,6 +680,10 @@ def redshift(self, val): def radial_velocity(self, val): self.shift_spectrum_to(radial_velocity=val) + def _return_with_redshift(self, result): + result.shift_spectrum_to(redshift=self.redshift) + return result + def __add__(self, other): if not isinstance(other, (NDCube, u.Quantity)): try: @@ -687,7 +691,7 @@ def __add__(self, other): except TypeError: return NotImplemented - return self.add(other) + return self._return_with_redshift(self.add(other)) def __sub__(self, other): if not isinstance(other, NDCube): @@ -696,25 +700,25 @@ def __sub__(self, other): except TypeError: return NotImplemented - return self.subtract(other) + return self._return_with_redshift(self.subtract(other)) def __mul__(self, other): if not isinstance(other, NDCube): other = u.Quantity(other) - return self.multiply(other) + return self._return_with_redshift(self.multiply(other)) def __div__(self, other): if not isinstance(other, NDCube): other = u.Quantity(other) - return self.divide(other) + return self._return_with_redshift(self.divide(other)) def __truediv__(self, other): if not isinstance(other, NDCube): other = u.Quantity(other) - return self.divide(other) + return self._return_with_redshift(self.divide(other)) __radd__ = __add__ diff --git a/specutils/tests/test_arithmetic.py b/specutils/tests/test_arithmetic.py index b432609b7..4ac9c5933 100644 --- a/specutils/tests/test_arithmetic.py +++ b/specutils/tests/test_arithmetic.py @@ -131,3 +131,12 @@ def test_with_constants(simulated_spectra): r_sub_result = 2 - spec l_sub_result = -1 * (spec - 2) assert_quantity_allclose(r_sub_result.flux, l_sub_result.flux) + + +def test_arithmetic_after_shift(simulated_spectra): + spec = simulated_spectra.s1_um_mJy_e1 + spec.shift_spectrum_to(redshift = 1) + + # Test that doing arithmetic preserves the shifted spectral axis + spec *= 2 + assert_quantity_allclose(spec.spectral_axis, 2*np.linspace(0.4, 1.05, 100)*u.um) \ No newline at end of file