Skip to content

Commit

Permalink
Refactor analysis config arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-ragonnet committed Sep 13, 2023
1 parent 78aaaa8 commit fea7596
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 186 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def make_country_output_tiling(iso3, uncertainty_dfs, diff_quantiles_df, output_
# ax11 = fig.add_subplot(inner_grid[2, 1], sharex=ax9)
# plot_incidence_by_age(derived_outputs, ax11, 1, as_proportion=True)

fig.savefig(os.path.join(output_folder, "tiling.png"), facecolor="white")
# fig.savefig(os.path.join(output_folder, "tiling.png"), facecolor="white")
fig.savefig(os.path.join(output_folder, "tiling.pdf"), facecolor="white")

plt.close()
Expand Down
65 changes: 40 additions & 25 deletions autumn/projects/sm_covid2/common_school/runner_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,33 @@


DEFAULT_RUN_CONFIG = {
"N_CORES": 8,
"N_CHAINS": 8,
"N_OPTI_SEARCHES": 16,
"OPTI_BUDGET": 10000,
"METROPOLIS_TUNE": 5000,
"METROPOLIS_DRAWS": 30000,
"METROPOLIS_METHOD": "DEMetropolisZ",
"FULL_RUNS_SAMPLES": 1000,
"BURN_IN": 20000
"n_cores": 8,
# Opti config
"n_opti_searches": 16,
"opti_budget": 10000,
# MCMC config
"n_chains": 8,
"metropolis_tune": 5000,
"metropolis_draws": 30000,
"metropolis_method": "DEMetropolisZ",
# Full runs config
"full_runs_samples": 1000,
"burn_in": 20000
}

TEST_RUN_CONFIG = {
"n_cores": 8,
# Opti config
"n_opti_searches": 2,
"opti_budget": 100,
# MCMC config
"n_chains": 2,
"metropolis_tune": 100,
"metropolis_draws": 100,
"metropolis_method": "DEMetropolisZ",
# Full runs config
"full_runs_samples": 50,
"burn_in": 50
}


Expand Down Expand Up @@ -217,15 +235,12 @@ def calculate_diff_output_quantiles(full_runs, quantiles=[.025, .25, .5, .75, .9
def run_full_analysis(
iso3: str,
analysis: str = "main",
n_cores: int = DEFAULT_RUN_CONFIG['N_CORES'],
opti_params: dict = {'n_searches': DEFAULT_RUN_CONFIG['N_SEARCHES'], 'search_iterations': DEFAULT_RUN_CONFIG['OPTI_BUDGET']},
mcmc_params: dict = {'draws': DEFAULT_RUN_CONFIG['METROPOLIS_DRAWS'], 'tune': DEFAULT_RUN_CONFIG['METROPOLIS_TUNE'], 'chains': DEFAULT_RUN_CONFIG['N_CHAINS'], 'method': DEFAULT_RUN_CONFIG['METROPOLIS_METHOD']},
full_run_params: dict = {'samples': DEFAULT_RUN_CONFIG['FULL_RUNS_SAMPLES'], 'burn_in': DEFAULT_RUN_CONFIG['BURN_IN']},
run_config: dict = DEFAULT_RUN_CONFIG,
output_folder="test_outputs",
logger=None
):
out_path = Path(output_folder)
assert mcmc_params['chains'] <= opti_params['n_searches']
assert run_config['n_chains'] <= run_config['n_opti_searches']

# Create BayesianCompartmentalModel object
bcm = get_bcm_object(iso3, analysis)
Expand All @@ -236,21 +251,21 @@ def run_full_analysis(
# Sample optimisation starting points with LHS
if logger:
logger.info("Perform LHS sampling")
sample_as_dicts = sample_with_lhs(opti_params['n_searches'], bcm)
sample_as_dicts = sample_with_lhs(run_config['n_opti_searches'], bcm)

# Store starting points
with open(out_path / "LHS_init_points.yml", "w") as f:
yaml.dump(sample_as_dicts, f)

# Perform optimisation searches
if logger:
logger.info(f"Perform optimisation ({opti_params['n_searches']} searches)")
logger.info(f"Perform optimisation ({run_config['n_opti_searches']} searches)")
n_opti_workers = 8
def opti_func(sample_dict):
best_p, _ = optimise_model_fit(bcm, num_workers=n_opti_workers, search_iterations=opti_params['search_iterations'], suggested_start=sample_dict)
best_p, _ = optimise_model_fit(bcm, num_workers=n_opti_workers, search_iterations=run_config['opti_budget'], suggested_start=sample_dict)
return best_p

best_params = map_parallel(opti_func, sample_as_dicts, n_workers=int(2 * n_cores / n_opti_workers)) # oversubscribing
best_params = map_parallel(opti_func, sample_as_dicts, n_workers=int(2 * run_config['n_cores'] / n_opti_workers)) # oversubscribing
# Store optimal solutions
with open(out_path / "best_params.yml", "w") as f:
yaml.dump(best_params, f)
Expand All @@ -260,7 +275,7 @@ def opti_func(sample_dict):

# Keep only n_chains best solutions
loglikelihoods = [bcm.loglikelihood(**p) for p in best_params]
ll_cutoff = sorted(loglikelihoods, reverse=True)[mcmc_params['chains'] - 1]
ll_cutoff = sorted(loglikelihoods, reverse=True)[run_config['n_chains'] - 1]

retained_init_points, retained_best_params = [], []
for init_sample, best_p, ll in zip(sample_as_dicts, best_params, loglikelihoods):
Expand All @@ -279,31 +294,31 @@ def opti_func(sample_dict):
plot_multiple_model_fits(bcm, retained_best_params, out_path / "optimal_fits.png")

# Early return if MCMC not requested
if mcmc_params['draws'] == 0:
if run_config['metropolis_draws'] == 0:
return None, None, None

"""
MCMC
"""
if logger:
logger.info(f"Start MCMC for {mcmc_params['tune']} + {mcmc_params['draws']} iterations and {mcmc_params['chains']} chains...")
logger.info(f"Start MCMC for {run_config['metropolis_tune']} + {run_config['metropolis_draws']} iterations and {run_config['n_chains']} chains...")

n_repeat_seed = 1
init_vals = [[best_p] * n_repeat_seed for i, best_p in enumerate(retained_best_params)]
init_vals = [p_dict for sublist in init_vals for p_dict in sublist]
idata = sample_with_pymc(bcm, initvals=init_vals, draws=mcmc_params['draws'], tune=mcmc_params['tune'], cores=n_cores, chains=mcmc_params['chains'], method=mcmc_params['method'])
idata = sample_with_pymc(bcm, initvals=init_vals, draws=run_config['metropolis_draws'], tune=run_config['metropolis_tune'], cores=run_config['n_cores'], chains=run_config['n_chains'], method=run_config['metropolis_method'])
idata.to_netcdf(out_path / "idata.nc")
make_post_mc_plots(idata, full_run_params['burn_in'], output_folder)
make_post_mc_plots(idata, run_config['burn_in'], output_folder)
if logger:
logger.info("... MCMC completed")

"""
Post-MCMC processes
"""
sampled_params = extract_sample_subset(idata, full_run_params['samples'], full_run_params['burn_in'])
sampled_params = extract_sample_subset(idata, run_config['full_runs_samples'], run_config['burn_in'])

if logger:
logger.info(f"Perform full runs for {full_run_params['samples']} samples")
logger.info(f"Perform full runs for {run_config['full_runs_samples']} samples")
full_runs = run_full_runs(sampled_params, iso3, analysis)

if logger:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
"from autumn.projects.sm_covid2.common_school.runner_tools import run_full_analysis, DEFAULT_RUN_CONFIG"
"from autumn.projects.sm_covid2.common_school.runner_tools import run_full_analysis, DEFAULT_RUN_CONFIG, TEST_RUN_CONFIG"
]
},
{
Expand All @@ -24,7 +24,7 @@
"metadata": {},
"outputs": [],
"source": [
"DEFAULT_RUN_CONFIG"
"TEST_RUN_CONFIG"
]
},
{
Expand All @@ -34,20 +34,7 @@
"outputs": [],
"source": [
"ISO3 = \"FRA\"\n",
"ANALYSIS = \"main\"\n",
"\n",
"N_CORES = 8\n",
"\n",
"N_CHAINS = 2 # DEFAULT_RUN_CONFIG['N_CHAINS']\n",
"N_OPTI_SEARCHES = 2 # DEFAULT_RUN_CONFIG['N_OPTI_SEARCHES']\n",
"OPTI_BUDGET = 100 # DEFAULT_RUN_CONFIG['OPTI_BUDGET']\n",
"\n",
"METROPOLIS_TUNE = 100 # DEFAULT_RUN_CONFIG['METROPOLIS_TUNE']\n",
"METROPOLIS_DRAWS = 100 # DEFAULT_RUN_CONFIG['METROPOLIS_DRAWS']\n",
"METROPOLIS_METHOD = DEFAULT_RUN_CONFIG['METROPOLIS_METHOD']\n",
"\n",
"FULL_RUNS_SAMPLES = 100 # DEFAULT_RUN_CONFIG['FULL_RUNS_SAMPLES']\n",
"BURN_IN = 50 # DEFAULT_RUN_CONFIG['BURN_IN']"
"ANALYSIS = \"main\""
]
},
{
Expand All @@ -56,7 +43,9 @@
"metadata": {},
"outputs": [],
"source": [
"from estival.utils.parallel import map_parallel"
"from pathlib import Path\n",
"test_output_folder = Path.cwd() / \"test_full\"\n",
"test_output_folder.mkdir(exist_ok=True)"
]
},
{
Expand All @@ -68,20 +57,10 @@
"idata, unc_dfs, diff_quantiles_df = run_full_analysis(\n",
" ISO3,\n",
" analysis=ANALYSIS, \n",
" n_cores=N_CORES,\n",
" opti_params={'n_searches': N_OPTI_SEARCHES, 'search_iterations': OPTI_BUDGET},\n",
" mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n",
" full_run_params={'samples': FULL_RUNS_SAMPLES, 'burn_in': BURN_IN},\n",
" output_folder=\"test_full\",\n",
" run_config=TEST_RUN_CONFIG,\n",
" output_folder=test_output_folder,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit fea7596

Please sign in to comment.