Skip to content

Commit

Permalink
Ensure arithmetic results preserve spectral axis shifts
Browse files Browse the repository at this point in the history
  • Loading branch information
rosteen committed Aug 8, 2024
1 parent c006678 commit 7b0d8c7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
14 changes: 9 additions & 5 deletions specutils/spectra/spectrum1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,14 +680,18 @@ def redshift(self, val):
def radial_velocity(self, val):
self.shift_spectrum_to(radial_velocity=val)

def _return_with_redshift(self, result):
result.shift_spectrum_to(redshift=self.redshift)
return result

def __add__(self, other):
if not isinstance(other, (NDCube, u.Quantity)):
try:
other = u.Quantity(other, unit=self.unit)
except TypeError:
return NotImplemented

return self.add(other)
return self._return_with_redshift(self.add(other))

def __sub__(self, other):
if not isinstance(other, NDCube):
Expand All @@ -696,25 +700,25 @@ def __sub__(self, other):
except TypeError:
return NotImplemented

return self.subtract(other)
return self._return_with_redshift(self.subtract(other))

def __mul__(self, other):
if not isinstance(other, NDCube):
other = u.Quantity(other)

return self.multiply(other)
return self._return_with_redshift(self.multiply(other))

def __div__(self, other):
if not isinstance(other, NDCube):
other = u.Quantity(other)

return self.divide(other)
return self._return_with_redshift(self.divide(other))

def __truediv__(self, other):
if not isinstance(other, NDCube):
other = u.Quantity(other)

return self.divide(other)
return self._return_with_redshift(self.divide(other))

__radd__ = __add__

Expand Down
9 changes: 9 additions & 0 deletions specutils/tests/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,12 @@ def test_with_constants(simulated_spectra):
r_sub_result = 2 - spec
l_sub_result = -1 * (spec - 2)
assert_quantity_allclose(r_sub_result.flux, l_sub_result.flux)


def test_arithmetic_after_shift(simulated_spectra):
spec = simulated_spectra.s1_um_mJy_e1
spec.shift_spectrum_to(redshift = 1)

# Test that doing arithmetic preserves the shifted spectral axis
spec *= 2
assert_quantity_allclose(spec.spectral_axis, 2*np.linspace(0.4, 1.05, 100)*u.um)

0 comments on commit 7b0d8c7

Please sign in to comment.