diff --git a/autumn/projects/sm_covid2/common_school/runner_tools.py b/autumn/projects/sm_covid2/common_school/runner_tools.py index dcea8b489..5a7f01c40 100644 --- a/autumn/projects/sm_covid2/common_school/runner_tools.py +++ b/autumn/projects/sm_covid2/common_school/runner_tools.py @@ -312,23 +312,67 @@ def get_quantile_outputs(outputs_df, diff_outputs_df, quantiles=[.025, .25, .5, def run_full_analysis( iso3, analysis="main", - best_param_dicts=[], + opti_params={'n_searches': 8, 'n_best_retained': 4, 'num_workers': 8, 'parallel_opti_jobs': 4, 'warmup_iterations': 2000, 'search_iterations': 5000, 'init_method': "LHS"}, mcmc_params={'draws': 10000, 'tune': 1000, 'cores': 32, 'chains': 32, 'method': 'DEMetropolis'}, - full_run_params={'burn_in': 25000}, + full_run_params={'samples': 1000, 'burn_in': 5000}, output_folder="test_outputs", - logger=None, - _pymc_transform_eps_scale=.1 + logger=None ): out_path = Path(output_folder) + # Check we requested enough searches and number of requested MCMC chains is a multiple of number of retained optimisation searches + assert opti_params['n_best_retained'] <= opti_params['n_searches'] + assert mcmc_params['chains'] % opti_params['n_best_retained'] == 0 + # Create BayesianCompartmentalModel object - bcm = get_bcm_object(iso3, analysis, _pymc_transform_eps_scale=_pymc_transform_eps_scale) + bcm = get_bcm_object(iso3, analysis) - retained_best_params = best_param_dicts + """ + OPTIMISATION + """ + # Sample optimisation starting points with LHS + if logger: + logger.info("Perform LHS sampling") + if opti_params['init_method'] == "LHS": + sample_as_dicts = sample_with_lhs(opti_params['n_searches'], bcm) + elif opti_params['init_method'] == "midpoint": + sample_as_dicts = [{}] * opti_params['n_searches'] + else: + raise ValueError('init_method optimisation argument not supported') + + # Store starting points + with open(out_path / "LHS_init_points.yml", "w") as f: + yaml.dump(sample_as_dicts, f) + + # Perform optimisation searches + if logger: + logger.info(f"Perform optimisation ({opti_params['n_searches']} searches)") + + def opti_func(sample_dict): + best_p, _ = optimise_model_fit(bcm, num_workers=opti_params['num_workers'], warmup_iterations=opti_params['warmup_iterations'], search_iterations=opti_params['search_iterations'], suggested_start=sample_dict) + return best_p + + best_params = map_parallel(opti_func, sample_as_dicts, n_workers=opti_params['parallel_opti_jobs']) + # Store optimal solutions + with open(out_path / "best_params.yml", "w") as f: + yaml.dump(best_params, f) + + # Keep only n_best_retained best solutions + loglikelihoods = [bcm.loglikelihood(**p) for p in best_params] + ll_cutoff = sorted(loglikelihoods, reverse=True)[opti_params['n_best_retained'] - 1] + + retained_init_points, retained_best_params = [], [] + for init_sample, best_p, ll in zip(sample_as_dicts, best_params, loglikelihoods): + if ll >= ll_cutoff: + retained_init_points.append(init_sample) + retained_best_params.append(best_p) # Store retained optimal solutions with open(out_path / "retained_best_params.yml", "w") as f: - yaml.dump(retained_best_params, f) + yaml.dump(retained_best_params, f) + + # Plot optimal solutions and starting points + plot_opti_params(retained_init_points, retained_best_params, bcm, output_folder) # Plot optimal model fits opt_fits_path = out_path / "optimised_fits" @@ -337,6 +381,9 @@ def run_full_analysis( plot_model_fit(bcm, best_p, iso3, opt_fits_path / f"best_fit_{j}.png") plot_multiple_model_fits(bcm, retained_best_params, out_path / "optimal_fits.png") + + if logger: + logger.info("... optimisation completed") # Early return if MCMC not requested if mcmc_params['draws'] == 0: @@ -356,8 +403,27 @@ def run_full_analysis( make_post_mc_plots(idata, full_run_params['burn_in'], output_folder) if logger: logger.info("... MCMC completed") - - return idata + + """ + Post-MCMC processes + """ + sample_df = extract_sample_subset(idata, full_run_params['samples'], full_run_params['burn_in']) + if logger: + logger.info(f"Perform full runs for {full_run_params['samples']} samples") + outputs_df = run_full_runs(sample_df, iso3, analysis) + if logger: + logger.info("Calculate differential outputs") + diff_outputs_df = calculate_diff_outputs(outputs_df) + if logger: + logger.info("Calculate quantiles") + uncertainty_df, diff_quantiles_df = get_quantile_outputs(outputs_df, diff_outputs_df) + uncertainty_df.to_parquet(out_path / "uncertainty_df.parquet") + diff_quantiles_df.to_parquet(out_path / "diff_quantiles_df.parquet") + + make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_folder) + + return idata, uncertainty_df, diff_quantiles_df + def run_full_analysis_smc(