Skip to content

Commit

Permalink
Create full analysis using smc
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-ragonnet committed Aug 28, 2023
1 parent 27cb664 commit 16df0a2
Show file tree
Hide file tree
Showing 2 changed files with 285 additions and 0 deletions.
57 changes: 57 additions & 0 deletions autumn/projects/sm_covid2/common_school/runner_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ def sample_with_pymc(bcm, initvals, draws=1000, tune=500, cores=8, chains=8, met

return idata


def sample_with_pymc_smc(bcm, draws=2000, cores=8, chains=4):

with pm.Model() as model:
variables = epm.use_model(bcm)
idata = pm.smc.sample_smc(draws=draws, chains=chains, cores=cores, progressbar=False)

return idata


"""
Functions related to post-calibration processes
"""
Expand Down Expand Up @@ -413,6 +423,53 @@ def opti_func(sample_dict):
return idata, uncertainty_df, diff_quantiles_df



def run_full_analysis_smc(
iso3,
analysis="main",
smc_params={'draws': 2000, 'cores': 32, 'chains': 4},
full_run_params={'samples': 1000, 'burn_in': 5000},
output_folder="test_outputs",
logger=None
):
out_path = Path(output_folder)

# Create BayesianCompartmentalModel object
bcm = get_bcm_object(iso3, analysis)

"""
SMC
"""
if logger:
logger.info(f"Start SMC for {smc_params['draws']} draws and {smc_params['chains']} chains...")

idata = sample_with_pymc_smc(bcm, draws=smc_params['draws'], cores=smc_params['cores'], chains=smc_params['chains'])
idata.to_netcdf(out_path / "idata.nc")
make_post_mc_plots(idata, full_run_params['burn_in'], output_folder)
if logger:
logger.info("... MCMC completed")

"""
Post-MCMC processes
"""
sample_df = extract_sample_subset(idata, full_run_params['samples'], full_run_params['burn_in'])
if logger:
logger.info(f"Perform full runs for {full_run_params['samples']} samples")
outputs_df = run_full_runs(sample_df, iso3, analysis)
if logger:
logger.info("Calculate differential outputs")
diff_outputs_df = calculate_diff_outputs(outputs_df)
if logger:
logger.info("Calculate quantiles")
uncertainty_df, diff_quantiles_df = get_quantile_outputs(outputs_df, diff_outputs_df)
uncertainty_df.to_parquet(out_path / "uncertainty_df.parquet")
diff_quantiles_df.to_parquet(out_path / "diff_quantiles_df.parquet")

make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_folder)

return idata, uncertainty_df, diff_quantiles_df


"""
Helper functions for remote runs
"""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "f5408fea",
"metadata": {},
"outputs": [],
"source": [
"from autumn.infrastructure.remote import springboard\n",
"from autumn.projects.sm_covid2.common_school.runner_tools import run_full_analysis_smc, print_continuous_status, download_analysis"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "e7255011",
"metadata": {},
"source": [
"### Define task function"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "2da27a35",
"metadata": {},
"source": [
"#### Standard config"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a5f49d54",
"metadata": {},
"outputs": [],
"source": [
"ISO3 = \"FRA\"\n",
"ANALYSIS = \"main\"\n",
"\n",
"N_CHAINS = 8\n",
"N_CORES = 32\n",
"\n",
"METROPOLIS_DRAWS = 2000\n",
"\n",
"FULL_RUNS_SAMPLES = 1000\n",
"BURN_IN = 5000"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "253866ab",
"metadata": {},
"source": [
"#### Testing config"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ac0a9fdd",
"metadata": {},
"outputs": [],
"source": [
"# ISO3 = \"FRA\"\n",
"# ANALYSIS = \"main\"\n",
"\n",
"# N_CHAINS = 32\n",
"# N_OPTI_SEARCHES = 8\n",
"# OPTI_BUDGET = 700\n",
"\n",
"# METROPOLIS_TUNE = 200\n",
"# METROPOLIS_DRAWS = 1000\n",
"# METROPOLIS_METHOD = \"DEMetropolis\"\n",
"\n",
"# FULL_RUNS_SAMPLES = 100\n",
"# BURN_IN = 500"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "47d05a73",
"metadata": {},
"outputs": [],
"source": [
"def remote_full_analysis_smc_task(bridge: springboard.task.TaskBridge, iso3: str = 'FRA', analysis: str = \"main\"):\n",
" \n",
" import multiprocessing as mp\n",
" mp.set_start_method('forkserver')\n",
"\n",
" bridge.logger.info(f\"Running full analysis for {iso3}. Analysis type: {analysis}.\")\n",
"\n",
" idata, uncertainty_df, diff_quantiles_df = run_full_analysis_smc(\n",
" iso3,\n",
" analysis=analysis, \n",
" smc_params={'draws': METROPOLIS_DRAWS, 'cores': N_CORES, 'chains': N_CHAINS},\n",
" full_run_params={'samples': FULL_RUNS_SAMPLES, 'burn_in': BURN_IN},\n",
" output_folder=bridge.out_path,\n",
" logger=bridge.logger\n",
" )\n",
" \n",
" bridge.logger.info(\"Full analysis complete\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1a1207a8",
"metadata": {},
"outputs": [],
"source": [
"mspec = springboard.EC2MachineSpec(N_CORES, 4, \"compute\")\n",
"task_kwargs = {\n",
" \"iso3\": ISO3,\n",
" \"analysis\": ANALYSIS\n",
"}\n",
"tspec = springboard.TaskSpec(remote_full_analysis_task, task_kwargs)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e12adbec",
"metadata": {},
"outputs": [],
"source": [
"analysis_title = \"SMC\"\n",
"config_str = f\"_{ANALYSIS}_mc{METROPOLIS_DRAWS}\"\n",
"\n",
"run_path = springboard.launch.get_autumn_project_run_path(\"school_project\", ISO3, analysis_title + config_str)\n",
"run_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16657a60",
"metadata": {},
"outputs": [],
"source": [
"runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b416f02f",
"metadata": {},
"outputs": [],
"source": [
"runner.s3.get_status()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5dfd7dfe",
"metadata": {},
"outputs": [],
"source": [
"print(runner.top(\"%CPU\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bc0faa03",
"metadata": {},
"outputs": [],
"source": [
"# wait function with status printing\n",
"print_continuous_status(runner)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f8019dfb",
"metadata": {},
"outputs": [],
"source": [
"run_path = 'projects/school_project/FRA/2023-08-22T1004-single_start_main_LHS4_opt10000_mc2000n10000'"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8d31fe3",
"metadata": {},
"outputs": [],
"source": [
"download_analysis(run_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "170c3a1f",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 16df0a2

Please sign in to comment.