From 516e86e51e0c22131a4363a2203fb626a4161750 Mon Sep 17 00:00:00 2001 From: Joep Vanlier Date: Thu, 15 Aug 2024 22:09:51 +0200 Subject: [PATCH] dwelltime: add asymptotic standard errors When calculating the asymptotic uncertainty interval, we should take into account that we actually impose a linear sum constraint (otherwise, the amplitudes will have indeterminate confidence intervals). One could add the constraint explicitly to the Hessian by simply adding a large penalty term to the relevant derivatives. Considering that the sum constraint is of the form (1 - sum(a_i)) ** 2, this would result in adding a constant term d^2f/daidaj = -c with c large to all amplitude terms. What is ugly is that we would need to choose this constant as large as possible without incurring numerical issues. This is why it is preferable to project onto the null space and then calculate the result back instead. --- changelog.md | 1 + .../kymotracker/tests/test_kymotrack.py | 19 +++-- lumicks/pylake/population/dwelltime.py | 73 +++++++++++++++++-- .../population/tests/test_dwelltimes.py | 51 ++++++++++--- 4 files changed, 118 insertions(+), 26 deletions(-) diff --git a/changelog.md b/changelog.md index 25a813dd2..885cafe36 100644 --- a/changelog.md +++ b/changelog.md @@ -13,6 +13,7 @@ * Added improved printing of calibration items under `channel.calibration` providing a more convenient overview of the items associated with a `Slice`. * Added improved printing of calibrations performed with `Pylake`. * Added parameter `titles` to customize title of each subplot in [`Kymo.plot_with_channels()`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.kymo.Kymo.html#lumicks.pylake.kymo.Kymo.plot_with_channels). +* Added `err_amplitudes` and `err_lifetimes` to [`DwelltimeModel`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.DwelltimeModel.html). These return asymptotic standard errors and should only be used for well identified models. ## v1.5.2 | 2024-07-24 diff --git a/lumicks/pylake/kymotracker/tests/test_kymotrack.py b/lumicks/pylake/kymotracker/tests/test_kymotrack.py index c680d19bb..f4f8e1b6f 100644 --- a/lumicks/pylake/kymotracker/tests/test_kymotrack.py +++ b/lumicks/pylake/kymotracker/tests/test_kymotrack.py @@ -173,7 +173,7 @@ def test_kymotrackgroup_remove(blank_kymo, remove, remaining): tracks = KymoTrackGroup(src_tracks) for track in remove: tracks.remove(src_tracks[track]) - for track, should_be_present in zip(src_tracks, remaining): + for track in src_tracks: if remaining: assert track in tracks else: @@ -300,9 +300,10 @@ def test_kymotrack_merge(): time_idx = ([1, 2, 3, 4, 5], [6, 7, 8], [6, 7, 8], [1, 2, 3]) pos_idx = ([1, 1, 1, 3, 3], [4, 4, 4], [9, 9, 9], [10, 10, 10]) - make_tracks = lambda: KymoTrackGroup( - [KymoTrack(t, p, kymo, "green", 0) for t, p in zip(time_idx, pos_idx)] - ) + def make_tracks(): + return KymoTrackGroup( + [KymoTrack(t, p, kymo, "green", 0) for t, p in zip(time_idx, pos_idx)] + ) # connect first two tracks = make_tracks() @@ -873,6 +874,7 @@ def test_fit_binding_times_nonzero(blank_kymo, blank_kymo_track_args): np.testing.assert_equal(dwelltime_model.dwelltimes, [4, 4, 4, 4]) np.testing.assert_equal(dwelltime_model._observation_limits[0], 4) np.testing.assert_allclose(dwelltime_model.lifetimes[0], [0.4]) + np.testing.assert_allclose(dwelltime_model.err_lifetimes[0], 0.199994, rtol=1e-5) def test_fit_binding_times_empty(): @@ -1174,10 +1176,11 @@ def make_coordinates(length, divisor): good_tracks = KymoTrackGroup([track for j, track in enumerate(tracks) if j in (0, 3, 4)]) - warning_string = lambda n_discarded: ( - f"{n_discarded} tracks were shorter than the specified min_length " - "and discarded from the analysis." - ) + def warning_string(n_discarded): + return ( + f"{n_discarded} tracks were shorter than the specified min_length and discarded " + f"from the analysis." + ) # test algorithms with default min_length with pytest.warns(RuntimeWarning, match=warning_string(3)): diff --git a/lumicks/pylake/population/dwelltime.py b/lumicks/pylake/population/dwelltime.py index 9e6683ee8..35f16c59d 100644 --- a/lumicks/pylake/population/dwelltime.py +++ b/lumicks/pylake/population/dwelltime.py @@ -8,6 +8,7 @@ from lumicks.pylake.fitting.parameters import Params, Parameter from lumicks.pylake.fitting.profile_likelihood import ProfileLikelihood1D +from lumicks.pylake.fitting.detail.derivative_manipulation import numerical_jacobian @dataclass(frozen=True) @@ -332,7 +333,7 @@ def _sample(optimized, iterations) -> np.ndarray: ) sample, min_obs, max_obs = to_resample[choices, :].T - result, _ = _exponential_mle_optimize( + result, _, _ = _exponential_mle_optimize( optimized.n_components, sample, min_obs, @@ -575,7 +576,7 @@ def __init__( if value is not None } - self._parameters, self._log_likelihood = _exponential_mle_optimize( + self._parameters, self._log_likelihood, self._std_errs = _exponential_mle_optimize( n_components, dwelltimes, min_observation_time, @@ -619,6 +620,34 @@ def lifetimes(self): """Lifetime parameter (in time units) of each model component.""" return self._parameters[self.n_components :] + @property + def err_amplitudes(self): + """Asymptotic standard error estimate on the model amplitudes. + + Returns an asymptotic standard error on the amplitude parameters. These error estimates + are only reliable for estimates where a lot of data is available and the model does + not suffer from identifiability issues. To verify that these conditions are met, please + use either :meth:`DwelltimeModel.profile_likelihood()` method or + :meth:`DwelltimeModel.calculate_bootstrap()`. + + Note that `np.nan` will be returned in case the parameter was either not estimated or no + error could be obtained.""" + return self._std_errs[: self.n_components] + + @property + def err_lifetimes(self): + """Asymptotic standard error estimate on the model lifetimes. + + Returns an asymptotic standard error on the amplitude parameters. These error estimates + are only reliable for estimates where a lot of data is available and the model does + not suffer from identifiability issues. To verify that these conditions are met, please + use either :meth:`DwelltimeModel.profile_likelihood()` method or + :meth:`DwelltimeModel.calculate_bootstrap()`. + + Note that `np.nan` will be returned in case the parameter was either not estimated or no + error could be obtained.""" + return self._std_errs[self.n_components :] + @property def rate_constants(self): """First order rate constant (units of per time) of each model component.""" @@ -1191,7 +1220,7 @@ def _handle_amplitude_constraint( else () ) - return fitted_param_mask, constraints, params + return fitted_param_mask, constraints, params, num_free_amps def _exponential_mle_bounds(n_components, min_observation_time, max_observation_time): @@ -1207,6 +1236,24 @@ def _exponential_mle_bounds(n_components, min_observation_time, max_observation_ ) +def _calculate_std_errs(jac_fun, constraints, num_free_amps, current_params, fitted_param_mask): + hessian_approx = numerical_jacobian(jac_fun, current_params[fitted_param_mask], dx=1e-6) + + if constraints: + from scipy.linalg import null_space + + # When we have a constraint, we should enforce it. We do this by projecting the Hessian + # onto the null space of the constraint and inverting it there. This null space only + # includes directions in which the constraint does not change (i.e. is fulfilled). + constraint = np.zeros((1, hessian_approx.shape[0])) + constraint[0, :num_free_amps] = -1 + n = null_space(constraint) + + return np.sqrt(np.diag(np.abs(n @ np.linalg.pinv(n.T @ hessian_approx @ n) @ n.T))) + + return np.sqrt(np.diag(np.abs(np.linalg.pinv(hessian_approx)))) + + def _exponential_mle_optimize( n_components, t, @@ -1273,7 +1320,7 @@ def _exponential_mle_optimize( bounds = _exponential_mle_bounds(n_components, min_observation_time, max_observation_time) - fitted_param_mask, constraints, initial_guess = _handle_amplitude_constraint( + fitted_param_mask, constraints, initial_guess, num_free_amps = _handle_amplitude_constraint( n_components, initial_guess, fixed_param_mask ) @@ -1291,10 +1338,11 @@ def cost_fun(params): ) def jac_fun(params): - current_params[fitted_param_mask] = params + jac_params = current_params.copy() + jac_params[fitted_param_mask] = params gradient = _exponential_mixture_log_likelihood_jacobian( - current_params, + jac_params, t=t, t_min=min_observation_time, t_max=max_observation_time, @@ -1305,7 +1353,7 @@ def jac_fun(params): # Nothing to fit, return! if np.sum(fitted_param_mask) == 0: - return initial_guess, -cost_fun([]) + return initial_guess, -cost_fun([]), np.full(initial_guess.shape, np.nan) # SLSQP is overly cautious when it comes to warning about bounds. The bound violations are # typically on the order of 1-2 ULP for a float32 and do not matter for our problem. Initially @@ -1328,7 +1376,16 @@ def jac_fun(params): # output parameters as [amplitudes, lifetimes], -log_likelihood current_params[fitted_param_mask] = result.x - return current_params, -result.fun + + if use_jacobian: + std_errs = np.full(current_params.shape, np.nan) + std_errs[fitted_param_mask] = _calculate_std_errs( + jac_fun, constraints, num_free_amps, current_params, fitted_param_mask + ) + else: + std_errs = np.full(current_params.shape, np.nan) + + return current_params, -result.fun, std_errs def _dwellcounts_from_statepath(statepath, exclude_ambiguous_dwells): diff --git a/lumicks/pylake/population/tests/test_dwelltimes.py b/lumicks/pylake/population/tests/test_dwelltimes.py index 4f1576d80..e8b812ffb 100644 --- a/lumicks/pylake/population/tests/test_dwelltimes.py +++ b/lumicks/pylake/population/tests/test_dwelltimes.py @@ -253,6 +253,31 @@ def test_dwelltime_profiles(exponential_data, exp_name, reference_bounds, reinte profiles.get_interval("amplitude", 0, 0.001) +@pytest.mark.parametrize( + # fmt:off + "exp_name, n_components, ref_std_errs", + [ + ("dataset_2exp", 1, [np.nan, 0.117634]), # Amplitude is not fitted! + ("dataset_2exp", 2, [0.072455, 0.072456, 0.212814, 0.449388]), + ("dataset_2exp_discrete", 2, [0.068027, 0.068027, 0.21403 , 0.350355]), + ("dataset_2exp_discrete", 3, [0.097556, 0.380667, 0.395212, 0.252004, 1.229997, 4.500617]), + ("dataset_2exp_discrete", 4, [9.755185e-02, 4.999662e-05, 3.788707e-01, 3.934488e-01, 2.520029e-01, 1.889606e+00, 1.227551e+00, 4.489603e+00]), + ] +) +def test_std_errs(exponential_data, exp_name, n_components, ref_std_errs): + dataset = exponential_data[exp_name] + + fit = DwelltimeModel( + dataset["data"], + n_components=n_components, + **dataset["parameters"].observation_limits, + discretization_timestep=dataset["parameters"].dt, + ) + np.testing.assert_allclose(fit._std_errs, ref_std_errs, rtol=1e-4) + np.testing.assert_allclose(fit.err_amplitudes, ref_std_errs[:n_components], rtol=1e-4) + np.testing.assert_allclose(fit.err_lifetimes, ref_std_errs[n_components:], rtol=1e-4) + + @pytest.mark.parametrize("n_components", [2, 1]) def test_dwelltime_profile_plots(n_components): """Verify that the threshold moves appropriately""" @@ -400,7 +425,7 @@ def test_invalid_bootstrap(exponential_data): def test_integration_dwelltime_fixing_parameters(exponential_data): dataset = exponential_data["dataset_2exp"] initial_params = np.array([0.2, 0.2, 0.5, 0.5]) - pars, log_likelihood = _exponential_mle_optimize( + pars, log_likelihood, std_errs = _exponential_mle_optimize( 2, dataset["data"], **dataset["parameters"].observation_limits, @@ -408,6 +433,7 @@ def test_integration_dwelltime_fixing_parameters(exponential_data): fixed_param_mask=[False, True, False, True], ) np.testing.assert_allclose(pars, [0.8, 0.2, 4.27753, 0.5], rtol=1e-4) + np.testing.assert_allclose(std_errs, [np.nan, np.nan, 0.15625, np.nan], rtol=1e-4) @pytest.mark.parametrize( @@ -443,38 +469,38 @@ def fit_data(dt): @pytest.mark.parametrize( - "n_components,params,fixed_param_mask,ref_fitted,ref_const_fun,free_amplitudes,ref_par", + "n_components,params,fixed_param_mask,ref_fitted,ref_const_fun,free_amplitudes,ref_par,ref_amp", [ # fmt:off # 2 components, fix one amplitude => everything fixed in the end [ 2, np.array([0.3, 0.4, 0.3, 0.3]), [True, False, False, False], - [False, False, True, True], None, 0, [0.3, 0.7, 0.3, 0.3], + [False, False, True, True], None, 0, [0.3, 0.7, 0.3, 0.3], 0, ], # 2 components, fix both amplitudes [ 2, np.array([0.3, 0.7, 0.3, 0.3]), [True, True, False, False], - [False, False, True, True], 0, 0, [0.3, 0.7, 0.3, 0.3] + [False, False, True, True], 0, 0, [0.3, 0.7, 0.3, 0.3], 0, ], # 2 components, free amplitudes [ 2, np.array([0.3, 0.7, 0.3, 0.3]), [False, False, True, False], - [True, True, False, True], 0.75, 2, [0.3, 0.7, 0.3, 0.3], + [True, True, False, True], 0.75, 2, [0.3, 0.7, 0.3, 0.3], 2, ], # 3 components, fix one amplitude => End up with two free ones [ 3, np.array([0.3, 0.4, 0.2, 0.3, 0.3, 0.3]), [True, False, False, False, False, False], - [False, True, True, True, True, True], 1.6 / 3, 2, [0.3, 0.4, 0.2, 0.3, 0.3, 0.3], + [False, True, True, True, True, True], 1.6 / 3, 2, [0.3, 0.4, 0.2, 0.3, 0.3, 0.3], 2, ], # 3 components, fix two amplitudes => Amplitudes are now fully determined [ 3, np.array([0.3, 0.4, 0.2, 0.3, 0.3, 0.3]), [True, True, False, False, False, False], - [False, False, False, True, True, True], 0, 0, [0.3, 0.4, 0.3, 0.3, 0.3, 0.3], + [False, False, False, True, True, True], 0, 0, [0.3, 0.4, 0.3, 0.3, 0.3, 0.3], 0, ], # 1 component, no amplitudes required [ 1, np.array([0.3, 0.5]), [False, False], - [False, True], None, 0, [1.0, 0.5], + [False, True], None, 0, [1.0, 0.5], 0, ], # fmt:on ], @@ -487,11 +513,13 @@ def test_parameter_fixing( ref_const_fun, free_amplitudes, ref_par, + ref_amp, ): old_params = np.copy(params) - fitted_param_mask, constraints, out_params = _handle_amplitude_constraint( + fitted_param_mask, constraints, out_params, free_amplitudes = _handle_amplitude_constraint( n_components, params, np.array(fixed_param_mask) ) + assert free_amplitudes == ref_amp # Verify that we didn't modify the input np.testing.assert_allclose(params, old_params) @@ -641,8 +669,11 @@ def quick_fit(fixed_params): fixed_param_mask=fixed_params, ) - x, cost = quick_fit([False, True]) # Amplitude is 1 -> Problem fully determined -> No fit + x, cost, std_err = quick_fit( + [False, True] + ) # Amplitude is 1 -> Problem fully determined -> No fit np.testing.assert_allclose(x, np.array([1.0, 2.0])) + np.testing.assert_equal(std_err, [np.nan, np.nan]) # no uncertainty estimates with pytest.raises(StopIteration): quick_fit([True, False]) # Lifetime unknown -> Need to fit