diff --git a/specutils/fitting/fitmodels.py b/specutils/fitting/fitmodels.py index 66f19de49..c85cf77ba 100644 --- a/specutils/fitting/fitmodels.py +++ b/specutils/fitting/fitmodels.py @@ -169,7 +169,7 @@ def find_lines_derivative(spectrum, flux_threshold=None): Parameters ---------- - spectrum : Spectrum1D + spectrum : `~specutils.Spectrum1D` The spectrum object over which the equivalent width will be calculated. flux_threshold : float, `~astropy.units.Quantity` or None The threshold a pixel must be above to be considered part of a line. If @@ -267,7 +267,7 @@ def fit_lines(spectrum, model, fitter=fitting.LevMarLSQFitter(calc_uncertainties Parameters ---------- - spectrum : Spectrum1D + spectrum : `~specutils.Spectrum1D` The spectrum object over which the equivalent width will be calculated. model: `~astropy.modeling.Model` or list of `~astropy.modeling.Model` The model or list of models that contain the initial guess. @@ -281,9 +281,10 @@ def fit_lines(spectrum, model, fitter=fitting.LevMarLSQFitter(calc_uncertainties use in the fitting. Note that if a mask is present on the spectrum, it will be applied to the ``weights`` as it would be to the spectrum itself. - window : `~specutils.SpectralRegion` or list of `~specutils.SpectralRegion` + window : `~specutils.SpectralRegion`, `~astropy.units.Quantity`, or list of either Regions of the spectrum to use in the fitting. If None, then the - whole spectrum will be used in the fitting.\ + whole spectrum will be used in the fitting. If a single `~astropy.units.Quantity` + is input, it will be used as the width of the region around the model mean. get_fit_info : bool, optional Flag to return the ``fit_info`` from the underlying scipy optimizer used in the fitting. If True, the returned model will have a ``fit_info`` @@ -430,16 +431,21 @@ def _fit_lines(spectrum, model, fitter=fitting.LevMarLSQFitter(calc_uncertaintie # In this case the window defines the area around the center of each model window_indices = None - if window is not None and isinstance(window, (float, int)): - center = model.mean - window_indices = np.nonzero((dispersion >= center-window) & - (dispersion < center+window)) + if window is not None and isinstance(window, u.Quantity): + print("Got quantity window") + if isinstance(window.value, (float, int)): + center = model.mean + window_indices = np.nonzero((dispersion >= center-window) & + (dispersion < center+window)) + elif len(window.value) == 2: + window_indices = np.nonzero((dispersion >= window.min()) & + (dispersion <= window.max())) # In this case the window is the start and end points of where we # should fit elif window is not None and isinstance(window, tuple): - window_indices = np.nonzero((dispersion >= window[0]) & - (dispersion <= window[1])) + window_indices = np.nonzero((dispersion >= min(window)) & + (dispersion <= max(window))) # in this case the window is spectral regions that determine where # to fit. diff --git a/specutils/tests/test_fitting.py b/specutils/tests/test_fitting.py index 329863297..c79e93b2a 100644 --- a/specutils/tests/test_fitting.py +++ b/specutils/tests/test_fitting.py @@ -252,7 +252,7 @@ def test_single_peak_fit_window(): Single Peak fit with a window specified """ - # Create the sepctrum + # Create the spectrum x_single, y_single = single_peak() s_single = Spectrum1D(flux=y_single*u.Jy, spectral_axis=x_single*u.um) @@ -262,11 +262,11 @@ def test_single_peak_fit_window(): y_single_fit = g_fit(x_single*u.um) # Comparing every 10th value. - y_single_fit_expected = np.array([3.69669474e-13, 3.57992454e-11, 2.36719426e-09, 1.06879318e-07, - 3.29498310e-06, 6.93605383e-05, 9.96945607e-04, 9.78431032e-03, - 6.55675141e-02, 3.00017760e-01, 9.37356842e-01, 1.99969007e+00, - 2.91286375e+00, 2.89719280e+00, 1.96758892e+00, 9.12412206e-01, - 2.88900005e-01, 6.24602556e-02, 9.22061121e-03, 9.29427266e-04]) * u.Jy + y_single_fit_expected = np.array([5.37863281e-13, 4.90559594e-11, 3.07085351e-09, 1.31939209e-07, + 3.89077487e-06, 7.87491324e-05, 1.09396368e-03, 1.04305565e-02, + 6.82589342e-02, 3.06590700e-01, 9.45161235e-01, 1.99986313e+00, + 2.90430417e+00, 2.89488537e+00, 1.98046921e+00, 9.29934299e-01, + 2.99698030e-01, 6.62922807e-02, 1.00644369e-02, 1.04872942e-03]) * u.Jy assert np.allclose(y_single_fit.value[::10], y_single_fit_expected.value, atol=1e-5) @@ -359,11 +359,11 @@ def test_double_peak_fit_window(): y2_double_fit = g2_fit(x_double*u.um) # Comparing every 10th value. - y2_double_fit_expected = np.array([1.66363393e-128, 5.28910721e-102, 1.40949521e-078, 3.14848385e-058, - 5.89516506e-041, 9.25224449e-027, 1.21718016e-015, 1.34220626e-007, - 1.24062432e-002, 9.61209273e-001, 6.24240938e-002, 3.39815491e-006, - 1.55056770e-013, 5.93054936e-024, 1.90132233e-037, 5.10943886e-054, - 1.15092572e-073, 2.17309153e-096, 3.43926290e-122, 4.56256813e-151]) + y2_double_fit_expected = np.array([1.19027928e-116, 1.26897102e-092, 2.20632873e-071, 6.25611070e-053, + 2.89303919e-037, 2.18182491e-024, 2.68349729e-014, 5.38267443e-007, + 1.76080293e-002, 9.39375172e-001, 8.17302995e-002, 1.15969234e-005, + 2.68360243e-012, 1.01276632e-021, 6.23327102e-034, 6.25660097e-049, + 1.02418071e-066, 2.73420084e-087, 1.19041919e-110, 8.45249876e-137]) assert np.allclose(y2_double_fit.value[::10], y2_double_fit_expected, atol=1e-5) @@ -385,21 +385,21 @@ def test_double_peak_fit_separate_window(): yr_double_fit = gr_fit(x_double*u.um) # Comparing every 10th value. - yl_double_fit_expected = np.array([3.40725147e-18, 5.05500395e-15, 3.59471319e-12, 1.22527176e-09, - 2.00182467e-07, 1.56763547e-05, 5.88422893e-04, 1.05866724e-02, - 9.12966452e-02, 3.77377148e-01, 7.47690410e-01, 7.10057397e-01, - 3.23214276e-01, 7.05201207e-02, 7.37498248e-03, 3.69687164e-04, - 8.88245844e-06, 1.02295712e-07, 5.64686114e-10, 1.49410879e-12]) + yl_double_fit_expected = np.array([1.17816675e-82, 1.35059822e-65, 1.58326993e-50, 1.89798658e-37, + 2.32670163e-26, 2.91673945e-17, 3.73907313e-10, 4.90162030e-05, + 6.57089846e-02, 9.00781025e-01, 1.26276660e-01, 1.81024070e-04, + 2.65374353e-09, 3.97823976e-16, 6.09863080e-25, 9.56055524e-36, + 1.53265099e-48, 2.51253899e-63, 4.21203267e-80, 7.22071220e-99]) assert np.allclose(yl_double_fit.value[::10], yl_double_fit_expected, atol=1e-5) # Comparing every 10th value. - yr_double_fit_expected = np.array([0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 3.04416285e-259, - 3.85323221e-198, 2.98888589e-145, 1.42075875e-100, 4.13864520e-064, - 7.38793226e-036, 8.08191847e-016, 5.41792361e-004, 2.22575901e+000, - 5.60338234e-005, 8.64468603e-018, 8.17287853e-039, 4.73508430e-068, - 1.68115300e-105, 3.65774659e-151, 4.87693358e-205, 3.98480359e-267]) - + yr_double_fit_expected = np.array([0.00000000e+000, 0.00000000e+000, 0.00000000e+000, 7.01997018e-269, + 1.34316521e-205, 8.31931099e-151, 1.66804993e-104, 1.08266599e-066, + 2.27480254e-037, 1.54723641e-016, 3.40669622e-004, 2.42814100e+000, + 5.60245215e-005, 4.18452504e-018, 1.01176103e-039, 7.91905729e-070, + 2.00647052e-108, 1.64571961e-155, 4.36960988e-211, 3.75572034e-275]) + print(yr_double_fit.value[::10]) assert np.allclose(yr_double_fit.value[::10], yr_double_fit_expected, atol=1e-5)