Skip to content

Commit

Permalink
Merge branch 'master' of github.com:monash-emu/AuTuMN
Browse files Browse the repository at this point in the history
  • Loading branch information
dshipman committed Sep 8, 2023
2 parents 76a9710 + 64a1915 commit fd5ea8b
Show file tree
Hide file tree
Showing 23 changed files with 596 additions and 1,552 deletions.
1 change: 1 addition & 0 deletions autumn/models/sm_covid2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ def build_model(params: dict, build_options: dict = None, ret_builder=False) ->

if params.activate_random_process:
outputs_builder.request_random_process_outputs()
outputs_builder.request_random_process_auc()

# request extra output to store the number of students*weeks of school missed
outputs_builder.request_student_weeks_missed_output(student_weeks_missed)
Expand Down
15 changes: 15 additions & 0 deletions autumn/models/sm_covid2/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,21 @@ def array_max(x):
func=peak_func
)

def request_random_process_auc(self):
"""
Create an output to calculate the area between the (transformed) random process and the horizontal line y=1.
"""

def sum_diffs_to_one(x):
return jnp.repeat(jnp.sum(jnp.abs(1 - x)), jnp.size(x))

sum_diffs_to_one_func = Function(sum_diffs_to_one, [DerivedOutput("transformed_random_process")])

self.model.request_function_output(
"random_process_auc",
func=sum_diffs_to_one_func
)

# def request_icu_outputs(
# self,
# prop_icu_among_hospitalised: float,
Expand Down
30 changes: 15 additions & 15 deletions autumn/models/sm_covid2/params.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ infectious_seed_time: 60
seed_duration: 7

sojourns:
latent: 6.57 # Pooled estimate from Wu et al. doi:10.1001/jamanetworkopen.2022.28008
active: 8.
latent: 6.65 # wild-type
active: 6.5

mobility:
region: null
Expand Down Expand Up @@ -150,9 +150,9 @@ is_dynamic_mixing_matrix: true

# parameters related to immunity stratification (two strata: unvaccinated / vaccinated)
vaccine_effects:
ve_infection: .5 #FIXME: dummy data
ve_hospitalisation: .9 #FIXME: dummy data
ve_death: .9 #FIXME: dummy data
ve_infection: .7
ve_hospitalisation: .9
ve_death: .9

# voc_emergence: null
voc_emergence:
Expand All @@ -176,11 +176,11 @@ voc_emergence:
new_voc_seed:
time_from_gisaid_report: 0.
seed_duration: 10.
contact_rate_multiplier: 1.38 # DOI: 10.1503/cmaj.211248
contact_rate_multiplier: 1.5
incubation_overwrite_value: 4.41 # Wu et al. doi:10.1001/jamanetworkopen.2022.28008
vacc_immune_escape: .32 # DOI: 10.1503/cmaj.211248
hosp_risk_adjuster: 2.1 # Fisman et al. doi: 10.1503/cmaj.211248
death_risk_adjuster: 2.3 # Fisman et al. doi: 10.1503/cmaj.211248
vacc_immune_escape: .3
hosp_risk_adjuster: 2.0
death_risk_adjuster: 2.3
icu_risk_adjuster: 3.4 # Fisman et al. doi: 10.1503/cmaj.211248
cross_protection:
wild_type: 1.
Expand All @@ -192,11 +192,11 @@ voc_emergence:
new_voc_seed:
time_from_gisaid_report: 0.
seed_duration: 10.
contact_rate_multiplier: 1.92 # DOI: 10.1503/cmaj.211248
contact_rate_multiplier: 2. # DOI: 10.1503/cmaj.211248
incubation_overwrite_value: 3.42 # Wu et al. doi:10.1001/jamanetworkopen.2022.28008
vacc_immune_escape: .55 # DOI: 10.1503/cmaj.211248
hosp_risk_adjuster: 0.861 # 2.1 * .41 using HR of Omicron vs Delta (Nyberg doi.org/10.1016/S0140-6736(22)00462-7)
death_risk_adjuster: 0.713 # 2.3 * 0.31 using HR of Omicron vs Delta (Nyberg doi.org/10.1016/S0140-6736(22)00462-7)
vacc_immune_escape: .6 # DOI: 10.1503/cmaj.211248
hosp_risk_adjuster: 0.82 # Multiply Nyberg by Fisman = 0.41*2 = 0.82
death_risk_adjuster: 0.71 # Multiply Nyberg by Fisman =0.31*2.3=0.713
icu_risk_adjuster: 1. # FIXME dummy value for now but not used at the moment
cross_protection:
wild_type: 1.
Expand All @@ -207,7 +207,7 @@ time_from_onset_to_event: # using time of symptom onset as reference
hospitalisation:
distribution: gamma
shape: 5.
mean: 7.7 # Estimates taken from ISARIC report 4th Oct 2020 (Mendeley citation key Pritchard2020}
mean: 3.
icu_admission: # using time of hospitalisation as reference
distribution: gamma
shape: 5.
Expand All @@ -223,7 +223,7 @@ hospital_stay:
hospital_all:
distribution: gamma
shape: 5.
mean: 12.8 # Estimates taken from ISARIC report 4th Oct 2020 (Mendeley citation key Pritchard2020}
mean: 9

icu:
distribution: gamma
Expand Down
16 changes: 11 additions & 5 deletions autumn/projects/sm_covid2/common_school/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,17 @@
ANALYSES_NAMES = ["main", "no_google_mobility", "increased_hh_contacts"]


def get_estival_uniform_priors(autumn_priors):
def get_estival_uniform_priors(autumn_priors, _pymc_transform_eps_scale):
estival_priors = []
for prior_dict in autumn_priors:
assert prior_dict["distribution"] == "uniform", "Only uniform priors are currently supported"

if not "random_process.delta_values" in prior_dict["param_name"] and not "dispersion_param" in prior_dict["param_name"]:
p = esp.UniformPrior(prior_dict["param_name"], prior_dict["distri_params"])
p._pymc_transform_eps_scale = _pymc_transform_eps_scale

estival_priors.append(
esp.UniformPrior(prior_dict["param_name"], prior_dict["distri_params"]),
p
)

ndelta_values = len([prior_dict for prior_dict in autumn_priors if prior_dict["param_name"].startswith("random_process.delta_values")])
Expand All @@ -46,18 +49,21 @@ def rp_loglikelihood(params):
return rp_loglikelihood


def get_bcm_object(iso3, analysis="main"):
def get_bcm_object(iso3, analysis="main", _pymc_transform_eps_scale=.1):

assert analysis in ANALYSES_NAMES, "wrong analysis name requested"

project = get_school_project(iso3, analysis)
death_target_data = project.calibration.targets[0].data

dispersion_prior = esp.UniformPrior("infection_deaths_dispersion_param", (200, 250))
dispersion_prior._pymc_transform_eps_scale = _pymc_transform_eps_scale

targets = [
est.NegativeBinomialTarget(
"infection_deaths_ma7",
death_target_data,
dispersion_param=esp.UniformPrior("infection_deaths_dispersion_param", (200, 250))
dispersion_param=dispersion_prior
)
]
if len(project.calibration.targets) > 1:
Expand Down Expand Up @@ -90,7 +96,7 @@ def censored_func(modelled, data, parameters, time_weights):
default_configuration = project.param_set.baseline
m = project.build_model(default_configuration.to_dict())

priors = get_estival_uniform_priors(project.calibration.all_priors)
priors = get_estival_uniform_priors(project.calibration.all_priors, _pymc_transform_eps_scale)

default_params = m.builder.get_default_parameters()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def make_post_mc_plots(idata, burn_in, output_folder=None):
chain_length = idata.sample_stats.sizes['draw']

# Traces (including burn-in)
ax = az.plot_trace(idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False);
az.plot_trace(idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False);
plt.subplots_adjust(hspace=.7)
if output_folder:
plt.savefig(output_folder_path / "mc_traces.jpg", facecolor="white", bbox_inches='tight')
Expand All @@ -23,7 +23,7 @@ def make_post_mc_plots(idata, burn_in, output_folder=None):
burnt_idata = idata.sel(draw=range(burn_in, chain_length)) # Discard burn-in

# Traces (after burn-in)
ax = az.plot_trace(burnt_idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False);
az.plot_trace(burnt_idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False);
plt.subplots_adjust(hspace=.7)
if output_folder:
plt.savefig(output_folder_path / "mc_traces_postburnin.jpg", facecolor="white", bbox_inches='tight')
Expand All @@ -48,9 +48,9 @@ def make_post_mc_plots(idata, burn_in, output_folder=None):
rhat_df = raw_rhat_df.drop(columns="random_process.delta_values").loc[0]
for i in range(len(raw_rhat_df)):
rhat_df[f"random_process.delta_values[{i}]"] = raw_rhat_df['random_process.delta_values'][i]
ax = rhat_df.plot.barh(xlim=(1.,1.105))
ax.vlines(x=1.05,ymin=-0.5, ymax=len(rhat_df), linestyles="--", color='orange')
ax.vlines(x=1.1,ymin=-0.5, ymax=len(rhat_df), linestyles="-",color='red')
axis = rhat_df.plot.barh(xlim=(1.,1.105))
axis.vlines(x=1.05,ymin=-0.5, ymax=len(rhat_df), linestyles="--", color='orange')
axis.vlines(x=1.1,ymin=-0.5, ymax=len(rhat_df), linestyles="-",color='red')
if output_folder:
plt.savefig(output_folder_path / "r_hats.jpg", facecolor="white", bbox_inches='tight')
plt.close()
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def plot_model_fit(bcm, params, iso3, outfile=None):
if outfile:
fig.savefig(outfile, facecolor="white")

return fig

plt.close()

def plot_model_fit_with_ll(bcm, params, outfile=None):
REF_DATE = datetime.date(2019,12,31)
Expand Down Expand Up @@ -189,9 +189,7 @@ def plot_multiple_model_fits(bcm, params_list, outfile=None):
death_ax, rp_ax, sero_ax = axs[0], axs[1], axs[2]

# set up the three axes
death_ax.set_ylabel("COVID-19 deaths")
targets["infection_deaths_ma7"].plot(style='.', ax=death_ax, label="", zorder=20, color='black')

death_ax.set_ylabel("COVID-19 deaths")
rp_ax.set_ylabel("Random process")

if "prop_ever_infected_age_matched" in targets:
Expand Down Expand Up @@ -228,6 +226,8 @@ def plot_multiple_model_fits(bcm, params_list, outfile=None):
color=colors[i]
)

targets["infection_deaths_ma7"].plot(style='.', ax=death_ax, label="", zorder=20, color='black')

# Post plotting processes
# death
death_ax.legend(loc='best', ncols=2)
Expand All @@ -241,4 +241,4 @@ def plot_multiple_model_fits(bcm, params_list, outfile=None):
if outfile:
fig.savefig(outfile, facecolor="white")

return fig
plt.close()
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,67 @@ def make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_f

plt.close()

def plot_incidence_by_age(derived_outputs, ax, scenario, as_proportion: bool):

colours = ["cornflowerblue", "slateblue", "mediumseagreen", "lightcoral", "purple"]

update_rcparams()
y_label = "COVID-19 incidence proportion" if as_proportion else "COVID-19 incidence"

times = derived_outputs["incidence", scenario].index.to_list()
running_total = [0] * len(derived_outputs["incidence", scenario])
age_groups = base_params['age_groups']

y_max = 1. if as_proportion else max([derived_outputs["incidence", sc].max() for sc in [0, 1]])

for i_age, age_group in enumerate(age_groups):
output_name = f"incidenceXagegroup_{age_group}"

if i_age < len(age_groups) - 1:
upper_age = age_groups[i_age + 1] - 1 if i_age < len(age_groups) - 1 else ""
age_group_name = f"{age_group}-{upper_age}"
else:
age_group_name = f"{age_group}+"

age_group_incidence = derived_outputs[output_name, scenario]

if as_proportion:
numerator, denominator = age_group_incidence, derived_outputs["incidence", scenario]
age_group_proportion = np.divide(numerator, denominator, out=np.zeros_like(numerator), where=denominator!=0)
new_running_total = age_group_proportion + running_total
else:
new_running_total = age_group_incidence + running_total

ax.fill_between(times, running_total, new_running_total, color=colours[i_age], label=age_group_name, zorder=2, alpha=.8)
running_total = copy(new_running_total)

# y_max = max(new_running_total)
plot_ymax = y_max * 1.1
add_school_closure_patches(ax, ISO3, ymax=plot_ymax)

# work out first time with positive incidence
t_min = derived_outputs['incidence', 0].gt(0).idxmax()
ax.set_xlim((t_min, model_end))
ax.set_ylim((0, plot_ymax))

ax.set_ylabel(y_label)

if not as_proportion and scenario == 0:
handles, labels = ax.get_legend_handles_labels()
ax.legend(
reversed(handles),
reversed(labels),
# title="Age:",
# fontsize=12,
# title_fontsize=12,
labelspacing=.2,
handlelength=1.,
handletextpad=.5,
columnspacing=1.,
facecolor="white",
ncol=2,

)

def test_tiling_plot():
from pathlib import Path
Expand Down
Loading

0 comments on commit fd5ea8b

Please sign in to comment.