Skip to content

Commit

Permalink
Try to fix issues with figure mix up
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-ragonnet committed Aug 30, 2023
1 parent f3c7f7b commit 1ffa59e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def make_post_mc_plots(idata, burn_in, output_folder=None):
chain_length = idata.sample_stats.sizes['draw']

# Traces (including burn-in)
ax = az.plot_trace(idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False);
az.plot_trace(idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False);
plt.subplots_adjust(hspace=.7)
if output_folder:
plt.savefig(output_folder_path / "mc_traces.jpg", facecolor="white", bbox_inches='tight')
Expand All @@ -23,7 +23,7 @@ def make_post_mc_plots(idata, burn_in, output_folder=None):
burnt_idata = idata.sel(draw=range(burn_in, chain_length)) # Discard burn-in

# Traces (after burn-in)
ax = az.plot_trace(burnt_idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False);
az.plot_trace(burnt_idata, figsize=(16, 5.0 * len(idata.posterior)), compact=False);
plt.subplots_adjust(hspace=.7)
if output_folder:
plt.savefig(output_folder_path / "mc_traces_postburnin.jpg", facecolor="white", bbox_inches='tight')
Expand All @@ -48,9 +48,9 @@ def make_post_mc_plots(idata, burn_in, output_folder=None):
rhat_df = raw_rhat_df.drop(columns="random_process.delta_values").loc[0]
for i in range(len(raw_rhat_df)):
rhat_df[f"random_process.delta_values[{i}]"] = raw_rhat_df['random_process.delta_values'][i]
ax = rhat_df.plot.barh(xlim=(1.,1.105))
ax.vlines(x=1.05,ymin=-0.5, ymax=len(rhat_df), linestyles="--", color='orange')
ax.vlines(x=1.1,ymin=-0.5, ymax=len(rhat_df), linestyles="-",color='red')
axis = rhat_df.plot.barh(xlim=(1.,1.105))
axis.vlines(x=1.05,ymin=-0.5, ymax=len(rhat_df), linestyles="--", color='orange')
axis.vlines(x=1.1,ymin=-0.5, ymax=len(rhat_df), linestyles="-",color='red')
if output_folder:
plt.savefig(output_folder_path / "r_hats.jpg", facecolor="white", bbox_inches='tight')
plt.close()
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,8 @@ def plot_model_fit(bcm, params, iso3, outfile=None):
if outfile:
fig.savefig(outfile, facecolor="white")

return fig

plt.close()

def plot_model_fit_with_ll(bcm, params, outfile=None):
REF_DATE = datetime.date(2019,12,31)
Expand Down Expand Up @@ -241,4 +241,4 @@ def plot_multiple_model_fits(bcm, params_list, outfile=None):
if outfile:
fig.savefig(outfile, facecolor="white")

return fig
plt.close()

0 comments on commit 1ffa59e

Please sign in to comment.