Skip to content

Commit

Permalink
Can now switch between nevergrad optimisers
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-ragonnet committed Jul 12, 2023
1 parent d0903eb commit d68b86e
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions autumn/projects/sm_covid2/common_school/runner_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from estival.wrappers import nevergrad as eng
from estival.wrappers import pymc as epm

import nevergrad as ng

from estival.utils.parallel import map_parallel

from autumn.core.runs import ManagedRun
Expand Down Expand Up @@ -55,10 +57,10 @@ def sample_with_lhs(n_samples, bcm):
return sample_as_dicts


def optimise_model_fit(bcm, num_workers: int = 8, warmup_iterations: int = 2000, search_iterations: int = 5000, suggested_start: dict = None):
def optimise_model_fit(bcm, num_workers: int = 8, warmup_iterations: int = 2000, search_iterations: int = 5000, suggested_start: dict = None, opt_class=ng.optimizers.NGOpt):

# Build optimizer
opt = eng.optimize_model(bcm, obj_function=bcm.loglikelihood, suggested=suggested_start, num_workers=num_workers)
opt = eng.optimize_model(bcm, obj_function=bcm.loglikelihood, suggested=suggested_start, num_workers=num_workers, opt_class=opt_class)

# Run warm-up iterations and
if warmup_iterations > 0:
Expand All @@ -71,11 +73,11 @@ def optimise_model_fit(bcm, num_workers: int = 8, warmup_iterations: int = 2000,
return best_params, opt


def multi_country_optimise(iso3_list: list, analysis: str = "main", num_workers: int = 8, search_iterations: int = 7000, parallel_opti_jobs: int = 4, logger=None, out_path: str = None):
def multi_country_optimise(iso3_list: list, analysis: str = "main", num_workers: int = 8, search_iterations: int = 7000, parallel_opti_jobs: int = 4, logger=None, out_path: str = None, opt_class=ng.optimizers.NGOpt):

def country_opti_wrapper(iso3):
bcm = get_bcm_object(iso3, analysis)
best_p, _ = optimise_model_fit(bcm, num_workers=num_workers, warmup_iterations=0, search_iterations=search_iterations)
best_p, _ = optimise_model_fit(bcm, num_workers=num_workers, warmup_iterations=0, search_iterations=search_iterations, opt_class=opt_class)
return best_p

if logger:
Expand Down

0 comments on commit d68b86e

Please sign in to comment.