Skip to content

Commit

Permalink
Revise LHS sampling code
Browse files Browse the repository at this point in the history
  • Loading branch information
romain-ragonnet committed Sep 14, 2023
1 parent b23ae7c commit 4fe8be8
Showing 1 changed file with 2 additions and 19 deletions.
21 changes: 2 additions & 19 deletions autumn/projects/sm_covid2/common_school/runner_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,24 +75,6 @@
Functions related to model calibration
"""

def sample_with_lhs(n_samples, bcm):

# sample using LHS in the right dimension
lhs_sampled_params = [p for p in bcm.priors if p != "random_process.delta_values"]
d = len(lhs_sampled_params)
sampler = qmc.LatinHypercube(d=d)
regular_sample = sampler.random(n=n_samples)

# scale the data cube to match parameter bounds
l_bounds = [bcm.priors[p].bounds()[0] for p in lhs_sampled_params]
u_bounds = [bcm.priors[p].bounds()[1] for p in lhs_sampled_params]
sample = qmc.scale(regular_sample, l_bounds, u_bounds)

sample_as_dicts = [{p: sample[i][j] for j, p in enumerate(lhs_sampled_params)} for i in range(n_samples)]

return sample_as_dicts


def optimise_model_fit(bcm, num_workers: int = 8, warmup_iterations: int = 0, search_iterations: int = 5000, suggested_start: dict = None, opt_class=ng.optimizers.CMA):

# Build optimizer
Expand Down Expand Up @@ -251,7 +233,8 @@ def run_full_analysis(
# Sample optimisation starting points with LHS
if logger:
logger.info("Perform LHS sampling")
sample_as_dicts = sample_with_lhs(run_config['n_opti_searches'], bcm)
sample_as_dicts = bcm.sample.lhs(run_config['n_opti_searches'], out_type="list_of_dicts")
sample_as_dicts = [{p: val for p, val in d.items() if p != 'random_process.delta_values'} for d in sample_as_dicts] # remove random process from LHS sample

# Store starting points
with open(out_path / "LHS_init_points.yml", "w") as f:
Expand Down

0 comments on commit 4fe8be8

Please sign in to comment.