Skip to content

Commit

Permalink
Restore full analysis function
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-ragonnet committed Sep 7, 2023
1 parent 46594bb commit 0699a67
Showing 1 changed file with 75 additions and 9 deletions.
84 changes: 75 additions & 9 deletions autumn/projects/sm_covid2/common_school/runner_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,23 +312,67 @@ def get_quantile_outputs(outputs_df, diff_outputs_df, quantiles=[.025, .25, .5,
def run_full_analysis(
iso3,
analysis="main",
best_param_dicts=[],
opti_params={'n_searches': 8, 'n_best_retained': 4, 'num_workers': 8, 'parallel_opti_jobs': 4, 'warmup_iterations': 2000, 'search_iterations': 5000, 'init_method': "LHS"},
mcmc_params={'draws': 10000, 'tune': 1000, 'cores': 32, 'chains': 32, 'method': 'DEMetropolis'},
full_run_params={'burn_in': 25000},
full_run_params={'samples': 1000, 'burn_in': 5000},
output_folder="test_outputs",
logger=None,
_pymc_transform_eps_scale=.1
logger=None
):
out_path = Path(output_folder)

# Check we requested enough searches and number of requested MCMC chains is a multiple of number of retained optimisation searches
assert opti_params['n_best_retained'] <= opti_params['n_searches']
assert mcmc_params['chains'] % opti_params['n_best_retained'] == 0

# Create BayesianCompartmentalModel object
bcm = get_bcm_object(iso3, analysis, _pymc_transform_eps_scale=_pymc_transform_eps_scale)
bcm = get_bcm_object(iso3, analysis)

retained_best_params = best_param_dicts
"""
OPTIMISATION
"""
# Sample optimisation starting points with LHS
if logger:
logger.info("Perform LHS sampling")
if opti_params['init_method'] == "LHS":
sample_as_dicts = sample_with_lhs(opti_params['n_searches'], bcm)
elif opti_params['init_method'] == "midpoint":
sample_as_dicts = [{}] * opti_params['n_searches']
else:
raise ValueError('init_method optimisation argument not supported')

# Store starting points
with open(out_path / "LHS_init_points.yml", "w") as f:
yaml.dump(sample_as_dicts, f)

# Perform optimisation searches
if logger:
logger.info(f"Perform optimisation ({opti_params['n_searches']} searches)")

def opti_func(sample_dict):
best_p, _ = optimise_model_fit(bcm, num_workers=opti_params['num_workers'], warmup_iterations=opti_params['warmup_iterations'], search_iterations=opti_params['search_iterations'], suggested_start=sample_dict)
return best_p

best_params = map_parallel(opti_func, sample_as_dicts, n_workers=opti_params['parallel_opti_jobs'])
# Store optimal solutions
with open(out_path / "best_params.yml", "w") as f:
yaml.dump(best_params, f)

# Keep only n_best_retained best solutions
loglikelihoods = [bcm.loglikelihood(**p) for p in best_params]
ll_cutoff = sorted(loglikelihoods, reverse=True)[opti_params['n_best_retained'] - 1]

retained_init_points, retained_best_params = [], []
for init_sample, best_p, ll in zip(sample_as_dicts, best_params, loglikelihoods):
if ll >= ll_cutoff:
retained_init_points.append(init_sample)
retained_best_params.append(best_p)

# Store retained optimal solutions
with open(out_path / "retained_best_params.yml", "w") as f:
yaml.dump(retained_best_params, f)
yaml.dump(retained_best_params, f)

# Plot optimal solutions and starting points
plot_opti_params(retained_init_points, retained_best_params, bcm, output_folder)

# Plot optimal model fits
opt_fits_path = out_path / "optimised_fits"
Expand All @@ -337,6 +381,9 @@ def run_full_analysis(
plot_model_fit(bcm, best_p, iso3, opt_fits_path / f"best_fit_{j}.png")

plot_multiple_model_fits(bcm, retained_best_params, out_path / "optimal_fits.png")

if logger:
logger.info("... optimisation completed")

# Early return if MCMC not requested
if mcmc_params['draws'] == 0:
Expand All @@ -356,8 +403,27 @@ def run_full_analysis(
make_post_mc_plots(idata, full_run_params['burn_in'], output_folder)
if logger:
logger.info("... MCMC completed")

return idata

"""
Post-MCMC processes
"""
sample_df = extract_sample_subset(idata, full_run_params['samples'], full_run_params['burn_in'])
if logger:
logger.info(f"Perform full runs for {full_run_params['samples']} samples")
outputs_df = run_full_runs(sample_df, iso3, analysis)
if logger:
logger.info("Calculate differential outputs")
diff_outputs_df = calculate_diff_outputs(outputs_df)
if logger:
logger.info("Calculate quantiles")
uncertainty_df, diff_quantiles_df = get_quantile_outputs(outputs_df, diff_outputs_df)
uncertainty_df.to_parquet(out_path / "uncertainty_df.parquet")
diff_quantiles_df.to_parquet(out_path / "diff_quantiles_df.parquet")

make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_folder)

return idata, uncertainty_df, diff_quantiles_df



def run_full_analysis_smc(
Expand Down

0 comments on commit 0699a67

Please sign in to comment.