diff --git a/ax/benchmark/benchmark.py b/ax/benchmark/benchmark.py index c657af1971f..3e7d295302c 100644 --- a/ax/benchmark/benchmark.py +++ b/ax/benchmark/benchmark.py @@ -125,7 +125,6 @@ def benchmark_replication( generation_strategy=method.generation_strategy.clone_reset(), options=method.scheduler_options, ) - timeout_hours = scheduler.options.timeout_hours # list of parameters for each trial best_params_by_trial: list[list[TParameterization]] = [] @@ -133,18 +132,20 @@ def benchmark_replication( is_mf_or_mt = len(problem.runner.target_fidelity_and_task) > 0 # Run the optimization loop. timeout_hours = scheduler.options.timeout_hours + remaining_hours = timeout_hours with with_rng_seed(seed=seed): start = monotonic() for _ in range(problem.num_trials): next( scheduler.run_trials_and_yield_results( - max_trials=1, timeout_hours=timeout_hours + max_trials=1, timeout_hours=remaining_hours ) ) if timeout_hours is not None: elapsed_hours = (monotonic() - start) / 3600 - timeout_hours = timeout_hours - elapsed_hours - if timeout_hours <= 0: + remaining_hours = timeout_hours - elapsed_hours + if remaining_hours <= 0.0: + logger.warning("The optimization loop timed out.") break if problem.is_moo or is_mf_or_mt: diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index 5985f9919c5..70b07a09e3d 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -7,6 +7,7 @@ import tempfile from itertools import product +from time import monotonic from unittest.mock import patch import numpy as np @@ -448,6 +449,7 @@ def test_timeout(self) -> None: num_sobol_trials=1000, # Ensures we don't use BO ).generation_strategy + timeout_seconds = 2.0 method = BenchmarkMethod( name=generation_strategy.name, generation_strategy=generation_strategy, @@ -455,13 +457,21 @@ def test_timeout(self) -> None: max_pending_trials=1, init_seconds_between_polls=0, min_seconds_before_poll=0, - timeout_hours=0.0001, # Strict timeout of 0.36 seconds + timeout_hours=timeout_seconds / 3600, ), ) # Each replication will have a different number of trials - result = benchmark_one_method_problem( - problem=problem, method=method, seeds=(0, 1) + + start = monotonic() + with self.assertLogs("ax.benchmark.benchmark", level="WARNING") as cm: + result = benchmark_one_method_problem( + problem=problem, method=method, seeds=(0, 1) + ) + elapsed = monotonic() - start + self.assertGreater(elapsed, timeout_seconds) + self.assertIn( + "WARNING:ax.benchmark.benchmark:The optimization loop timed out.", cm.output ) # Test the traces get composited correctly. The AggregatedResult's traces