Skip to content

Commit

Permalink
Merge branch 'master' of github.com:monash-emu/AuTuMN
Browse files Browse the repository at this point in the history
  • Loading branch information
dshipman committed Sep 11, 2023
2 parents 70d0c98 + 91add12 commit 19f4351
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 52 deletions.
8 changes: 5 additions & 3 deletions autumn/projects/sm_covid2/common_school/runner_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@


DEFAULT_RUN_CONFIG = {
"N_CORES": 8,
"N_CHAINS": 8,
"N_OPTI_SEARCHES": 16,
"OPTI_BUDGET": 10000,
Expand Down Expand Up @@ -326,6 +327,7 @@ def get_quantile_outputs(outputs_df, diff_outputs_df, quantiles=[.025, .25, .5,
def run_full_analysis(
iso3,
analysis="main",
n_cores=8,
opti_params={'n_searches': 8, 'search_iterations': 5000},
mcmc_params={'draws': 30000, 'tune': 5000, 'chains': 8, 'method': 'DEMetropolisZ'},
full_run_params={'samples': 1000, 'burn_in': 20000},
Expand Down Expand Up @@ -358,12 +360,12 @@ def opti_func(sample_dict):
best_p, _ = optimise_model_fit(bcm, num_workers=n_opti_workers, search_iterations=opti_params['search_iterations'], suggested_start=sample_dict)
return best_p

best_params = map_parallel(opti_func, sample_as_dicts, n_workers=2 * mcmc_params['chains'] / n_opti_workers) # oversubscribing
best_params = map_parallel(opti_func, sample_as_dicts, n_workers=2 * n_cores / n_opti_workers) # oversubscribing
# Store optimal solutions
with open(out_path / "best_params.yml", "w") as f:
yaml.dump(best_params, f)

# Keep only n_best_retained best solutions
# Keep only n_chains best solutions
loglikelihoods = [bcm.loglikelihood(**p) for p in best_params]
ll_cutoff = sorted(loglikelihoods, reverse=True)[mcmc_params['chains'] - 1]

Expand Down Expand Up @@ -404,7 +406,7 @@ def opti_func(sample_dict):
n_repeat_seed = 1
init_vals = [[best_p] * n_repeat_seed for i, best_p in enumerate(retained_best_params)]
init_vals = [p_dict for sublist in init_vals for p_dict in sublist]
idata = sample_with_pymc(bcm, initvals=init_vals, draws=mcmc_params['draws'], tune=mcmc_params['tune'], cores=mcmc_params['chains'], chains=mcmc_params['chains'], method=mcmc_params['method'])
idata = sample_with_pymc(bcm, initvals=init_vals, draws=mcmc_params['draws'], tune=mcmc_params['tune'], cores=n_cores, chains=mcmc_params['chains'], method=mcmc_params['method'])
idata.to_netcdf(out_path / "idata.nc")
make_post_mc_plots(idata, full_run_params['burn_in'], output_folder)
if logger:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from autumn.infrastructure.remote import springboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rts = springboard.task.RemoteTaskStore(\"projects/school_project\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run_name = \"2023-09-11T1023-test_multirun_main_LHS16_opt10000_mc5000n30000\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run_paths = rts.glob(f\"*/{run_name}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mt_list = [rts.get_managed_task(run_path) for run_path in run_paths]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for mt in mt_list:\n",
" print(mt.get_status())"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "summer2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,20 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"id": "f5408fea",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"c:\\Users\\rrag0004\\.conda\\envs\\summer2\\lib\\site-packages\\summer\\runner\\vectorized_runner.py:363: NumbaDeprecationWarning: \u001b[1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\u001b[0m\n",
" def get_strain_infection_values(\n",
"WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n"
]
}
],
"source": [
"from autumn.infrastructure.remote import springboard\n",
"from autumn.projects.sm_covid2.common_school.runner_tools import (\n",
Expand Down Expand Up @@ -36,24 +46,44 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"id": "c4aa031b",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"{'N_CHAINS': 8,\n",
" 'N_OPTI_SEARCHES': 16,\n",
" 'OPTI_BUDGET': 10000,\n",
" 'METROPOLIS_TUNE': 5000,\n",
" 'METROPOLIS_DRAWS': 30000,\n",
" 'METROPOLIS_METHOD': 'DEMetropolisZ',\n",
" 'FULL_RUNS_SAMPLES': 1000,\n",
" 'BURN_IN': 20000}"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"DEFAULT_RUN_CONFIG"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "a5f49d54",
"metadata": {},
"outputs": [],
"source": [
"ISO3 = \"FRA\"\n",
"ANALYSIS = \"main\"\n",
"\n",
"N_CORES = DEFAULT_RUN_CONFIG['N_CORES']\n",
"\n",
"N_CHAINS = DEFAULT_RUN_CONFIG['N_CHAINS']\n",
"N_OPTI_SEARCHES = DEFAULT_RUN_CONFIG['N_OPTI_SEARCHES']\n",
"OPTI_BUDGET = DEFAULT_RUN_CONFIG['OPTI_BUDGET']\n",
Expand All @@ -77,7 +107,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"id": "86180b5e",
"metadata": {},
"outputs": [],
Expand All @@ -88,7 +118,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "4810f137",
"metadata": {},
"outputs": [],
Expand All @@ -103,6 +133,7 @@
" idata, uncertainty_df, diff_quantiles_df = run_full_analysis(\n",
" iso3,\n",
" analysis=analysis, \n",
" n_cores=N_CORES,\n",
" opti_params={'n_searches': N_OPTI_SEARCHES, 'search_iterations': OPTI_BUDGET},\n",
" mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n",
" full_run_params={'samples': FULL_RUNS_SAMPLES, 'burn_in': BURN_IN},\n",
Expand All @@ -115,12 +146,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "1a1207a8",
"metadata": {},
"outputs": [],
"source": [
"mspec = springboard.EC2MachineSpec(N_CHAINS, 4, \"compute\")\n",
"mspec = springboard.EC2MachineSpec(N_CORES, 4, \"compute\")\n",
"task_kwargs = {\n",
" \"iso3\": ISO3,\n",
" \"analysis\": ANALYSIS\n",
Expand All @@ -130,12 +161,23 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 7,
"id": "e12adbec",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'projects/school_project/FRA/2023-09-11T1117-test_4cpu_main_LHS16_opt10000_mc5000n30000'"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"analysis_title = \"single_start\"\n",
"analysis_title = \"test_4cpu\"\n",
"config_str = f\"_{ANALYSIS}_LHS{N_OPTI_SEARCHES}_opt{OPTI_BUDGET}_mc{METROPOLIS_TUNE}n{METROPOLIS_DRAWS}\"\n",
"\n",
"run_path = springboard.launch.get_autumn_project_run_path(\"school_project\", ISO3, analysis_title + config_str)\n",
Expand All @@ -144,20 +186,39 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 8,
"id": "16657a60",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"SSH connection still waiting, retrying 1/5\n"
]
}
],
"source": [
"runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 15,
"id": "b416f02f",
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"'SUCCESS'"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"runner.s3.get_status()"
]
Expand Down Expand Up @@ -195,13 +256,21 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 16,
"id": "a8d31fe3",
"metadata": {},
"outputs": [],
"source": [
"download_analysis(run_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82728672",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 19f4351

Please sign in to comment.