Skip to content

Commit

Permalink
fix: explicit Exception for combined fit constant edge case. (#202)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjosw committed Jul 14, 2023
1 parent 6dcd0c3 commit 5c2a6de
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
6 changes: 6 additions & 0 deletions pyerrors/fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,12 @@ def func_b(a, x):
n_parms_ls.append(n_loc)

n_parms = max(n_parms_ls)

if len(key_ls) > 1:
for key in key_ls:
if np.asarray(yd[key]).shape != funcd[key](np.arange(n_parms), xd[key]).shape:
raise ValueError(f"Fit function {key} returns the wrong shape ({funcd[key](np.arange(n_parms), xd[key]).shape} instead of {xd[key].shape})\nIf the fit function is just a constant you could try adding x*0 to get the correct shape.")

if not silent:
print('Fit with', n_parms, 'parameter' + 's' * (n_parms > 1))

Expand Down
17 changes: 17 additions & 0 deletions tests/fits_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,23 @@ def func(a, x):
assert np.all(np.array(cd[1:]) > 0)


def test_combined_fit_constant_shape():
N1 = 16
N2 = 10
x = {"a": np.arange(N1),
"": np.arange(N2)}
y = {"a": [pe.pseudo_Obs(o + np.random.normal(0.0, 0.1), 0.1, "test") for o in range(N1)],
"": [pe.pseudo_Obs(o + np.random.normal(0.0, 0.1), 0.1, "test") for o in range(N2)]}
funcs = {"a": lambda a, x: a[0] + a[1] * x,
"": lambda a, x: a[1]}
with pytest.raises(ValueError):
pe.fits.least_squares(x, y, funcs, method='migrad')

funcs = {"a": lambda a, x: a[0] + a[1] * x,
"": lambda a, x: a[1] + x * 0}
pe.fits.least_squares(x, y, funcs, method='migrad')


def fit_general(x, y, func, silent=False, **kwargs):
"""Performs a non-linear fit to y = func(x) and returns a list of Obs corresponding to the fit parameters.
Expand Down

0 comments on commit 5c2a6de

Please sign in to comment.