From 2c3db4793e964cb94ff1793a7724e64445646208 Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Wed, 30 Aug 2023 16:12:23 +1000 Subject: [PATCH 01/18] Multicountry plots --- .../output_plots/multicountry.py | 107 +++++++++++++++ .../School_Closure/multic_plots.ipynb | 125 ++++++++++++++++++ 2 files changed, 232 insertions(+) create mode 100644 autumn/projects/sm_covid2/common_school/output_plots/multicountry.py create mode 100644 notebooks/user/rragonnet/project_specific/School_Closure/multic_plots.ipynb diff --git a/autumn/projects/sm_covid2/common_school/output_plots/multicountry.py b/autumn/projects/sm_covid2/common_school/output_plots/multicountry.py new file mode 100644 index 000000000..cff610eef --- /dev/null +++ b/autumn/projects/sm_covid2/common_school/output_plots/multicountry.py @@ -0,0 +1,107 @@ +from matplotlib import pyplot as plt +from matplotlib.patches import Rectangle +import pandas as pd + +import plotly.graph_objects as go + + +def plot_multic_relative_outputs(output_dfs_dict: dict[str, pd.DataFrame], req_outputs=["cases_averted_relative", "deaths_averted_relative", "delta_hospital_peak_relative"]): + n_subplots = len(req_outputs) + fig, axes = plt.subplots(n_subplots, 1, figsize=(25, n_subplots*6)) + + this_iso3_list = list(output_dfs_dict.keys()) + n_countries = len(this_iso3_list) + ylab_lookup = { + "cases_averted_relative": "% infections averted by school closure", + "deaths_averted_relative": "% deaths averted by school closure", + "delta_hospital_peak_relative": "Relative reduction in peak hospital occupancy (%)" + } + + box_width = .4 + med_color = 'white' + box_colors= ['black', 'purple', 'firebrick'] + + for i_output, output in enumerate(req_outputs): + box_color = box_colors[i_output] + axis = axes[i_output] + + mean_values = [output_dfs_dict[iso3][output].loc[0.5] for iso3 in this_iso3_list] + sorted_iso3_list = [iso3 for _ , iso3 in sorted(zip(mean_values, this_iso3_list))] + + y_max_abs = 0. + for i_iso3, iso3 in enumerate(sorted_iso3_list): + x = i_iso3 + 1 + + data = - 100. * output_dfs_dict[iso3][output] # use %. And use "-" so positive nbs indicate positive effect of closures + + # median + axis.hlines(y=data.loc[0.5], xmin=x - box_width / 2. , xmax= x + box_width / 2., lw=2., color=med_color, zorder=3) + + # IQR + q_75 = data.loc[0.75] + q_25 = data.loc[0.25] + rect = Rectangle(xy=(x - box_width / 2., q_25), width=box_width, height=q_75 - q_25, zorder=2, facecolor=box_color) + axis.add_patch(rect) + + # 95% CI + q_025 = data.loc[0.025] + q_975 = data.loc[0.975] + axis.vlines(x=x, ymin=q_025 , ymax=q_975, lw=1.5, color=box_color, zorder=1) + + y_max_abs = max(abs(q_975), y_max_abs) + y_max_abs = max(abs(q_025), y_max_abs) + + axis.set_xlim((0, n_countries + 1)) + axis.set_ylim(-1.2*y_max_abs, 1.2*y_max_abs) + + axis.set_xticks(ticks=range(1, n_countries + 1), labels=sorted_iso3_list, rotation=90, fontsize=13) + + y_label = ylab_lookup[output] + axis.set_ylabel(y_label, fontsize=13) + + # add coloured backgorund patches + xmin, xmax = axis.get_xlim() + ymin, ymax = axis.get_ylim() + rect_up = Rectangle(xy=(xmin, 0.), width=xmax - xmin, height=(ymax - ymin)/2., zorder=-1, facecolor="white") #"honeydew") + axis.add_patch(rect_up) + rect_low = Rectangle(xy=(xmin, ymin), width=xmax - xmin, height=(ymax - ymin)/2., zorder=-1, facecolor="gainsboro", alpha=.5) #"mistyrose") + axis.add_patch(rect_low) + + axis.text(n_countries * .75, ymax / 2., s="Positive effect of\nschool closures", fontsize=13) + axis.text(n_countries * .25, ymin / 2., s="Negative effect of\nschool closures", fontsize=13) + + plt.tight_layout() + + return fig + + +def plot_relative_map(output_dfs_dict: dict[str, pd.DataFrame], req_output="delta_hospital_peak_relative"): + this_iso3_list = list(output_dfs_dict.keys()) + values = [- 100 * output_dfs_dict[iso3][req_output].loc[0.5] for iso3 in this_iso3_list] + data_df = pd.DataFrame.from_dict({"iso3": this_iso3_list, "values": values}) + + fig = go.Figure( + data=go.Choropleth( + locations=data_df["iso3"], + z=data_df["values"], + colorscale= [[0, 'lightblue'], + # [0.5, 'darkgrey'], + [1, 'blue']], # "Plotly3", + marker_line_color='darkgrey', + marker_line_width=0.5, + colorbar_title="Relative reduction in
peak hospital occupancy", + colorbar_ticksuffix="%", + ) + ) + fig.update_layout( + geo=dict( + showframe=False, + showcoastlines=False, + projection_type='equirectangular' #'natural earth' 'equirectangular' + ), + margin={"r":0,"t":0,"l":0,"b":0}, + autosize=False, + height=500, + width=1200, + ) + return fig \ No newline at end of file diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/multic_plots.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/multic_plots.ipynb new file mode 100644 index 000000000..5a04ecaea --- /dev/null +++ b/notebooks/user/rragonnet/project_specific/School_Closure/multic_plots.ipynb @@ -0,0 +1,125 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from matplotlib import pyplot as plt\n", + "from matplotlib.patches import Rectangle\n", + "\n", + "from autumn.projects.sm_covid2.common_school.output_plots.multicountry import plot_multic_relative_outputs, plot_relative_map\n", + "\n", + "\n", + "import pandas as pd\n", + "from pathlib import Path\n", + "\n", + "from autumn.projects.sm_covid2.common_school.runner_tools import INCLUDED_COUNTRIES\n", + "iso3_list = list(INCLUDED_COUNTRIES['google_mobility'].keys())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "\n", + "run_paths_dict = {}\n", + "\n", + "available_iso3s = [\"FRA\", \"BOL\", \"MAR\", \"PHL\", \"AUS\"]\n", + "for iso3 in iso3_list:\n", + " if iso3 in available_iso3s:\n", + " run_paths_dict[iso3] = f\"projects/school_project/{iso3}/2023-08-30T1023-fixedbug_main_LHS16_opt10000_mc5000n50000\"\n", + " else: \n", + " proxy_iso3 = random.choice(available_iso3s)\n", + " run_paths_dict[iso3] = f\"projects/school_project/{proxy_iso3}/2023-08-30T1023-fixedbug_main_LHS16_opt10000_mc5000n50000\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_multic_relative_outputs(run_paths_dict: dict[str, str]) -> dict[str, pd.DataFrame]:\n", + " projects_path = Path.home() / \"Models/AuTuMN_new/data/outputs/runs/\"\n", + " diff_quantiles_dfs = {\n", + " iso3: pd.read_parquet(projects_path / run_path / \"output\" / \"diff_quantiles_df.parquet\") for iso3, run_path in run_paths_dict.items()\n", + " } \n", + " return diff_quantiles_dfs" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dfs_dict = get_multic_relative_outputs(run_paths_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_multic_relative_outputs(dfs_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_relative_map(dfs_dict, \"delta_hospital_peak_relative\")\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig = plot_relative_map(dfs_dict, \"cases_averted_relative\")\n", + "fig.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "fig.write_image(\"testMAP.pdf\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "summer2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} From f3c7f7b7abe84c297c0846a076ada37acff28c7c Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Wed, 30 Aug 2023 16:15:25 +1000 Subject: [PATCH 02/18] Update multi_opti notebook --- .../School_Closure/remote/sb_multi_opti.ipynb | 31 +++---------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_multi_opti.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_multi_opti.ipynb index a3fcf0a62..4298dfa55 100644 --- a/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_multi_opti.ipynb +++ b/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_multi_opti.ipynb @@ -23,26 +23,6 @@ "len(iso3_list)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "a31106a5", - "metadata": {}, - "outputs": [], - "source": [ - "iso3_list = iso3_list[:48]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d0149785", - "metadata": {}, - "outputs": [], - "source": [ - "len(iso3_list)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -50,7 +30,7 @@ "metadata": {}, "outputs": [], "source": [ - "N_CPUS = 64\n", + "N_CPUS = 32\n", "N_OPTI_WORKERS = 8\n", "assert N_CPUS % N_OPTI_WORKERS == 0" ] @@ -69,7 +49,7 @@ "\n", " bridge.logger.info(f\"Running optimisation for {len(iso3_list)} countries.\")\n", "\n", - " n_parallel_opti_jobs = N_CPUS / N_OPTI_WORKERS\n", + " n_parallel_opti_jobs = 2. * N_CPUS / N_OPTI_WORKERS\n", "\n", " multi_country_optimise(\n", " iso3_list=iso3_list, \n", @@ -104,7 +84,7 @@ "metadata": {}, "outputs": [], "source": [ - "analysis_title = \"multi_opti_first48_60dupdate\"\n", + "analysis_title = \"multi_opti_15k\"\n", "run_path = springboard.launch.get_autumn_project_run_path(\"school_project\", \"multicountry\", analysis_title)\n", "run_path" ] @@ -147,7 +127,7 @@ "outputs": [], "source": [ "# wait function with status printing\n", - "print_continuous_status(runner)" + "# print_continuous_status(runner)" ] }, { @@ -157,8 +137,7 @@ "metadata": {}, "outputs": [], "source": [ - "# run_path = 'projects/school_project/multicountry/2023-07-05T2142-multi_opti'\n", - "# run_path = 'projects/school_project/multicountry/2023-07-07T1117-multi_opti_first48_60dupdate'" + "# run_path = 'projects/school_project/multicountry/2023-08-30T1614-multi_opti_15k'" ] }, { From 1ffa59e127116c0598f497489d873fba6da1bdb3 Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Wed, 30 Aug 2023 16:27:46 +1000 Subject: [PATCH 03/18] Try to fix issues with figure mix up --- .../common_school/calibration_plots/mc_plots.py | 10 +++++----- .../common_school/calibration_plots/opti_plots.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/autumn/projects/sm_covid2/common_school/calibration_plots/mc_plots.py b/autumn/projects/sm_covid2/common_school/calibration_plots/mc_plots.py index 27ce249ab..7cfcfa24f 100644 --- a/autumn/projects/sm_covid2/common_school/calibration_plots/mc_plots.py +++ b/autumn/projects/sm_covid2/common_school/calibration_plots/mc_plots.py @@ -13,7 +13,7 @@ def make_post_mc_plots(idata, burn_in, output_folder=None): chain_length = idata.sample_stats.sizes['draw'] # Traces (including burn-in) - ax = az.plot_trace(idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False); + az.plot_trace(idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False); plt.subplots_adjust(hspace=.7) if output_folder: plt.savefig(output_folder_path / "mc_traces.jpg", facecolor="white", bbox_inches='tight') @@ -23,7 +23,7 @@ def make_post_mc_plots(idata, burn_in, output_folder=None): burnt_idata = idata.sel(draw=range(burn_in, chain_length)) # Discard burn-in # Traces (after burn-in) - ax = az.plot_trace(burnt_idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False); + az.plot_trace(burnt_idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False); plt.subplots_adjust(hspace=.7) if output_folder: plt.savefig(output_folder_path / "mc_traces_postburnin.jpg", facecolor="white", bbox_inches='tight') @@ -48,9 +48,9 @@ def make_post_mc_plots(idata, burn_in, output_folder=None): rhat_df = raw_rhat_df.drop(columns="random_process.delta_values").loc[0] for i in range(len(raw_rhat_df)): rhat_df[f"random_process.delta_values[{i}]"] = raw_rhat_df['random_process.delta_values'][i] - ax = rhat_df.plot.barh(xlim=(1.,1.105)) - ax.vlines(x=1.05,ymin=-0.5, ymax=len(rhat_df), linestyles="--", color='orange') - ax.vlines(x=1.1,ymin=-0.5, ymax=len(rhat_df), linestyles="-",color='red') + axis = rhat_df.plot.barh(xlim=(1.,1.105)) + axis.vlines(x=1.05,ymin=-0.5, ymax=len(rhat_df), linestyles="--", color='orange') + axis.vlines(x=1.1,ymin=-0.5, ymax=len(rhat_df), linestyles="-",color='red') if output_folder: plt.savefig(output_folder_path / "r_hats.jpg", facecolor="white", bbox_inches='tight') plt.close() diff --git a/autumn/projects/sm_covid2/common_school/calibration_plots/opti_plots.py b/autumn/projects/sm_covid2/common_school/calibration_plots/opti_plots.py index fff1a1473..1e7471653 100644 --- a/autumn/projects/sm_covid2/common_school/calibration_plots/opti_plots.py +++ b/autumn/projects/sm_covid2/common_school/calibration_plots/opti_plots.py @@ -133,8 +133,8 @@ def plot_model_fit(bcm, params, iso3, outfile=None): if outfile: fig.savefig(outfile, facecolor="white") - return fig - + plt.close() + def plot_model_fit_with_ll(bcm, params, outfile=None): REF_DATE = datetime.date(2019,12,31) @@ -241,4 +241,4 @@ def plot_multiple_model_fits(bcm, params_list, outfile=None): if outfile: fig.savefig(outfile, facecolor="white") - return fig \ No newline at end of file + plt.close() \ No newline at end of file From 28a11e65ede000186cb14ffc46edd02ffd0208f0 Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Thu, 31 Aug 2023 10:30:52 +1000 Subject: [PATCH 04/18] Fix seeding time --- .../sm_covid2/common_school/project_maker.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/autumn/projects/sm_covid2/common_school/project_maker.py b/autumn/projects/sm_covid2/common_school/project_maker.py index 83192498e..393a6ff2c 100644 --- a/autumn/projects/sm_covid2/common_school/project_maker.py +++ b/autumn/projects/sm_covid2/common_school/project_maker.py @@ -214,10 +214,15 @@ def get_school_project_parameter_set(iso3, first_date_with_death, sero_age_min, } baseline_params = baseline_params.update(sero_age_params) + # Set seeding time 40 days prior first reported death + baseline_params = baseline_params.update( + {"infectious_seed_time": first_date_with_death - 40.} + ) + # update using MLE params, if available - mle_path= param_path / "mle_files" / f"mle_{iso3}.yml" - if exists(mle_path): - baseline_params = baseline_params.update(mle_path, calibration_format=True) + # mle_path= param_path / "mle_files" / f"mle_{iso3}.yml" + # if exists(mle_path): + # baseline_params = baseline_params.update(mle_path, calibration_format=True) # update using potential Sensitivity Analysis params sa_params_path = param_path / "SA_analyses" / f"{analysis}.yml" @@ -360,12 +365,12 @@ def get_school_project_priors(first_date_with_death): """ # Work out max infectious seeding time so transmission starts before first observed deaths - min_seed_time = first_date_with_death - 100 - max_seed_time = first_date_with_death - 1 + # min_seed_time = first_date_with_death - 100 + # max_seed_time = first_date_with_death - 1 priors = [ UniformPrior("contact_rate", [0.01, 0.06]), - UniformPrior("infectious_seed_time", [min_seed_time, max_seed_time]), + # UniformPrior("infectious_seed_time", [min_seed_time, max_seed_time]), UniformPrior("age_stratification.ifr.multiplier", [0.5, 1.5]), # VOC-related parameters From 9ce264ff214f7a3dbab51af451e06ec791e9dadb Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Thu, 31 Aug 2023 11:58:07 +1000 Subject: [PATCH 05/18] Track random process AUC --- autumn/models/sm_covid2/model.py | 1 + autumn/models/sm_covid2/outputs.py | 15 +++++++++++++++ .../sm_covid2/common_school/runner_tools.py | 5 +++-- 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/autumn/models/sm_covid2/model.py b/autumn/models/sm_covid2/model.py index 796b407f0..7351330d2 100644 --- a/autumn/models/sm_covid2/model.py +++ b/autumn/models/sm_covid2/model.py @@ -526,6 +526,7 @@ def build_model(params: dict, build_options: dict = None, ret_builder=False) -> if params.activate_random_process: outputs_builder.request_random_process_outputs() + outputs_builder.request_random_process_auc() # request extra output to store the number of students*weeks of school missed outputs_builder.request_student_weeks_missed_output(student_weeks_missed) diff --git a/autumn/models/sm_covid2/outputs.py b/autumn/models/sm_covid2/outputs.py index 9996ae662..eb8983ad1 100644 --- a/autumn/models/sm_covid2/outputs.py +++ b/autumn/models/sm_covid2/outputs.py @@ -405,6 +405,21 @@ def array_max(x): func=peak_func ) + def request_random_process_auc(self): + """ + Create an output to calculate the area between the (transformed) random process and the horizontal line y=1. + """ + + def sum_diffs_to_one(x): + return jnp.repeat(jnp.sum(jnp.abs(1 - x)), jnp.size(x)) + + sum_diffs_to_one_func = Function(sum_diffs_to_one, [DerivedOutput("transformed_random_process")]) + + self.model.request_function_output( + "random_process_auc", + func=sum_diffs_to_one_func + ) + # def request_icu_outputs( # self, # prop_icu_among_hospitalised: float, diff --git a/autumn/projects/sm_covid2/common_school/runner_tools.py b/autumn/projects/sm_covid2/common_school/runner_tools.py index 55b6c0b22..0644356b4 100644 --- a/autumn/projects/sm_covid2/common_school/runner_tools.py +++ b/autumn/projects/sm_covid2/common_school/runner_tools.py @@ -273,14 +273,15 @@ def get_quantile_outputs(outputs_df, diff_outputs_df, quantiles=[.025, .25, .5, times = sorted(outputs_df.index.unique()) scenarios = outputs_df["scenario"].unique() unc_output_names = [ - "infection_deaths_ma7", "prop_ever_infected_age_matched", "prop_ever_infected", "transformed_random_process", "cumulative_incidence", "cumulative_infection_deaths", - "peak_hospital_occupancy", "hospital_occupancy" + "infection_deaths_ma7", "prop_ever_infected_age_matched", "prop_ever_infected", "transformed_random_process", "random_process_auc", + "cumulative_incidence", "cumulative_infection_deaths", "peak_hospital_occupancy", "hospital_occupancy" ] uncertainty_data = [] for scenario in scenarios: scenario_mask = outputs_df["scenario"] == scenario scenario_df = outputs_df[scenario_mask] + for time in times: masked_df = scenario_df.loc[time] if masked_df.empty: From d1adfa3a6bde8e2aa04d9a0d613d996781df886a Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Fri, 1 Sep 2023 13:34:38 +1000 Subject: [PATCH 06/18] Request missing output for random process auc --- autumn/projects/sm_covid2/common_school/runner_tools.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/autumn/projects/sm_covid2/common_school/runner_tools.py b/autumn/projects/sm_covid2/common_school/runner_tools.py index 0644356b4..5a7f01c40 100644 --- a/autumn/projects/sm_covid2/common_school/runner_tools.py +++ b/autumn/projects/sm_covid2/common_school/runner_tools.py @@ -193,7 +193,8 @@ def run_full_runs(sampled_df, iso3, analysis): output_names=[ "infection_deaths_ma7", "prop_ever_infected_age_matched", "prop_ever_infected", "cumulative_incidence", - "cumulative_infection_deaths", "hospital_occupancy", 'peak_hospital_occupancy', 'student_weeks_missed', "transformed_random_process" + "cumulative_infection_deaths", "hospital_occupancy", 'peak_hospital_occupancy', 'student_weeks_missed', + "transformed_random_process", "random_process_auc" ] project = get_school_project(iso3, analysis) From e0f547c805c88ad4ffa414d64df64dc2f9b7508a Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Fri, 1 Sep 2023 13:57:24 +1000 Subject: [PATCH 07/18] Fix issue with multiopti plots --- .../sm_covid2/common_school/calibration_plots/opti_plots.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/autumn/projects/sm_covid2/common_school/calibration_plots/opti_plots.py b/autumn/projects/sm_covid2/common_school/calibration_plots/opti_plots.py index 1e7471653..2aecd9e7f 100644 --- a/autumn/projects/sm_covid2/common_school/calibration_plots/opti_plots.py +++ b/autumn/projects/sm_covid2/common_school/calibration_plots/opti_plots.py @@ -189,9 +189,7 @@ def plot_multiple_model_fits(bcm, params_list, outfile=None): death_ax, rp_ax, sero_ax = axs[0], axs[1], axs[2] # set up the three axes - death_ax.set_ylabel("COVID-19 deaths") - targets["infection_deaths_ma7"].plot(style='.', ax=death_ax, label="", zorder=20, color='black') - + death_ax.set_ylabel("COVID-19 deaths") rp_ax.set_ylabel("Random process") if "prop_ever_infected_age_matched" in targets: @@ -228,6 +226,8 @@ def plot_multiple_model_fits(bcm, params_list, outfile=None): color=colors[i] ) + targets["infection_deaths_ma7"].plot(style='.', ax=death_ax, label="", zorder=20, color='black') + # Post plotting processes # death death_ax.legend(loc='best', ncols=2) From d2f3ed09e710d2986ef71d6b0930bd3ed258544d Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Mon, 4 Sep 2023 12:11:04 +1000 Subject: [PATCH 08/18] Estival requirement bump --- requirements/requirements310.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/requirements310.txt b/requirements/requirements310.txt index df0bbffac..98bb4a9d0 100644 --- a/requirements/requirements310.txt +++ b/requirements/requirements310.txt @@ -14,7 +14,7 @@ summerepi==3.6.4 summerepi2==1.2.9 -estival==0.3.4 +estival==0.4.2b0 # Jax for Windows (summer2 requirement) # Linux/OSX already installed via computegraph From 0c9abfd490e1fa9dcfd44d86c7c0a5284d942476 Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Mon, 4 Sep 2023 12:17:55 +1000 Subject: [PATCH 09/18] Adjust _pymc_transform_eps_scale in project code --- autumn/projects/sm_covid2/common_school/calibration.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/autumn/projects/sm_covid2/common_school/calibration.py b/autumn/projects/sm_covid2/common_school/calibration.py index 612cae6b2..40ecd39f7 100644 --- a/autumn/projects/sm_covid2/common_school/calibration.py +++ b/autumn/projects/sm_covid2/common_school/calibration.py @@ -23,8 +23,11 @@ def get_estival_uniform_priors(autumn_priors): assert prior_dict["distribution"] == "uniform", "Only uniform priors are currently supported" if not "random_process.delta_values" in prior_dict["param_name"] and not "dispersion_param" in prior_dict["param_name"]: + p = esp.UniformPrior(prior_dict["param_name"], prior_dict["distri_params"]) + p._pymc_transform_eps_scale = .1 + estival_priors.append( - esp.UniformPrior(prior_dict["param_name"], prior_dict["distri_params"]), + p ) ndelta_values = len([prior_dict for prior_dict in autumn_priors if prior_dict["param_name"].startswith("random_process.delta_values")]) From 0b726b999f54779fbd747a6e316ee6d69d5f7c60 Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Wed, 6 Sep 2023 10:03:23 +1000 Subject: [PATCH 10/18] Prepare code for epsilon comparison experiments --- .../sm_covid2/common_school/calibration.py | 13 +-- .../sm_covid2/common_school/runner_tools.py | 84 ++----------------- 2 files changed, 17 insertions(+), 80 deletions(-) diff --git a/autumn/projects/sm_covid2/common_school/calibration.py b/autumn/projects/sm_covid2/common_school/calibration.py index 40ecd39f7..9e755c6d0 100644 --- a/autumn/projects/sm_covid2/common_school/calibration.py +++ b/autumn/projects/sm_covid2/common_school/calibration.py @@ -17,14 +17,14 @@ ANALYSES_NAMES = ["main", "no_google_mobility", "increased_hh_contacts"] -def get_estival_uniform_priors(autumn_priors): +def get_estival_uniform_priors(autumn_priors, _pymc_transform_eps_scale): estival_priors = [] for prior_dict in autumn_priors: assert prior_dict["distribution"] == "uniform", "Only uniform priors are currently supported" if not "random_process.delta_values" in prior_dict["param_name"] and not "dispersion_param" in prior_dict["param_name"]: p = esp.UniformPrior(prior_dict["param_name"], prior_dict["distri_params"]) - p._pymc_transform_eps_scale = .1 + p._pymc_transform_eps_scale = _pymc_transform_eps_scale estival_priors.append( p @@ -49,18 +49,21 @@ def rp_loglikelihood(params): return rp_loglikelihood -def get_bcm_object(iso3, analysis="main"): +def get_bcm_object(iso3, analysis="main", _pymc_transform_eps_scale=.1): assert analysis in ANALYSES_NAMES, "wrong analysis name requested" project = get_school_project(iso3, analysis) death_target_data = project.calibration.targets[0].data + dispersion_prior = esp.UniformPrior("infection_deaths_dispersion_param", (200, 250)) + dispersion_prior._pymc_transform_eps_scale = _pymc_transform_eps_scale + targets = [ est.NegativeBinomialTarget( "infection_deaths_ma7", death_target_data, - dispersion_param=esp.UniformPrior("infection_deaths_dispersion_param", (200, 250)) + dispersion_param=dispersion_prior ) ] if len(project.calibration.targets) > 1: @@ -93,7 +96,7 @@ def censored_func(modelled, data, parameters, time_weights): default_configuration = project.param_set.baseline m = project.build_model(default_configuration.to_dict()) - priors = get_estival_uniform_priors(project.calibration.all_priors) + priors = get_estival_uniform_priors(project.calibration.all_priors, _pymc_transform_eps_scale) default_params = m.builder.get_default_parameters() diff --git a/autumn/projects/sm_covid2/common_school/runner_tools.py b/autumn/projects/sm_covid2/common_school/runner_tools.py index 5a7f01c40..dcea8b489 100644 --- a/autumn/projects/sm_covid2/common_school/runner_tools.py +++ b/autumn/projects/sm_covid2/common_school/runner_tools.py @@ -312,67 +312,23 @@ def get_quantile_outputs(outputs_df, diff_outputs_df, quantiles=[.025, .25, .5, def run_full_analysis( iso3, analysis="main", - opti_params={'n_searches': 8, 'n_best_retained': 4, 'num_workers': 8, 'parallel_opti_jobs': 4, 'warmup_iterations': 2000, 'search_iterations': 5000, 'init_method': "LHS"}, + best_param_dicts=[], mcmc_params={'draws': 10000, 'tune': 1000, 'cores': 32, 'chains': 32, 'method': 'DEMetropolis'}, - full_run_params={'samples': 1000, 'burn_in': 5000}, + full_run_params={'burn_in': 25000}, output_folder="test_outputs", - logger=None + logger=None, + _pymc_transform_eps_scale=.1 ): out_path = Path(output_folder) - # Check we requested enough searches and number of requested MCMC chains is a multiple of number of retained optimisation searches - assert opti_params['n_best_retained'] <= opti_params['n_searches'] - assert mcmc_params['chains'] % opti_params['n_best_retained'] == 0 - # Create BayesianCompartmentalModel object - bcm = get_bcm_object(iso3, analysis) + bcm = get_bcm_object(iso3, analysis, _pymc_transform_eps_scale=_pymc_transform_eps_scale) - """ - OPTIMISATION - """ - # Sample optimisation starting points with LHS - if logger: - logger.info("Perform LHS sampling") - if opti_params['init_method'] == "LHS": - sample_as_dicts = sample_with_lhs(opti_params['n_searches'], bcm) - elif opti_params['init_method'] == "midpoint": - sample_as_dicts = [{}] * opti_params['n_searches'] - else: - raise ValueError('init_method optimisation argument not supported') - - # Store starting points - with open(out_path / "LHS_init_points.yml", "w") as f: - yaml.dump(sample_as_dicts, f) - - # Perform optimisation searches - if logger: - logger.info(f"Perform optimisation ({opti_params['n_searches']} searches)") - - def opti_func(sample_dict): - best_p, _ = optimise_model_fit(bcm, num_workers=opti_params['num_workers'], warmup_iterations=opti_params['warmup_iterations'], search_iterations=opti_params['search_iterations'], suggested_start=sample_dict) - return best_p - - best_params = map_parallel(opti_func, sample_as_dicts, n_workers=opti_params['parallel_opti_jobs']) - # Store optimal solutions - with open(out_path / "best_params.yml", "w") as f: - yaml.dump(best_params, f) - - # Keep only n_best_retained best solutions - loglikelihoods = [bcm.loglikelihood(**p) for p in best_params] - ll_cutoff = sorted(loglikelihoods, reverse=True)[opti_params['n_best_retained'] - 1] - - retained_init_points, retained_best_params = [], [] - for init_sample, best_p, ll in zip(sample_as_dicts, best_params, loglikelihoods): - if ll >= ll_cutoff: - retained_init_points.append(init_sample) - retained_best_params.append(best_p) + retained_best_params = best_param_dicts # Store retained optimal solutions with open(out_path / "retained_best_params.yml", "w") as f: - yaml.dump(retained_best_params, f) - - # Plot optimal solutions and starting points - plot_opti_params(retained_init_points, retained_best_params, bcm, output_folder) + yaml.dump(retained_best_params, f) # Plot optimal model fits opt_fits_path = out_path / "optimised_fits" @@ -381,9 +337,6 @@ def opti_func(sample_dict): plot_model_fit(bcm, best_p, iso3, opt_fits_path / f"best_fit_{j}.png") plot_multiple_model_fits(bcm, retained_best_params, out_path / "optimal_fits.png") - - if logger: - logger.info("... optimisation completed") # Early return if MCMC not requested if mcmc_params['draws'] == 0: @@ -403,27 +356,8 @@ def opti_func(sample_dict): make_post_mc_plots(idata, full_run_params['burn_in'], output_folder) if logger: logger.info("... MCMC completed") - - """ - Post-MCMC processes - """ - sample_df = extract_sample_subset(idata, full_run_params['samples'], full_run_params['burn_in']) - if logger: - logger.info(f"Perform full runs for {full_run_params['samples']} samples") - outputs_df = run_full_runs(sample_df, iso3, analysis) - if logger: - logger.info("Calculate differential outputs") - diff_outputs_df = calculate_diff_outputs(outputs_df) - if logger: - logger.info("Calculate quantiles") - uncertainty_df, diff_quantiles_df = get_quantile_outputs(outputs_df, diff_outputs_df) - uncertainty_df.to_parquet(out_path / "uncertainty_df.parquet") - diff_quantiles_df.to_parquet(out_path / "diff_quantiles_df.parquet") - - make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_folder) - - return idata, uncertainty_df, diff_quantiles_df - + + return idata def run_full_analysis_smc( From 8d5213de645820ecd8f46988214310191116c25f Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Wed, 6 Sep 2023 10:17:41 +1000 Subject: [PATCH 11/18] Fix analysis runner function --- autumn/projects/sm_covid2/common_school/runner_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autumn/projects/sm_covid2/common_school/runner_tools.py b/autumn/projects/sm_covid2/common_school/runner_tools.py index dcea8b489..0f6fb0ca3 100644 --- a/autumn/projects/sm_covid2/common_school/runner_tools.py +++ b/autumn/projects/sm_covid2/common_school/runner_tools.py @@ -348,7 +348,7 @@ def run_full_analysis( if logger: logger.info(f"Start MCMC for {mcmc_params['tune']} + {mcmc_params['draws']} iterations and {mcmc_params['chains']} chains...") - n_repeat_seed = int(mcmc_params['chains'] / opti_params['n_best_retained']) + n_repeat_seed = 1 init_vals = [[best_p] * n_repeat_seed for i, best_p in enumerate(retained_best_params)] init_vals = [p_dict for sublist in init_vals for p_dict in sublist] idata = sample_with_pymc(bcm, initvals=init_vals, draws=mcmc_params['draws'], tune=mcmc_params['tune'], cores=mcmc_params['cores'], chains=mcmc_params['chains'], method=mcmc_params['method']) From 46594bb9ec39e93655ea8faab953a05c1dc26c6c Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Thu, 7 Sep 2023 11:21:44 +1000 Subject: [PATCH 12/18] Restore full analysis code after testing --- autumn/projects/sm_covid2/common_school/runner_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autumn/projects/sm_covid2/common_school/runner_tools.py b/autumn/projects/sm_covid2/common_school/runner_tools.py index 0f6fb0ca3..dcea8b489 100644 --- a/autumn/projects/sm_covid2/common_school/runner_tools.py +++ b/autumn/projects/sm_covid2/common_school/runner_tools.py @@ -348,7 +348,7 @@ def run_full_analysis( if logger: logger.info(f"Start MCMC for {mcmc_params['tune']} + {mcmc_params['draws']} iterations and {mcmc_params['chains']} chains...") - n_repeat_seed = 1 + n_repeat_seed = int(mcmc_params['chains'] / opti_params['n_best_retained']) init_vals = [[best_p] * n_repeat_seed for i, best_p in enumerate(retained_best_params)] init_vals = [p_dict for sublist in init_vals for p_dict in sublist] idata = sample_with_pymc(bcm, initvals=init_vals, draws=mcmc_params['draws'], tune=mcmc_params['tune'], cores=mcmc_params['cores'], chains=mcmc_params['chains'], method=mcmc_params['method']) From 0699a6751f9076adc19ddf173d7dec1b29b9ebad Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Thu, 7 Sep 2023 11:25:37 +1000 Subject: [PATCH 13/18] Restore full analysis function --- .../sm_covid2/common_school/runner_tools.py | 84 +++++++++++++++++-- 1 file changed, 75 insertions(+), 9 deletions(-) diff --git a/autumn/projects/sm_covid2/common_school/runner_tools.py b/autumn/projects/sm_covid2/common_school/runner_tools.py index dcea8b489..5a7f01c40 100644 --- a/autumn/projects/sm_covid2/common_school/runner_tools.py +++ b/autumn/projects/sm_covid2/common_school/runner_tools.py @@ -312,23 +312,67 @@ def get_quantile_outputs(outputs_df, diff_outputs_df, quantiles=[.025, .25, .5, def run_full_analysis( iso3, analysis="main", - best_param_dicts=[], + opti_params={'n_searches': 8, 'n_best_retained': 4, 'num_workers': 8, 'parallel_opti_jobs': 4, 'warmup_iterations': 2000, 'search_iterations': 5000, 'init_method': "LHS"}, mcmc_params={'draws': 10000, 'tune': 1000, 'cores': 32, 'chains': 32, 'method': 'DEMetropolis'}, - full_run_params={'burn_in': 25000}, + full_run_params={'samples': 1000, 'burn_in': 5000}, output_folder="test_outputs", - logger=None, - _pymc_transform_eps_scale=.1 + logger=None ): out_path = Path(output_folder) + # Check we requested enough searches and number of requested MCMC chains is a multiple of number of retained optimisation searches + assert opti_params['n_best_retained'] <= opti_params['n_searches'] + assert mcmc_params['chains'] % opti_params['n_best_retained'] == 0 + # Create BayesianCompartmentalModel object - bcm = get_bcm_object(iso3, analysis, _pymc_transform_eps_scale=_pymc_transform_eps_scale) + bcm = get_bcm_object(iso3, analysis) - retained_best_params = best_param_dicts + """ + OPTIMISATION + """ + # Sample optimisation starting points with LHS + if logger: + logger.info("Perform LHS sampling") + if opti_params['init_method'] == "LHS": + sample_as_dicts = sample_with_lhs(opti_params['n_searches'], bcm) + elif opti_params['init_method'] == "midpoint": + sample_as_dicts = [{}] * opti_params['n_searches'] + else: + raise ValueError('init_method optimisation argument not supported') + + # Store starting points + with open(out_path / "LHS_init_points.yml", "w") as f: + yaml.dump(sample_as_dicts, f) + + # Perform optimisation searches + if logger: + logger.info(f"Perform optimisation ({opti_params['n_searches']} searches)") + + def opti_func(sample_dict): + best_p, _ = optimise_model_fit(bcm, num_workers=opti_params['num_workers'], warmup_iterations=opti_params['warmup_iterations'], search_iterations=opti_params['search_iterations'], suggested_start=sample_dict) + return best_p + + best_params = map_parallel(opti_func, sample_as_dicts, n_workers=opti_params['parallel_opti_jobs']) + # Store optimal solutions + with open(out_path / "best_params.yml", "w") as f: + yaml.dump(best_params, f) + + # Keep only n_best_retained best solutions + loglikelihoods = [bcm.loglikelihood(**p) for p in best_params] + ll_cutoff = sorted(loglikelihoods, reverse=True)[opti_params['n_best_retained'] - 1] + + retained_init_points, retained_best_params = [], [] + for init_sample, best_p, ll in zip(sample_as_dicts, best_params, loglikelihoods): + if ll >= ll_cutoff: + retained_init_points.append(init_sample) + retained_best_params.append(best_p) # Store retained optimal solutions with open(out_path / "retained_best_params.yml", "w") as f: - yaml.dump(retained_best_params, f) + yaml.dump(retained_best_params, f) + + # Plot optimal solutions and starting points + plot_opti_params(retained_init_points, retained_best_params, bcm, output_folder) # Plot optimal model fits opt_fits_path = out_path / "optimised_fits" @@ -337,6 +381,9 @@ def run_full_analysis( plot_model_fit(bcm, best_p, iso3, opt_fits_path / f"best_fit_{j}.png") plot_multiple_model_fits(bcm, retained_best_params, out_path / "optimal_fits.png") + + if logger: + logger.info("... optimisation completed") # Early return if MCMC not requested if mcmc_params['draws'] == 0: @@ -356,8 +403,27 @@ def run_full_analysis( make_post_mc_plots(idata, full_run_params['burn_in'], output_folder) if logger: logger.info("... MCMC completed") - - return idata + + """ + Post-MCMC processes + """ + sample_df = extract_sample_subset(idata, full_run_params['samples'], full_run_params['burn_in']) + if logger: + logger.info(f"Perform full runs for {full_run_params['samples']} samples") + outputs_df = run_full_runs(sample_df, iso3, analysis) + if logger: + logger.info("Calculate differential outputs") + diff_outputs_df = calculate_diff_outputs(outputs_df) + if logger: + logger.info("Calculate quantiles") + uncertainty_df, diff_quantiles_df = get_quantile_outputs(outputs_df, diff_outputs_df) + uncertainty_df.to_parquet(out_path / "uncertainty_df.parquet") + diff_quantiles_df.to_parquet(out_path / "diff_quantiles_df.parquet") + + make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_folder) + + return idata, uncertainty_df, diff_quantiles_df + def run_full_analysis_smc( From a97f4b925dd4d49d4a7339e6510dae95cca60100 Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Thu, 7 Sep 2023 14:15:32 +1000 Subject: [PATCH 14/18] Clean up full analysis runners --- .../sm_covid2/common_school/runner_tools.py | 104 +++----- .../School_Closure/full_analysis_runner.ipynb | 79 ++---- .../remote/sb_full_analysis.ipynb | 74 ++---- .../remote/sb_full_analysis_smc.ipynb | 228 ------------------ .../School_Closure/remote/sb_mcmc.ipynb | 0 .../remote/sb_multi_full_analysis.ipynb | 90 ++++--- 6 files changed, 119 insertions(+), 456 deletions(-) delete mode 100644 notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_full_analysis_smc.ipynb delete mode 100644 notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_mcmc.ipynb diff --git a/autumn/projects/sm_covid2/common_school/runner_tools.py b/autumn/projects/sm_covid2/common_school/runner_tools.py index 5a7f01c40..a87f302db 100644 --- a/autumn/projects/sm_covid2/common_school/runner_tools.py +++ b/autumn/projects/sm_covid2/common_school/runner_tools.py @@ -19,6 +19,7 @@ from estival.utils.parallel import map_parallel from autumn.core.runs import ManagedRun +from autumn.infrastructure.remote import springboard from autumn.settings.folders import PROJECTS_PATH from autumn.projects.sm_covid2.common_school.calibration import get_bcm_object @@ -37,6 +38,19 @@ with countries_path.open() as f: INCLUDED_COUNTRIES = yaml.unsafe_load(f) + +DEFAULT_RUN_CONFIG = { + "N_CHAINS": 8, + "N_OPTI_SEARCHES": 16, + "OPTI_BUDGET": 10000, + "METROPOLIS_TUNE": 5000, + "METROPOLIS_DRAWS": 30000, + "METROPOLIS_METHOD": "DEMetropolisZ", + "FULL_RUNS_SAMPLES": 1000, + "BURN_IN": 20000 +} + + """ Functions related to model calibration """ @@ -312,17 +326,14 @@ def get_quantile_outputs(outputs_df, diff_outputs_df, quantiles=[.025, .25, .5, def run_full_analysis( iso3, analysis="main", - opti_params={'n_searches': 8, 'n_best_retained': 4, 'num_workers': 8, 'parallel_opti_jobs': 4, 'warmup_iterations': 2000, 'search_iterations': 5000, 'init_method': "LHS"}, - mcmc_params={'draws': 10000, 'tune': 1000, 'cores': 32, 'chains': 32, 'method': 'DEMetropolis'}, - full_run_params={'samples': 1000, 'burn_in': 5000}, + opti_params={'n_searches': 8, 'search_iterations': 5000}, + mcmc_params={'draws': 30000, 'tune': 5000, 'chains': 8, 'method': 'DEMetropolisZ'}, + full_run_params={'samples': 1000, 'burn_in': 20000}, output_folder="test_outputs", logger=None ): out_path = Path(output_folder) - - # Check we requested enough searches and number of requested MCMC chains is a multiple of number of retained optimisation searches - assert opti_params['n_best_retained'] <= opti_params['n_searches'] - assert mcmc_params['chains'] % opti_params['n_best_retained'] == 0 + assert mcmc_params['chains'] <= opti_params['n_searches'] # Create BayesianCompartmentalModel object bcm = get_bcm_object(iso3, analysis) @@ -333,13 +344,8 @@ def run_full_analysis( # Sample optimisation starting points with LHS if logger: logger.info("Perform LHS sampling") - if opti_params['init_method'] == "LHS": - sample_as_dicts = sample_with_lhs(opti_params['n_searches'], bcm) - elif opti_params['init_method'] == "midpoint": - sample_as_dicts = [{}] * opti_params['n_searches'] - else: - raise ValueError('init_method optimisation argument not supported') - + sample_as_dicts = sample_with_lhs(opti_params['n_searches'], bcm) + # Store starting points with open(out_path / "LHS_init_points.yml", "w") as f: yaml.dump(sample_as_dicts, f) @@ -347,19 +353,19 @@ def run_full_analysis( # Perform optimisation searches if logger: logger.info(f"Perform optimisation ({opti_params['n_searches']} searches)") - + n_opti_workers = 8 def opti_func(sample_dict): - best_p, _ = optimise_model_fit(bcm, num_workers=opti_params['num_workers'], warmup_iterations=opti_params['warmup_iterations'], search_iterations=opti_params['search_iterations'], suggested_start=sample_dict) + best_p, _ = optimise_model_fit(bcm, num_workers=n_opti_workers, search_iterations=opti_params['search_iterations'], suggested_start=sample_dict) return best_p - best_params = map_parallel(opti_func, sample_as_dicts, n_workers=opti_params['parallel_opti_jobs']) + best_params = map_parallel(opti_func, sample_as_dicts, n_workers=2 * mcmc_params['chains'] / n_opti_workers) # oversubscribing # Store optimal solutions with open(out_path / "best_params.yml", "w") as f: yaml.dump(best_params, f) # Keep only n_best_retained best solutions loglikelihoods = [bcm.loglikelihood(**p) for p in best_params] - ll_cutoff = sorted(loglikelihoods, reverse=True)[opti_params['n_best_retained'] - 1] + ll_cutoff = sorted(loglikelihoods, reverse=True)[mcmc_params['chains'] - 1] retained_init_points, retained_best_params = [], [] for init_sample, best_p, ll in zip(sample_as_dicts, best_params, loglikelihoods): @@ -374,11 +380,11 @@ def opti_func(sample_dict): # Plot optimal solutions and starting points plot_opti_params(retained_init_points, retained_best_params, bcm, output_folder) - # Plot optimal model fits - opt_fits_path = out_path / "optimised_fits" - opt_fits_path.mkdir(exist_ok=True) - for j, best_p in enumerate(retained_best_params): - plot_model_fit(bcm, best_p, iso3, opt_fits_path / f"best_fit_{j}.png") + # # Plot optimal model fits + # opt_fits_path = out_path / "optimised_fits" + # opt_fits_path.mkdir(exist_ok=True) + # for j, best_p in enumerate(retained_best_params): + # plot_model_fit(bcm, best_p, iso3, opt_fits_path / f"best_fit_{j}.png") plot_multiple_model_fits(bcm, retained_best_params, out_path / "optimal_fits.png") @@ -395,57 +401,10 @@ def opti_func(sample_dict): if logger: logger.info(f"Start MCMC for {mcmc_params['tune']} + {mcmc_params['draws']} iterations and {mcmc_params['chains']} chains...") - n_repeat_seed = int(mcmc_params['chains'] / opti_params['n_best_retained']) + n_repeat_seed = 1 init_vals = [[best_p] * n_repeat_seed for i, best_p in enumerate(retained_best_params)] init_vals = [p_dict for sublist in init_vals for p_dict in sublist] - idata = sample_with_pymc(bcm, initvals=init_vals, draws=mcmc_params['draws'], tune=mcmc_params['tune'], cores=mcmc_params['cores'], chains=mcmc_params['chains'], method=mcmc_params['method']) - idata.to_netcdf(out_path / "idata.nc") - make_post_mc_plots(idata, full_run_params['burn_in'], output_folder) - if logger: - logger.info("... MCMC completed") - - """ - Post-MCMC processes - """ - sample_df = extract_sample_subset(idata, full_run_params['samples'], full_run_params['burn_in']) - if logger: - logger.info(f"Perform full runs for {full_run_params['samples']} samples") - outputs_df = run_full_runs(sample_df, iso3, analysis) - if logger: - logger.info("Calculate differential outputs") - diff_outputs_df = calculate_diff_outputs(outputs_df) - if logger: - logger.info("Calculate quantiles") - uncertainty_df, diff_quantiles_df = get_quantile_outputs(outputs_df, diff_outputs_df) - uncertainty_df.to_parquet(out_path / "uncertainty_df.parquet") - diff_quantiles_df.to_parquet(out_path / "diff_quantiles_df.parquet") - - make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_folder) - - return idata, uncertainty_df, diff_quantiles_df - - - -def run_full_analysis_smc( - iso3, - analysis="main", - smc_params={'draws': 2000, 'cores': 32, 'chains': 4}, - full_run_params={'samples': 1000, 'burn_in': 5000}, - output_folder="test_outputs", - logger=None -): - out_path = Path(output_folder) - - # Create BayesianCompartmentalModel object - bcm = get_bcm_object(iso3, analysis) - - """ - SMC - """ - if logger: - logger.info(f"Start SMC for {smc_params['draws']} draws and {smc_params['chains']} chains...") - - idata = sample_with_pymc_smc(bcm, draws=smc_params['draws'], cores=smc_params['cores'], chains=smc_params['chains']) + idata = sample_with_pymc(bcm, initvals=init_vals, draws=mcmc_params['draws'], tune=mcmc_params['tune'], cores=mcmc_params['chains'], chains=mcmc_params['chains'], method=mcmc_params['method']) idata.to_netcdf(out_path / "idata.nc") make_post_mc_plots(idata, full_run_params['burn_in'], output_folder) if logger: @@ -471,7 +430,6 @@ def run_full_analysis_smc( return idata, uncertainty_df, diff_quantiles_df - """ Helper functions for remote runs """ diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/full_analysis_runner.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/full_analysis_runner.ipynb index 979220923..8518fd74f 100644 --- a/notebooks/user/rragonnet/project_specific/School_Closure/full_analysis_runner.ipynb +++ b/notebooks/user/rragonnet/project_specific/School_Closure/full_analysis_runner.ipynb @@ -6,11 +6,7 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "import pandas as pd\n", - "\n", - "from autumn.projects.sm_covid2.common_school.runner_tools import run_full_analysis\n", - "from autumn.projects.sm_covid2.common_school.output_plots.country_spec import make_country_output_tiling" + "from autumn.projects.sm_covid2.common_school.runner_tools import run_full_analysis, DEFAULT_RUN_CONFIG" ] }, { @@ -22,6 +18,15 @@ "Note that the three returned objects will be dumped automatically into the specified output folder (1 .nc file, 2 .csv files)." ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "DEFAULT_RUN_CONFIG" + ] + }, { "cell_type": "code", "execution_count": null, @@ -31,17 +36,16 @@ "ISO3 = \"FRA\"\n", "ANALYSIS = \"main\"\n", "\n", - "N_CHAINS = 8\n", - "N_OPTI_SEARCHES = 2\n", - "N_BEST_RETAINED = 1\n", - "OPTI_BUDGET = 100\n", + "N_CHAINS = 2 # DEFAULT_RUN_CONFIG['N_CHAINS']\n", + "N_OPTI_SEARCHES = 2 # DEFAULT_RUN_CONFIG['N_OPTI_SEARCHES']\n", + "OPTI_BUDGET = 100 # DEFAULT_RUN_CONFIG['OPTI_BUDGET']\n", "\n", - "METROPOLIS_TUNE = 100\n", - "METROPOLIS_DRAWS = 500\n", - "METROPOLIS_METHOD = \"DEMetropolis\"\n", + "METROPOLIS_TUNE = 100 # DEFAULT_RUN_CONFIG['METROPOLIS_TUNE']\n", + "METROPOLIS_DRAWS = 100 # DEFAULT_RUN_CONFIG['METROPOLIS_DRAWS']\n", + "METROPOLIS_METHOD = DEFAULT_RUN_CONFIG['METROPOLIS_METHOD']\n", "\n", - "FULL_RUNS_SAMPLES = 100\n", - "BURN_IN = 50" + "FULL_RUNS_SAMPLES = 100 # DEFAULT_RUN_CONFIG['FULL_RUNS_SAMPLES']\n", + "BURN_IN = 50 # DEFAULT_RUN_CONFIG['BURN_IN']" ] }, { @@ -53,55 +57,12 @@ "idata, uncertainty_df, diff_quantiles_df = run_full_analysis(\n", " ISO3,\n", " analysis=ANALYSIS, \n", - " opti_params={'n_searches': N_OPTI_SEARCHES, \"n_best_retained\": N_BEST_RETAINED, 'num_workers': 8, 'parallel_opti_jobs': 4, 'warmup_iterations': 0, 'search_iterations': OPTI_BUDGET, 'init_method': 'LHS'},\n", - " mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'cores': N_CHAINS, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n", + " opti_params={'n_searches': N_OPTI_SEARCHES, 'search_iterations': OPTI_BUDGET},\n", + " mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n", " full_run_params={'samples': FULL_RUNS_SAMPLES, 'burn_in': BURN_IN},\n", " output_folder=\"test_full\",\n", ")" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Plot analysis outputs" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Load data from previous run if required" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# uncertainty_df = pd.read_parquet(os.path.join(output_folder, \"uncertainty_df.parquet\"))\n", - "# diff_quantiles_df = pd.read_parquet(os.path.join(output_folder, \"diff_quantiles_df.parquet\"))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Make tiling plot" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_folder)" - ] } ], "metadata": { diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_full_analysis.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_full_analysis.ipynb index 5d65b6816..a18d5462b 100644 --- a/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_full_analysis.ipynb +++ b/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_full_analysis.ipynb @@ -8,7 +8,12 @@ "outputs": [], "source": [ "from autumn.infrastructure.remote import springboard\n", - "from autumn.projects.sm_covid2.common_school.runner_tools import run_full_analysis, print_continuous_status, download_analysis" + "from autumn.projects.sm_covid2.common_school.runner_tools import (\n", + " run_full_analysis, \n", + " print_continuous_status, \n", + " download_analysis,\n", + " DEFAULT_RUN_CONFIG,\n", + ")" ] }, { @@ -32,55 +37,33 @@ { "cell_type": "code", "execution_count": null, - "id": "a5f49d54", + "id": "c4aa031b", "metadata": {}, "outputs": [], "source": [ - "ISO3 = \"FRA\"\n", - "ANALYSIS = \"main\"\n", - "\n", - "N_CHAINS = 32\n", - "N_OPTI_SEARCHES = 4\n", - "N_BEST_RETAINED = 1\n", - "OPTI_BUDGET = 10000\n", - "\n", - "METROPOLIS_TUNE = 2000\n", - "METROPOLIS_DRAWS = 10000\n", - "METROPOLIS_METHOD = \"DEMetropolis\"\n", - "\n", - "FULL_RUNS_SAMPLES = 1000\n", - "BURN_IN = 5000" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "253866ab", - "metadata": {}, - "source": [ - "#### Testing config" + "DEFAULT_RUN_CONFIG" ] }, { "cell_type": "code", "execution_count": null, - "id": "ac0a9fdd", + "id": "a5f49d54", "metadata": {}, "outputs": [], "source": [ - "# ISO3 = \"FRA\"\n", - "# ANALYSIS = \"main\"\n", + "ISO3 = \"FRA\"\n", + "ANALYSIS = \"main\"\n", "\n", - "# N_CHAINS = 32\n", - "# N_OPTI_SEARCHES = 8\n", - "# OPTI_BUDGET = 700\n", + "N_CHAINS = DEFAULT_RUN_CONFIG['N_CHAINS']\n", + "N_OPTI_SEARCHES = DEFAULT_RUN_CONFIG['N_OPTI_SEARCHES']\n", + "OPTI_BUDGET = DEFAULT_RUN_CONFIG['OPTI_BUDGET']\n", "\n", - "# METROPOLIS_TUNE = 200\n", - "# METROPOLIS_DRAWS = 1000\n", - "# METROPOLIS_METHOD = \"DEMetropolis\"\n", + "METROPOLIS_TUNE = DEFAULT_RUN_CONFIG['METROPOLIS_TUNE']\n", + "METROPOLIS_DRAWS = DEFAULT_RUN_CONFIG['METROPOLIS_DRAWS']\n", + "METROPOLIS_METHOD = DEFAULT_RUN_CONFIG['METROPOLIS_METHOD']\n", "\n", - "# FULL_RUNS_SAMPLES = 100\n", - "# BURN_IN = 500" + "FULL_RUNS_SAMPLES = DEFAULT_RUN_CONFIG['FULL_RUNS_SAMPLES']\n", + "BURN_IN = DEFAULT_RUN_CONFIG['BURN_IN']" ] }, { @@ -99,7 +82,6 @@ "metadata": {}, "outputs": [], "source": [ - "assert N_CHAINS % N_OPTI_SEARCHES == 0\n", "assert (METROPOLIS_DRAWS - BURN_IN) * N_CHAINS >= FULL_RUNS_SAMPLES\n", "assert METROPOLIS_METHOD in (\"DEMetropolis\", \"DEMetropolisZ\")" ] @@ -107,7 +89,7 @@ { "cell_type": "code", "execution_count": null, - "id": "47d05a73", + "id": "4810f137", "metadata": {}, "outputs": [], "source": [ @@ -121,8 +103,8 @@ " idata, uncertainty_df, diff_quantiles_df = run_full_analysis(\n", " iso3,\n", " analysis=analysis, \n", - " opti_params={'n_searches': N_OPTI_SEARCHES, \"n_best_retained\": N_BEST_RETAINED, 'num_workers': 8, 'parallel_opti_jobs': 8, 'warmup_iterations': 0, 'search_iterations': OPTI_BUDGET, 'init_method': 'LHS'},\n", - " mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'cores': N_CHAINS, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n", + " opti_params={'n_searches': N_OPTI_SEARCHES, 'search_iterations': OPTI_BUDGET},\n", + " mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n", " full_run_params={'samples': FULL_RUNS_SAMPLES, 'burn_in': BURN_IN},\n", " output_folder=bridge.out_path,\n", " logger=bridge.logger\n", @@ -198,7 +180,7 @@ "outputs": [], "source": [ "# wait function with status printing\n", - "print_continuous_status(runner)" + "# print_continuous_status(runner)" ] }, { @@ -208,7 +190,7 @@ "metadata": {}, "outputs": [], "source": [ - "run_path = 'projects/school_project/FRA/2023-08-22T1004-single_start_main_LHS4_opt10000_mc2000n10000'" + "# run_path = 'projects/school_project/FRA/2023-08-22T1004-single_start_main_LHS4_opt10000_mc2000n10000'" ] }, { @@ -220,14 +202,6 @@ "source": [ "download_analysis(run_path)" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "170c3a1f", - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_full_analysis_smc.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_full_analysis_smc.ipynb deleted file mode 100644 index 275d64cd2..000000000 --- a/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_full_analysis_smc.ipynb +++ /dev/null @@ -1,228 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "f5408fea", - "metadata": {}, - "outputs": [], - "source": [ - "from autumn.infrastructure.remote import springboard\n", - "from autumn.projects.sm_covid2.common_school.runner_tools import run_full_analysis_smc, print_continuous_status, download_analysis" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "e7255011", - "metadata": {}, - "source": [ - "### Define task function" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "2da27a35", - "metadata": {}, - "source": [ - "#### Standard config" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a5f49d54", - "metadata": {}, - "outputs": [], - "source": [ - "ISO3 = \"FRA\"\n", - "ANALYSIS = \"main\"\n", - "\n", - "N_CHAINS = 8\n", - "N_CORES = 32\n", - "\n", - "METROPOLIS_DRAWS = 2000\n", - "\n", - "FULL_RUNS_SAMPLES = 1000\n", - "BURN_IN = 5000" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "253866ab", - "metadata": {}, - "source": [ - "#### Testing config" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ac0a9fdd", - "metadata": {}, - "outputs": [], - "source": [ - "# ISO3 = \"FRA\"\n", - "# ANALYSIS = \"main\"\n", - "\n", - "# N_CHAINS = 32\n", - "# N_OPTI_SEARCHES = 8\n", - "# OPTI_BUDGET = 700\n", - "\n", - "# METROPOLIS_TUNE = 200\n", - "# METROPOLIS_DRAWS = 1000\n", - "# METROPOLIS_METHOD = \"DEMetropolis\"\n", - "\n", - "# FULL_RUNS_SAMPLES = 100\n", - "# BURN_IN = 500" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "47d05a73", - "metadata": {}, - "outputs": [], - "source": [ - "def remote_full_analysis_smc_task(bridge: springboard.task.TaskBridge, iso3: str = 'FRA', analysis: str = \"main\"):\n", - " \n", - " import multiprocessing as mp\n", - " mp.set_start_method('forkserver')\n", - "\n", - " bridge.logger.info(f\"Running full analysis for {iso3}. Analysis type: {analysis}.\")\n", - "\n", - " idata, uncertainty_df, diff_quantiles_df = run_full_analysis_smc(\n", - " iso3,\n", - " analysis=analysis, \n", - " smc_params={'draws': METROPOLIS_DRAWS, 'cores': N_CORES, 'chains': N_CHAINS},\n", - " full_run_params={'samples': FULL_RUNS_SAMPLES, 'burn_in': BURN_IN},\n", - " output_folder=bridge.out_path,\n", - " logger=bridge.logger\n", - " )\n", - " \n", - " bridge.logger.info(\"Full analysis complete\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1a1207a8", - "metadata": {}, - "outputs": [], - "source": [ - "mspec = springboard.EC2MachineSpec(N_CORES, 4, \"compute\")\n", - "task_kwargs = {\n", - " \"iso3\": ISO3,\n", - " \"analysis\": ANALYSIS\n", - "}\n", - "tspec = springboard.TaskSpec(remote_full_analysis_task, task_kwargs)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e12adbec", - "metadata": {}, - "outputs": [], - "source": [ - "analysis_title = \"SMC\"\n", - "config_str = f\"_{ANALYSIS}_mc{METROPOLIS_DRAWS}\"\n", - "\n", - "run_path = springboard.launch.get_autumn_project_run_path(\"school_project\", ISO3, analysis_title + config_str)\n", - "run_path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "16657a60", - "metadata": {}, - "outputs": [], - "source": [ - "runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b416f02f", - "metadata": {}, - "outputs": [], - "source": [ - "runner.s3.get_status()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5dfd7dfe", - "metadata": {}, - "outputs": [], - "source": [ - "print(runner.top(\"%CPU\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bc0faa03", - "metadata": {}, - "outputs": [], - "source": [ - "# wait function with status printing\n", - "print_continuous_status(runner)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f8019dfb", - "metadata": {}, - "outputs": [], - "source": [ - "run_path = 'projects/school_project/FRA/2023-08-22T1004-single_start_main_LHS4_opt10000_mc2000n10000'" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a8d31fe3", - "metadata": {}, - "outputs": [], - "source": [ - "download_analysis(run_path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "170c3a1f", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_mcmc.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_mcmc.ipynb deleted file mode 100644 index e69de29bb..000000000 diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_multi_full_analysis.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_multi_full_analysis.ipynb index ca4a6fdae..45d3eebc6 100644 --- a/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_multi_full_analysis.ipynb +++ b/notebooks/user/rragonnet/project_specific/School_Closure/remote/sb_multi_full_analysis.ipynb @@ -2,19 +2,36 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\rrag0004\\.conda\\envs\\summer2\\lib\\site-packages\\summer\\runner\\vectorized_runner.py:363: NumbaDeprecationWarning: \u001b[1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\u001b[0m\n", + " def get_strain_infection_values(\n", + "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n" + ] + } + ], "source": [ "from autumn.infrastructure.remote import springboard\n", - "from autumn.projects.sm_covid2.common_school.runner_tools import run_full_analysis, print_continuous_status, download_analysis" + "from autumn.projects.sm_covid2.common_school.runner_tools import (\n", + " run_full_analysis, \n", + " print_continuous_status, \n", + " download_analysis,\n", + " DEFAULT_RUN_CONFIG,\n", + ")" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "#### Standard config" + "DEFAULT_RUN_CONFIG" ] }, { @@ -26,17 +43,26 @@ "ISO3_LIST = [\"AUS\", \"BOL\", \"FRA\", \"MAR\", \"PHL\"]\n", "ANALYSIS = \"main\"\n", "\n", - "N_CHAINS = 32\n", - "N_OPTI_SEARCHES = 32\n", - "N_BEST_RETAINED = 16\n", - "OPTI_BUDGET = 10000\n", + "N_CHAINS = DEFAULT_RUN_CONFIG['N_CHAINS']\n", + "N_OPTI_SEARCHES = DEFAULT_RUN_CONFIG['N_OPTI_SEARCHES']\n", + "OPTI_BUDGET = DEFAULT_RUN_CONFIG['OPTI_BUDGET']\n", "\n", - "METROPOLIS_TUNE = 2000\n", - "METROPOLIS_DRAWS = 10000\n", - "METROPOLIS_METHOD = \"DEMetropolis\"\n", + "METROPOLIS_TUNE = DEFAULT_RUN_CONFIG['METROPOLIS_TUNE']\n", + "METROPOLIS_DRAWS = DEFAULT_RUN_CONFIG['METROPOLIS_DRAWS']\n", + "METROPOLIS_METHOD = DEFAULT_RUN_CONFIG['METROPOLIS_METHOD']\n", "\n", - "FULL_RUNS_SAMPLES = 1000\n", - "BURN_IN = 5000" + "FULL_RUNS_SAMPLES = DEFAULT_RUN_CONFIG['FULL_RUNS_SAMPLES']\n", + "BURN_IN = DEFAULT_RUN_CONFIG['BURN_IN']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "assert (METROPOLIS_DRAWS - BURN_IN) * N_CHAINS >= FULL_RUNS_SAMPLES\n", + "assert METROPOLIS_METHOD in (\"DEMetropolis\", \"DEMetropolisZ\")" ] }, { @@ -55,8 +81,8 @@ " idata, uncertainty_df, diff_quantiles_df = run_full_analysis(\n", " iso3,\n", " analysis=analysis, \n", - " opti_params={'n_searches': N_OPTI_SEARCHES, \"n_best_retained\": N_BEST_RETAINED, 'num_workers': 8, 'parallel_opti_jobs': 8, 'warmup_iterations': 0, 'search_iterations': OPTI_BUDGET, 'init_method': 'LHS'},\n", - " mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'cores': N_CHAINS, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n", + " opti_params={'n_searches': N_OPTI_SEARCHES, 'search_iterations': OPTI_BUDGET},\n", + " mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n", " full_run_params={'samples': FULL_RUNS_SAMPLES, 'burn_in': BURN_IN},\n", " output_folder=bridge.out_path,\n", " logger=bridge.logger\n", @@ -117,37 +143,9 @@ "metadata": {}, "outputs": [], "source": [ - "runners_dict = {iso3: runners[list(runners.keys())[i]] for i, iso3 in enumerate(ISO3_LIST)}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "runner.instance" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "runners_dict[\"AUS\"].s3.get_status()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for iso3 in ISO3_LIST:\n", - " runner = runners_dict[iso3]\n", + "for run_path, runner in runners:\n", " print(runner.s3.get_status())\n", - " print(runner.top(\"%MEM\"))" + " # print(runner.top(\"%MEM\"))" ] }, { From 7638f3a594fe062a7029086c5311fd87c5c6d8ab Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Thu, 7 Sep 2023 14:19:21 +1000 Subject: [PATCH 15/18] Remove VNM data prior to local transmission --- autumn/projects/sm_covid2/common_school/project_maker.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/autumn/projects/sm_covid2/common_school/project_maker.py b/autumn/projects/sm_covid2/common_school/project_maker.py index 393a6ff2c..21a1e403f 100644 --- a/autumn/projects/sm_covid2/common_school/project_maker.py +++ b/autumn/projects/sm_covid2/common_school/project_maker.py @@ -282,6 +282,8 @@ def get_school_project_timeseries(iso3, sero_data): table_name="owid", conditions={"iso_code": iso3}, columns=["date", "new_deaths"] ) data = remove_death_outliers(iso3, data) + if iso3 == "VNM": # remove early data points associated with few deaths prior to local transmission + data = data[pd.to_datetime(data["date"]) >= "15 May 2021"] # apply moving average data["smoothed_new_deaths"] = data["new_deaths"].rolling(7).mean()[6:] From b8b346f765b3708747c7c855e691236cc72f6f3b Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Thu, 7 Sep 2023 14:35:00 +1000 Subject: [PATCH 16/18] Update params following Angus's review --- autumn/models/sm_covid2/params.yml | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/autumn/models/sm_covid2/params.yml b/autumn/models/sm_covid2/params.yml index 4477f2251..97abfe8aa 100644 --- a/autumn/models/sm_covid2/params.yml +++ b/autumn/models/sm_covid2/params.yml @@ -38,8 +38,8 @@ infectious_seed_time: 60 seed_duration: 7 sojourns: - latent: 6.57 # Pooled estimate from Wu et al. doi:10.1001/jamanetworkopen.2022.28008 - active: 8. + latent: 6.65 # wild-type + active: 6.5 mobility: region: null @@ -150,9 +150,9 @@ is_dynamic_mixing_matrix: true # parameters related to immunity stratification (two strata: unvaccinated / vaccinated) vaccine_effects: - ve_infection: .5 #FIXME: dummy data - ve_hospitalisation: .9 #FIXME: dummy data - ve_death: .9 #FIXME: dummy data + ve_infection: .7 + ve_hospitalisation: .9 + ve_death: .9 # voc_emergence: null voc_emergence: @@ -176,11 +176,11 @@ voc_emergence: new_voc_seed: time_from_gisaid_report: 0. seed_duration: 10. - contact_rate_multiplier: 1.38 # DOI: 10.1503/cmaj.211248 + contact_rate_multiplier: 1.5 incubation_overwrite_value: 4.41 # Wu et al. doi:10.1001/jamanetworkopen.2022.28008 - vacc_immune_escape: .32 # DOI: 10.1503/cmaj.211248 - hosp_risk_adjuster: 2.1 # Fisman et al. doi: 10.1503/cmaj.211248 - death_risk_adjuster: 2.3 # Fisman et al. doi: 10.1503/cmaj.211248 + vacc_immune_escape: .3 + hosp_risk_adjuster: 2.0 + death_risk_adjuster: 2.3 icu_risk_adjuster: 3.4 # Fisman et al. doi: 10.1503/cmaj.211248 cross_protection: wild_type: 1. @@ -192,11 +192,11 @@ voc_emergence: new_voc_seed: time_from_gisaid_report: 0. seed_duration: 10. - contact_rate_multiplier: 1.92 # DOI: 10.1503/cmaj.211248 + contact_rate_multiplier: 2. # DOI: 10.1503/cmaj.211248 incubation_overwrite_value: 3.42 # Wu et al. doi:10.1001/jamanetworkopen.2022.28008 - vacc_immune_escape: .55 # DOI: 10.1503/cmaj.211248 - hosp_risk_adjuster: 0.861 # 2.1 * .41 using HR of Omicron vs Delta (Nyberg doi.org/10.1016/S0140-6736(22)00462-7) - death_risk_adjuster: 0.713 # 2.3 * 0.31 using HR of Omicron vs Delta (Nyberg doi.org/10.1016/S0140-6736(22)00462-7) + vacc_immune_escape: .6 # DOI: 10.1503/cmaj.211248 + hosp_risk_adjuster: 0.82 # Multiply Nyberg by Fisman = 0.41*2 = 0.82 + death_risk_adjuster: 0.71 # Multiply Nyberg by Fisman =0.31*2.3=0.713 icu_risk_adjuster: 1. # FIXME dummy value for now but not used at the moment cross_protection: wild_type: 1. @@ -207,7 +207,7 @@ time_from_onset_to_event: # using time of symptom onset as reference hospitalisation: distribution: gamma shape: 5. - mean: 7.7 # Estimates taken from ISARIC report 4th Oct 2020 (Mendeley citation key Pritchard2020} + mean: 3. icu_admission: # using time of hospitalisation as reference distribution: gamma shape: 5. @@ -223,7 +223,7 @@ hospital_stay: hospital_all: distribution: gamma shape: 5. - mean: 12.8 # Estimates taken from ISARIC report 4th Oct 2020 (Mendeley citation key Pritchard2020} + mean: 9 icu: distribution: gamma From 40e8b1f9006a9b162fe19563b7703cf8b1443ab5 Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Thu, 7 Sep 2023 15:33:35 +1000 Subject: [PATCH 17/18] Clean up notebooks --- .../output_plots/country_spec.py | 61 ++ .../School_Closure/analysis_plots.ipynb | 171 ----- .../School_Closure/multistart_opti.ipynb | 157 ----- .../School_Closure/output_plots.ipynb | 584 ------------------ .../School_Closure/post_mcmc_runner.ipynb | 117 ---- 5 files changed, 61 insertions(+), 1029 deletions(-) delete mode 100644 notebooks/user/rragonnet/project_specific/School_Closure/analysis_plots.ipynb delete mode 100644 notebooks/user/rragonnet/project_specific/School_Closure/multistart_opti.ipynb delete mode 100644 notebooks/user/rragonnet/project_specific/School_Closure/output_plots.ipynb delete mode 100644 notebooks/user/rragonnet/project_specific/School_Closure/post_mcmc_runner.ipynb diff --git a/autumn/projects/sm_covid2/common_school/output_plots/country_spec.py b/autumn/projects/sm_covid2/common_school/output_plots/country_spec.py index 1cb7ffa4f..93de9e869 100644 --- a/autumn/projects/sm_covid2/common_school/output_plots/country_spec.py +++ b/autumn/projects/sm_covid2/common_school/output_plots/country_spec.py @@ -430,6 +430,67 @@ def make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_f plt.close() +def plot_incidence_by_age(derived_outputs, ax, scenario, as_proportion: bool): + + colours = ["cornflowerblue", "slateblue", "mediumseagreen", "lightcoral", "purple"] + + update_rcparams() + y_label = "COVID-19 incidence proportion" if as_proportion else "COVID-19 incidence" + + times = derived_outputs["incidence", scenario].index.to_list() + running_total = [0] * len(derived_outputs["incidence", scenario]) + age_groups = base_params['age_groups'] + + y_max = 1. if as_proportion else max([derived_outputs["incidence", sc].max() for sc in [0, 1]]) + + for i_age, age_group in enumerate(age_groups): + output_name = f"incidenceXagegroup_{age_group}" + + if i_age < len(age_groups) - 1: + upper_age = age_groups[i_age + 1] - 1 if i_age < len(age_groups) - 1 else "" + age_group_name = f"{age_group}-{upper_age}" + else: + age_group_name = f"{age_group}+" + + age_group_incidence = derived_outputs[output_name, scenario] + + if as_proportion: + numerator, denominator = age_group_incidence, derived_outputs["incidence", scenario] + age_group_proportion = np.divide(numerator, denominator, out=np.zeros_like(numerator), where=denominator!=0) + new_running_total = age_group_proportion + running_total + else: + new_running_total = age_group_incidence + running_total + + ax.fill_between(times, running_total, new_running_total, color=colours[i_age], label=age_group_name, zorder=2, alpha=.8) + running_total = copy(new_running_total) + + # y_max = max(new_running_total) + plot_ymax = y_max * 1.1 + add_school_closure_patches(ax, ISO3, ymax=plot_ymax) + + # work out first time with positive incidence + t_min = derived_outputs['incidence', 0].gt(0).idxmax() + ax.set_xlim((t_min, model_end)) + ax.set_ylim((0, plot_ymax)) + + ax.set_ylabel(y_label) + + if not as_proportion and scenario == 0: + handles, labels = ax.get_legend_handles_labels() + ax.legend( + reversed(handles), + reversed(labels), + # title="Age:", + # fontsize=12, + # title_fontsize=12, + labelspacing=.2, + handlelength=1., + handletextpad=.5, + columnspacing=1., + facecolor="white", + ncol=2, + + ) def test_tiling_plot(): from pathlib import Path diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/analysis_plots.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/analysis_plots.ipynb deleted file mode 100644 index 1fae05973..000000000 --- a/notebooks/user/rragonnet/project_specific/School_Closure/analysis_plots.ipynb +++ /dev/null @@ -1,171 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from matplotlib import pyplot as plt\n", - "import os\n", - "import pandas as pd \n", - "\n", - "from autumn.projects.sm_covid2.common_school.output_plots.country_spec import (\n", - " plot_model_fit,\n", - " plot_two_scenarios,\n", - " plot_final_size_compare,\n", - " plot_diff_outputs,\n", - " make_country_output_tiling\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "output_folder = \"test_outputs2\"\n", - "uncertainty_df = pd.read_parquet(os.path.join(output_folder, \"uncertainty_df.parquet\"))\n", - "diff_quantiles_df = pd.read_parquet(os.path.join(output_folder, \"diff_quantiles_df.parquet\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "iso3 = \"FRA\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_folder)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(8, 4))\n", - "plot_model_fit(ax, uncertainty_df, \"infection_deaths\", iso3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(8, 4))\n", - "plot_model_fit(ax, uncertainty_df, \"transformed_random_process\", iso3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(8, 4))\n", - "plot_two_scenarios(ax, uncertainty_df, \"infection_deaths\", iso3, True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(8, 4))\n", - "plot_two_scenarios(ax, uncertainty_df, \"hospital_occupancy\", iso3, True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(4, 4))\n", - "plot_final_size_compare(ax, uncertainty_df, \"peak_hospital_occupancy\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(4, 4))\n", - "plot_final_size_compare(ax, uncertainty_df, \"cumulative_infection_deaths\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(4, 4))\n", - "plot_final_size_compare(ax, uncertainty_df, \"cumulative_incidence\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(6, 4))\n", - "plot_diff_outputs(ax, diff_quantiles_df, [\"cases_averted_relative\", \"deaths_averted_relative\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig, ax = plt.subplots(1, 1, figsize=(4, 4))\n", - "plot_diff_outputs(ax, diff_quantiles_df, [\"delta_hospital_peak_relative\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "summer2", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/multistart_opti.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/multistart_opti.ipynb deleted file mode 100644 index f4cadc3aa..000000000 --- a/notebooks/user/rragonnet/project_specific/School_Closure/multistart_opti.ipynb +++ /dev/null @@ -1,157 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from autumn.projects.sm_covid2.common_school.calibration import get_bcm_object, plot_model_fit\n", - "from autumn.projects.sm_covid2.common_school.runner_tools import optimise_model_fit\n", - "from autumn.projects.sm_covid2.common_school.calibration_plots.opti_plots import plot_opti_params\n", - "\n", - "from scipy.stats import qmc" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "iso3 = \"FRA\"\n", - "n_samples = 2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "bcm = get_bcm_object(iso3, \"main\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "lhs_sampled_params = [p for p in bcm.priors if p != \"random_process.delta_values\"] " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "d = len(lhs_sampled_params)\n", - "sampler = qmc.LatinHypercube(d=d)\n", - "regular_sample = sampler.random(n=n_samples)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "l_bounds = [bcm.priors[p].bounds()[0] for p in lhs_sampled_params]\n", - "u_bounds = [bcm.priors[p].bounds()[1] for p in lhs_sampled_params]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sample = qmc.scale(regular_sample, l_bounds, u_bounds)\n", - "sample_as_dicts = [{p: sample[i][j] for j, p in enumerate(lhs_sampled_params)} for i in range(n_samples)]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import yaml\n", - "\n", - "with open('test_dump.yml', 'w') as outfile:\n", - " yaml.dump(sample_as_dicts, outfile)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "best_params = {}\n", - "for j, sample_dict in enumerate(sample_as_dicts):\n", - " best_p, _ = optimise_model_fit(bcm, 200, 500, suggested_start=sample_dict)\n", - " best_params[j] = best_p" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with open('test_dump_best.yml', 'w') as outfile:\n", - " yaml.dump(best_params, outfile)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "f = plot_model_fit(bcm, best_params[0], \"name\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig = plot_opti_params(sample_as_dicts, best_params, bcm)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "summer2", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/output_plots.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/output_plots.ipynb deleted file mode 100644 index 652a6e9e2..000000000 --- a/notebooks/user/rragonnet/project_specific/School_Closure/output_plots.ipynb +++ /dev/null @@ -1,584 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from matplotlib import pyplot as plt\n", - "import os\n", - "import numpy as np\n", - "import pandas as pd\n", - "import datetime\n", - "from matplotlib.patches import Rectangle\n", - "import matplotlib.gridspec as gridspec\n", - "from copy import copy\n", - "\n", - "\n", - "from summer.utils import ref_times_to_dti\n", - "from autumn.core.runs.managed import ManagedRun\n", - "from autumn.models.sm_covid import base_params\n", - "from autumn.settings.constants import COVID_BASE_DATETIME\n", - "from autumn.core import inputs\n", - "\n", - "from autumn.projects.sm_covid.common_school.project_maker import get_school_project_timeseries\n", - "from notebooks.user.rragonnet.project_specific.School_Closure.plotting_constants import (\n", - " SCHOOL_PROJECT_NOTEBOOK_PATH, \n", - " FIGURE_WIDTH,\n", - " RESOLUTION,\n", - " INCLUDED_COUNTRIES,\n", - " set_up_style\n", - ")\n", - "\n", - "set_up_style()\n", - "output_fig_path = os.path.join(SCHOOL_PROJECT_NOTEBOOK_PATH, \"output_figs\")\n", - "# xx-small, x-small, small, medium, large, x-large, xx-large, larger, or smaller" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def update_rcparams():\n", - " plt.rcParams.update(\n", - " {\n", - " 'font.size': 6,\n", - " 'axes.titlesize': \"large\",\n", - " 'axes.labelsize': \"x-large\",\n", - " 'xtick.labelsize': 'large',\n", - " 'ytick.labelsize': 'large',\n", - " 'legend.fontsize': 'large',\n", - " 'legend.title_fontsize': 'large',\n", - " 'lines.linewidth': 1.,\n", - "\n", - " 'xtick.major.size': 2.5,\n", - " 'xtick.major.width': 0.6,\n", - " 'xtick.major.pad': 2,\n", - "\n", - " 'ytick.major.size': 2.5,\n", - " 'ytick.major.width': 0.6,\n", - " 'ytick.major.pad': 2,\n", - "\n", - " 'axes.labelpad': 2.\n", - " }\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "run_id = \"sm_covid2/france/1665973291/63de5a1\"\n", - "ISO3 = \"FRA\"\n", - "COUNTRY_NAME = \"France\"\n", - "\n", - "mr = ManagedRun(run_id)\n", - "pbi = mr.powerbi.get_db()\n", - "targets = pbi.get_targets()\n", - "results = pbi.get_uncertainty()\n", - "\n", - "model_dates = pbi.get_derived_outputs().index\n", - "model_start, model_end = min(model_dates), max(model_dates)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "title_lookup = {\n", - " \"infection_deaths\": \"COVID-19 deaths\",\n", - " \"cumulative_infection_deaths\": \"Cumulative COVID-19 deaths\",\n", - " \"cumulative_incidence\": \"Cumulative COVID-19 incidence\",\n", - "\n", - " \"hospital_admissions\": \"new daily hospital admissions\",\n", - " \"icu_admissions\": \"new daily admissions to ICU\",\n", - " \"incidence\": \"daily new infections\",\n", - " \"hospital_admissions\": \"daily hospital admissions\",\n", - " \"hospital_occupancy\": \"total hospital beds\",\n", - " \"icu_admissions\": \"daily ICU admissions\",\n", - " \"icu_occupancy\": \"total ICU beds\",\n", - " \"prop_ever_infected\": \"ever infected with Delta or Omicron\",\n", - "\n", - " \"peak_hospital_occupancy\": \"Peak COVID-19 hospital occupancy\"\n", - "}\n", - "sc_colours = [\"black\", \"crimson\"]\n", - "unc_sc_colours = ((0.2, 0.2, 0.8), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2))\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "input_db = inputs.database.get_input_db()\n", - "unesco_data = input_db.query(\n", - " table_name='school_closure', \n", - " columns=[\"date\", \"status\", \"country_id\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "SCHOOL_COLORS = {\n", - " 'partial': 'azure',\n", - " 'full': 'thistle'\n", - "}\n", - "\n", - "def add_school_closure_patches(ax, iso3, ymax, school_colors=SCHOOL_COLORS):\n", - " data = unesco_data[unesco_data['country_id'] == iso3]\n", - " partial_dates = data[data['status'] == \"Partially open\"]['date'].to_list()\n", - " closed_dates = data[data['status'] == \"Closed due to COVID-19\"]['date'].to_list()\n", - " \n", - " # for date in partial_dates:\n", - " ax.vlines(partial_dates,ymin=0, ymax=ymax, lw=1, alpha=1., color=school_colors['partial'], zorder = 1)\n", - " ax.vlines(closed_dates, ymin=0, ymax=ymax, lw=1, alpha=1, color=school_colors['full'], zorder = 1)\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Model calibration" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "timeseries = get_school_project_timeseries(\"France\")\n", - "all_targets = {}\n", - "for k, v in timeseries.items():\n", - " all_targets[k] = pd.Series(data=v['values'], index=v['times'], name=v['output_key'])\n", - "for target in all_targets:\n", - " all_targets[target].index = ref_times_to_dti(COVID_BASE_DATETIME, all_targets[target].index)\n", - "\n", - " all_targets[target] = all_targets[target][model_start <= all_targets[target].index][all_targets[target].index <= model_end]\n", - "\n", - "\n", - "def plot_model_fit(axis, results, output_name):\n", - " update_rcparams() \n", - " \n", - " if output_name in all_targets and len(all_targets[output_name]) > 0:\n", - " axis.scatter(all_targets[output_name].index, all_targets[output_name], marker=\".\", color='black', label='observations', zorder=11, s=.5)\n", - "\n", - " colour = unc_sc_colours[0]\n", - " \n", - " results_df = results[(output_name, 0)]\n", - " indices = results_df.index\n", - "\n", - " axis.plot(indices, results_df[0.500], color=colour, zorder=10, label=\"model (median)\")\n", - "\n", - " axis.fill_between(\n", - " indices, \n", - " results_df[0.25], results_df[0.75], \n", - " color=colour, \n", - " alpha=0.5, \n", - " edgecolor=None,\n", - " label=\"model (IQR)\"\n", - " )\n", - " axis.fill_between(\n", - " indices, \n", - " results_df[0.025], results_df[0.975], \n", - " color=colour, \n", - " alpha=0.3,\n", - " edgecolor=None,\n", - " label=\"model (95% CI)\",\n", - " )\n", - " # axis.tick_params(axis=\"x\", labelrotation=45)\n", - " title = output_name if output_name not in title_lookup else title_lookup[output_name]\n", - " axis.set_ylabel(title)\n", - " axis.set_xlim((model_start, model_end))\n", - " # plt.tight_layout()\n", - "\n", - " plt.legend(markerscale=2.)\n", - "\n", - " # axis.set_ylim((0, 1500))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# for output_name in [\"abs_diff_cumulative_infection_deaths\", \"infection_deaths\", \"cumulative_infection_deaths\"]:\n", - "# fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.7))\n", - "# plot_model_fit(axis, results, output_name)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Scenario comparison over time" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "def plot_two_scenarios(axis, results, output_name, include_unc=False):\n", - " update_rcparams()\n", - "\n", - " ymax = 0.\n", - " for scenario in [0, 1]:\n", - " colour = unc_sc_colours[scenario]\n", - " results_df = results[(output_name, scenario)]\n", - " indices = results_df.index\n", - " label = \"baseline\" if scenario == 0 else \"schools open\"\n", - " scenario_zorder = 10 if scenario == 0 else scenario + 2\n", - "\n", - " if include_unc:\n", - " axis.fill_between(\n", - " indices, \n", - " results_df[0.25], results_df[0.75], \n", - " color=colour, alpha=0.7, \n", - " # label=interval_label,\n", - " zorder=scenario_zorder\n", - " )\n", - " ymax = max(ymax, max(results_df[0.75]))\n", - " else:\n", - " ymax = max(ymax, max(results_df[0.500]))\n", - "\n", - " axis.plot(indices, results_df[0.500], color=colour, label=label, lw=1.)\n", - " \n", - " plot_ymax = ymax * 1.1 \n", - " add_school_closure_patches(axis, ISO3, ymax=plot_ymax)\n", - "\n", - " # axis.tick_params(axis=\"x\", labelrotation=45)\n", - " title = output_name if output_name not in title_lookup else title_lookup[output_name]\n", - " axis.set_ylabel(title)\n", - " axis.set_xlim((model_start, model_end))\n", - " axis.set_ylim((0, plot_ymax))\n", - " axis.legend()\n", - " # plt.tight_layout()\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# for output_name in [\"infection_deaths\", \"cumulative_infection_deaths\", \"missed_school_death_ratio\"]:\n", - "# fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.7))\n", - "# plt.style.use(\"ggplot\")\n", - "# plot_two_scenarios(axis,results, output_name, True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Scenario comparison final size" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def plot_final_size_compare(axis, results, output_name):\n", - " update_rcparams()\n", - " # plt.rcParams.update({'font.size': 12}) \n", - " box_width = .7\n", - " color = 'black'\n", - " box_color= 'lightcoral'\n", - " y_max = 0\n", - " for i, label in enumerate([\"baseline\", \"schools open\"]):\n", - " quantiles = results[(output_name, i)].iloc[-1]\n", - " x = 1 + i\n", - "\n", - " # median\n", - " axis.hlines(y=quantiles[0.5], xmin=x - box_width / 2. , xmax= x + box_width / 2., lw=1., color=color, zorder=3) \n", - " \n", - " # IQR\n", - " height = quantiles[0.75] - quantiles[0.25]\n", - " rect = Rectangle(xy=(x - box_width / 2., quantiles[0.25]), width=box_width, height=height, zorder=2, facecolor=box_color)\n", - " axis.add_patch(rect)\n", - "\n", - " # 95% CI\n", - " axis.vlines(x=x, ymin=quantiles[0.025] , ymax=quantiles[0.975], lw=.7, color=color, zorder=1)\n", - "\n", - " y_max = max(y_max, quantiles[0.975])\n", - " title = output_name if output_name not in title_lookup else title_lookup[output_name]\n", - " axis.set_ylabel(title)\n", - " axis.set_xticks(ticks=[1, 2], labels=[\"baseline\", \"schools open\"]) #, fontsize=15)\n", - "\n", - " axis.set_xlim((0., 3.))\n", - " axis.set_ylim((0, y_max * 1.2))\n", - " # plt.tight_layout()\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# for output_name in [\"cumulative_infection_deaths\", \"cumulative_incidence\", \"peak_hospital_occupancy\"]: \n", - "# fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH * .6 , FIGURE_WIDTH *.7))\n", - "# plot_final_size_compare(axis, results, output_name)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Age-specific incidence" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# colours = [\"cornflowerblue\", \"darkorange\", \"mediumseagreen\", \"pink\", \"purple\"]\n", - "colours = [\"cornflowerblue\", \"slateblue\", \"mediumseagreen\", \"lightcoral\", \"purple\"]\n", - "\n", - "\n", - "def plot_incidence_by_age(derived_outputs, ax, scenario, as_proportion: bool):\n", - " update_rcparams()\n", - " y_label = \"COVID-19 incidence proportion\" if as_proportion else \"COVID-19 incidence\" \n", - "\n", - " times = derived_outputs[\"incidence\", scenario].index.to_list()\n", - " running_total = [0] * len(derived_outputs[\"incidence\", scenario])\n", - " age_groups = base_params['age_groups']\n", - "\n", - " y_max = 1. if as_proportion else max([derived_outputs[\"incidence\", sc].max() for sc in [0, 1]])\n", - "\n", - " for i_age, age_group in enumerate(age_groups):\n", - " output_name = f\"incidenceXagegroup_{age_group}\"\n", - " \n", - " if i_age < len(age_groups) - 1:\n", - " upper_age = age_groups[i_age + 1] - 1 if i_age < len(age_groups) - 1 else \"\"\n", - " age_group_name = f\"{age_group}-{upper_age}\"\n", - " else:\n", - " age_group_name = f\"{age_group}+\"\n", - "\n", - " age_group_incidence = derived_outputs[output_name, scenario]\n", - " \n", - " if as_proportion:\n", - " numerator, denominator = age_group_incidence, derived_outputs[\"incidence\", scenario]\n", - " age_group_proportion = np.divide(numerator, denominator, out=np.zeros_like(numerator), where=denominator!=0)\n", - " new_running_total = age_group_proportion + running_total\n", - " else: \n", - " new_running_total = age_group_incidence + running_total \n", - "\n", - " ax.fill_between(times, running_total, new_running_total, color=colours[i_age], label=age_group_name, zorder=2, alpha=.8)\n", - " running_total = copy(new_running_total)\n", - "\n", - " # y_max = max(new_running_total)\n", - " plot_ymax = y_max * 1.1\n", - " add_school_closure_patches(ax, ISO3, ymax=plot_ymax)\n", - "\n", - " # work out first time with positive incidence\n", - " t_min = derived_outputs['incidence', 0].gt(0).idxmax() \n", - " ax.set_xlim((t_min, model_end))\n", - " ax.set_ylim((0, plot_ymax))\n", - "\n", - " ax.set_ylabel(y_label)\n", - "\n", - " if not as_proportion and scenario == 0:\n", - " handles, labels = ax.get_legend_handles_labels()\n", - " ax.legend(\n", - " reversed(handles),\n", - " reversed(labels),\n", - " # title=\"Age:\",\n", - " # fontsize=12,\n", - " # title_fontsize=12,\n", - " labelspacing=.2,\n", - " handlelength=1.,\n", - " handletextpad=.5,\n", - " columnspacing=1.,\n", - " facecolor=\"white\",\n", - " ncol=2,\n", - "\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# derived_outputs = pbi.get_derived_outputs()\n", - "\n", - "# fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.5))\n", - "# plot_incidence_by_age(derived_outputs, axis, 0, as_proportion=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.5))\n", - "# plot_incidence_by_age(derived_outputs, axis, 1, as_proportion=True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Make combined multi-panel figure" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "update_rcparams()\n", - "\n", - "fig = plt.figure(figsize=(8.3, 11.7), dpi=300) # crete an A4 figure\n", - "outer = gridspec.GridSpec(\n", - " 3, 1, hspace=.1, height_ratios=(3, 62, 35), \n", - " left=0.07, right=0.97, bottom=0.03, top =.97 # this affects the outer margins of the saved figure \n", - ")\n", - "\n", - "#### Top row with country name\n", - "ax1 = fig.add_subplot(outer[0, 0])\n", - "t = ax1.text(0.5,0.5, COUNTRY_NAME, fontsize=16)\n", - "t.set_ha('center')\n", - "t.set_va('center')\n", - "ax1.set_xticks([])\n", - "ax1.set_yticks([])\n", - "\n", - "#### Second row will need to be split\n", - "outer_cell = outer[1, 0]\n", - "# first split in left/right panels\n", - "inner_grid = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=outer_cell, wspace=.2, width_ratios=(70, 30))\n", - "left_grid = inner_grid[0, 0] # will contain timeseries plots\n", - "right_grid = inner_grid[0, 1] # will contain final size plots\n", - "\n", - "#### Split left panel into 3 panels\n", - "inner_left_grid = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=left_grid, hspace=.05, height_ratios=(1, 1, 1))\n", - "# calibration\n", - "ax2 = fig.add_subplot(inner_left_grid[0, 0])\n", - "plot_model_fit(ax2, results,\"infection_deaths\")\n", - "plt.setp(ax2.get_xticklabels(), visible=False)\n", - "# scenario compare deaths\n", - "ax3 = fig.add_subplot(inner_left_grid[1, 0], sharex=ax2)\n", - "plot_model_fit(ax3, results,\"prop_ever_infected\") \n", - "# plot_two_scenarios(ax3, results, \"infection_deaths\", True)\n", - "plt.setp(ax3.get_xticklabels(), visible=False)\n", - "# scenario compare hosp\n", - "ax4 = fig.add_subplot(inner_left_grid[2, 0], sharex=ax2)\n", - "# plot_two_scenarios(ax4, results, \"hospital_occupancy\", False)\n", - "plot_two_scenarios(ax4, results, \"cumulative_infection_deaths\", True)\n", - "\n", - "\n", - "## Split right panel into 3 panels\n", - "inner_right_grid = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=right_grid, hspace=.1, height_ratios=(1, 1, 1))\n", - "# final size deaths\n", - "ax5 = fig.add_subplot(inner_right_grid[0, 0])\n", - "plot_final_size_compare(ax5, results, \"cumulative_infection_deaths\")\n", - "# final size incidence\n", - "ax6 = fig.add_subplot(inner_right_grid[1, 0])\n", - "plot_final_size_compare(ax6, results, \"cumulative_incidence\")\n", - "# # hosp peak\n", - "ax7 = fig.add_subplot(inner_right_grid[2, 0])\n", - "plot_final_size_compare(ax7, results, \"peak_hospital_occupancy\")\n", - "\n", - "#### Third row will need to be split into 6 panels\n", - "derived_outputs = pbi.get_derived_outputs()\n", - "outer_cell = outer[2, 0]\n", - "inner_grid = gridspec.GridSpecFromSubplotSpec(3, 2, subplot_spec=outer_cell, wspace=.2, hspace=.05, width_ratios=(50, 50), height_ratios=(6, 42, 42))\n", - "\n", - "# top left\n", - "ax_tl = fig.add_subplot(inner_grid[0, 0])\n", - "t = ax_tl.text(0.5,0.5, \"Age-specific incidence (baseline scenario)\", fontsize=12)\n", - "t.set_ha('center')\n", - "t.set_va('center')\n", - "ax_tl.set_xticks([])\n", - "ax_tl.set_yticks([])\n", - "\n", - "# top right\n", - "ax_tr = fig.add_subplot(inner_grid[0, 1])\n", - "t = ax_tr.text(0.5,0.5, \"Age-specific incidence (schools open)\", fontsize=12)\n", - "t.set_ha('center')\n", - "t.set_va('center')\n", - "ax_tr.set_xticks([])\n", - "ax_tr.set_yticks([])\n", - "\n", - "# middle left\n", - "ax8 = fig.add_subplot(inner_grid[1, 0])\n", - "plot_incidence_by_age(derived_outputs, ax8, 0, as_proportion=False)\n", - "plt.setp(ax8.get_xticklabels(), visible=False)\n", - "\n", - "# middle right\n", - "ax9 = fig.add_subplot(inner_grid[1, 1])\n", - "plot_incidence_by_age(derived_outputs, ax9, 1, as_proportion=False)\n", - "plt.setp(ax9.get_xticklabels(), visible=False)\n", - "\n", - "# bottom left\n", - "ax10 = fig.add_subplot(inner_grid[2, 0], sharex=ax8)\n", - "plot_incidence_by_age(derived_outputs, ax10, 0, as_proportion=True)\n", - "# bottom right\n", - "ax11 = fig.add_subplot(inner_grid[2, 1], sharex=ax9)\n", - "plot_incidence_by_age(derived_outputs, ax11, 1, as_proportion=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "fig.savefig(\"out.png\", facecolor=\"white\")\n", - "fig.savefig(\"out.pdf\", facecolor=\"white\")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.10.3 ('autumn310')", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.3" - }, - "vscode": { - "interpreter": { - "hash": "7afc08b952f75bca94590012dd49682c815a0fa68720c270ce23d7ae27bf110a" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/post_mcmc_runner.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/post_mcmc_runner.ipynb deleted file mode 100644 index 209e1bf05..000000000 --- a/notebooks/user/rragonnet/project_specific/School_Closure/post_mcmc_runner.ipynb +++ /dev/null @@ -1,117 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.\n", - "c:\\Users\\rrag0004\\.conda\\envs\\summer2\\lib\\site-packages\\summer\\runner\\vectorized_runner.py:363: NumbaDeprecationWarning: \u001b[1mThe 'nopython' keyword argument was not supplied to the 'numba.jit' decorator. The implicit default value for this argument is currently False, but it will be changed to True in Numba 0.59.0. See https://numba.readthedocs.io/en/stable/reference/deprecation.html#deprecation-of-object-mode-fall-back-behaviour-when-using-jit for details.\u001b[0m\n", - " def get_strain_infection_values(\n" - ] - } - ], - "source": [ - "from autumn.projects.sm_covid2.common_school.runner_tools import (\n", - " get_sampled_results, \n", - " extract_sample_subset, \n", - " run_full_runs, \n", - " calculate_diff_outputs,\n", - " get_quantile_outputs\n", - ")\n", - "\n", - "import arviz as az\n", - "import os" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "calibration_file = os.path.join(\"test_outputs2\", \"idata.nc\")\n", - "idata = az.from_netcdf(calibration_file)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "sample_df = extract_sample_subset(idata, 100, 0)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "outputs_df = run_full_runs(sample_df, \"FRA\", \"main\")" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "diff_outputs_df = calculate_diff_outputs(outputs_df)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "uncertainty_df, diff_quantiles_df = get_quantile_outputs(outputs_df, diff_outputs_df)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "uncertainty_df.to_parquet(os.path.join(\"test_outputs2\", \"uncertainty_df.parquet\"))\n", - "diff_quantiles_df.to_parquet(os.path.join(\"test_outputs2\", \"diff_quantiles_df.parquet\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "summer2", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.11" - }, - "orig_nbformat": 4 - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 64a1915d30481d2f23ee0d2e6b7102008e08e903 Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Thu, 7 Sep 2023 16:57:47 +1000 Subject: [PATCH 18/18] Notebook for mcmc transforms --- notebooks/user/rragonnet/mcmc_transform.ipynb | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 notebooks/user/rragonnet/mcmc_transform.ipynb diff --git a/notebooks/user/rragonnet/mcmc_transform.ipynb b/notebooks/user/rragonnet/mcmc_transform.ipynb new file mode 100644 index 000000000..6da818193 --- /dev/null +++ b/notebooks/user/rragonnet/mcmc_transform.ipynb @@ -0,0 +1,107 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from matplotlib import pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Prior bounds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a, b = 0, 1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Default transformation (standard with pymc, stan...)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_default_transform():\n", + " default_transform = lambda x: np.log(x - a) - np.log(b - x)\n", + "\n", + " x = np.linspace(a, b, num=1000)[1:-1] \n", + " y = default_transform(x)\n", + " plt.plot(x, y, label=\"Default transform\")\n", + " plt.xlim((a - .5 * (b-a), b + .5 * (b-a)))\n", + " plt.xlabel(\"Original parameter\")\n", + " plt.ylabel(\"Transformed parameter\")\n", + " plt.vlines(x=[a, b], ymin=plt.gca().get_ylim()[0], ymax=plt.gca().get_ylim()[1], linestyles=[\"--\", \"--\"], color='black')\n", + "\n", + "\n", + "plot_default_transform()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tweaked transformation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def plot_tweaked_transform(eps=.1):\n", + " a_prime = a - eps * (b - a)\n", + " b_prime = b + eps * (b - a)\n", + " tweaked_transform = lambda x: np.log(x - a_prime) - np.log(b_prime - x)\n", + "\n", + " x = np.linspace(a_prime, b_prime, num=1000)[1:-1] \n", + " y = tweaked_transform(x)\n", + " plot_default_transform()\n", + " plt.plot(x, y, label=\"Tweaked transform\")\n", + " plt.legend()\n", + "\n", + "plot_tweaked_transform(.1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "summer2", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +}