Skip to content

Commit

Permalink
Clean up full analysis runners
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-ragonnet committed Sep 7, 2023
1 parent 0699a67 commit a97f4b9
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 456 deletions.
104 changes: 31 additions & 73 deletions autumn/projects/sm_covid2/common_school/runner_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from estival.utils.parallel import map_parallel

from autumn.core.runs import ManagedRun
from autumn.infrastructure.remote import springboard

from autumn.settings.folders import PROJECTS_PATH
from autumn.projects.sm_covid2.common_school.calibration import get_bcm_object
Expand All @@ -37,6 +38,19 @@
with countries_path.open() as f:
INCLUDED_COUNTRIES = yaml.unsafe_load(f)


DEFAULT_RUN_CONFIG = {
"N_CHAINS": 8,
"N_OPTI_SEARCHES": 16,
"OPTI_BUDGET": 10000,
"METROPOLIS_TUNE": 5000,
"METROPOLIS_DRAWS": 30000,
"METROPOLIS_METHOD": "DEMetropolisZ",
"FULL_RUNS_SAMPLES": 1000,
"BURN_IN": 20000
}


"""
Functions related to model calibration
"""
Expand Down Expand Up @@ -312,17 +326,14 @@ def get_quantile_outputs(outputs_df, diff_outputs_df, quantiles=[.025, .25, .5,
def run_full_analysis(
iso3,
analysis="main",
opti_params={'n_searches': 8, 'n_best_retained': 4, 'num_workers': 8, 'parallel_opti_jobs': 4, 'warmup_iterations': 2000, 'search_iterations': 5000, 'init_method': "LHS"},
mcmc_params={'draws': 10000, 'tune': 1000, 'cores': 32, 'chains': 32, 'method': 'DEMetropolis'},
full_run_params={'samples': 1000, 'burn_in': 5000},
opti_params={'n_searches': 8, 'search_iterations': 5000},
mcmc_params={'draws': 30000, 'tune': 5000, 'chains': 8, 'method': 'DEMetropolisZ'},
full_run_params={'samples': 1000, 'burn_in': 20000},
output_folder="test_outputs",
logger=None
):
out_path = Path(output_folder)

# Check we requested enough searches and number of requested MCMC chains is a multiple of number of retained optimisation searches
assert opti_params['n_best_retained'] <= opti_params['n_searches']
assert mcmc_params['chains'] % opti_params['n_best_retained'] == 0
assert mcmc_params['chains'] <= opti_params['n_searches']

# Create BayesianCompartmentalModel object
bcm = get_bcm_object(iso3, analysis)
Expand All @@ -333,33 +344,28 @@ def run_full_analysis(
# Sample optimisation starting points with LHS
if logger:
logger.info("Perform LHS sampling")
if opti_params['init_method'] == "LHS":
sample_as_dicts = sample_with_lhs(opti_params['n_searches'], bcm)
elif opti_params['init_method'] == "midpoint":
sample_as_dicts = [{}] * opti_params['n_searches']
else:
raise ValueError('init_method optimisation argument not supported')

sample_as_dicts = sample_with_lhs(opti_params['n_searches'], bcm)

# Store starting points
with open(out_path / "LHS_init_points.yml", "w") as f:
yaml.dump(sample_as_dicts, f)

# Perform optimisation searches
if logger:
logger.info(f"Perform optimisation ({opti_params['n_searches']} searches)")

n_opti_workers = 8
def opti_func(sample_dict):
best_p, _ = optimise_model_fit(bcm, num_workers=opti_params['num_workers'], warmup_iterations=opti_params['warmup_iterations'], search_iterations=opti_params['search_iterations'], suggested_start=sample_dict)
best_p, _ = optimise_model_fit(bcm, num_workers=n_opti_workers, search_iterations=opti_params['search_iterations'], suggested_start=sample_dict)
return best_p

best_params = map_parallel(opti_func, sample_as_dicts, n_workers=opti_params['parallel_opti_jobs'])
best_params = map_parallel(opti_func, sample_as_dicts, n_workers=2 * mcmc_params['chains'] / n_opti_workers) # oversubscribing
# Store optimal solutions
with open(out_path / "best_params.yml", "w") as f:
yaml.dump(best_params, f)

# Keep only n_best_retained best solutions
loglikelihoods = [bcm.loglikelihood(**p) for p in best_params]
ll_cutoff = sorted(loglikelihoods, reverse=True)[opti_params['n_best_retained'] - 1]
ll_cutoff = sorted(loglikelihoods, reverse=True)[mcmc_params['chains'] - 1]

retained_init_points, retained_best_params = [], []
for init_sample, best_p, ll in zip(sample_as_dicts, best_params, loglikelihoods):
Expand All @@ -374,11 +380,11 @@ def opti_func(sample_dict):
# Plot optimal solutions and starting points
plot_opti_params(retained_init_points, retained_best_params, bcm, output_folder)

# Plot optimal model fits
opt_fits_path = out_path / "optimised_fits"
opt_fits_path.mkdir(exist_ok=True)
for j, best_p in enumerate(retained_best_params):
plot_model_fit(bcm, best_p, iso3, opt_fits_path / f"best_fit_{j}.png")
# # Plot optimal model fits
# opt_fits_path = out_path / "optimised_fits"
# opt_fits_path.mkdir(exist_ok=True)
# for j, best_p in enumerate(retained_best_params):
# plot_model_fit(bcm, best_p, iso3, opt_fits_path / f"best_fit_{j}.png")

plot_multiple_model_fits(bcm, retained_best_params, out_path / "optimal_fits.png")

Expand All @@ -395,57 +401,10 @@ def opti_func(sample_dict):
if logger:
logger.info(f"Start MCMC for {mcmc_params['tune']} + {mcmc_params['draws']} iterations and {mcmc_params['chains']} chains...")

n_repeat_seed = int(mcmc_params['chains'] / opti_params['n_best_retained'])
n_repeat_seed = 1
init_vals = [[best_p] * n_repeat_seed for i, best_p in enumerate(retained_best_params)]
init_vals = [p_dict for sublist in init_vals for p_dict in sublist]
idata = sample_with_pymc(bcm, initvals=init_vals, draws=mcmc_params['draws'], tune=mcmc_params['tune'], cores=mcmc_params['cores'], chains=mcmc_params['chains'], method=mcmc_params['method'])
idata.to_netcdf(out_path / "idata.nc")
make_post_mc_plots(idata, full_run_params['burn_in'], output_folder)
if logger:
logger.info("... MCMC completed")

"""
Post-MCMC processes
"""
sample_df = extract_sample_subset(idata, full_run_params['samples'], full_run_params['burn_in'])
if logger:
logger.info(f"Perform full runs for {full_run_params['samples']} samples")
outputs_df = run_full_runs(sample_df, iso3, analysis)
if logger:
logger.info("Calculate differential outputs")
diff_outputs_df = calculate_diff_outputs(outputs_df)
if logger:
logger.info("Calculate quantiles")
uncertainty_df, diff_quantiles_df = get_quantile_outputs(outputs_df, diff_outputs_df)
uncertainty_df.to_parquet(out_path / "uncertainty_df.parquet")
diff_quantiles_df.to_parquet(out_path / "diff_quantiles_df.parquet")

make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_folder)

return idata, uncertainty_df, diff_quantiles_df



def run_full_analysis_smc(
iso3,
analysis="main",
smc_params={'draws': 2000, 'cores': 32, 'chains': 4},
full_run_params={'samples': 1000, 'burn_in': 5000},
output_folder="test_outputs",
logger=None
):
out_path = Path(output_folder)

# Create BayesianCompartmentalModel object
bcm = get_bcm_object(iso3, analysis)

"""
SMC
"""
if logger:
logger.info(f"Start SMC for {smc_params['draws']} draws and {smc_params['chains']} chains...")

idata = sample_with_pymc_smc(bcm, draws=smc_params['draws'], cores=smc_params['cores'], chains=smc_params['chains'])
idata = sample_with_pymc(bcm, initvals=init_vals, draws=mcmc_params['draws'], tune=mcmc_params['tune'], cores=mcmc_params['chains'], chains=mcmc_params['chains'], method=mcmc_params['method'])
idata.to_netcdf(out_path / "idata.nc")
make_post_mc_plots(idata, full_run_params['burn_in'], output_folder)
if logger:
Expand All @@ -471,7 +430,6 @@ def run_full_analysis_smc(

return idata, uncertainty_df, diff_quantiles_df


"""
Helper functions for remote runs
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"\n",
"from autumn.projects.sm_covid2.common_school.runner_tools import run_full_analysis\n",
"from autumn.projects.sm_covid2.common_school.output_plots.country_spec import make_country_output_tiling"
"from autumn.projects.sm_covid2.common_school.runner_tools import run_full_analysis, DEFAULT_RUN_CONFIG"
]
},
{
Expand All @@ -22,6 +18,15 @@
"Note that the three returned objects will be dumped automatically into the specified output folder (1 .nc file, 2 .csv files)."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"DEFAULT_RUN_CONFIG"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -31,17 +36,16 @@
"ISO3 = \"FRA\"\n",
"ANALYSIS = \"main\"\n",
"\n",
"N_CHAINS = 8\n",
"N_OPTI_SEARCHES = 2\n",
"N_BEST_RETAINED = 1\n",
"OPTI_BUDGET = 100\n",
"N_CHAINS = 2 # DEFAULT_RUN_CONFIG['N_CHAINS']\n",
"N_OPTI_SEARCHES = 2 # DEFAULT_RUN_CONFIG['N_OPTI_SEARCHES']\n",
"OPTI_BUDGET = 100 # DEFAULT_RUN_CONFIG['OPTI_BUDGET']\n",
"\n",
"METROPOLIS_TUNE = 100\n",
"METROPOLIS_DRAWS = 500\n",
"METROPOLIS_METHOD = \"DEMetropolis\"\n",
"METROPOLIS_TUNE = 100 # DEFAULT_RUN_CONFIG['METROPOLIS_TUNE']\n",
"METROPOLIS_DRAWS = 100 # DEFAULT_RUN_CONFIG['METROPOLIS_DRAWS']\n",
"METROPOLIS_METHOD = DEFAULT_RUN_CONFIG['METROPOLIS_METHOD']\n",
"\n",
"FULL_RUNS_SAMPLES = 100\n",
"BURN_IN = 50"
"FULL_RUNS_SAMPLES = 100 # DEFAULT_RUN_CONFIG['FULL_RUNS_SAMPLES']\n",
"BURN_IN = 50 # DEFAULT_RUN_CONFIG['BURN_IN']"
]
},
{
Expand All @@ -53,55 +57,12 @@
"idata, uncertainty_df, diff_quantiles_df = run_full_analysis(\n",
" ISO3,\n",
" analysis=ANALYSIS, \n",
" opti_params={'n_searches': N_OPTI_SEARCHES, \"n_best_retained\": N_BEST_RETAINED, 'num_workers': 8, 'parallel_opti_jobs': 4, 'warmup_iterations': 0, 'search_iterations': OPTI_BUDGET, 'init_method': 'LHS'},\n",
" mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'cores': N_CHAINS, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n",
" opti_params={'n_searches': N_OPTI_SEARCHES, 'search_iterations': OPTI_BUDGET},\n",
" mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n",
" full_run_params={'samples': FULL_RUNS_SAMPLES, 'burn_in': BURN_IN},\n",
" output_folder=\"test_full\",\n",
")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"### Plot analysis outputs"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Load data from previous run if required"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# uncertainty_df = pd.read_parquet(os.path.join(output_folder, \"uncertainty_df.parquet\"))\n",
"# diff_quantiles_df = pd.read_parquet(os.path.join(output_folder, \"diff_quantiles_df.parquet\"))"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"Make tiling plot"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_folder)"
]
}
],
"metadata": {
Expand Down
Loading

0 comments on commit a97f4b9

Please sign in to comment.