From e05fba3c5ac95e849014f474268828f20b8fc7b9 Mon Sep 17 00:00:00 2001 From: Romain Ragonnet Date: Tue, 12 Sep 2023 16:03:32 +1000 Subject: [PATCH] Refactor to use estival sampling functionalities --- .../sm_covid2/common_school/calibration.py | 7 +- .../sm_covid2/common_school/project_maker.py | 39 ++-- .../sm_covid2/common_school/runner_tools.py | 190 ++++++------------ .../School_Closure/full_analysis_runner.ipynb | 21 +- 4 files changed, 108 insertions(+), 149 deletions(-) diff --git a/autumn/projects/sm_covid2/common_school/calibration.py b/autumn/projects/sm_covid2/common_school/calibration.py index 9e755c6d0..0e66317a7 100644 --- a/autumn/projects/sm_covid2/common_school/calibration.py +++ b/autumn/projects/sm_covid2/common_school/calibration.py @@ -49,11 +49,12 @@ def rp_loglikelihood(params): return rp_loglikelihood -def get_bcm_object(iso3, analysis="main", _pymc_transform_eps_scale=.1): +def get_bcm_object(iso3, analysis="main", scenario='baseline', _pymc_transform_eps_scale=.1): assert analysis in ANALYSES_NAMES, "wrong analysis name requested" - - project = get_school_project(iso3, analysis) + assert scenario in ['baseline', 'scenario_1'], f"Requested scenario {scenario} not currently supported" + + project = get_school_project(iso3, analysis, scenario) death_target_data = project.calibration.targets[0].data dispersion_prior = esp.UniformPrior("infection_deaths_dispersion_param", (200, 250)) diff --git a/autumn/projects/sm_covid2/common_school/project_maker.py b/autumn/projects/sm_covid2/common_school/project_maker.py index 21a1e403f..15d5fb9d3 100644 --- a/autumn/projects/sm_covid2/common_school/project_maker.py +++ b/autumn/projects/sm_covid2/common_school/project_maker.py @@ -55,7 +55,7 @@ param_path = Path(__file__).parent.resolve() / "params" -def get_school_project(iso3, analysis="main"): +def get_school_project(iso3, analysis="main", scenario='baseline'): # read seroprevalence data (needed to specify the sero age range params and then to define the calibration targets) positive_prop, sero_target_sd, midpoint_as_int, sero_age_min, sero_age_max = get_sero_estimate(iso3) @@ -80,10 +80,10 @@ def get_school_project(iso3, analysis="main"): first_date_with_death = infection_deaths_ma7[round(infection_deaths_ma7) >= 1].index[0] # Get parameter set - param_set = get_school_project_parameter_set(iso3, first_date_with_death, sero_age_min, sero_age_max, analysis) + param_set = get_school_project_parameter_set(iso3, first_date_with_death, sero_age_min, sero_age_max, analysis, scenario) # Define priors - priors = get_school_project_priors(first_date_with_death) + priors = get_school_project_priors(first_date_with_death, scenario) # define calibration targets model_end_time = param_set.baseline.to_dict()["time"]["end"] @@ -166,7 +166,7 @@ def calc_death_missed_school_ratio(deaths_averted, student_weeks_missed): return project -def get_school_project_parameter_set(iso3, first_date_with_death, sero_age_min, sero_age_max, analysis="main"): +def get_school_project_parameter_set(iso3, first_date_with_death, sero_age_min, sero_age_max, analysis="main", scenario='baseline'): """ Get the country-specific parameter sets. @@ -177,6 +177,16 @@ def get_school_project_parameter_set(iso3, first_date_with_death, sero_age_min, Returns: param_set: A ParameterSet object containing parameter sets for baseline and scenarios """ + scenario_params = { + "baseline": {}, + "scenario_1": { + "mobility": { + "unesco_partial_opening_value": 1., + "unesco_full_closure_value": 1. + } + } + } + # get common parameters base_params = get_base_params() @@ -219,22 +229,15 @@ def get_school_project_parameter_set(iso3, first_date_with_death, sero_age_min, {"infectious_seed_time": first_date_with_death - 40.} ) - # update using MLE params, if available - # mle_path= param_path / "mle_files" / f"mle_{iso3}.yml" - # if exists(mle_path): - # baseline_params = baseline_params.update(mle_path, calibration_format=True) - # update using potential Sensitivity Analysis params sa_params_path = param_path / "SA_analyses" / f"{analysis}.yml" baseline_params = baseline_params.update(sa_params_path, calibration_format=True) - # get scenario parameters - scenario_dir_path = param_path - scenario_paths = get_all_available_scenario_paths(scenario_dir_path) - scenario_params = [baseline_params.update(p) for p in scenario_paths] + # update using scenario parameters + baseline_params = baseline_params.update(scenario_params[scenario]) # build ParameterSet object - param_set = ParameterSet(baseline=baseline_params, scenarios=scenario_params) + param_set = ParameterSet(baseline=baseline_params) return param_set @@ -354,7 +357,7 @@ def remove_death_outliers(iso3, data): return data -def get_school_project_priors(first_date_with_death): +def get_school_project_priors(first_date_with_death, scenario): """ Get the list of calibration priors. This depends on the first date with death which is used to define the prior around the infection seed. @@ -380,10 +383,14 @@ def get_school_project_priors(first_date_with_death): # UniformPrior("voc_emergence.omicron.new_voc_seed.time_from_gisaid_report", [-30, 30]), # Account for mixing matrix uncertainty - UniformPrior("mobility.unesco_partial_opening_value", [0.1, 0.5]), UniformPrior("school_multiplier", [0.8, 1.2]), ] + if scenario == 'baseline': + priors.append( + UniformPrior("mobility.unesco_partial_opening_value", [0.1, 0.5]) + ) + return priors diff --git a/autumn/projects/sm_covid2/common_school/runner_tools.py b/autumn/projects/sm_covid2/common_school/runner_tools.py index 25c3ba452..e8745098b 100644 --- a/autumn/projects/sm_covid2/common_school/runner_tools.py +++ b/autumn/projects/sm_covid2/common_school/runner_tools.py @@ -17,6 +17,7 @@ import nevergrad as ng from estival.utils.parallel import map_parallel +from estival.sampling import tools as esamp from autumn.core.runs import ManagedRun from autumn.infrastructure.remote import springboard @@ -174,14 +175,7 @@ def extract_sample_subset(idata, n_samples, burn_in, chain_filter: list = None): chain_length = idata.sample_stats.sizes['draw'] burnt_idata = idata.sel(draw=range(burn_in, chain_length)) # Discard burn-in - if chain_filter: - burnt_idata = burnt_idata.sel(chain=chain_filter) - - param_names = list(burnt_idata.posterior.data_vars.keys()) - sampled_idata = az.extract(burnt_idata, num_samples=n_samples) # Sample from the inference data - sampled_df = sampled_idata.to_dataframe()[param_names] - - return sampled_df.sort_index(level="draw").sort_index(level="chain") + return az.extract(burnt_idata, num_samples=n_samples) def get_sampled_results(sampled_df, output_names): @@ -204,121 +198,59 @@ def get_sampled_results(sampled_df, output_names): return sampled_results -def run_full_runs(sampled_df, iso3, analysis): - - output_names=[ - "infection_deaths_ma7", "prop_ever_infected_age_matched", "prop_ever_infected", "cumulative_incidence", - "cumulative_infection_deaths", "hospital_occupancy", 'peak_hospital_occupancy', 'student_weeks_missed', - "transformed_random_process", "random_process_auc" - ] - - project = get_school_project(iso3, analysis) - default_params = project.param_set.baseline - model = project.build_model(default_params.to_dict()) - - scenario_params = { - "baseline": {}, - "scenario_1": { - "mobility": { - "unesco_partial_opening_value": 1., - "unesco_full_closure_value": 1. - } - } - } - - d2_index = pd.Index([index[:2] for index in sampled_df.index]).unique() +def run_full_runs(sampled_params, iso3, analysis): - outputs_df = pd.DataFrame(columns=output_names + ["scenario", "urun"]) - for chain, draw in d2_index: - # read rp delta values - update_params = sampled_df.loc[chain, draw, 0].to_dict() - delta_values = sampled_df.loc[chain, draw]['random_process.delta_values'] - update_params["random_process.delta_values"] = np.array(delta_values) - - params_dict = default_params.update(update_params, calibration_format=True) - - for sc_name, sc_params in scenario_params.items(): - params = params_dict.update(sc_params) - model.run(params.to_dict()) - derived_df = model.get_derived_outputs_df()[output_names] - derived_df["scenario"] = [sc_name] * len(derived_df) - derived_df["urun"] = [f"{chain}_{draw}"] * len(derived_df) - - outputs_df = outputs_df.append(derived_df) - - return outputs_df + full_runs = {} + for scenario in ["baseline", "scenario_1"]: + bcm = get_bcm_object(iso3, analysis, scenario) + full_run_params = sampled_params[list(bcm.priors)] # drop parameters for which scenario value should override calibrated value + full_runs[scenario] = esamp.model_results_for_samples(full_run_params, bcm) + return full_runs def diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, column, relative=False): if not relative: return outputs_df_latest_0[column] - outputs_df_latest_1[column] else: return (outputs_df_latest_0[column] - outputs_df_latest_1[column]) / outputs_df_latest_1[column] - + -def calculate_diff_outputs(outputs_df): +def get_uncertainty_dfs(full_runs, quantiles=[.025, .25, .5, .75, .975]): + unc_dfs = {} + for scenario in full_runs: + unc_df = esamp.quantiles_for_results(full_runs[scenario].results, quantiles) + unc_df.columns.set_levels([str(q) for q in unc_df.columns.levels[1]], level=1, inplace=True) # to avoid using floats as column names (not parquet-compatible) + unc_dfs[scenario] = unc_df - index = outputs_df['urun'].unique() - latest_time = outputs_df.index.max() + return unc_dfs - outputs_df_latest_0 = outputs_df[outputs_df['scenario'] == "baseline"].loc[latest_time] - outputs_df_latest_0.index = outputs_df_latest_0['urun'] - outputs_df_latest_1 = outputs_df[outputs_df['scenario'] == "scenario_1"].loc[latest_time] - outputs_df_latest_1.index = outputs_df_latest_1['urun'] +def calculate_diff_output_quantiles(full_runs, quantiles=[.025, .25, .5, .75, .975]): + diff_names = { + "cases_averted": "cumulative_incidence", + "death_averted": "cumulative_infection_deaths", + "delta_hospital_peak": "peak_hospital_occupancy", + "delta_student_weeks_missed": "student_weeks_missed" + } - diff_outputs_df = pd.DataFrame(index=index) - diff_outputs_df.index.name = "urun" - - diff_outputs_df["cases_averted"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "cumulative_incidence") - diff_outputs_df["cases_averted_relative"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "cumulative_incidence", relative=True) - - diff_outputs_df["deaths_averted"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "cumulative_infection_deaths") - diff_outputs_df["deaths_averted_relative"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "cumulative_infection_deaths", relative=True) - - diff_outputs_df["delta_student_weeks_missed"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "student_weeks_missed") - - diff_outputs_df["delta_hospital_peak"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "peak_hospital_occupancy") - diff_outputs_df["delta_hospital_peak_relative"] = diff_latest_output(outputs_df_latest_0, outputs_df_latest_1, "peak_hospital_occupancy", relative=True) - - return diff_outputs_df - - -def get_quantile_outputs(outputs_df, diff_outputs_df, quantiles=[.025, .25, .5, .75, .975]): - - times = sorted(outputs_df.index.unique()) - scenarios = outputs_df["scenario"].unique() - unc_output_names = [ - "infection_deaths_ma7", "prop_ever_infected_age_matched", "prop_ever_infected", "transformed_random_process", "random_process_auc", - "cumulative_incidence", "cumulative_infection_deaths", "peak_hospital_occupancy", "hospital_occupancy" - ] - - uncertainty_data = [] - for scenario in scenarios: - scenario_mask = outputs_df["scenario"] == scenario - scenario_df = outputs_df[scenario_mask] - - for time in times: - masked_df = scenario_df.loc[time] - if masked_df.empty: - continue - for output_name in unc_output_names: - quantile_vals = np.quantile(masked_df[output_name], quantiles) - for q_idx, q_value in enumerate(quantile_vals): - datum = { - "scenario": scenario, - "type": output_name, - "time": time, - "quantile": quantiles[q_idx], - "value": q_value, - } - uncertainty_data.append(datum) - - uncertainty_df = pd.DataFrame(uncertainty_data) - - diff_quantiles_df = pd.DataFrame(index=quantiles, data={col: np.quantile(diff_outputs_df[col], quantiles) for col in diff_outputs_df.columns}) + latest_time = full_runs['baseline'].results.index.max() - return uncertainty_df, diff_quantiles_df + runs_0_latest = full_runs['baseline'].results.loc[latest_time] + runs_1_latest = full_runs['scenario_1'].results.loc[latest_time] + + abs_diff = runs_0_latest - runs_1_latest + rel_diff = (runs_0_latest - runs_1_latest) / runs_1_latest + + diff_quantiles_df_abs = pd.DataFrame( + index=quantiles, + data={colname: abs_diff[output_name].quantile(quantiles) for colname, output_name in diff_names.items()} + ) + diff_quantiles_df_rel = pd.DataFrame( + index=quantiles, + data={f"{colname}_relative" : rel_diff[output_name].quantile(quantiles) for colname, output_name in diff_names.items()} + ) + + return pd.concat([diff_quantiles_df_abs, diff_quantiles_df_rel], axis=1) """ @@ -360,11 +292,14 @@ def opti_func(sample_dict): best_p, _ = optimise_model_fit(bcm, num_workers=n_opti_workers, search_iterations=opti_params['search_iterations'], suggested_start=sample_dict) return best_p - best_params = map_parallel(opti_func, sample_as_dicts, n_workers=2 * n_cores / n_opti_workers) # oversubscribing + best_params = map_parallel(opti_func, sample_as_dicts, n_workers=int(2 * n_cores / n_opti_workers)) # oversubscribing # Store optimal solutions with open(out_path / "best_params.yml", "w") as f: yaml.dump(best_params, f) + if logger: + logger.info("... optimisation completed") + # Keep only n_chains best solutions loglikelihoods = [bcm.loglikelihood(**p) for p in best_params] ll_cutoff = sorted(loglikelihoods, reverse=True)[mcmc_params['chains'] - 1] @@ -379,19 +314,11 @@ def opti_func(sample_dict): with open(out_path / "retained_best_params.yml", "w") as f: yaml.dump(retained_best_params, f) - # Plot optimal solutions and starting points + # Plot optimal solutions and matching starting points plot_opti_params(retained_init_points, retained_best_params, bcm, output_folder) - # # Plot optimal model fits - # opt_fits_path = out_path / "optimised_fits" - # opt_fits_path.mkdir(exist_ok=True) - # for j, best_p in enumerate(retained_best_params): - # plot_model_fit(bcm, best_p, iso3, opt_fits_path / f"best_fit_{j}.png") - + # Plot optimised model fits on a same figure plot_multiple_model_fits(bcm, retained_best_params, out_path / "optimal_fits.png") - - if logger: - logger.info("... optimisation completed") # Early return if MCMC not requested if mcmc_params['draws'] == 0: @@ -415,22 +342,27 @@ def opti_func(sample_dict): """ Post-MCMC processes """ - sample_df = extract_sample_subset(idata, full_run_params['samples'], full_run_params['burn_in']) + sampled_params = extract_sample_subset(idata, full_run_params['samples'], full_run_params['burn_in']) + if logger: logger.info(f"Perform full runs for {full_run_params['samples']} samples") - outputs_df = run_full_runs(sample_df, iso3, analysis) + full_runs = run_full_runs(sampled_params, iso3, analysis) + if logger: - logger.info("Calculate differential outputs") - diff_outputs_df = calculate_diff_outputs(outputs_df) + logger.info("Calculate uncertainty quantiles") + unc_dfs = get_uncertainty_dfs(full_runs) + for scenario, unc_df in unc_dfs.items(): + unc_df.to_parquet(out_path / f"uncertainty_df_{scenario}.parquet") + if logger: - logger.info("Calculate quantiles") - uncertainty_df, diff_quantiles_df = get_quantile_outputs(outputs_df, diff_outputs_df) - uncertainty_df.to_parquet(out_path / "uncertainty_df.parquet") + logger.info("Calculate differential output quantiles") + diff_quantiles_df = calculate_diff_output_quantiles(full_runs) diff_quantiles_df.to_parquet(out_path / "diff_quantiles_df.parquet") - make_country_output_tiling(iso3, uncertainty_df, diff_quantiles_df, output_folder) + # Make multi-panel figure #FIXME: not compatible with new unc_dfs format + # make_country_output_tiling(iso3, unc_dfs, diff_quantiles_df, output_folder) - return idata, uncertainty_df, diff_quantiles_df + return idata, unc_dfs, diff_quantiles_df """ Helper functions for remote runs diff --git a/notebooks/user/rragonnet/project_specific/School_Closure/full_analysis_runner.ipynb b/notebooks/user/rragonnet/project_specific/School_Closure/full_analysis_runner.ipynb index 8518fd74f..477252f3d 100644 --- a/notebooks/user/rragonnet/project_specific/School_Closure/full_analysis_runner.ipynb +++ b/notebooks/user/rragonnet/project_specific/School_Closure/full_analysis_runner.ipynb @@ -36,6 +36,8 @@ "ISO3 = \"FRA\"\n", "ANALYSIS = \"main\"\n", "\n", + "N_CORES = 8\n", + "\n", "N_CHAINS = 2 # DEFAULT_RUN_CONFIG['N_CHAINS']\n", "N_OPTI_SEARCHES = 2 # DEFAULT_RUN_CONFIG['N_OPTI_SEARCHES']\n", "OPTI_BUDGET = 100 # DEFAULT_RUN_CONFIG['OPTI_BUDGET']\n", @@ -54,15 +56,32 @@ "metadata": {}, "outputs": [], "source": [ - "idata, uncertainty_df, diff_quantiles_df = run_full_analysis(\n", + "from estival.utils.parallel import map_parallel" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "idata, unc_dfs, diff_quantiles_df = run_full_analysis(\n", " ISO3,\n", " analysis=ANALYSIS, \n", + " n_cores=N_CORES,\n", " opti_params={'n_searches': N_OPTI_SEARCHES, 'search_iterations': OPTI_BUDGET},\n", " mcmc_params={'draws': METROPOLIS_DRAWS, 'tune': METROPOLIS_TUNE, 'chains': N_CHAINS, 'method': METROPOLIS_METHOD},\n", " full_run_params={'samples': FULL_RUNS_SAMPLES, 'burn_in': BURN_IN},\n", " output_folder=\"test_full\",\n", ")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {