diff --git a/autumn/projects/sm_covid2/common_school/calibration_plots/mc_plots.py b/autumn/projects/sm_covid2/common_school/calibration_plots/mc_plots.py index 27ce249ab..7cfcfa24f 100644 --- a/autumn/projects/sm_covid2/common_school/calibration_plots/mc_plots.py +++ b/autumn/projects/sm_covid2/common_school/calibration_plots/mc_plots.py @@ -13,7 +13,7 @@ def make_post_mc_plots(idata, burn_in, output_folder=None): chain_length = idata.sample_stats.sizes['draw'] # Traces (including burn-in) - ax = az.plot_trace(idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False); + az.plot_trace(idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False); plt.subplots_adjust(hspace=.7) if output_folder: plt.savefig(output_folder_path / "mc_traces.jpg", facecolor="white", bbox_inches='tight') @@ -23,7 +23,7 @@ def make_post_mc_plots(idata, burn_in, output_folder=None): burnt_idata = idata.sel(draw=range(burn_in, chain_length)) # Discard burn-in # Traces (after burn-in) - ax = az.plot_trace(burnt_idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False); + az.plot_trace(burnt_idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False); plt.subplots_adjust(hspace=.7) if output_folder: plt.savefig(output_folder_path / "mc_traces_postburnin.jpg", facecolor="white", bbox_inches='tight') @@ -48,9 +48,9 @@ def make_post_mc_plots(idata, burn_in, output_folder=None): rhat_df = raw_rhat_df.drop(columns="random_process.delta_values").loc[0] for i in range(len(raw_rhat_df)): rhat_df[f"random_process.delta_values[{i}]"] = raw_rhat_df['random_process.delta_values'][i] - ax = rhat_df.plot.barh(xlim=(1.,1.105)) - ax.vlines(x=1.05,ymin=-0.5, ymax=len(rhat_df), linestyles="--", color='orange') - ax.vlines(x=1.1,ymin=-0.5, ymax=len(rhat_df), linestyles="-",color='red') + axis = rhat_df.plot.barh(xlim=(1.,1.105)) + axis.vlines(x=1.05,ymin=-0.5, ymax=len(rhat_df), linestyles="--", color='orange') + axis.vlines(x=1.1,ymin=-0.5, ymax=len(rhat_df), linestyles="-",color='red') if output_folder: plt.savefig(output_folder_path / "r_hats.jpg", facecolor="white", bbox_inches='tight') plt.close() diff --git a/autumn/projects/sm_covid2/common_school/calibration_plots/opti_plots.py b/autumn/projects/sm_covid2/common_school/calibration_plots/opti_plots.py index fff1a1473..1e7471653 100644 --- a/autumn/projects/sm_covid2/common_school/calibration_plots/opti_plots.py +++ b/autumn/projects/sm_covid2/common_school/calibration_plots/opti_plots.py @@ -133,8 +133,8 @@ def plot_model_fit(bcm, params, iso3, outfile=None): if outfile: fig.savefig(outfile, facecolor="white") - return fig - + plt.close() + def plot_model_fit_with_ll(bcm, params, outfile=None): REF_DATE = datetime.date(2019,12,31) @@ -241,4 +241,4 @@ def plot_multiple_model_fits(bcm, params_list, outfile=None): if outfile: fig.savefig(outfile, facecolor="white") - return fig \ No newline at end of file + plt.close() \ No newline at end of file