Skip to content

Commit

Permalink
cum_inc by age and get_derived_outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-ragonnet committed Nov 30, 2023
1 parent 20c2dd8 commit 40da43b
Show file tree
Hide file tree
Showing 3 changed files with 216 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def add_variant_emergence(ax, iso3):
d = get_first_variant_report_date(voc_name, iso3)
ax.vlines(x=d, ymin=plot_ymin, ymax=plot_ymax, linestyle=linestyles[voc_name], color="grey")

AGE_COLOURS = ["cornflowerblue", "slateblue", "mediumseagreen", "lightcoral", "purple"]

def _plot_incidence_by_age(derived_outputs, ax, scenario, as_proportion: bool, legend=False):

colours = ["cornflowerblue", "slateblue", "mediumseagreen", "lightcoral", "purple"]

y_label = "Incidence prop." if as_proportion else "Daily infections"

scenario_name = {"baseline": "historical", "scenario_1": "counterfactual"}
Expand Down Expand Up @@ -92,7 +92,7 @@ def _plot_incidence_by_age(derived_outputs, ax, scenario, as_proportion: bool, l
else:
new_running_total = age_group_incidence + running_total

ax.fill_between(times, running_total, new_running_total, color=colours[i_age], label=age_group_name, zorder=2, alpha=.8)
ax.fill_between(times, running_total, new_running_total, color=AGE_COLOURS[i_age], label=age_group_name, zorder=2, alpha=.8)
running_total = copy(new_running_total)

y_max = max(new_running_total)
Expand Down Expand Up @@ -125,8 +125,8 @@ def _plot_incidence_by_age(derived_outputs, ax, scenario, as_proportion: bool, l


def plot_inc_by_strain(derived_outputs, ax, as_prop=False, legend=False):

y_label = "Infection proportion" if as_prop else "N infections"
y_label = "Infection proportion" if as_prop else "N infections by strain"

output_name = "cumulative_incidence_prop" if as_prop else "cumulative_incidence"
strain_data = {s: [derived_outputs[sc][f"{output_name}Xstrain_{s}"].iloc[-1] for sc in ["baseline", "scenario_1"]] for s in ["wild_type", "delta", "omicron"]}
Expand Down Expand Up @@ -157,6 +157,86 @@ def plot_inc_by_strain(derived_outputs, ax, as_prop=False, legend=False):
ax.get_legend().remove()



def cum_over50_fmt(tick_val):
if tick_val >= 1000000000:
val = round(tick_val/1000000000, 1)
if val.is_integer():
val = int(val)
return f"{val}G"
elif tick_val >= 1000000:
val = round(tick_val/1000000, 1)
if val.is_integer():
val = int(val)
return f"{val}M"
elif tick_val >= 1000:
val = round(tick_val / 1000, 1)
if val.is_integer():
val = int(val)
return f"{val}K"
elif 0. < tick_val < 1.:
return round(tick_val, 2)
else:
val = tick_val
if val.is_integer():
val = int(val)
return val


def plot_cum_incidence_by_age(derived_outputs, ax, legend=False):
y_label = "N infections by age"
age_groups = [0, 15, 25, 50, 70]

age_inc_data = {str(agegroup): [derived_outputs[sc][f"incidenceXagegroup_{agegroup}"].sum() for sc in ["baseline", "scenario_1"]] for agegroup in age_groups}

df = pd.DataFrame(age_inc_data, index = ["Historical", "Counterfactual"])

column_names, colors = {}, {}
for i_age, age_group in enumerate(age_groups):

if i_age < len(age_groups) - 1:
upper_age = age_groups[i_age + 1] - 1 if i_age < len(age_groups) - 1 else ""
c_name = f"{age_group}-{upper_age}"
else:
c_name = f"{age_group}+"

column_names[str(age_group)] = c_name
colors[c_name] = AGE_COLOURS[i_age]

df = df.rename(columns=column_names)

df.plot.bar(
stacked=True,
ax=ax,
color=colors,
alpha=.8,
rot=0
)

ax.set_ylabel(y_label)

for i_sc, sc in enumerate(["baseline", "scenario_1"]):
total_inc = derived_outputs[sc]["incidence"].sum()
n_under_50 = (derived_outputs[sc][f"incidenceXagegroup_0"] + derived_outputs[sc][f"incidenceXagegroup_15"] + derived_outputs[sc][f"incidenceXagegroup_25"]).sum()
n_over_50 = total_inc - n_under_50

ax.plot([i_sc + .28, i_sc + .28], [n_under_50, total_inc], marker='_', lw=0, color='black', ms=3.)
ax.text(x=i_sc + .34, y=0.5 * (n_under_50 + total_inc), s=cum_over50_fmt(n_over_50), fontsize=7, rotation=90, va='center')

if legend:
ax.legend(
labelspacing=.2,
handlelength=1.,
handletextpad=.5,
columnspacing=1.,
facecolor="white",
ncol=2,
# bbox_to_anchor=(1.05, 1.05)
)
else:
ax.get_legend().remove()


def _plot_two_scenarios(axis, uncertainty_dfs, output_name, iso3, include_unc=False, include_legend=True):

ymax = 0.
Expand Down Expand Up @@ -306,11 +386,15 @@ def make_country_highlight_figure(iso3, uncertainty_dfs, diff_quantiles_df, deri
# Top Left: deaths fit
death_fit_ax = fig.add_subplot(inner_grid[0, 0])
plot_model_fit_with_uncertainty(death_fit_ax, uncertainty_dfs['baseline'], "infection_deaths_ma7", iso3)
add_variant_emergence(death_fit_ax, iso3)
add_variant_emergence(death_fit_ax, iso3)
format_date_axis(death_fit_ax)
remove_axes_box(death_fit_ax)
ad_panel_number(death_fit_ax, "A")

# if iso3 in INCLUDED_COUNTRIES['national_sero']:
# insert_inside_plot_for_sero(death_fit_ax, uncertainty_dfs['baseline'], iso3)


# Middle Left: incidence by age
inner_inner_grid = gridspec.GridSpecFromSubplotSpec(2, 1, subplot_spec=inner_grid[1, 0], hspace=.05)
age_inc_baseline_ax = fig.add_subplot(inner_inner_grid[0, 0])
Expand All @@ -331,14 +415,15 @@ def make_country_highlight_figure(iso3, uncertainty_dfs, diff_quantiles_df, deri

# Bottom Left: Inc prop by strain
inner_inner_grid = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=inner_grid[2, 0], wspace=.5)
for i_as_prop, as_prop in enumerate([False, True]):
inc_prop_strain_ax = fig.add_subplot(inner_inner_grid[0, i_as_prop])
plot_inc_by_strain(derived_outputs, inc_prop_strain_ax, as_prop, legend=as_prop)

if not as_prop:
inc_prop_strain_ax.yaxis.set_major_formatter(tick.FuncFormatter(y_fmt))
ad_panel_number(inc_prop_strain_ax, "C", x=-0.25)
inc_prop_age_ax = fig.add_subplot(inner_inner_grid[0, 0])
plot_cum_incidence_by_age(derived_outputs, inc_prop_age_ax)
inc_prop_age_ax.yaxis.set_major_formatter(tick.FuncFormatter(y_fmt))
ad_panel_number(inc_prop_age_ax, "C", x=-0.25)

inc_prop_strain_ax = fig.add_subplot(inner_inner_grid[0, 1])
plot_inc_by_strain(derived_outputs, inc_prop_strain_ax, False, legend=True)
inc_prop_strain_ax.yaxis.set_major_formatter(tick.FuncFormatter(y_fmt))

# MIDDLE Column
outer_cell = outer[0, 1]
Expand Down
50 changes: 18 additions & 32 deletions user/rragonnet/remote_run_outputs/country_highlights.ipynb

Large diffs are not rendered by default.

101 changes: 101 additions & 0 deletions user/rragonnet/remote_run_outputs/get_derived_outputs.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from pathlib import Path\n",
"from autumn.projects.sm_covid2.common_school.calibration import get_bcm_object\n",
"from estival.sampling import tools as esamp\n",
"import yaml\n",
"from autumn.projects.sm_covid2.common_school.runner_tools import INCLUDED_COUNTRIES\n",
"import pickle\n",
"\n",
"full_iso3_list = list(INCLUDED_COUNTRIES['all'].keys())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"analysis_folders = {\n",
" \"main\": Path.cwd() / \"31747883_full_analysis_26Sep2023_main\",\n",
" \"increased_hh_contacts\": Path.cwd() /\"31902886_full_analysis_05Oct2023_increased_hh_contacts\",\n",
" \"no_google_mobility\": Path.cwd() /\"31915437_full_analysis_05Oct2023_no_google_mobility\"\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_mle_derived_outputs(iso3, analysis, save=True):\n",
" best_params_path = analysis_folders[analysis] / iso3 / \"retained_best_params.yml\"\n",
" with open(best_params_path, \"r\") as f:\n",
" best_params = yaml.unsafe_load(f)\n",
" mle_params = best_params[0]\n",
" \n",
" derived_outputs = {}\n",
" for sc in [\"baseline\", \"scenario_1\"]:\n",
" bcm = get_bcm_object(iso3, analysis=analysis, scenario=sc)\n",
" res = esamp.model_results_for_samples([mle_params], bcm)\n",
" derived_outputs[sc] = res.results.xs(0, level=\"sample\", axis=1)\n",
"\n",
" if save:\n",
" with open(analysis_folders[analysis] / iso3 / \"derived_outputs.pickle\", \"wb\") as f:\n",
" pickle.dump(derived_outputs, f)\n",
"\n",
" return derived_outputs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for analysis in analysis_folders:\n",
" for iso3 in full_iso3_list:\n",
" print(f\"{analysis}: {iso3}\")\n",
" derived_outputs = get_mle_derived_outputs(iso3, analysis)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dos = pd.read_pickle(analysis_folders['main'] / 'FRA' / \"derived_outputs.pickle\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "summer2",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 40da43b

Please sign in to comment.