Skip to content

Commit

Permalink
Merge pull request #164 from StingraySoftware/fix_ml_timing_first_guess
Browse files Browse the repository at this point in the history
Separate parameter guessing and test properly in ml_fitting
  • Loading branch information
matteobachetti authored Jun 21, 2024
2 parents 5dc8c88 + e5ae7d9 commit 9b580a3
Showing 1 changed file with 72 additions and 27 deletions.
99 changes: 72 additions & 27 deletions hendrics/ml_timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,55 @@ def normalized_template(template, tomax=False, subtract_min=True):
# return np.std(bases), np.std(amps), np.std(phases)


def _func_for_toa_fitting(phases, pars, template_fun):
amp, shift = pars[:2]
base = 0
if len(pars) > 2:
base = pars[2]
return base + amp * template_fun(phases - shift)


def _guess_start_pars(profile, template, fit_base=True, mean_phase=None):
"""Guess the starting parameters for the fit.
Examples
--------
>>> phases = np.linspace(0, 1, 10001)
>>> template_fun = lambda x : 3.11 * np.exp(-(x - 0.5)**2 / (2 * 0.05**2)) + 355
>>> template = template_fun(phases)
>>> profile1 = 34 * template_fun(phases - 0.2) + 11.
>>> profile2 = 53 * template_fun(phases - 0.2)
>>> pars1 = _guess_start_pars(profile1, template, fit_base=True)
>>> newprof1 = _func_for_toa_fitting(phases, pars1, template_fun)
>>> assert np.allclose(profile1, newprof1, atol=1e-10)
>>> pars2 = _guess_start_pars(profile2, template, fit_base=False)
>>> newprof2 = _func_for_toa_fitting(phases, pars2, template_fun)
>>> assert np.allclose(profile2, newprof2, atol=1e-10)
"""
minp = np.min(profile)
maxp = np.max(profile)
mint = np.min(template)
maxt = np.max(template)

dph = 1 / profile.size
if mean_phase is None:
mean_phase = ((np.argmax(profile) - np.argmax(template))) * dph

if fit_base:
amp_tr = (maxp - minp) / (maxt - mint)
x0 = (
amp_tr,
phases_from_zero_to_one(mean_phase),
minp - mint * amp_tr,
)
else:
x0 = (
maxp / maxt,
phases_from_zero_to_one(mean_phase),
)
return x0


def ml_pulsefit(
profile,
template,
Expand Down Expand Up @@ -321,28 +370,11 @@ def func(pars):
return np.inf
return ll

minp = np.min(profile)
maxp = np.max(profile)
mint = np.min(template)
maxt = np.max(template)

dph = 1 / profile.size
if mean_phase is None:
mean_phase = ((np.argmax(profile) - np.argmax(template)) + 0.5) * dph
x0 = _guess_start_pars(profile, template, fit_base=fit_base, mean_phase=mean_phase)

if fit_base:
x0 = (
(maxp - minp) / (maxt - mint),
phases_from_zero_to_one(mean_phase),
minp - mint,
)
bounds = [(0, np.inf), (0, 1), (0, np.inf)]

else:
x0 = (
maxp / maxt,
phases_from_zero_to_one(mean_phase),
)
bounds = [(0, np.inf), (0, 1)]

res = minimize(func, x0, bounds=bounds)
Expand All @@ -368,25 +400,38 @@ def func(pars):
errs = np.concatenate((errs, [0]))

# import matplotlib.pyplot as plt

# amp_tr, shift_tr = x0[:2]
# base_tr = x0[2] if fit_base else 0

# plt.figure()
# phases_fine = np.linspace(0, 1, 300)

# amp, shift, base = final_pars
# amp_tr, shift_tr = x0[:2]
# base_tr = x0[2] if fit_base else 0

# shift = phases_from_zero_to_one(shift)
# plt.title(f"{template.size} {shift}")
# plt.plot(phases_fine, base + amp * template_fun(phases_fine - shift), label="Best fit")
# plt.plot(phases_fine,
# base_tr + amp_tr * template_fun(phases_fine - shift_tr),
# color="grey", label="Start guess")
# plt.plot(phases_fine, base + amp * template_fun(phases_fine), color="grey", alpha=0.5,
# label="Template")
# plt.plot(
# phases_fine, base + amp * template_fun(phases_fine - shift), label="Best fit"
# )
# plt.plot(
# phases_fine,
# base_tr + amp_tr * template_fun(phases_fine - shift_tr),
# color="grey",
# label="Start guess",
# )
# plt.plot(
# phases_fine,
# base + amp * template_fun(phases_fine),
# color="grey",
# alpha=0.5,
# label="Template",
# )
# plt.axvline(shift - errs[1])
# plt.axvline(shift + errs[1])
# plt.axvline(phases_from_zero_to_one(mean_phase), color="k")
# plt.plot(phases, profile, label="Data")
# plt.show()
# plt.legend()
# plt.show()
# plt.savefig(f"{np.random.random()}.png")
return final_pars, errs

0 comments on commit 9b580a3

Please sign in to comment.