Skip to content

Commit

Permalink
N cores now an argument of full analysis function
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-ragonnet committed Sep 11, 2023
1 parent b872508 commit e65731a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 20 deletions.
8 changes: 5 additions & 3 deletions autumn/projects/sm_covid2/common_school/runner_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@


DEFAULT_RUN_CONFIG = {
"N_CORES": 8,
"N_CHAINS": 8,
"N_OPTI_SEARCHES": 16,
"OPTI_BUDGET": 10000,
Expand Down Expand Up @@ -326,6 +327,7 @@ def get_quantile_outputs(outputs_df, diff_outputs_df, quantiles=[.025, .25, .5,
def run_full_analysis(
iso3,
analysis="main",
n_cores=8,
opti_params={'n_searches': 8, 'search_iterations': 5000},
mcmc_params={'draws': 30000, 'tune': 5000, 'chains': 8, 'method': 'DEMetropolisZ'},
full_run_params={'samples': 1000, 'burn_in': 20000},
Expand Down Expand Up @@ -358,12 +360,12 @@ def opti_func(sample_dict):
best_p, _ = optimise_model_fit(bcm, num_workers=n_opti_workers, search_iterations=opti_params['search_iterations'], suggested_start=sample_dict)
return best_p

best_params = map_parallel(opti_func, sample_as_dicts, n_workers=2 * mcmc_params['chains'] / n_opti_workers) # oversubscribing
best_params = map_parallel(opti_func, sample_as_dicts, n_workers=2 * n_cores / n_opti_workers) # oversubscribing
# Store optimal solutions
with open(out_path / "best_params.yml", "w") as f:
yaml.dump(best_params, f)

# Keep only n_best_retained best solutions
# Keep only n_chains best solutions
loglikelihoods = [bcm.loglikelihood(**p) for p in best_params]
ll_cutoff = sorted(loglikelihoods, reverse=True)[mcmc_params['chains'] - 1]

Expand Down Expand Up @@ -404,7 +406,7 @@ def opti_func(sample_dict):
n_repeat_seed = 1
init_vals = [[best_p] * n_repeat_seed for i, best_p in enumerate(retained_best_params)]
init_vals = [p_dict for sublist in init_vals for p_dict in sublist]
idata = sample_with_pymc(bcm, initvals=init_vals, draws=mcmc_params['draws'], tune=mcmc_params['tune'], cores=mcmc_params['chains'], chains=mcmc_params['chains'], method=mcmc_params['method'])
idata = sample_with_pymc(bcm, initvals=init_vals, draws=mcmc_params['draws'], tune=mcmc_params['tune'], cores=n_cores, chains=mcmc_params['chains'], method=mcmc_params['method'])
idata.to_netcdf(out_path / "idata.nc")
make_post_mc_plots(idata, full_run_params['burn_in'], output_folder)
if logger:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,20 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "f5408fea",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\rrag0004\\.conda\\envs\\summer2\\lib\\site-packages\\summer\\runner\\vectorized_runner.py:363: NumbaDeprecationWarning: \u001b[1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\u001b[0m\n",
" def get_strain_infection_values(\n",
"WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n"
]
}
],
"source": [
"from autumn.infrastructure.remote import springboard\n",
"from autumn.projects.sm_covid2.common_school.runner_tools import (\n",
Expand Down Expand Up @@ -36,24 +46,44 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "c4aa031b",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"{'N_CHAINS': 8,\n",
" 'N_OPTI_SEARCHES': 16,\n",
" 'OPTI_BUDGET': 10000,\n",
" 'METROPOLIS_TUNE': 5000,\n",
" 'METROPOLIS_DRAWS': 30000,\n",
" 'METROPOLIS_METHOD': 'DEMetropolisZ',\n",
" 'FULL_RUNS_SAMPLES': 1000,\n",
" 'BURN_IN': 20000}"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"DEFAULT_RUN_CONFIG"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "a5f49d54",
"metadata": {},
"outputs": [],
"source": [
"ISO3 = \"FRA\"\n",
"ANALYSIS = \"main\"\n",
"\n",
"N_CORES = DEFAULT_RUN_CONFIG['N_CORES']\n",
"\n",
"N_CHAINS = DEFAULT_RUN_CONFIG['N_CHAINS']\n",
"N_OPTI_SEARCHES = DEFAULT_RUN_CONFIG['N_OPTI_SEARCHES']\n",
"OPTI_BUDGET = DEFAULT_RUN_CONFIG['OPTI_BUDGET']\n",
Expand All @@ -77,7 +107,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "86180b5e",
"metadata": {},
"outputs": [],
Expand All @@ -88,7 +118,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "4810f137",
"metadata": {},
"outputs": [],
Expand All @@ -103,6 +133,7 @@
" idata, uncertainty_df, diff_quantiles_df = run_full_analysis(\n",
" iso3,\n",
" analysis=analysis, \n",
" n_cores=N_CORES,\n",
" opti_params={'n_searches': N_OPTI_SEARCHES, 'search_iterations': OPTI_BUDGET},\n",
" mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n",
" full_run_params={'samples': FULL_RUNS_SAMPLES, 'burn_in': BURN_IN},\n",
Expand All @@ -115,12 +146,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "1a1207a8",
"metadata": {},
"outputs": [],
"source": [
"mspec = springboard.EC2MachineSpec(N_CHAINS, 4, \"compute\")\n",
"mspec = springboard.EC2MachineSpec(N_CORES, 4, \"compute\")\n",
"task_kwargs = {\n",
" \"iso3\": ISO3,\n",
" \"analysis\": ANALYSIS\n",
Expand All @@ -130,12 +161,23 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "e12adbec",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'projects/school_project/FRA/2023-09-11T1117-test_4cpu_main_LHS16_opt10000_mc5000n30000'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"analysis_title = \"single_start\"\n",
"analysis_title = \"test_4cpu\"\n",
"config_str = f\"_{ANALYSIS}_LHS{N_OPTI_SEARCHES}_opt{OPTI_BUDGET}_mc{METROPOLIS_TUNE}n{METROPOLIS_DRAWS}\"\n",
"\n",
"run_path = springboard.launch.get_autumn_project_run_path(\"school_project\", ISO3, analysis_title + config_str)\n",
Expand All @@ -144,20 +186,39 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"id": "16657a60",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"SSH connection still waiting, retrying 1/5\n"
]
}
],
"source": [
"runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"id": "b416f02f",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'SUCCESS'"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"runner.s3.get_status()"
]
Expand Down Expand Up @@ -195,13 +256,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"id": "a8d31fe3",
"metadata": {},
"outputs": [],
"source": [
"download_analysis(run_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82728672",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"ISO3_LIST = [\"AUS\", \"BOL\", \"FRA\"]\n",
"ANALYSIS = \"main\"\n",
"\n",
"N_CORES = DEFAULT_RUN_CONFIG['N_CORES']\n",
"\n",
"N_CHAINS = DEFAULT_RUN_CONFIG['N_CHAINS']\n",
"N_OPTI_SEARCHES = DEFAULT_RUN_CONFIG['N_OPTI_SEARCHES']\n",
"OPTI_BUDGET = DEFAULT_RUN_CONFIG['OPTI_BUDGET']\n",
Expand Down Expand Up @@ -70,6 +72,7 @@
"\n",
" idata, uncertainty_df, diff_quantiles_df = run_full_analysis(\n",
" iso3,\n",
" n_cores=N_CORES,\n",
" analysis=analysis, \n",
" opti_params={'n_searches': N_OPTI_SEARCHES, 'search_iterations': OPTI_BUDGET},\n",
" mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n",
Expand Down

0 comments on commit e65731a

Please sign in to comment.