diff --git a/pyerrors/fits.py b/pyerrors/fits.py index 04051ffd..5ec857e0 100644 --- a/pyerrors/fits.py +++ b/pyerrors/fits.py @@ -230,6 +230,12 @@ def func_b(a, x): n_parms_ls.append(n_loc) n_parms = max(n_parms_ls) + + if len(key_ls) > 1: + for key in key_ls: + if np.asarray(yd[key]).shape != funcd[key](np.arange(n_parms), xd[key]).shape: + raise ValueError(f"Fit function {key} returns the wrong shape ({funcd[key](np.arange(n_parms), xd[key]).shape} instead of {xd[key].shape})\nIf the fit function is just a constant you could try adding x*0 to get the correct shape.") + if not silent: print('Fit with', n_parms, 'parameter' + 's' * (n_parms > 1)) diff --git a/tests/fits_test.py b/tests/fits_test.py index 48e788bd..80b6de5a 100644 --- a/tests/fits_test.py +++ b/tests/fits_test.py @@ -1143,6 +1143,23 @@ def func(a, x): assert np.all(np.array(cd[1:]) > 0) +def test_combined_fit_constant_shape(): + N1 = 16 + N2 = 10 + x = {"a": np.arange(N1), + "": np.arange(N2)} + y = {"a": [pe.pseudo_Obs(o + np.random.normal(0.0, 0.1), 0.1, "test") for o in range(N1)], + "": [pe.pseudo_Obs(o + np.random.normal(0.0, 0.1), 0.1, "test") for o in range(N2)]} + funcs = {"a": lambda a, x: a[0] + a[1] * x, + "": lambda a, x: a[1]} + with pytest.raises(ValueError): + pe.fits.least_squares(x, y, funcs, method='migrad') + + funcs = {"a": lambda a, x: a[0] + a[1] * x, + "": lambda a, x: a[1] + x * 0} + pe.fits.least_squares(x, y, funcs, method='migrad') + + def fit_general(x, y, func, silent=False, **kwargs): """Performs a non-linear fit to y = func(x) and returns a list of Obs corresponding to the fit parameters.