diff --git a/mlos_bench/mlos_bench/config/schedulers/sync_scheduler.jsonc b/mlos_bench/mlos_bench/config/schedulers/sync_scheduler.jsonc index c3b6438caa..daf95d56fa 100644 --- a/mlos_bench/mlos_bench/config/schedulers/sync_scheduler.jsonc +++ b/mlos_bench/mlos_bench/config/schedulers/sync_scheduler.jsonc @@ -6,6 +6,7 @@ "config": { "trial_config_repeat_count": 3, + "max_trials": -1, // Limited only in hte Optimizer logic/config. "teardown": false } } diff --git a/mlos_bench/mlos_bench/config/schemas/schedulers/scheduler-schema.json b/mlos_bench/mlos_bench/config/schemas/schedulers/scheduler-schema.json index 53a3f02b09..81b2e79754 100644 --- a/mlos_bench/mlos_bench/config/schemas/schedulers/scheduler-schema.json +++ b/mlos_bench/mlos_bench/config/schemas/schedulers/scheduler-schema.json @@ -25,6 +25,12 @@ "description": "Whether to teardown the experiment after running it.", "type": "boolean" }, + "max_trials": { + "description": "Max. number of trials to run. Use -1 or 0 for unlimited.", + "type": "integer", + "minimum": -1, + "examples": [50, -1] + }, "trial_config_repeat_count": { "description": "Number of times to repeat a config.", "type": "integer", diff --git a/mlos_bench/mlos_bench/schedulers/base_scheduler.py b/mlos_bench/mlos_bench/schedulers/base_scheduler.py index 77e7ccbc95..6e3da151e5 100644 --- a/mlos_bench/mlos_bench/schedulers/base_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/base_scheduler.py @@ -67,6 +67,8 @@ def __init__(self, *, self._experiment_id = config["experiment_id"].strip() self._trial_id = int(config["trial_id"]) self._config_id = int(config.get("config_id", -1)) + self._max_trials = int(config.get("max_trials", -1)) + self._trial_count = 0 self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1)) if self._trial_config_repeat_count <= 0: @@ -192,12 +194,12 @@ def _schedule_new_optimizer_suggestions(self) -> bool: self.optimizer.bulk_register(configs, scores, status) self._last_trial_id = max(trial_ids, default=self._last_trial_id) - not_converged = self.optimizer.not_converged() - if not_converged: + not_done = self.not_done() + if not_done: tunables = self.optimizer.suggest() self.schedule_trial(tunables) - return not_converged + return not_done def schedule_trial(self, tunables: TunableGroups) -> None: """ @@ -240,10 +242,20 @@ def _run_schedule(self, running: bool = False) -> None: for trial in self.experiment.pending_trials(datetime.now(UTC), running=running): self.run_trial(trial) + def not_done(self) -> bool: + """ + Check the stopping conditions. + By default, stop when the optimizer converges or max limit of trials reached. + """ + return self.optimizer.not_converged() and ( + self._trial_count < self._max_trials or self._max_trials <= 0 + ) + @abstractmethod def run_trial(self, trial: Storage.Trial) -> None: """ Set up and run a single trial. Save the results in the storage. """ assert self.experiment is not None - _LOG.info("QUEUE: Execute trial: %s", trial) + self._trial_count += 1 + _LOG.info("QUEUE: Execute trial # %d/%d :: %s", self._trial_count, self._max_trials, trial) diff --git a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py index 557e64ffa4..a73a493533 100644 --- a/mlos_bench/mlos_bench/schedulers/sync_scheduler.py +++ b/mlos_bench/mlos_bench/schedulers/sync_scheduler.py @@ -33,11 +33,11 @@ def start(self) -> None: if not is_warm_up: _LOG.warning("Skip pending trials and warm-up: %s", self.optimizer) - not_converged = True - while not_converged: + not_done = True + while not_done: _LOG.info("Optimization loop: Last trial ID: %d", self._last_trial_id) self._run_schedule(is_warm_up) - not_converged = self._schedule_new_optimizer_suggestions() + not_done = self._schedule_new_optimizer_suggestions() is_warm_up = False def run_trial(self, trial: Storage.Trial) -> None: diff --git a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/good/full/sync_sched-full.jsonc b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/good/full/sync_sched-full.jsonc index 63694dac4a..c72e8f4d15 100644 --- a/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/good/full/sync_sched-full.jsonc +++ b/mlos_bench/mlos_bench/tests/config/schemas/schedulers/test-cases/good/full/sync_sched-full.jsonc @@ -6,6 +6,7 @@ "teardown": false, "experiment_id": "MyExperimentName", "config_id": 1, - "trial_id": 1 + "trial_id": 1, + "max_trials": 100 } } diff --git a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py index bddee5729e..90e52bb880 100644 --- a/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py +++ b/mlos_bench/mlos_bench/tests/launcher_parse_args_test.py @@ -109,6 +109,7 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None: # Check that we pick up the right scheduler config: assert isinstance(launcher.scheduler, SyncScheduler) assert launcher.scheduler._trial_config_repeat_count == 3 # pylint: disable=protected-access + assert launcher.scheduler._max_trials == -1 # pylint: disable=protected-access def test_launcher_args_parse_2(config_paths: List[str]) -> None: @@ -136,7 +137,8 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: ' --no-teardown' + \ ' --random-init' + \ ' --random-seed 1234' + \ - ' --trial-config-repeat-count 5' + ' --trial-config-repeat-count 5' + \ + ' --max_trials 200' launcher = Launcher(description=__name__, argv=cli_args.split()) # Check that the parent service assert isinstance(launcher.service, SupportsAuth) @@ -188,6 +190,7 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None: # Check that CLI parameter overrides JSON config: assert isinstance(launcher.scheduler, SyncScheduler) assert launcher.scheduler._trial_config_repeat_count == 5 # pylint: disable=protected-access + assert launcher.scheduler._max_trials == 200 # pylint: disable=protected-access # Check that the value from the file is overridden by the CLI arg. assert config['random_seed'] == 42