Skip to content

Commit

Permalink
Fix plotting functions following refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-ragonnet committed Sep 13, 2023
1 parent 4f06d95 commit 6b7996a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def plot_model_fit_with_uncertainty(axis, uncertainty_df, output_name, iso3, inc

# update_rcparams()

df = uncertainty_df[(uncertainty_df["scenario"] == "baseline") & (uncertainty_df["type"] == output_name)]
df = uncertainty_df[output_name]

if output_name in bcm.targets:
t = copy(bcm.targets[output_name].data)
Expand All @@ -129,22 +129,20 @@ def plot_model_fit_with_uncertainty(axis, uncertainty_df, output_name, iso3, inc

colour = unc_sc_colours[0]

median_df = df[df["quantile"] == .5]

time = median_df['time']
axis.plot(time, median_df['value'], color=colour, zorder=10, label="model (median)")
time = df.index
axis.plot(time, df['0.5'], color=colour, zorder=10, label="model (median)")

axis.fill_between(
time,
df[df["quantile"] == .25]['value'], df[df["quantile"] == .75]['value'],
df['0.25'], df['0.75'],
color=colour,
alpha=0.5,
edgecolor=None,
label="model (IQR)"
)
axis.fill_between(
time,
df[df["quantile"] == .025]['value'], df[df["quantile"] == .975]['value'],
df['0.025'], df['0.975'],
color=colour,
alpha=0.3,
edgecolor=None,
Expand Down Expand Up @@ -172,14 +170,14 @@ def plot_model_fit_with_uncertainty(axis, uncertainty_df, output_name, iso3, inc

return x_min

def plot_two_scenarios(axis, uncertainty_df, output_name, iso3, include_unc=False, include_legend=True):
def plot_two_scenarios(axis, uncertainty_dfs, output_name, iso3, include_unc=False, include_legend=True):
# update_rcparams()

ymax = 0.
for i_sc, scenario in enumerate(["baseline", "scenario_1"]):
df = uncertainty_df[(uncertainty_df["scenario"] == scenario) & (uncertainty_df["type"] == output_name)]
median_df = df[df["quantile"] == .5]
time = median_df['time']
df = uncertainty_dfs[scenario][output_name]
median_df = df['0.5']
time = df.index

colour = unc_sc_colours[i_sc]
label = "baseline" if i_sc == 0 else "schools open"
Expand All @@ -188,17 +186,17 @@ def plot_two_scenarios(axis, uncertainty_df, output_name, iso3, include_unc=Fals
if include_unc:
axis.fill_between(
time,
df[df["quantile"] == .25]['value'], df[df["quantile"] == .75]['value'],
df['0.25'], df['0.75'],
color=colour, alpha=0.7,
edgecolor=None,
# label=interval_label,
zorder=scenario_zorder
)
ymax = max(ymax, df[df["quantile"] == .75]['value'].max())
ymax = max(ymax, df['0.75'].max())
else:
ymax = median_df['value'].max()
ymax = median_df.max()

axis.plot(time, median_df['value'], color=colour, label=label, lw=1.)
axis.plot(time, median_df, color=colour, label=label, lw=1.)

plot_ymax = ymax * 1.1
add_school_closure_patches(axis, iso3, ymax=plot_ymax)
Expand All @@ -217,30 +215,29 @@ def plot_two_scenarios(axis, uncertainty_df, output_name, iso3, include_unc=Fals
plt.tight_layout()


def plot_final_size_compare(axis, uncertainty_df, output_name):
def plot_final_size_compare(axis, uncertainty_dfs, output_name):
# update_rcparams()
# plt.rcParams.update({'font.size': 12})
box_width = .5
color = 'black'
box_color= 'lightcoral'
y_max = 0
for i, scenario in enumerate(["baseline", "scenario_1"]):

df = uncertainty_df[(uncertainty_df["scenario"] == scenario) & (uncertainty_df["type"] == output_name) & (uncertainty_df["time"] == uncertainty_df["time"].max())]
for i, scenario in enumerate(["baseline", "scenario_1"]):
df = uncertainty_dfs[scenario][output_name].iloc[-1]

x = 1 + i
# median
axis.hlines(y=df[df["quantile"] == .5]['value'], xmin=x - box_width / 2. , xmax= x + box_width / 2., lw=1., color=color, zorder=3)
axis.hlines(y=df['0.5'], xmin=x - box_width / 2. , xmax= x + box_width / 2., lw=1., color=color, zorder=3)

# IQR
q_75 = float(df[df["quantile"] == .75]['value'])
q_25 = float(df[df["quantile"] == .25]['value'])
q_75 = float(df['0.75'])
q_25 = float(df['0.25'])
rect = Rectangle(xy=(x - box_width / 2., q_25), width=box_width, height=q_75 - q_25, zorder=2, facecolor=box_color)
axis.add_patch(rect)

# 95% CI
q_025 = float(df[df["quantile"] == .025]['value'])
q_975 = float(df[df["quantile"] == .975]['value'])
q_025 = float(df['0.025'])
q_975 = float(df['0.975'])
axis.vlines(x=x, ymin=q_025 , ymax=q_975, lw=.7, color=color, zorder=1)

y_max = max(y_max, q_975)
Expand Down Expand Up @@ -322,7 +319,7 @@ def remove_axes_box(axis):
axis.spines['right'].set_visible(False)


def make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_folder):
def make_country_output_tiling(iso3, uncertainty_dfs, diff_quantiles_df, output_folder):
country_name = INCLUDED_COUNTRIES['all'][iso3]

update_rcparams()
Expand Down Expand Up @@ -352,22 +349,23 @@ def make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_f
inner_left_grid = gridspec.GridSpecFromSubplotSpec(4, 1, subplot_spec=left_grid, hspace=.15, height_ratios=(1, 1, 1, 1))
# calibration, deaths
ax2 = fig.add_subplot(inner_left_grid[0, 0])
x_min = plot_model_fit_with_uncertainty(ax2, uncertainty_df, "infection_deaths_ma7", iso3)
x_min = plot_model_fit_with_uncertainty(ax2, uncertainty_dfs['baseline'], "infection_deaths_ma7", iso3)
format_date_axis(ax2)
remove_axes_box(ax2)
# seropos prop over time
ax_sero = fig.add_subplot(inner_left_grid[1, 0])
plot_model_fit_with_uncertainty(ax_sero, uncertainty_df, "prop_ever_infected_age_matched", iso3, include_legend=False)
plot_model_fit_with_uncertainty(ax_sero, uncertainty_dfs['baseline'], "prop_ever_infected_age_matched", iso3, include_legend=False)
format_date_axis(ax_sero)
remove_axes_box(ax_sero)

# scenario compare deaths
ax3 = fig.add_subplot(inner_left_grid[2, 0]) #, sharex=ax2)
plot_two_scenarios(ax3, uncertainty_df, "infection_deaths_ma7", iso3, True)
plot_two_scenarios(ax3, uncertainty_dfs, "infection_deaths_ma7", iso3, True)
format_date_axis(ax3)
remove_axes_box(ax3)
# scenario compare hosp
ax4 = fig.add_subplot(inner_left_grid[3, 0]) #, sharex=ax2)
plot_two_scenarios(ax4, uncertainty_df, "hospital_occupancy", iso3, True, include_legend=False)
plot_two_scenarios(ax4, uncertainty_dfs, "hospital_occupancy", iso3, True, include_legend=False)
format_date_axis(ax4)
remove_axes_box(ax4)

Expand All @@ -378,17 +376,17 @@ def make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_f
inner_right_grid = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=right_grid, hspace=.15, height_ratios=(1, 1, 1))
# final size incidence
ax5 = fig.add_subplot(inner_right_grid[0, 0])
plot_final_size_compare(ax5, uncertainty_df, "cumulative_incidence")
plot_final_size_compare(ax5, uncertainty_dfs, "cumulative_incidence")
remove_axes_box(ax5)

# final size deaths
ax6 = fig.add_subplot(inner_right_grid[1, 0]) #, sharex=ax5)
plot_final_size_compare(ax6, uncertainty_df, "cumulative_infection_deaths")
plot_final_size_compare(ax6, uncertainty_dfs, "cumulative_infection_deaths")
remove_axes_box(ax6)

# # hosp peak
ax7 = fig.add_subplot(inner_right_grid[2, 0]) #, sharex=ax5)
plot_final_size_compare(ax7, uncertainty_df, "peak_hospital_occupancy")
plot_final_size_compare(ax7, uncertainty_dfs, "peak_hospital_occupancy")
# ax7.set_xticks(ticks=[1, 2], labels=["baseline", "schools\nopen"]) #, fontsize=15)
remove_axes_box(ax7)

Expand Down
6 changes: 3 additions & 3 deletions autumn/projects/sm_covid2/common_school/runner_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def get_uncertainty_dfs(full_runs, quantiles=[.025, .25, .5, .75, .975]):
def calculate_diff_output_quantiles(full_runs, quantiles=[.025, .25, .5, .75, .975]):
diff_names = {
"cases_averted": "cumulative_incidence",
"death_averted": "cumulative_infection_deaths",
"deaths_averted": "cumulative_infection_deaths",
"delta_hospital_peak": "peak_hospital_occupancy",
"delta_student_weeks_missed": "student_weeks_missed"
}
Expand Down Expand Up @@ -317,8 +317,8 @@ def opti_func(sample_dict):
diff_quantiles_df = calculate_diff_output_quantiles(full_runs)
diff_quantiles_df.to_parquet(out_path / "diff_quantiles_df.parquet")

# Make multi-panel figure #FIXME: not compatible with new unc_dfs format
# make_country_output_tiling(iso3, unc_dfs, diff_quantiles_df, output_folder)
# Make multi-panel figure
make_country_output_tiling(iso3, unc_dfs, diff_quantiles_df, output_folder)

return idata, unc_dfs, diff_quantiles_df

Expand Down

0 comments on commit 6b7996a

Please sign in to comment.