Skip to content

Commit

Permalink
Refactor to use estival sampling functionalities
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-ragonnet committed Sep 12, 2023
1 parent 3047952 commit e05fba3
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 149 deletions.
7 changes: 4 additions & 3 deletions autumn/projects/sm_covid2/common_school/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,12 @@ def rp_loglikelihood(params):
return rp_loglikelihood


def get_bcm_object(iso3, analysis="main", _pymc_transform_eps_scale=.1):
def get_bcm_object(iso3, analysis="main", scenario='baseline', _pymc_transform_eps_scale=.1):

assert analysis in ANALYSES_NAMES, "wrong analysis name requested"

project = get_school_project(iso3, analysis)
assert scenario in ['baseline', 'scenario_1'], f"Requested scenario {scenario} not currently supported"

project = get_school_project(iso3, analysis, scenario)
death_target_data = project.calibration.targets[0].data

dispersion_prior = esp.UniformPrior("infection_deaths_dispersion_param", (200, 250))
Expand Down
39 changes: 23 additions & 16 deletions autumn/projects/sm_covid2/common_school/project_maker.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
param_path = Path(__file__).parent.resolve() / "params"


def get_school_project(iso3, analysis="main"):
def get_school_project(iso3, analysis="main", scenario='baseline'):

# read seroprevalence data (needed to specify the sero age range params and then to define the calibration targets)
positive_prop, sero_target_sd, midpoint_as_int, sero_age_min, sero_age_max = get_sero_estimate(iso3)
Expand All @@ -80,10 +80,10 @@ def get_school_project(iso3, analysis="main"):
first_date_with_death = infection_deaths_ma7[round(infection_deaths_ma7) >= 1].index[0]

# Get parameter set
param_set = get_school_project_parameter_set(iso3, first_date_with_death, sero_age_min, sero_age_max, analysis)
param_set = get_school_project_parameter_set(iso3, first_date_with_death, sero_age_min, sero_age_max, analysis, scenario)

# Define priors
priors = get_school_project_priors(first_date_with_death)
priors = get_school_project_priors(first_date_with_death, scenario)

# define calibration targets
model_end_time = param_set.baseline.to_dict()["time"]["end"]
Expand Down Expand Up @@ -166,7 +166,7 @@ def calc_death_missed_school_ratio(deaths_averted, student_weeks_missed):
return project


def get_school_project_parameter_set(iso3, first_date_with_death, sero_age_min, sero_age_max, analysis="main"):
def get_school_project_parameter_set(iso3, first_date_with_death, sero_age_min, sero_age_max, analysis="main", scenario='baseline'):
"""
Get the country-specific parameter sets.
Expand All @@ -177,6 +177,16 @@ def get_school_project_parameter_set(iso3, first_date_with_death, sero_age_min,
Returns:
param_set: A ParameterSet object containing parameter sets for baseline and scenarios
"""
scenario_params = {
"baseline": {},
"scenario_1": {
"mobility": {
"unesco_partial_opening_value": 1.,
"unesco_full_closure_value": 1.
}
}
}

# get common parameters
base_params = get_base_params()

Expand Down Expand Up @@ -219,22 +229,15 @@ def get_school_project_parameter_set(iso3, first_date_with_death, sero_age_min,
{"infectious_seed_time": first_date_with_death - 40.}
)

# update using MLE params, if available
# mle_path= param_path / "mle_files" / f"mle_{iso3}.yml"
# if exists(mle_path):
# baseline_params = baseline_params.update(mle_path, calibration_format=True)

# update using potential Sensitivity Analysis params
sa_params_path = param_path / "SA_analyses" / f"{analysis}.yml"
baseline_params = baseline_params.update(sa_params_path, calibration_format=True)

# get scenario parameters
scenario_dir_path = param_path
scenario_paths = get_all_available_scenario_paths(scenario_dir_path)
scenario_params = [baseline_params.update(p) for p in scenario_paths]
# update using scenario parameters
baseline_params = baseline_params.update(scenario_params[scenario])

# build ParameterSet object
param_set = ParameterSet(baseline=baseline_params, scenarios=scenario_params)
param_set = ParameterSet(baseline=baseline_params)

return param_set

Expand Down Expand Up @@ -354,7 +357,7 @@ def remove_death_outliers(iso3, data):
return data


def get_school_project_priors(first_date_with_death):
def get_school_project_priors(first_date_with_death, scenario):
"""
Get the list of calibration priors. This depends on the first date with death which is used to define the
prior around the infection seed.
Expand All @@ -380,10 +383,14 @@ def get_school_project_priors(first_date_with_death):
# UniformPrior("voc_emergence.omicron.new_voc_seed.time_from_gisaid_report", [-30, 30]),

# Account for mixing matrix uncertainty
UniformPrior("mobility.unesco_partial_opening_value", [0.1, 0.5]),
UniformPrior("school_multiplier", [0.8, 1.2]),
]

if scenario == 'baseline':
priors.append(
UniformPrior("mobility.unesco_partial_opening_value", [0.1, 0.5])
)

return priors


Expand Down
190 changes: 61 additions & 129 deletions autumn/projects/sm_covid2/common_school/runner_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import nevergrad as ng

from estival.utils.parallel import map_parallel
from estival.sampling import tools as esamp

from autumn.core.runs import ManagedRun
from autumn.infrastructure.remote import springboard
Expand Down Expand Up @@ -174,14 +175,7 @@ def extract_sample_subset(idata, n_samples, burn_in, chain_filter: list = None):
chain_length = idata.sample_stats.sizes['draw']
burnt_idata = idata.sel(draw=range(burn_in, chain_length)) # Discard burn-in

if chain_filter:
burnt_idata = burnt_idata.sel(chain=chain_filter)

param_names = list(burnt_idata.posterior.data_vars.keys())
sampled_idata = az.extract(burnt_idata, num_samples=n_samples) # Sample from the inference data
sampled_df = sampled_idata.to_dataframe()[param_names]

return sampled_df.sort_index(level="draw").sort_index(level="chain")
return az.extract(burnt_idata, num_samples=n_samples)


def get_sampled_results(sampled_df, output_names):
Expand All @@ -204,121 +198,59 @@ def get_sampled_results(sampled_df, output_names):
return sampled_results


def run_full_runs(sampled_df, iso3, analysis):

output_names=[
"infection_deaths_ma7", "prop_ever_infected_age_matched", "prop_ever_infected", "cumulative_incidence",
"cumulative_infection_deaths", "hospital_occupancy", 'peak_hospital_occupancy', 'student_weeks_missed',
"transformed_random_process", "random_process_auc"
]

project = get_school_project(iso3, analysis)
default_params = project.param_set.baseline
model = project.build_model(default_params.to_dict())

scenario_params = {
"baseline": {},
"scenario_1": {
"mobility": {
"unesco_partial_opening_value": 1.,
"unesco_full_closure_value": 1.
}
}
}

d2_index = pd.Index([index[:2] for index in sampled_df.index]).unique()
def run_full_runs(sampled_params, iso3, analysis):

outputs_df = pd.DataFrame(columns=output_names + ["scenario", "urun"])
for chain, draw in d2_index:
# read rp delta values
update_params = sampled_df.loc[chain, draw, 0].to_dict()
delta_values = sampled_df.loc[chain, draw]['random_process.delta_values']
update_params["random_process.delta_values"] = np.array(delta_values)

params_dict = default_params.update(update_params, calibration_format=True)

for sc_name, sc_params in scenario_params.items():
params = params_dict.update(sc_params)
model.run(params.to_dict())
derived_df = model.get_derived_outputs_df()[output_names]
derived_df["scenario"] = [sc_name] * len(derived_df)
derived_df["urun"] = [f"{chain}_{draw}"] * len(derived_df)

outputs_df = outputs_df.append(derived_df)

return outputs_df
full_runs = {}
for scenario in ["baseline", "scenario_1"]:
bcm = get_bcm_object(iso3, analysis, scenario)
full_run_params = sampled_params[list(bcm.priors)] # drop parameters for which scenario value should override calibrated value
full_runs[scenario] = esamp.model_results_for_samples(full_run_params, bcm)

return full_runs

def diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, column, relative=False):
if not relative:
return outputs_df_latest_0[column] - outputs_df_latest_1[column]
else:
return (outputs_df_latest_0[column] - outputs_df_latest_1[column]) / outputs_df_latest_1[column]


def calculate_diff_outputs(outputs_df):
def get_uncertainty_dfs(full_runs, quantiles=[.025, .25, .5, .75, .975]):
unc_dfs = {}
for scenario in full_runs:
unc_df = esamp.quantiles_for_results(full_runs[scenario].results, quantiles)
unc_df.columns.set_levels([str(q) for q in unc_df.columns.levels[1]], level=1, inplace=True) # to avoid using floats as column names (not parquet-compatible)
unc_dfs[scenario] = unc_df

index = outputs_df['urun'].unique()
latest_time = outputs_df.index.max()
return unc_dfs

outputs_df_latest_0 = outputs_df[outputs_df['scenario'] == "baseline"].loc[latest_time]
outputs_df_latest_0.index = outputs_df_latest_0['urun']

outputs_df_latest_1 = outputs_df[outputs_df['scenario'] == "scenario_1"].loc[latest_time]
outputs_df_latest_1.index = outputs_df_latest_1['urun']
def calculate_diff_output_quantiles(full_runs, quantiles=[.025, .25, .5, .75, .975]):
diff_names = {
"cases_averted": "cumulative_incidence",
"death_averted": "cumulative_infection_deaths",
"delta_hospital_peak": "peak_hospital_occupancy",
"delta_student_weeks_missed": "student_weeks_missed"
}

diff_outputs_df = pd.DataFrame(index=index)
diff_outputs_df.index.name = "urun"

diff_outputs_df["cases_averted"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "cumulative_incidence")
diff_outputs_df["cases_averted_relative"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "cumulative_incidence", relative=True)

diff_outputs_df["deaths_averted"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "cumulative_infection_deaths")
diff_outputs_df["deaths_averted_relative"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "cumulative_infection_deaths", relative=True)

diff_outputs_df["delta_student_weeks_missed"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "student_weeks_missed")

diff_outputs_df["delta_hospital_peak"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "peak_hospital_occupancy")
diff_outputs_df["delta_hospital_peak_relative"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "peak_hospital_occupancy", relative=True)

return diff_outputs_df


def get_quantile_outputs(outputs_df, diff_outputs_df, quantiles=[.025, .25, .5, .75, .975]):

times = sorted(outputs_df.index.unique())
scenarios = outputs_df["scenario"].unique()
unc_output_names = [
"infection_deaths_ma7", "prop_ever_infected_age_matched", "prop_ever_infected", "transformed_random_process", "random_process_auc",
"cumulative_incidence", "cumulative_infection_deaths", "peak_hospital_occupancy", "hospital_occupancy"
]

uncertainty_data = []
for scenario in scenarios:
scenario_mask = outputs_df["scenario"] == scenario
scenario_df = outputs_df[scenario_mask]

for time in times:
masked_df = scenario_df.loc[time]
if masked_df.empty:
continue
for output_name in unc_output_names:
quantile_vals = np.quantile(masked_df[output_name], quantiles)
for q_idx, q_value in enumerate(quantile_vals):
datum = {
"scenario": scenario,
"type": output_name,
"time": time,
"quantile": quantiles[q_idx],
"value": q_value,
}
uncertainty_data.append(datum)

uncertainty_df = pd.DataFrame(uncertainty_data)

diff_quantiles_df = pd.DataFrame(index=quantiles, data={col: np.quantile(diff_outputs_df[col], quantiles) for col in diff_outputs_df.columns})
latest_time = full_runs['baseline'].results.index.max()

return uncertainty_df, diff_quantiles_df
runs_0_latest = full_runs['baseline'].results.loc[latest_time]
runs_1_latest = full_runs['scenario_1'].results.loc[latest_time]

abs_diff = runs_0_latest - runs_1_latest
rel_diff = (runs_0_latest - runs_1_latest) / runs_1_latest

diff_quantiles_df_abs = pd.DataFrame(
index=quantiles,
data={colname: abs_diff[output_name].quantile(quantiles) for colname, output_name in diff_names.items()}
)
diff_quantiles_df_rel = pd.DataFrame(
index=quantiles,
data={f"{colname}_relative" : rel_diff[output_name].quantile(quantiles) for colname, output_name in diff_names.items()}
)

return pd.concat([diff_quantiles_df_abs, diff_quantiles_df_rel], axis=1)


"""
Expand Down Expand Up @@ -360,11 +292,14 @@ def opti_func(sample_dict):
best_p, _ = optimise_model_fit(bcm, num_workers=n_opti_workers, search_iterations=opti_params['search_iterations'], suggested_start=sample_dict)
return best_p

best_params = map_parallel(opti_func, sample_as_dicts, n_workers=2 * n_cores / n_opti_workers) # oversubscribing
best_params = map_parallel(opti_func, sample_as_dicts, n_workers=int(2 * n_cores / n_opti_workers)) # oversubscribing
# Store optimal solutions
with open(out_path / "best_params.yml", "w") as f:
yaml.dump(best_params, f)

if logger:
logger.info("... optimisation completed")

# Keep only n_chains best solutions
loglikelihoods = [bcm.loglikelihood(**p) for p in best_params]
ll_cutoff = sorted(loglikelihoods, reverse=True)[mcmc_params['chains'] - 1]
Expand All @@ -379,19 +314,11 @@ def opti_func(sample_dict):
with open(out_path / "retained_best_params.yml", "w") as f:
yaml.dump(retained_best_params, f)

# Plot optimal solutions and starting points
# Plot optimal solutions and matching starting points
plot_opti_params(retained_init_points, retained_best_params, bcm, output_folder)

# # Plot optimal model fits
# opt_fits_path = out_path / "optimised_fits"
# opt_fits_path.mkdir(exist_ok=True)
# for j, best_p in enumerate(retained_best_params):
# plot_model_fit(bcm, best_p, iso3, opt_fits_path / f"best_fit_{j}.png")

# Plot optimised model fits on a same figure
plot_multiple_model_fits(bcm, retained_best_params, out_path / "optimal_fits.png")

if logger:
logger.info("... optimisation completed")

# Early return if MCMC not requested
if mcmc_params['draws'] == 0:
Expand All @@ -415,22 +342,27 @@ def opti_func(sample_dict):
"""
Post-MCMC processes
"""
sample_df = extract_sample_subset(idata, full_run_params['samples'], full_run_params['burn_in'])
sampled_params = extract_sample_subset(idata, full_run_params['samples'], full_run_params['burn_in'])

if logger:
logger.info(f"Perform full runs for {full_run_params['samples']} samples")
outputs_df = run_full_runs(sample_df, iso3, analysis)
full_runs = run_full_runs(sampled_params, iso3, analysis)

if logger:
logger.info("Calculate differential outputs")
diff_outputs_df = calculate_diff_outputs(outputs_df)
logger.info("Calculate uncertainty quantiles")
unc_dfs = get_uncertainty_dfs(full_runs)
for scenario, unc_df in unc_dfs.items():
unc_df.to_parquet(out_path / f"uncertainty_df_{scenario}.parquet")

if logger:
logger.info("Calculate quantiles")
uncertainty_df, diff_quantiles_df = get_quantile_outputs(outputs_df, diff_outputs_df)
uncertainty_df.to_parquet(out_path / "uncertainty_df.parquet")
logger.info("Calculate differential output quantiles")
diff_quantiles_df = calculate_diff_output_quantiles(full_runs)
diff_quantiles_df.to_parquet(out_path / "diff_quantiles_df.parquet")

make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_folder)
# Make multi-panel figure #FIXME: not compatible with new unc_dfs format
# make_country_output_tiling(iso3, unc_dfs, diff_quantiles_df, output_folder)

return idata, uncertainty_df, diff_quantiles_df
return idata, unc_dfs, diff_quantiles_df

"""
Helper functions for remote runs
Expand Down
Loading

0 comments on commit e05fba3

Please sign in to comment.