Skip to content

Commit

Permalink
Fix fit_lines with single quantity as window value
Browse files Browse the repository at this point in the history
  • Loading branch information
rosteen committed Sep 4, 2024
1 parent cadfe73 commit 812dcac
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 32 deletions.
26 changes: 16 additions & 10 deletions specutils/fitting/fitmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def find_lines_derivative(spectrum, flux_threshold=None):
Parameters
----------
spectrum : Spectrum1D
spectrum : `~specutils.Spectrum1D`
The spectrum object over which the equivalent width will be calculated.
flux_threshold : float, `~astropy.units.Quantity` or None
The threshold a pixel must be above to be considered part of a line. If
Expand Down Expand Up @@ -267,7 +267,7 @@ def fit_lines(spectrum, model, fitter=fitting.LevMarLSQFitter(calc_uncertainties
Parameters
----------
spectrum : Spectrum1D
spectrum : `~specutils.Spectrum1D`
The spectrum object over which the equivalent width will be calculated.
model: `~astropy.modeling.Model` or list of `~astropy.modeling.Model`
The model or list of models that contain the initial guess.
Expand All @@ -281,9 +281,10 @@ def fit_lines(spectrum, model, fitter=fitting.LevMarLSQFitter(calc_uncertainties
use in the fitting. Note that if a mask is present on the spectrum, it
will be applied to the ``weights`` as it would be to the spectrum
itself.
window : `~specutils.SpectralRegion` or list of `~specutils.SpectralRegion`
window : `~specutils.SpectralRegion`, `~astropy.units.Quantity`, or list of either
Regions of the spectrum to use in the fitting. If None, then the
whole spectrum will be used in the fitting.\
whole spectrum will be used in the fitting. If a single `~astropy.units.Quantity`
is input, it will be used as the width of the region around the model mean.
get_fit_info : bool, optional
Flag to return the ``fit_info`` from the underlying scipy optimizer used
in the fitting. If True, the returned model will have a ``fit_info``
Expand Down Expand Up @@ -430,16 +431,21 @@ def _fit_lines(spectrum, model, fitter=fitting.LevMarLSQFitter(calc_uncertaintie

# In this case the window defines the area around the center of each model
window_indices = None
if window is not None and isinstance(window, (float, int)):
center = model.mean
window_indices = np.nonzero((dispersion >= center-window) &
(dispersion < center+window))
if window is not None and isinstance(window, u.Quantity):
print("Got quantity window")
if isinstance(window.value, (float, int)):
center = model.mean
window_indices = np.nonzero((dispersion >= center-window) &
(dispersion < center+window))
elif len(window.value) == 2:
window_indices = np.nonzero((dispersion >= window.min()) &
(dispersion <= window.max()))

# In this case the window is the start and end points of where we
# should fit
elif window is not None and isinstance(window, tuple):
window_indices = np.nonzero((dispersion >= window[0]) &
(dispersion <= window[1]))
window_indices = np.nonzero((dispersion >= min(window)) &
(dispersion <= max(window)))

# in this case the window is spectral regions that determine where
# to fit.
Expand Down
44 changes: 22 additions & 22 deletions specutils/tests/test_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def test_single_peak_fit_window():
Single Peak fit with a window specified
"""

# Create the sepctrum
# Create the spectrum
x_single, y_single = single_peak()
s_single = Spectrum1D(flux=y_single*u.Jy, spectral_axis=x_single*u.um)

Expand All @@ -262,11 +262,11 @@ def test_single_peak_fit_window():
y_single_fit = g_fit(x_single*u.um)

# Comparing every 10th value.
y_single_fit_expected = np.array([3.69669474e-13, 3.57992454e-11, 2.36719426e-09, 1.06879318e-07,
3.29498310e-06, 6.93605383e-05, 9.96945607e-04, 9.78431032e-03,
6.55675141e-02, 3.00017760e-01, 9.37356842e-01, 1.99969007e+00,
2.91286375e+00, 2.89719280e+00, 1.96758892e+00, 9.12412206e-01,
2.88900005e-01, 6.24602556e-02, 9.22061121e-03, 9.29427266e-04]) * u.Jy
y_single_fit_expected = np.array([5.37863281e-13, 4.90559594e-11, 3.07085351e-09, 1.31939209e-07,
3.89077487e-06, 7.87491324e-05, 1.09396368e-03, 1.04305565e-02,
6.82589342e-02, 3.06590700e-01, 9.45161235e-01, 1.99986313e+00,
2.90430417e+00, 2.89488537e+00, 1.98046921e+00, 9.29934299e-01,
2.99698030e-01, 6.62922807e-02, 1.00644369e-02, 1.04872942e-03]) * u.Jy

assert np.allclose(y_single_fit.value[::10], y_single_fit_expected.value, atol=1e-5)

Expand Down Expand Up @@ -359,11 +359,11 @@ def test_double_peak_fit_window():
y2_double_fit = g2_fit(x_double*u.um)

# Comparing every 10th value.
y2_double_fit_expected = np.array([1.66363393e-128, 5.28910721e-102, 1.40949521e-078, 3.14848385e-058,
5.89516506e-041, 9.25224449e-027, 1.21718016e-015, 1.34220626e-007,
1.24062432e-002, 9.61209273e-001, 6.24240938e-002, 3.39815491e-006,
1.55056770e-013, 5.93054936e-024, 1.90132233e-037, 5.10943886e-054,
1.15092572e-073, 2.17309153e-096, 3.43926290e-122, 4.56256813e-151])
y2_double_fit_expected = np.array([1.19027928e-116, 1.26897102e-092, 2.20632873e-071, 6.25611070e-053,
2.89303919e-037, 2.18182491e-024, 2.68349729e-014, 5.38267443e-007,
1.76080293e-002, 9.39375172e-001, 8.17302995e-002, 1.15969234e-005,
2.68360243e-012, 1.01276632e-021, 6.23327102e-034, 6.25660097e-049,
1.02418071e-066, 2.73420084e-087, 1.19041919e-110, 8.45249876e-137])

assert np.allclose(y2_double_fit.value[::10], y2_double_fit_expected, atol=1e-5)

Expand All @@ -385,21 +385,21 @@ def test_double_peak_fit_separate_window():
yr_double_fit = gr_fit(x_double*u.um)

# Comparing every 10th value.
yl_double_fit_expected = np.array([3.40725147e-18, 5.05500395e-15, 3.59471319e-12, 1.22527176e-09,
2.00182467e-07, 1.56763547e-05, 5.88422893e-04, 1.05866724e-02,
9.12966452e-02, 3.77377148e-01, 7.47690410e-01, 7.10057397e-01,
3.23214276e-01, 7.05201207e-02, 7.37498248e-03, 3.69687164e-04,
8.88245844e-06, 1.02295712e-07, 5.64686114e-10, 1.49410879e-12])
yl_double_fit_expected = np.array([1.17816675e-82, 1.35059822e-65, 1.58326993e-50, 1.89798658e-37,
2.32670163e-26, 2.91673945e-17, 3.73907313e-10, 4.90162030e-05,
6.57089846e-02, 9.00781025e-01, 1.26276660e-01, 1.81024070e-04,
2.65374353e-09, 3.97823976e-16, 6.09863080e-25, 9.56055524e-36,
1.53265099e-48, 2.51253899e-63, 4.21203267e-80, 7.22071220e-99])

assert np.allclose(yl_double_fit.value[::10], yl_double_fit_expected, atol=1e-5)

# Comparing every 10th value.
yr_double_fit_expected = np.array([0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 3.04416285e-259,
3.85323221e-198, 2.98888589e-145, 1.42075875e-100, 4.13864520e-064,
7.38793226e-036, 8.08191847e-016, 5.41792361e-004, 2.22575901e+000,
5.60338234e-005, 8.64468603e-018, 8.17287853e-039, 4.73508430e-068,
1.68115300e-105, 3.65774659e-151, 4.87693358e-205, 3.98480359e-267])

yr_double_fit_expected = np.array([0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 7.01997018e-269,
1.34316521e-205, 8.31931099e-151, 1.66804993e-104, 1.08266599e-066,
2.27480254e-037, 1.54723641e-016, 3.40669622e-004, 2.42814100e+000,
5.60245215e-005, 4.18452504e-018, 1.01176103e-039, 7.91905729e-070,
2.00647052e-108, 1.64571961e-155, 4.36960988e-211, 3.75572034e-275])
print(yr_double_fit.value[::10])
assert np.allclose(yr_double_fit.value[::10], yr_double_fit_expected, atol=1e-5)


Expand Down

0 comments on commit 812dcac

Please sign in to comment.