From 91e43cc7961ad1de6026840c1e945b50d1e75525 Mon Sep 17 00:00:00 2001 From: Joep Vanlier Date: Thu, 15 Aug 2024 22:09:51 +0200 Subject: [PATCH 1/2] dwelltime: add asymptotic standard errors When calculating the asymptotic uncertainty interval, we should take into account that we actually impose a linear sum constraint (otherwise, the amplitudes will have indeterminate confidence intervals). One could add the constraint explicitly to the Hessian by simply adding a large penalty term to the relevant derivatives. Considering that the sum constraint is of the form (1 - sum(a_i)) ** 2, this would result in adding a constant term d^2f/daidaj = -c with c large to all amplitude terms. What is ugly is that we would need to choose this constant as large as possible without incurring numerical issues. This is why it is preferable to project onto the null space and then calculate the result back instead. --- changelog.md | 1 + .../kymotracker/tests/test_kymotrack.py | 19 +++-- lumicks/pylake/population/dwelltime.py | 73 +++++++++++++++++-- .../population/tests/test_dwelltimes.py | 51 ++++++++++--- 4 files changed, 118 insertions(+), 26 deletions(-) diff --git a/changelog.md b/changelog.md index 89948df01..e08b752ea 100644 --- a/changelog.md +++ b/changelog.md @@ -17,6 +17,7 @@ * Added parameter `titles` to customize title of each subplot in [`Kymo.plot_with_channels()`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.kymo.Kymo.html#lumicks.pylake.kymo.Kymo.plot_with_channels). * Added [`KymoTrack.sample_from_channel()`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.kymotracker.kymotrack.KymoTrack.html#lumicks.pylake.kymotracker.kymotrack.KymoTrack.sample_from_channel) to downsample channel data to the time points of a kymotrack. * Added support for file names with spaces in [`lk.download_from_doi()`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.download_from_doi.html#lumicks.pylake.download_from_doi). +* Added `err_amplitudes` and `err_lifetimes` to [`DwelltimeModel`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.DwelltimeModel.html). These return asymptotic standard errors and should only be used for well identified models. ## v1.5.3 | 2024-10-29 diff --git a/lumicks/pylake/kymotracker/tests/test_kymotrack.py b/lumicks/pylake/kymotracker/tests/test_kymotrack.py index c680d19bb..f4f8e1b6f 100644 --- a/lumicks/pylake/kymotracker/tests/test_kymotrack.py +++ b/lumicks/pylake/kymotracker/tests/test_kymotrack.py @@ -173,7 +173,7 @@ def test_kymotrackgroup_remove(blank_kymo, remove, remaining): tracks = KymoTrackGroup(src_tracks) for track in remove: tracks.remove(src_tracks[track]) - for track, should_be_present in zip(src_tracks, remaining): + for track in src_tracks: if remaining: assert track in tracks else: @@ -300,9 +300,10 @@ def test_kymotrack_merge(): time_idx = ([1, 2, 3, 4, 5], [6, 7, 8], [6, 7, 8], [1, 2, 3]) pos_idx = ([1, 1, 1, 3, 3], [4, 4, 4], [9, 9, 9], [10, 10, 10]) - make_tracks = lambda: KymoTrackGroup( - [KymoTrack(t, p, kymo, "green", 0) for t, p in zip(time_idx, pos_idx)] - ) + def make_tracks(): + return KymoTrackGroup( + [KymoTrack(t, p, kymo, "green", 0) for t, p in zip(time_idx, pos_idx)] + ) # connect first two tracks = make_tracks() @@ -873,6 +874,7 @@ def test_fit_binding_times_nonzero(blank_kymo, blank_kymo_track_args): np.testing.assert_equal(dwelltime_model.dwelltimes, [4, 4, 4, 4]) np.testing.assert_equal(dwelltime_model._observation_limits[0], 4) np.testing.assert_allclose(dwelltime_model.lifetimes[0], [0.4]) + np.testing.assert_allclose(dwelltime_model.err_lifetimes[0], 0.199994, rtol=1e-5) def test_fit_binding_times_empty(): @@ -1174,10 +1176,11 @@ def make_coordinates(length, divisor): good_tracks = KymoTrackGroup([track for j, track in enumerate(tracks) if j in (0, 3, 4)]) - warning_string = lambda n_discarded: ( - f"{n_discarded} tracks were shorter than the specified min_length " - "and discarded from the analysis." - ) + def warning_string(n_discarded): + return ( + f"{n_discarded} tracks were shorter than the specified min_length and discarded " + f"from the analysis." + ) # test algorithms with default min_length with pytest.warns(RuntimeWarning, match=warning_string(3)): diff --git a/lumicks/pylake/population/dwelltime.py b/lumicks/pylake/population/dwelltime.py index 9e6683ee8..35f16c59d 100644 --- a/lumicks/pylake/population/dwelltime.py +++ b/lumicks/pylake/population/dwelltime.py @@ -8,6 +8,7 @@ from lumicks.pylake.fitting.parameters import Params, Parameter from lumicks.pylake.fitting.profile_likelihood import ProfileLikelihood1D +from lumicks.pylake.fitting.detail.derivative_manipulation import numerical_jacobian @dataclass(frozen=True) @@ -332,7 +333,7 @@ def _sample(optimized, iterations) -> np.ndarray: ) sample, min_obs, max_obs = to_resample[choices, :].T - result, _ = _exponential_mle_optimize( + result, _, _ = _exponential_mle_optimize( optimized.n_components, sample, min_obs, @@ -575,7 +576,7 @@ def __init__( if value is not None } - self._parameters, self._log_likelihood = _exponential_mle_optimize( + self._parameters, self._log_likelihood, self._std_errs = _exponential_mle_optimize( n_components, dwelltimes, min_observation_time, @@ -619,6 +620,34 @@ def lifetimes(self): """Lifetime parameter (in time units) of each model component.""" return self._parameters[self.n_components :] + @property + def err_amplitudes(self): + """Asymptotic standard error estimate on the model amplitudes. + + Returns an asymptotic standard error on the amplitude parameters. These error estimates + are only reliable for estimates where a lot of data is available and the model does + not suffer from identifiability issues. To verify that these conditions are met, please + use either :meth:`DwelltimeModel.profile_likelihood()` method or + :meth:`DwelltimeModel.calculate_bootstrap()`. + + Note that `np.nan` will be returned in case the parameter was either not estimated or no + error could be obtained.""" + return self._std_errs[: self.n_components] + + @property + def err_lifetimes(self): + """Asymptotic standard error estimate on the model lifetimes. + + Returns an asymptotic standard error on the amplitude parameters. These error estimates + are only reliable for estimates where a lot of data is available and the model does + not suffer from identifiability issues. To verify that these conditions are met, please + use either :meth:`DwelltimeModel.profile_likelihood()` method or + :meth:`DwelltimeModel.calculate_bootstrap()`. + + Note that `np.nan` will be returned in case the parameter was either not estimated or no + error could be obtained.""" + return self._std_errs[self.n_components :] + @property def rate_constants(self): """First order rate constant (units of per time) of each model component.""" @@ -1191,7 +1220,7 @@ def _handle_amplitude_constraint( else () ) - return fitted_param_mask, constraints, params + return fitted_param_mask, constraints, params, num_free_amps def _exponential_mle_bounds(n_components, min_observation_time, max_observation_time): @@ -1207,6 +1236,24 @@ def _exponential_mle_bounds(n_components, min_observation_time, max_observation_ ) +def _calculate_std_errs(jac_fun, constraints, num_free_amps, current_params, fitted_param_mask): + hessian_approx = numerical_jacobian(jac_fun, current_params[fitted_param_mask], dx=1e-6) + + if constraints: + from scipy.linalg import null_space + + # When we have a constraint, we should enforce it. We do this by projecting the Hessian + # onto the null space of the constraint and inverting it there. This null space only + # includes directions in which the constraint does not change (i.e. is fulfilled). + constraint = np.zeros((1, hessian_approx.shape[0])) + constraint[0, :num_free_amps] = -1 + n = null_space(constraint) + + return np.sqrt(np.diag(np.abs(n @ np.linalg.pinv(n.T @ hessian_approx @ n) @ n.T))) + + return np.sqrt(np.diag(np.abs(np.linalg.pinv(hessian_approx)))) + + def _exponential_mle_optimize( n_components, t, @@ -1273,7 +1320,7 @@ def _exponential_mle_optimize( bounds = _exponential_mle_bounds(n_components, min_observation_time, max_observation_time) - fitted_param_mask, constraints, initial_guess = _handle_amplitude_constraint( + fitted_param_mask, constraints, initial_guess, num_free_amps = _handle_amplitude_constraint( n_components, initial_guess, fixed_param_mask ) @@ -1291,10 +1338,11 @@ def cost_fun(params): ) def jac_fun(params): - current_params[fitted_param_mask] = params + jac_params = current_params.copy() + jac_params[fitted_param_mask] = params gradient = _exponential_mixture_log_likelihood_jacobian( - current_params, + jac_params, t=t, t_min=min_observation_time, t_max=max_observation_time, @@ -1305,7 +1353,7 @@ def jac_fun(params): # Nothing to fit, return! if np.sum(fitted_param_mask) == 0: - return initial_guess, -cost_fun([]) + return initial_guess, -cost_fun([]), np.full(initial_guess.shape, np.nan) # SLSQP is overly cautious when it comes to warning about bounds. The bound violations are # typically on the order of 1-2 ULP for a float32 and do not matter for our problem. Initially @@ -1328,7 +1376,16 @@ def jac_fun(params): # output parameters as [amplitudes, lifetimes], -log_likelihood current_params[fitted_param_mask] = result.x - return current_params, -result.fun + + if use_jacobian: + std_errs = np.full(current_params.shape, np.nan) + std_errs[fitted_param_mask] = _calculate_std_errs( + jac_fun, constraints, num_free_amps, current_params, fitted_param_mask + ) + else: + std_errs = np.full(current_params.shape, np.nan) + + return current_params, -result.fun, std_errs def _dwellcounts_from_statepath(statepath, exclude_ambiguous_dwells): diff --git a/lumicks/pylake/population/tests/test_dwelltimes.py b/lumicks/pylake/population/tests/test_dwelltimes.py index 4f1576d80..e8b812ffb 100644 --- a/lumicks/pylake/population/tests/test_dwelltimes.py +++ b/lumicks/pylake/population/tests/test_dwelltimes.py @@ -253,6 +253,31 @@ def test_dwelltime_profiles(exponential_data, exp_name, reference_bounds, reinte profiles.get_interval("amplitude", 0, 0.001) +@pytest.mark.parametrize( + # fmt:off + "exp_name, n_components, ref_std_errs", + [ + ("dataset_2exp", 1, [np.nan, 0.117634]), # Amplitude is not fitted! + ("dataset_2exp", 2, [0.072455, 0.072456, 0.212814, 0.449388]), + ("dataset_2exp_discrete", 2, [0.068027, 0.068027, 0.21403 , 0.350355]), + ("dataset_2exp_discrete", 3, [0.097556, 0.380667, 0.395212, 0.252004, 1.229997, 4.500617]), + ("dataset_2exp_discrete", 4, [9.755185e-02, 4.999662e-05, 3.788707e-01, 3.934488e-01, 2.520029e-01, 1.889606e+00, 1.227551e+00, 4.489603e+00]), + ] +) +def test_std_errs(exponential_data, exp_name, n_components, ref_std_errs): + dataset = exponential_data[exp_name] + + fit = DwelltimeModel( + dataset["data"], + n_components=n_components, + **dataset["parameters"].observation_limits, + discretization_timestep=dataset["parameters"].dt, + ) + np.testing.assert_allclose(fit._std_errs, ref_std_errs, rtol=1e-4) + np.testing.assert_allclose(fit.err_amplitudes, ref_std_errs[:n_components], rtol=1e-4) + np.testing.assert_allclose(fit.err_lifetimes, ref_std_errs[n_components:], rtol=1e-4) + + @pytest.mark.parametrize("n_components", [2, 1]) def test_dwelltime_profile_plots(n_components): """Verify that the threshold moves appropriately""" @@ -400,7 +425,7 @@ def test_invalid_bootstrap(exponential_data): def test_integration_dwelltime_fixing_parameters(exponential_data): dataset = exponential_data["dataset_2exp"] initial_params = np.array([0.2, 0.2, 0.5, 0.5]) - pars, log_likelihood = _exponential_mle_optimize( + pars, log_likelihood, std_errs = _exponential_mle_optimize( 2, dataset["data"], **dataset["parameters"].observation_limits, @@ -408,6 +433,7 @@ def test_integration_dwelltime_fixing_parameters(exponential_data): fixed_param_mask=[False, True, False, True], ) np.testing.assert_allclose(pars, [0.8, 0.2, 4.27753, 0.5], rtol=1e-4) + np.testing.assert_allclose(std_errs, [np.nan, np.nan, 0.15625, np.nan], rtol=1e-4) @pytest.mark.parametrize( @@ -443,38 +469,38 @@ def fit_data(dt): @pytest.mark.parametrize( - "n_components,params,fixed_param_mask,ref_fitted,ref_const_fun,free_amplitudes,ref_par", + "n_components,params,fixed_param_mask,ref_fitted,ref_const_fun,free_amplitudes,ref_par,ref_amp", [ # fmt:off # 2 components, fix one amplitude => everything fixed in the end [ 2, np.array([0.3, 0.4, 0.3, 0.3]), [True, False, False, False], - [False, False, True, True], None, 0, [0.3, 0.7, 0.3, 0.3], + [False, False, True, True], None, 0, [0.3, 0.7, 0.3, 0.3], 0, ], # 2 components, fix both amplitudes [ 2, np.array([0.3, 0.7, 0.3, 0.3]), [True, True, False, False], - [False, False, True, True], 0, 0, [0.3, 0.7, 0.3, 0.3] + [False, False, True, True], 0, 0, [0.3, 0.7, 0.3, 0.3], 0, ], # 2 components, free amplitudes [ 2, np.array([0.3, 0.7, 0.3, 0.3]), [False, False, True, False], - [True, True, False, True], 0.75, 2, [0.3, 0.7, 0.3, 0.3], + [True, True, False, True], 0.75, 2, [0.3, 0.7, 0.3, 0.3], 2, ], # 3 components, fix one amplitude => End up with two free ones [ 3, np.array([0.3, 0.4, 0.2, 0.3, 0.3, 0.3]), [True, False, False, False, False, False], - [False, True, True, True, True, True], 1.6 / 3, 2, [0.3, 0.4, 0.2, 0.3, 0.3, 0.3], + [False, True, True, True, True, True], 1.6 / 3, 2, [0.3, 0.4, 0.2, 0.3, 0.3, 0.3], 2, ], # 3 components, fix two amplitudes => Amplitudes are now fully determined [ 3, np.array([0.3, 0.4, 0.2, 0.3, 0.3, 0.3]), [True, True, False, False, False, False], - [False, False, False, True, True, True], 0, 0, [0.3, 0.4, 0.3, 0.3, 0.3, 0.3], + [False, False, False, True, True, True], 0, 0, [0.3, 0.4, 0.3, 0.3, 0.3, 0.3], 0, ], # 1 component, no amplitudes required [ 1, np.array([0.3, 0.5]), [False, False], - [False, True], None, 0, [1.0, 0.5], + [False, True], None, 0, [1.0, 0.5], 0, ], # fmt:on ], @@ -487,11 +513,13 @@ def test_parameter_fixing( ref_const_fun, free_amplitudes, ref_par, + ref_amp, ): old_params = np.copy(params) - fitted_param_mask, constraints, out_params = _handle_amplitude_constraint( + fitted_param_mask, constraints, out_params, free_amplitudes = _handle_amplitude_constraint( n_components, params, np.array(fixed_param_mask) ) + assert free_amplitudes == ref_amp # Verify that we didn't modify the input np.testing.assert_allclose(params, old_params) @@ -641,8 +669,11 @@ def quick_fit(fixed_params): fixed_param_mask=fixed_params, ) - x, cost = quick_fit([False, True]) # Amplitude is 1 -> Problem fully determined -> No fit + x, cost, std_err = quick_fit( + [False, True] + ) # Amplitude is 1 -> Problem fully determined -> No fit np.testing.assert_allclose(x, np.array([1.0, 2.0])) + np.testing.assert_equal(std_err, [np.nan, np.nan]) # no uncertainty estimates with pytest.raises(StopIteration): quick_fit([True, False]) # Lifetime unknown -> Need to fit From 4776d86a40484f84f3c81ce14e5cf5e357eeca32 Mon Sep 17 00:00:00 2001 From: Joep Vanlier Date: Thu, 15 Aug 2024 22:56:49 +0200 Subject: [PATCH 2/2] profiles: provide the option to plot w/stderr --- lumicks/pylake/fitting/parameters.py | 3 ++- lumicks/pylake/fitting/profile_likelihood.py | 12 +++++++++- lumicks/pylake/population/dwelltime.py | 24 +++++++++++++++----- 3 files changed, 31 insertions(+), 8 deletions(-) diff --git a/lumicks/pylake/fitting/parameters.py b/lumicks/pylake/fitting/parameters.py index 18f841315..c9352d9fd 100644 --- a/lumicks/pylake/fitting/parameters.py +++ b/lumicks/pylake/fitting/parameters.py @@ -43,6 +43,7 @@ def __init__( fixed=False, shared=False, unit=None, + stderr=None, ): """Model parameter @@ -90,7 +91,7 @@ def __init__( from the data. See also: :meth:`~lumicks.pylake.FdFit.profile_likelihood()`. """ - self.stderr = None + self.stderr = stderr """Standard error of this parameter. Standard errors are calculated after fitting the model. These asymptotic errors are based diff --git a/lumicks/pylake/fitting/profile_likelihood.py b/lumicks/pylake/fitting/profile_likelihood.py index db973e1e3..715dfb534 100644 --- a/lumicks/pylake/fitting/profile_likelihood.py +++ b/lumicks/pylake/fitting/profile_likelihood.py @@ -624,7 +624,7 @@ def chi2(self): def p(self): return self.parameters[:, self.profile_info.profiled_parameter_index] - def plot(self, *, significance_level=None, **kwargs): + def plot(self, *, significance_level=None, std_err=None, **kwargs): """Plot profile likelihood Parameters @@ -632,9 +632,19 @@ def plot(self, *, significance_level=None, **kwargs): significance_level : float, optional Desired significance level (resulting in a 100 * (1 - alpha)% confidence interval) to plot. Default is the significance level specified when the profile was generated. + std_err : float | None + If provided, also make a quadratic plot based on a standard error. """ import matplotlib.pyplot as plt + if std_err: + x = np.arange(-3 * std_err, 3 * std_err, 0.1 * std_err) + plt.plot( + self.p[np.argmin(self.chi2)] + x, + self.profile_info.minimum_chi2 + x**2 / (2 * std_err**2), + "k--", + ) + dash_length = 5 plt.plot(self.p, self.chi2, **kwargs) diff --git a/lumicks/pylake/population/dwelltime.py b/lumicks/pylake/population/dwelltime.py index 35f16c59d..0644a10b8 100644 --- a/lumicks/pylake/population/dwelltime.py +++ b/lumicks/pylake/population/dwelltime.py @@ -99,8 +99,10 @@ def fit_func(params, lb, ub, fitted): ) parameters = Params( **{ - key: Parameter(param, lower_bound=lb, upper_bound=ub) - for key, param, (lb, ub) in zip(keys, dwelltime_model._parameters, bounds) + key: Parameter(param, lower_bound=lb, upper_bound=ub, stderr=std_err) + for key, param, (lb, ub), std_err in zip( + keys, dwelltime_model._parameters, bounds, dwelltime_model._std_errs + ) } ) @@ -195,7 +197,7 @@ def n_components(self): """Number of components in the model.""" return self.model.n_components - def plot(self, alpha=None): + def plot(self, alpha=None, *, with_stderr=False, **kwargs): """Plot the profile likelihoods for the parameters of a model. Confidence interval is indicated by the region where the profile crosses the chi squared @@ -207,16 +209,26 @@ def plot(self, alpha=None): Significance level. Confidence intervals are calculated as 100 * (1 - alpha)%. The default value of `None` results in using the significance level applied when profiling (default: 0.05). + with_stderr : bool + Also show bounds based on standard errors. """ import matplotlib.pyplot as plt + std_errs = self.model._std_errs[~np.isnan(self.model._std_errs)] if self.n_components == 1: - next(iter(self.profiles.values())).plot(significance_level=alpha) + next(iter(self.profiles.values())).plot( + significance_level=alpha, + std_err=std_errs[0] if with_stderr else None, + ) else: plot_idx = np.reshape(np.arange(1, len(self.profiles) + 1), (-1, 2)).T.flatten() - for idx, profile in zip(plot_idx, self.profiles.values()): + for par_idx, (idx, profile) in enumerate(zip(plot_idx, self.profiles.values())): plt.subplot(self.n_components, 2, idx) - profile.plot(significance_level=alpha) + profile.plot( + significance_level=alpha, + std_err=std_errs[par_idx] if with_stderr else None, + **kwargs, + ) @dataclass(frozen=True)