Skip to content

Commit

Permalink
dwelltime: add asymptotic standard errors
Browse files Browse the repository at this point in the history
When calculating the asymptotic uncertainty interval, we should take into account that we actually impose a linear sum constraint (otherwise, the amplitudes will have indeterminate confidence intervals).

One could add the constraint explicitly to the Hessian by simply adding a large penalty term to the relevant derivatives. Considering that the sum constraint is of the form (1 - sum(a_i)) ** 2, this would result in adding a constant term d^2f/daidaj = -c with c large to all amplitude terms.

What is ugly is that we would need to choose this constant as large as possible without incurring numerical issues. This is why it is preferable to project onto the null space and then calculate the result back instead.
  • Loading branch information
JoepVanlier committed Sep 5, 2024
1 parent cde5cc4 commit 516e86e
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 26 deletions.
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* Added improved printing of calibration items under `channel.calibration` providing a more convenient overview of the items associated with a `Slice`.
* Added improved printing of calibrations performed with `Pylake`.
* Added parameter `titles` to customize title of each subplot in [`Kymo.plot_with_channels()`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.kymo.Kymo.html#lumicks.pylake.kymo.Kymo.plot_with_channels).
* Added `err_amplitudes` and `err_lifetimes` to [`DwelltimeModel`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.DwelltimeModel.html). These return asymptotic standard errors and should only be used for well identified models.

## v1.5.2 | 2024-07-24

Expand Down
19 changes: 11 additions & 8 deletions lumicks/pylake/kymotracker/tests/test_kymotrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_kymotrackgroup_remove(blank_kymo, remove, remaining):
tracks = KymoTrackGroup(src_tracks)
for track in remove:
tracks.remove(src_tracks[track])
for track, should_be_present in zip(src_tracks, remaining):
for track in src_tracks:
if remaining:
assert track in tracks
else:
Expand Down Expand Up @@ -300,9 +300,10 @@ def test_kymotrack_merge():
time_idx = ([1, 2, 3, 4, 5], [6, 7, 8], [6, 7, 8], [1, 2, 3])
pos_idx = ([1, 1, 1, 3, 3], [4, 4, 4], [9, 9, 9], [10, 10, 10])

make_tracks = lambda: KymoTrackGroup(
[KymoTrack(t, p, kymo, "green", 0) for t, p in zip(time_idx, pos_idx)]
)
def make_tracks():
return KymoTrackGroup(
[KymoTrack(t, p, kymo, "green", 0) for t, p in zip(time_idx, pos_idx)]
)

# connect first two
tracks = make_tracks()
Expand Down Expand Up @@ -873,6 +874,7 @@ def test_fit_binding_times_nonzero(blank_kymo, blank_kymo_track_args):
np.testing.assert_equal(dwelltime_model.dwelltimes, [4, 4, 4, 4])
np.testing.assert_equal(dwelltime_model._observation_limits[0], 4)
np.testing.assert_allclose(dwelltime_model.lifetimes[0], [0.4])
np.testing.assert_allclose(dwelltime_model.err_lifetimes[0], 0.199994, rtol=1e-5)


def test_fit_binding_times_empty():
Expand Down Expand Up @@ -1174,10 +1176,11 @@ def make_coordinates(length, divisor):

good_tracks = KymoTrackGroup([track for j, track in enumerate(tracks) if j in (0, 3, 4)])

warning_string = lambda n_discarded: (
f"{n_discarded} tracks were shorter than the specified min_length "
"and discarded from the analysis."
)
def warning_string(n_discarded):
return (
f"{n_discarded} tracks were shorter than the specified min_length and discarded "
f"from the analysis."
)

# test algorithms with default min_length
with pytest.warns(RuntimeWarning, match=warning_string(3)):
Expand Down
73 changes: 65 additions & 8 deletions lumicks/pylake/population/dwelltime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from lumicks.pylake.fitting.parameters import Params, Parameter
from lumicks.pylake.fitting.profile_likelihood import ProfileLikelihood1D
from lumicks.pylake.fitting.detail.derivative_manipulation import numerical_jacobian


@dataclass(frozen=True)
Expand Down Expand Up @@ -332,7 +333,7 @@ def _sample(optimized, iterations) -> np.ndarray:
)
sample, min_obs, max_obs = to_resample[choices, :].T

result, _ = _exponential_mle_optimize(
result, _, _ = _exponential_mle_optimize(
optimized.n_components,
sample,
min_obs,
Expand Down Expand Up @@ -575,7 +576,7 @@ def __init__(
if value is not None
}

self._parameters, self._log_likelihood = _exponential_mle_optimize(
self._parameters, self._log_likelihood, self._std_errs = _exponential_mle_optimize(
n_components,
dwelltimes,
min_observation_time,
Expand Down Expand Up @@ -619,6 +620,34 @@ def lifetimes(self):
"""Lifetime parameter (in time units) of each model component."""
return self._parameters[self.n_components :]

@property
def err_amplitudes(self):
"""Asymptotic standard error estimate on the model amplitudes.
Returns an asymptotic standard error on the amplitude parameters. These error estimates
are only reliable for estimates where a lot of data is available and the model does
not suffer from identifiability issues. To verify that these conditions are met, please
use either :meth:`DwelltimeModel.profile_likelihood()` method or
:meth:`DwelltimeModel.calculate_bootstrap()`.
Note that `np.nan` will be returned in case the parameter was either not estimated or no
error could be obtained."""
return self._std_errs[: self.n_components]

@property
def err_lifetimes(self):
"""Asymptotic standard error estimate on the model lifetimes.
Returns an asymptotic standard error on the amplitude parameters. These error estimates
are only reliable for estimates where a lot of data is available and the model does
not suffer from identifiability issues. To verify that these conditions are met, please
use either :meth:`DwelltimeModel.profile_likelihood()` method or
:meth:`DwelltimeModel.calculate_bootstrap()`.
Note that `np.nan` will be returned in case the parameter was either not estimated or no
error could be obtained."""
return self._std_errs[self.n_components :]

@property
def rate_constants(self):
"""First order rate constant (units of per time) of each model component."""
Expand Down Expand Up @@ -1191,7 +1220,7 @@ def _handle_amplitude_constraint(
else ()
)

return fitted_param_mask, constraints, params
return fitted_param_mask, constraints, params, num_free_amps


def _exponential_mle_bounds(n_components, min_observation_time, max_observation_time):
Expand All @@ -1207,6 +1236,24 @@ def _exponential_mle_bounds(n_components, min_observation_time, max_observation_
)


def _calculate_std_errs(jac_fun, constraints, num_free_amps, current_params, fitted_param_mask):
hessian_approx = numerical_jacobian(jac_fun, current_params[fitted_param_mask], dx=1e-6)

if constraints:
from scipy.linalg import null_space

# When we have a constraint, we should enforce it. We do this by projecting the Hessian
# onto the null space of the constraint and inverting it there. This null space only
# includes directions in which the constraint does not change (i.e. is fulfilled).
constraint = np.zeros((1, hessian_approx.shape[0]))
constraint[0, :num_free_amps] = -1
n = null_space(constraint)

return np.sqrt(np.diag(np.abs(n @ np.linalg.pinv(n.T @ hessian_approx @ n) @ n.T)))

return np.sqrt(np.diag(np.abs(np.linalg.pinv(hessian_approx))))


def _exponential_mle_optimize(
n_components,
t,
Expand Down Expand Up @@ -1273,7 +1320,7 @@ def _exponential_mle_optimize(

bounds = _exponential_mle_bounds(n_components, min_observation_time, max_observation_time)

fitted_param_mask, constraints, initial_guess = _handle_amplitude_constraint(
fitted_param_mask, constraints, initial_guess, num_free_amps = _handle_amplitude_constraint(
n_components, initial_guess, fixed_param_mask
)

Expand All @@ -1291,10 +1338,11 @@ def cost_fun(params):
)

def jac_fun(params):
current_params[fitted_param_mask] = params
jac_params = current_params.copy()
jac_params[fitted_param_mask] = params

gradient = _exponential_mixture_log_likelihood_jacobian(
current_params,
jac_params,
t=t,
t_min=min_observation_time,
t_max=max_observation_time,
Expand All @@ -1305,7 +1353,7 @@ def jac_fun(params):

# Nothing to fit, return!
if np.sum(fitted_param_mask) == 0:
return initial_guess, -cost_fun([])
return initial_guess, -cost_fun([]), np.full(initial_guess.shape, np.nan)

# SLSQP is overly cautious when it comes to warning about bounds. The bound violations are
# typically on the order of 1-2 ULP for a float32 and do not matter for our problem. Initially
Expand All @@ -1328,7 +1376,16 @@ def jac_fun(params):

# output parameters as [amplitudes, lifetimes], -log_likelihood
current_params[fitted_param_mask] = result.x
return current_params, -result.fun

if use_jacobian:
std_errs = np.full(current_params.shape, np.nan)
std_errs[fitted_param_mask] = _calculate_std_errs(
jac_fun, constraints, num_free_amps, current_params, fitted_param_mask
)
else:
std_errs = np.full(current_params.shape, np.nan)

return current_params, -result.fun, std_errs


def _dwellcounts_from_statepath(statepath, exclude_ambiguous_dwells):
Expand Down
51 changes: 41 additions & 10 deletions lumicks/pylake/population/tests/test_dwelltimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,31 @@ def test_dwelltime_profiles(exponential_data, exp_name, reference_bounds, reinte
profiles.get_interval("amplitude", 0, 0.001)


@pytest.mark.parametrize(
# fmt:off
"exp_name, n_components, ref_std_errs",
[
("dataset_2exp", 1, [np.nan, 0.117634]), # Amplitude is not fitted!
("dataset_2exp", 2, [0.072455, 0.072456, 0.212814, 0.449388]),
("dataset_2exp_discrete", 2, [0.068027, 0.068027, 0.21403 , 0.350355]),
("dataset_2exp_discrete", 3, [0.097556, 0.380667, 0.395212, 0.252004, 1.229997, 4.500617]),
("dataset_2exp_discrete", 4, [9.755185e-02, 4.999662e-05, 3.788707e-01, 3.934488e-01, 2.520029e-01, 1.889606e+00, 1.227551e+00, 4.489603e+00]),
]
)
def test_std_errs(exponential_data, exp_name, n_components, ref_std_errs):
dataset = exponential_data[exp_name]

fit = DwelltimeModel(
dataset["data"],
n_components=n_components,
**dataset["parameters"].observation_limits,
discretization_timestep=dataset["parameters"].dt,
)
np.testing.assert_allclose(fit._std_errs, ref_std_errs, rtol=1e-4)
np.testing.assert_allclose(fit.err_amplitudes, ref_std_errs[:n_components], rtol=1e-4)
np.testing.assert_allclose(fit.err_lifetimes, ref_std_errs[n_components:], rtol=1e-4)


@pytest.mark.parametrize("n_components", [2, 1])
def test_dwelltime_profile_plots(n_components):
"""Verify that the threshold moves appropriately"""
Expand Down Expand Up @@ -400,14 +425,15 @@ def test_invalid_bootstrap(exponential_data):
def test_integration_dwelltime_fixing_parameters(exponential_data):
dataset = exponential_data["dataset_2exp"]
initial_params = np.array([0.2, 0.2, 0.5, 0.5])
pars, log_likelihood = _exponential_mle_optimize(
pars, log_likelihood, std_errs = _exponential_mle_optimize(
2,
dataset["data"],
**dataset["parameters"].observation_limits,
initial_guess=initial_params,
fixed_param_mask=[False, True, False, True],
)
np.testing.assert_allclose(pars, [0.8, 0.2, 4.27753, 0.5], rtol=1e-4)
np.testing.assert_allclose(std_errs, [np.nan, np.nan, 0.15625, np.nan], rtol=1e-4)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -443,38 +469,38 @@ def fit_data(dt):


@pytest.mark.parametrize(
"n_components,params,fixed_param_mask,ref_fitted,ref_const_fun,free_amplitudes,ref_par",
"n_components,params,fixed_param_mask,ref_fitted,ref_const_fun,free_amplitudes,ref_par,ref_amp",
[
# fmt:off
# 2 components, fix one amplitude => everything fixed in the end
[
2, np.array([0.3, 0.4, 0.3, 0.3]), [True, False, False, False],
[False, False, True, True], None, 0, [0.3, 0.7, 0.3, 0.3],
[False, False, True, True], None, 0, [0.3, 0.7, 0.3, 0.3], 0,
],
# 2 components, fix both amplitudes
[
2, np.array([0.3, 0.7, 0.3, 0.3]), [True, True, False, False],
[False, False, True, True], 0, 0, [0.3, 0.7, 0.3, 0.3]
[False, False, True, True], 0, 0, [0.3, 0.7, 0.3, 0.3], 0,
],
# 2 components, free amplitudes
[
2, np.array([0.3, 0.7, 0.3, 0.3]), [False, False, True, False],
[True, True, False, True], 0.75, 2, [0.3, 0.7, 0.3, 0.3],
[True, True, False, True], 0.75, 2, [0.3, 0.7, 0.3, 0.3], 2,
],
# 3 components, fix one amplitude => End up with two free ones
[
3, np.array([0.3, 0.4, 0.2, 0.3, 0.3, 0.3]), [True, False, False, False, False, False],
[False, True, True, True, True, True], 1.6 / 3, 2, [0.3, 0.4, 0.2, 0.3, 0.3, 0.3],
[False, True, True, True, True, True], 1.6 / 3, 2, [0.3, 0.4, 0.2, 0.3, 0.3, 0.3], 2,
],
# 3 components, fix two amplitudes => Amplitudes are now fully determined
[
3, np.array([0.3, 0.4, 0.2, 0.3, 0.3, 0.3]), [True, True, False, False, False, False],
[False, False, False, True, True, True], 0, 0, [0.3, 0.4, 0.3, 0.3, 0.3, 0.3],
[False, False, False, True, True, True], 0, 0, [0.3, 0.4, 0.3, 0.3, 0.3, 0.3], 0,
],
# 1 component, no amplitudes required
[
1, np.array([0.3, 0.5]), [False, False],
[False, True], None, 0, [1.0, 0.5],
[False, True], None, 0, [1.0, 0.5], 0,
],
# fmt:on
],
Expand All @@ -487,11 +513,13 @@ def test_parameter_fixing(
ref_const_fun,
free_amplitudes,
ref_par,
ref_amp,
):
old_params = np.copy(params)
fitted_param_mask, constraints, out_params = _handle_amplitude_constraint(
fitted_param_mask, constraints, out_params, free_amplitudes = _handle_amplitude_constraint(
n_components, params, np.array(fixed_param_mask)
)
assert free_amplitudes == ref_amp

# Verify that we didn't modify the input
np.testing.assert_allclose(params, old_params)
Expand Down Expand Up @@ -641,8 +669,11 @@ def quick_fit(fixed_params):
fixed_param_mask=fixed_params,
)

x, cost = quick_fit([False, True]) # Amplitude is 1 -> Problem fully determined -> No fit
x, cost, std_err = quick_fit(
[False, True]
) # Amplitude is 1 -> Problem fully determined -> No fit
np.testing.assert_allclose(x, np.array([1.0, 2.0]))
np.testing.assert_equal(std_err, [np.nan, np.nan]) # no uncertainty estimates

with pytest.raises(StopIteration):
quick_fit([True, False]) # Lifetime unknown -> Need to fit

0 comments on commit 516e86e

Please sign in to comment.