Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dwelltime: add standard errors based on hessian approximation #690

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* Added improved printing of calibration items under `channel.calibration` providing a more convenient overview of the items associated with a `Slice`.
* Added improved printing of calibrations performed with `Pylake`.
* Added parameter `titles` to customize title of each subplot in [`Kymo.plot_with_channels()`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.kymo.Kymo.html#lumicks.pylake.kymo.Kymo.plot_with_channels).
* Added `err_amplitudes` and `err_lifetimes` to [`DwelltimeModel`](https://lumicks-pylake.readthedocs.io/en/latest/_api/lumicks.pylake.DwelltimeModel.html). These return asymptotic standard errors and should only be used for well identified models.

## v1.5.2 | 2024-07-24

Expand Down
3 changes: 2 additions & 1 deletion lumicks/pylake/fitting/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
fixed=False,
shared=False,
unit=None,
stderr=None,
):
"""Model parameter

Expand Down Expand Up @@ -90,7 +91,7 @@ def __init__(
from the data. See also: :meth:`~lumicks.pylake.FdFit.profile_likelihood()`.
"""

self.stderr = None
self.stderr = stderr
"""Standard error of this parameter.

Standard errors are calculated after fitting the model. These asymptotic errors are based
Expand Down
12 changes: 11 additions & 1 deletion lumicks/pylake/fitting/profile_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,17 +624,27 @@ def chi2(self):
def p(self):
return self.parameters[:, self.profile_info.profiled_parameter_index]

def plot(self, *, significance_level=None, **kwargs):
def plot(self, *, significance_level=None, std_err=None, **kwargs):
"""Plot profile likelihood

Parameters
----------
significance_level : float, optional
Desired significance level (resulting in a 100 * (1 - alpha)% confidence interval) to
plot. Default is the significance level specified when the profile was generated.
std_err : float | None
If provided, also make a quadratic plot based on a standard error.
"""
import matplotlib.pyplot as plt

if std_err:
x = np.arange(-3 * std_err, 3 * std_err, 0.1 * std_err)
plt.plot(
self.p[np.argmin(self.chi2)] + x,
self.profile_info.minimum_chi2 + x**2 / (2 * std_err**2),
"k--",
)

dash_length = 5
plt.plot(self.p, self.chi2, **kwargs)

Expand Down
19 changes: 11 additions & 8 deletions lumicks/pylake/kymotracker/tests/test_kymotrack.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def test_kymotrackgroup_remove(blank_kymo, remove, remaining):
tracks = KymoTrackGroup(src_tracks)
for track in remove:
tracks.remove(src_tracks[track])
for track, should_be_present in zip(src_tracks, remaining):
for track in src_tracks:
if remaining:
assert track in tracks
else:
Expand Down Expand Up @@ -300,9 +300,10 @@ def test_kymotrack_merge():
time_idx = ([1, 2, 3, 4, 5], [6, 7, 8], [6, 7, 8], [1, 2, 3])
pos_idx = ([1, 1, 1, 3, 3], [4, 4, 4], [9, 9, 9], [10, 10, 10])

make_tracks = lambda: KymoTrackGroup(
[KymoTrack(t, p, kymo, "green", 0) for t, p in zip(time_idx, pos_idx)]
)
def make_tracks():
return KymoTrackGroup(
[KymoTrack(t, p, kymo, "green", 0) for t, p in zip(time_idx, pos_idx)]
)

# connect first two
tracks = make_tracks()
Expand Down Expand Up @@ -873,6 +874,7 @@ def test_fit_binding_times_nonzero(blank_kymo, blank_kymo_track_args):
np.testing.assert_equal(dwelltime_model.dwelltimes, [4, 4, 4, 4])
np.testing.assert_equal(dwelltime_model._observation_limits[0], 4)
np.testing.assert_allclose(dwelltime_model.lifetimes[0], [0.4])
np.testing.assert_allclose(dwelltime_model.err_lifetimes[0], 0.199994, rtol=1e-5)


def test_fit_binding_times_empty():
Expand Down Expand Up @@ -1174,10 +1176,11 @@ def make_coordinates(length, divisor):

good_tracks = KymoTrackGroup([track for j, track in enumerate(tracks) if j in (0, 3, 4)])

warning_string = lambda n_discarded: (
f"{n_discarded} tracks were shorter than the specified min_length "
"and discarded from the analysis."
)
def warning_string(n_discarded):
return (
f"{n_discarded} tracks were shorter than the specified min_length and discarded "
f"from the analysis."
)

# test algorithms with default min_length
with pytest.warns(RuntimeWarning, match=warning_string(3)):
Expand Down
97 changes: 83 additions & 14 deletions lumicks/pylake/population/dwelltime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from lumicks.pylake.fitting.parameters import Params, Parameter
from lumicks.pylake.fitting.profile_likelihood import ProfileLikelihood1D
from lumicks.pylake.fitting.detail.derivative_manipulation import numerical_jacobian


@dataclass(frozen=True)
Expand Down Expand Up @@ -98,8 +99,10 @@ def fit_func(params, lb, ub, fitted):
)
parameters = Params(
**{
key: Parameter(param, lower_bound=lb, upper_bound=ub)
for key, param, (lb, ub) in zip(keys, dwelltime_model._parameters, bounds)
key: Parameter(param, lower_bound=lb, upper_bound=ub, stderr=std_err)
for key, param, (lb, ub), std_err in zip(
keys, dwelltime_model._parameters, bounds, dwelltime_model._std_errs
)
}
)

Expand Down Expand Up @@ -194,7 +197,7 @@ def n_components(self):
"""Number of components in the model."""
return self.model.n_components

def plot(self, alpha=None):
def plot(self, alpha=None, *, with_stderr=False, **kwargs):
"""Plot the profile likelihoods for the parameters of a model.

Confidence interval is indicated by the region where the profile crosses the chi squared
Expand All @@ -206,16 +209,26 @@ def plot(self, alpha=None):
Significance level. Confidence intervals are calculated as 100 * (1 - alpha)%. The
default value of `None` results in using the significance level applied when
profiling (default: 0.05).
with_stderr : bool
Also show bounds based on standard errors.
"""
import matplotlib.pyplot as plt

std_errs = self.model._std_errs[~np.isnan(self.model._std_errs)]
if self.n_components == 1:
next(iter(self.profiles.values())).plot(significance_level=alpha)
next(iter(self.profiles.values())).plot(
significance_level=alpha,
std_err=std_errs[0] if with_stderr else None,
)
else:
plot_idx = np.reshape(np.arange(1, len(self.profiles) + 1), (-1, 2)).T.flatten()
for idx, profile in zip(plot_idx, self.profiles.values()):
for par_idx, (idx, profile) in enumerate(zip(plot_idx, self.profiles.values())):
plt.subplot(self.n_components, 2, idx)
profile.plot(significance_level=alpha)
profile.plot(
significance_level=alpha,
std_err=std_errs[par_idx] if with_stderr else None,
**kwargs,
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -332,7 +345,7 @@ def _sample(optimized, iterations) -> np.ndarray:
)
sample, min_obs, max_obs = to_resample[choices, :].T

result, _ = _exponential_mle_optimize(
result, _, _ = _exponential_mle_optimize(
optimized.n_components,
sample,
min_obs,
Expand Down Expand Up @@ -575,7 +588,7 @@ def __init__(
if value is not None
}

self._parameters, self._log_likelihood = _exponential_mle_optimize(
self._parameters, self._log_likelihood, self._std_errs = _exponential_mle_optimize(
n_components,
dwelltimes,
min_observation_time,
Expand Down Expand Up @@ -619,6 +632,34 @@ def lifetimes(self):
"""Lifetime parameter (in time units) of each model component."""
return self._parameters[self.n_components :]

@property
def err_amplitudes(self):
"""Asymptotic standard error estimate on the model amplitudes.

Returns an asymptotic standard error on the amplitude parameters. These error estimates
are only reliable for estimates where a lot of data is available and the model does
not suffer from identifiability issues. To verify that these conditions are met, please
use either :meth:`DwelltimeModel.profile_likelihood()` method or
:meth:`DwelltimeModel.calculate_bootstrap()`.

Note that `np.nan` will be returned in case the parameter was either not estimated or no
error could be obtained."""
return self._std_errs[: self.n_components]

@property
def err_lifetimes(self):
"""Asymptotic standard error estimate on the model lifetimes.

Returns an asymptotic standard error on the amplitude parameters. These error estimates
are only reliable for estimates where a lot of data is available and the model does
not suffer from identifiability issues. To verify that these conditions are met, please
use either :meth:`DwelltimeModel.profile_likelihood()` method or
:meth:`DwelltimeModel.calculate_bootstrap()`.

Note that `np.nan` will be returned in case the parameter was either not estimated or no
error could be obtained."""
return self._std_errs[self.n_components :]

@property
def rate_constants(self):
"""First order rate constant (units of per time) of each model component."""
Expand Down Expand Up @@ -1191,7 +1232,7 @@ def _handle_amplitude_constraint(
else ()
)

return fitted_param_mask, constraints, params
return fitted_param_mask, constraints, params, num_free_amps


def _exponential_mle_bounds(n_components, min_observation_time, max_observation_time):
Expand All @@ -1207,6 +1248,24 @@ def _exponential_mle_bounds(n_components, min_observation_time, max_observation_
)


def _calculate_std_errs(jac_fun, constraints, num_free_amps, current_params, fitted_param_mask):
hessian_approx = numerical_jacobian(jac_fun, current_params[fitted_param_mask], dx=1e-6)

if constraints:
from scipy.linalg import null_space

# When we have a constraint, we should enforce it. We do this by projecting the Hessian
# onto the null space of the constraint and inverting it there. This null space only
# includes directions in which the constraint does not change (i.e. is fulfilled).
constraint = np.zeros((1, hessian_approx.shape[0]))
constraint[0, :num_free_amps] = -1
n = null_space(constraint)

return np.sqrt(np.diag(np.abs(n @ np.linalg.pinv(n.T @ hessian_approx @ n) @ n.T)))

return np.sqrt(np.diag(np.abs(np.linalg.pinv(hessian_approx))))


def _exponential_mle_optimize(
n_components,
t,
Expand Down Expand Up @@ -1273,7 +1332,7 @@ def _exponential_mle_optimize(

bounds = _exponential_mle_bounds(n_components, min_observation_time, max_observation_time)

fitted_param_mask, constraints, initial_guess = _handle_amplitude_constraint(
fitted_param_mask, constraints, initial_guess, num_free_amps = _handle_amplitude_constraint(
n_components, initial_guess, fixed_param_mask
)

Expand All @@ -1291,10 +1350,11 @@ def cost_fun(params):
)

def jac_fun(params):
current_params[fitted_param_mask] = params
jac_params = current_params.copy()
jac_params[fitted_param_mask] = params

gradient = _exponential_mixture_log_likelihood_jacobian(
current_params,
jac_params,
t=t,
t_min=min_observation_time,
t_max=max_observation_time,
Expand All @@ -1305,7 +1365,7 @@ def jac_fun(params):

# Nothing to fit, return!
if np.sum(fitted_param_mask) == 0:
return initial_guess, -cost_fun([])
return initial_guess, -cost_fun([]), np.full(initial_guess.shape, np.nan)

# SLSQP is overly cautious when it comes to warning about bounds. The bound violations are
# typically on the order of 1-2 ULP for a float32 and do not matter for our problem. Initially
Expand All @@ -1328,7 +1388,16 @@ def jac_fun(params):

# output parameters as [amplitudes, lifetimes], -log_likelihood
current_params[fitted_param_mask] = result.x
return current_params, -result.fun

if use_jacobian:
std_errs = np.full(current_params.shape, np.nan)
std_errs[fitted_param_mask] = _calculate_std_errs(
jac_fun, constraints, num_free_amps, current_params, fitted_param_mask
)
else:
std_errs = np.full(current_params.shape, np.nan)

return current_params, -result.fun, std_errs


def _dwellcounts_from_statepath(statepath, exclude_ambiguous_dwells):
Expand Down
Loading