Skip to content

Commit

Permalink
Add max_trials scheduler config parameter (#719)
Browse files Browse the repository at this point in the history
maybe merge after #713 (they might conflict)

Part of issue #715

---------

Co-authored-by: Brian Kroth <bpkroth@users.noreply.github.com>
  • Loading branch information
motus and bpkroth authored Mar 19, 2024
1 parent d2e7f05 commit f25e09f
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

"config": {
"trial_config_repeat_count": 3,
"max_trials": -1, // Limited only in hte Optimizer logic/config.
"teardown": false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
"description": "Whether to teardown the experiment after running it.",
"type": "boolean"
},
"max_trials": {
"description": "Max. number of trials to run. Use -1 or 0 for unlimited.",
"type": "integer",
"minimum": -1,
"examples": [50, -1]
},
"trial_config_repeat_count": {
"description": "Number of times to repeat a config.",
"type": "integer",
Expand Down
20 changes: 16 additions & 4 deletions mlos_bench/mlos_bench/schedulers/base_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def __init__(self, *,
self._experiment_id = config["experiment_id"].strip()
self._trial_id = int(config["trial_id"])
self._config_id = int(config.get("config_id", -1))
self._max_trials = int(config.get("max_trials", -1))
self._trial_count = 0

self._trial_config_repeat_count = int(config.get("trial_config_repeat_count", 1))
if self._trial_config_repeat_count <= 0:
Expand Down Expand Up @@ -192,12 +194,12 @@ def _schedule_new_optimizer_suggestions(self) -> bool:
self.optimizer.bulk_register(configs, scores, status)
self._last_trial_id = max(trial_ids, default=self._last_trial_id)

not_converged = self.optimizer.not_converged()
if not_converged:
not_done = self.not_done()
if not_done:
tunables = self.optimizer.suggest()
self.schedule_trial(tunables)

return not_converged
return not_done

def schedule_trial(self, tunables: TunableGroups) -> None:
"""
Expand Down Expand Up @@ -240,10 +242,20 @@ def _run_schedule(self, running: bool = False) -> None:
for trial in self.experiment.pending_trials(datetime.now(UTC), running=running):
self.run_trial(trial)

def not_done(self) -> bool:
"""
Check the stopping conditions.
By default, stop when the optimizer converges or max limit of trials reached.
"""
return self.optimizer.not_converged() and (
self._trial_count < self._max_trials or self._max_trials <= 0
)

@abstractmethod
def run_trial(self, trial: Storage.Trial) -> None:
"""
Set up and run a single trial. Save the results in the storage.
"""
assert self.experiment is not None
_LOG.info("QUEUE: Execute trial: %s", trial)
self._trial_count += 1
_LOG.info("QUEUE: Execute trial # %d/%d :: %s", self._trial_count, self._max_trials, trial)
6 changes: 3 additions & 3 deletions mlos_bench/mlos_bench/schedulers/sync_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def start(self) -> None:
if not is_warm_up:
_LOG.warning("Skip pending trials and warm-up: %s", self.optimizer)

not_converged = True
while not_converged:
not_done = True
while not_done:
_LOG.info("Optimization loop: Last trial ID: %d", self._last_trial_id)
self._run_schedule(is_warm_up)
not_converged = self._schedule_new_optimizer_suggestions()
not_done = self._schedule_new_optimizer_suggestions()
is_warm_up = False

def run_trial(self, trial: Storage.Trial) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"teardown": false,
"experiment_id": "MyExperimentName",
"config_id": 1,
"trial_id": 1
"trial_id": 1,
"max_trials": 100
}
}
5 changes: 4 additions & 1 deletion mlos_bench/mlos_bench/tests/launcher_parse_args_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def test_launcher_args_parse_1(config_paths: List[str]) -> None:
# Check that we pick up the right scheduler config:
assert isinstance(launcher.scheduler, SyncScheduler)
assert launcher.scheduler._trial_config_repeat_count == 3 # pylint: disable=protected-access
assert launcher.scheduler._max_trials == -1 # pylint: disable=protected-access


def test_launcher_args_parse_2(config_paths: List[str]) -> None:
Expand Down Expand Up @@ -136,7 +137,8 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None:
' --no-teardown' + \
' --random-init' + \
' --random-seed 1234' + \
' --trial-config-repeat-count 5'
' --trial-config-repeat-count 5' + \
' --max_trials 200'
launcher = Launcher(description=__name__, argv=cli_args.split())
# Check that the parent service
assert isinstance(launcher.service, SupportsAuth)
Expand Down Expand Up @@ -188,6 +190,7 @@ def test_launcher_args_parse_2(config_paths: List[str]) -> None:
# Check that CLI parameter overrides JSON config:
assert isinstance(launcher.scheduler, SyncScheduler)
assert launcher.scheduler._trial_config_repeat_count == 5 # pylint: disable=protected-access
assert launcher.scheduler._max_trials == 200 # pylint: disable=protected-access

# Check that the value from the file is overridden by the CLI arg.
assert config['random_seed'] == 42
Expand Down

0 comments on commit f25e09f

Please sign in to comment.