diff --git a/autumn/projects/sm_covid2/common_school/runner_tools.py b/autumn/projects/sm_covid2/common_school/runner_tools.py index 01d16b746..38e17cd85 100644 --- a/autumn/projects/sm_covid2/common_school/runner_tools.py +++ b/autumn/projects/sm_covid2/common_school/runner_tools.py @@ -14,6 +14,8 @@ from estival.wrappers import nevergrad as eng from estival.wrappers import pymc as epm +import nevergrad as ng + from estival.utils.parallel import map_parallel from autumn.core.runs import ManagedRun @@ -55,10 +57,10 @@ def sample_with_lhs(n_samples, bcm): return sample_as_dicts -def optimise_model_fit(bcm, num_workers: int = 8, warmup_iterations: int = 2000, search_iterations: int = 5000, suggested_start: dict = None): +def optimise_model_fit(bcm, num_workers: int = 8, warmup_iterations: int = 2000, search_iterations: int = 5000, suggested_start: dict = None, opt_class=ng.optimizers.NGOpt): # Build optimizer - opt = eng.optimize_model(bcm, obj_function=bcm.loglikelihood, suggested=suggested_start, num_workers=num_workers) + opt = eng.optimize_model(bcm, obj_function=bcm.loglikelihood, suggested=suggested_start, num_workers=num_workers, opt_class=opt_class) # Run warm-up iterations and if warmup_iterations > 0: @@ -71,11 +73,11 @@ def optimise_model_fit(bcm, num_workers: int = 8, warmup_iterations: int = 2000, return best_params, opt -def multi_country_optimise(iso3_list: list, analysis: str = "main", num_workers: int = 8, search_iterations: int = 7000, parallel_opti_jobs: int = 4, logger=None, out_path: str = None): +def multi_country_optimise(iso3_list: list, analysis: str = "main", num_workers: int = 8, search_iterations: int = 7000, parallel_opti_jobs: int = 4, logger=None, out_path: str = None, opt_class=ng.optimizers.NGOpt): def country_opti_wrapper(iso3): bcm = get_bcm_object(iso3, analysis) - best_p, _ = optimise_model_fit(bcm, num_workers=num_workers, warmup_iterations=0, search_iterations=search_iterations) + best_p, _ = optimise_model_fit(bcm, num_workers=num_workers, warmup_iterations=0, search_iterations=search_iterations, opt_class=opt_class) return best_p if logger: