From e65731a11efd951dd40a8fbb87fa0036f79b79ea Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Mon, 11 Sep 2023 15:36:33 +1000 Subject: [PATCH] N cores now an argument of full analysis function --- .../sm_covid2/common_school/runner_tools.py | 8 +- .../remote/sb_full_analysis.ipynb | 103 +++++++++++++++--- .../remote/sb_multi_full_analysis.ipynb | 3 + 3 files changed, 94 insertions(+), 20 deletions(-) diff --git a/autumn/projects/sm_covid2/common_school/runner_tools.py b/autumn/projects/sm_covid2/common_school/runner_tools.py index a87f302db..25c3ba452 100644 --- a/autumn/projects/sm_covid2/common_school/runner_tools.py +++ b/autumn/projects/sm_covid2/common_school/runner_tools.py @@ -40,6 +40,7 @@ DEFAULT_RUN_CONFIG = { + "N_CORES": 8, "N_CHAINS": 8, "N_OPTI_SEARCHES": 16, "OPTI_BUDGET": 10000, @@ -326,6 +327,7 @@ def get_quantile_outputs(outputs_df, diff_outputs_df, quantiles=[.025, .25, .5, def run_full_analysis( iso3, analysis="main", + n_cores=8, opti_params={'n_searches': 8, 'search_iterations': 5000}, mcmc_params={'draws': 30000, 'tune': 5000, 'chains': 8, 'method': 'DEMetropolisZ'}, full_run_params={'samples': 1000, 'burn_in': 20000}, @@ -358,12 +360,12 @@ def opti_func(sample_dict): best_p, _ = optimise_model_fit(bcm, num_workers=n_opti_workers, search_iterations=opti_params['search_iterations'], suggested_start=sample_dict) return best_p - best_params = map_parallel(opti_func, sample_as_dicts, n_workers=2 * mcmc_params['chains'] / n_opti_workers) # oversubscribing + best_params = map_parallel(opti_func, sample_as_dicts, n_workers=2 * n_cores / n_opti_workers) # oversubscribing # Store optimal solutions with open(out_path / "best_params.yml", "w") as f: yaml.dump(best_params, f) - # Keep only n_best_retained best solutions + # Keep only n_chains best solutions loglikelihoods = [bcm.loglikelihood(**p) for p in best_params] ll_cutoff = sorted(loglikelihoods, reverse=True)[mcmc_params['chains'] - 1] @@ -404,7 +406,7 @@ def opti_func(sample_dict): n_repeat_seed = 1 init_vals = [[best_p] * n_repeat_seed for i, best_p in enumerate(retained_best_params)] init_vals = [p_dict for sublist in init_vals for p_dict in sublist] - idata = sample_with_pymc(bcm, initvals=init_vals, draws=mcmc_params['draws'], tune=mcmc_params['tune'], cores=mcmc_params['chains'], chains=mcmc_params['chains'], method=mcmc_params['method']) + idata = sample_with_pymc(bcm, initvals=init_vals, draws=mcmc_params['draws'], tune=mcmc_params['tune'], cores=n_cores, chains=mcmc_params['chains'], method=mcmc_params['method']) idata.to_netcdf(out_path / "idata.nc") make_post_mc_plots(idata, full_run_params['burn_in'], output_folder) if logger: diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_full_analysis.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_full_analysis.ipynb index a18d5462b..19439c25d 100644 --- a/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_full_analysis.ipynb +++ b/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_full_analysis.ipynb @@ -2,10 +2,20 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "f5408fea", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\rrag0004\\.conda\\envs\\summer2\\lib\\site-packages\\summer\\runner\\vectorized_runner.py:363: NumbaDeprecationWarning: \u001b[1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\u001b[0m\n", + " def get_strain_infection_values(\n", + "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n" + ] + } + ], "source": [ "from autumn.infrastructure.remote import springboard\n", "from autumn.projects.sm_covid2.common_school.runner_tools import (\n", @@ -36,17 +46,35 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "c4aa031b", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'N_CHAINS': 8,\n", + " 'N_OPTI_SEARCHES': 16,\n", + " 'OPTI_BUDGET': 10000,\n", + " 'METROPOLIS_TUNE': 5000,\n", + " 'METROPOLIS_DRAWS': 30000,\n", + " 'METROPOLIS_METHOD': 'DEMetropolisZ',\n", + " 'FULL_RUNS_SAMPLES': 1000,\n", + " 'BURN_IN': 20000}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "DEFAULT_RUN_CONFIG" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "a5f49d54", "metadata": {}, "outputs": [], @@ -54,6 +82,8 @@ "ISO3 = \"FRA\"\n", "ANALYSIS = \"main\"\n", "\n", + "N_CORES = DEFAULT_RUN_CONFIG['N_CORES']\n", + "\n", "N_CHAINS = DEFAULT_RUN_CONFIG['N_CHAINS']\n", "N_OPTI_SEARCHES = DEFAULT_RUN_CONFIG['N_OPTI_SEARCHES']\n", "OPTI_BUDGET = DEFAULT_RUN_CONFIG['OPTI_BUDGET']\n", @@ -77,7 +107,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "86180b5e", "metadata": {}, "outputs": [], @@ -88,7 +118,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "4810f137", "metadata": {}, "outputs": [], @@ -103,6 +133,7 @@ " idata, uncertainty_df, diff_quantiles_df = run_full_analysis(\n", " iso3,\n", " analysis=analysis, \n", + " n_cores=N_CORES,\n", " opti_params={'n_searches': N_OPTI_SEARCHES, 'search_iterations': OPTI_BUDGET},\n", " mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n", " full_run_params={'samples': FULL_RUNS_SAMPLES, 'burn_in': BURN_IN},\n", @@ -115,12 +146,12 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "1a1207a8", "metadata": {}, "outputs": [], "source": [ - "mspec = springboard.EC2MachineSpec(N_CHAINS, 4, \"compute\")\n", + "mspec = springboard.EC2MachineSpec(N_CORES, 4, \"compute\")\n", "task_kwargs = {\n", " \"iso3\": ISO3,\n", " \"analysis\": ANALYSIS\n", @@ -130,12 +161,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "e12adbec", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'projects/school_project/FRA/2023-09-11T1117-test_4cpu_main_LHS16_opt10000_mc5000n30000'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "analysis_title = \"single_start\"\n", + "analysis_title = \"test_4cpu\"\n", "config_str = f\"_{ANALYSIS}_LHS{N_OPTI_SEARCHES}_opt{OPTI_BUDGET}_mc{METROPOLIS_TUNE}n{METROPOLIS_DRAWS}\"\n", "\n", "run_path = springboard.launch.get_autumn_project_run_path(\"school_project\", ISO3, analysis_title + config_str)\n", @@ -144,20 +186,39 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "16657a60", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "SSH connection still waiting, retrying 1/5\n" + ] + } + ], "source": [ "runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "id": "b416f02f", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'SUCCESS'" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "runner.s3.get_status()" ] @@ -195,13 +256,21 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "id": "a8d31fe3", "metadata": {}, "outputs": [], "source": [ "download_analysis(run_path)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82728672", + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_multi_full_analysis.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_multi_full_analysis.ipynb index 502784318..f0eb45ac9 100644 --- a/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_multi_full_analysis.ipynb +++ b/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_multi_full_analysis.ipynb @@ -33,6 +33,8 @@ "ISO3_LIST = [\"AUS\", \"BOL\", \"FRA\"]\n", "ANALYSIS = \"main\"\n", "\n", + "N_CORES = DEFAULT_RUN_CONFIG['N_CORES']\n", + "\n", "N_CHAINS = DEFAULT_RUN_CONFIG['N_CHAINS']\n", "N_OPTI_SEARCHES = DEFAULT_RUN_CONFIG['N_OPTI_SEARCHES']\n", "OPTI_BUDGET = DEFAULT_RUN_CONFIG['OPTI_BUDGET']\n", @@ -70,6 +72,7 @@ "\n", " idata, uncertainty_df, diff_quantiles_df = run_full_analysis(\n", " iso3,\n", + " n_cores=N_CORES,\n", " analysis=analysis, \n", " opti_params={'n_searches': N_OPTI_SEARCHES, 'search_iterations': OPTI_BUDGET},\n", " mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n",