Skip to content

Commit

Permalink
Add concept of recoverable errors respected by Scheduler (#3262)
Browse files Browse the repository at this point in the history
Summary:

The motivation is that some metrics are flaky and we don't want to fail the trial just because we encountered one exception fetching.  Especially trials with `period_of_new_data_after_trial_completion()` > 0.

This alternative to implementing this on the metric is that the set of recoverable errors should be a scheduler option, and it's more a matter of scheduler use case than metric.

Reviewed By: Cesar-Cardoso

Differential Revision: D68273328
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Jan 23, 2025
1 parent 369b487 commit abc2381
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 25 deletions.
16 changes: 16 additions & 0 deletions ax/core/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,11 @@ class Metric(SortableBase, SerializationMixin):
"""

data_constructor: type[Data] = Data
# The set of exception types stored in a ``MetchFetchE.exception`` that are
# recoverable ``Scheduler._fetch_and_process_trials_data_results()``.
# Exception may be a subclass of any of these types. If you want your metric
# to never fail the trial, set this to ``{Exception}`` in your metric subclass.
recoverable_exceptions: set[type[Exception]] = set()

def __init__(
self,
Expand Down Expand Up @@ -138,6 +143,17 @@ def period_of_new_data_after_trial_completion(cls) -> timedelta:
"""
return timedelta(0)

@classmethod
def is_reconverable_fetch_e(cls, metric_fetch_e: MetricFetchE) -> bool:
"""Checks whether the given MetricFetchE is recoverable for this metric class
in ``Scheduler._fetch_and_process_trials_data_results``.
"""
if metric_fetch_e.exception is None:
return False
return any(
isinstance(metric_fetch_e.exception, e) for e in cls.recoverable_exceptions
)

# NOTE: This is rarely overridden –– oonly if you want to fetch data in groups
# consisting of multiple different metric classes, for data to be fetched together.
# This makes sense only if `fetch_trial data_multi` or `fetch_experiment_data_multi`
Expand Down
5 changes: 4 additions & 1 deletion ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2041,11 +2041,14 @@ def _fetch_and_process_trials_data_results(
)

# If the fetch failure was for a metric in the optimization config (an
# objective or constraint) the trial as failed
# objective or constraint) mark the trial as failed
optimization_config = self.experiment.optimization_config
if (
optimization_config is not None
and metric_name in optimization_config.metrics.keys()
and not self.experiment.metrics[
metric_name
].is_reconverable_fetch_e(metric_fetch_e=metric_fetch_e)
):
status = self._mark_err_trial_status(
trial=self.experiment.trials[trial_index],
Expand Down
147 changes: 123 additions & 24 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@
get_pending_observation_features_based_on_trial_status,
)
from ax.early_stopping.strategies import BaseEarlyStoppingStrategy
from ax.exceptions.core import OptimizationComplete, UnsupportedError, UserInputError
from ax.exceptions.core import (
AxError,
OptimizationComplete,
UnsupportedError,
UserInputError,
)
from ax.exceptions.generation_strategy import AxGenerationException
from ax.metrics.branin import BraninMetric
from ax.metrics.branin_map import BraninTimestampMapMetric
Expand Down Expand Up @@ -1981,43 +1986,137 @@ def test_fetch_and_process_trials_data_results_failed_objective(self) -> None:
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=gs,
options=SchedulerOptions(
**self.scheduler_options_kwargs,
),
db_settings=self.db_settings_if_always_needed,
)
with patch(
f"{BraninMetric.__module__}.BraninMetric.f", side_effect=Exception("yikes!")
), patch(
f"{BraninMetric.__module__}.BraninMetric.is_available_while_running",
return_value=False,
), self.assertLogs(logger="ax.service.scheduler") as lg:
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=gs,
options=SchedulerOptions(
**self.scheduler_options_kwargs,
),
db_settings=self.db_settings_if_always_needed,
# This trial will fail
with self.assertRaises(FailureRateExceededError):
scheduler.run_n_trials(max_trials=1)
self.assertTrue(
any(
re.search(r"Failed to fetch (branin|m1) for trial 0", warning)
is not None
for warning in lg.output
)
)
self.assertTrue(
any(
re.search(
r"Because (branin|m1) is an objective, marking trial 0 as "
"TrialStatus.FAILED",
warning,
)
is not None
for warning in lg.output
)
)
self.assertEqual(scheduler.experiment.trials[0].status, TrialStatus.FAILED)

def test_fetch_and_process_trials_data_results_failed_objective_but_recoverable(
self,
) -> None:
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=gs,
options=SchedulerOptions(
enforce_immutable_search_space_and_opt_config=False,
**self.scheduler_options_kwargs,
),
db_settings=self.db_settings_if_always_needed,
)
BraninMetric.recoverable_exceptions = {AxError, TypeError}
# we're throwing a recoverable exception because UserInputError
# is a subclass of AxError
with patch(
f"{BraninMetric.__module__}.BraninMetric.f",
side_effect=UserInputError("yikes!"),
), patch(
f"{BraninMetric.__module__}.BraninMetric.is_available_while_running",
return_value=False,
), self.assertLogs(logger="ax.service.scheduler") as lg:
scheduler.run_n_trials(max_trials=1)
self.assertTrue(
any(
re.search(r"Failed to fetch (branin|m1) for trial 0", warning)
is not None
for warning in lg.output
),
lg.output,
)
self.assertTrue(
any(
re.search(
"MetricFetchE INFO: Continuing optimization even though "
"MetricFetchE encountered",
warning,
)
is not None
for warning in lg.output
)
)
self.assertEqual(scheduler.experiment.trials[0].status, TrialStatus.COMPLETED)

def test_fetch_and_process_trials_data_results_failed_objective_not_recoverable(
self,
) -> None:
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=gs,
options=SchedulerOptions(
**self.scheduler_options_kwargs,
),
db_settings=self.db_settings_if_always_needed,
)
# we're throwing a unrecoverable exception because Exception is not subclass
# of either error type in recoverable_exceptions
BraninMetric.recoverable_exceptions = {AxError, TypeError}
with patch(
f"{BraninMetric.__module__}.BraninMetric.f", side_effect=Exception("yikes!")
), patch(
f"{BraninMetric.__module__}.BraninMetric.is_available_while_running",
return_value=False,
), self.assertLogs(logger="ax.service.scheduler") as lg:
# This trial will fail
with self.assertRaises(FailureRateExceededError):
scheduler.run_n_trials(max_trials=1)
self.assertTrue(
any(
re.search(r"Failed to fetch (branin|m1) for trial 0", warning)
is not None
for warning in lg.output
)
self.assertTrue(
any(
re.search(r"Failed to fetch (branin|m1) for trial 0", warning)
is not None
for warning in lg.output
)
self.assertTrue(
any(
re.search(
r"Because (branin|m1) is an objective, marking trial 0 as "
"TrialStatus.FAILED",
warning,
)
is not None
for warning in lg.output
)
self.assertTrue(
any(
re.search(
r"Because (branin|m1) is an objective, marking trial 0 as "
"TrialStatus.FAILED",
warning,
)
is not None
for warning in lg.output
)
self.assertEqual(scheduler.experiment.trials[0].status, TrialStatus.FAILED)
)
self.assertEqual(scheduler.experiment.trials[0].status, TrialStatus.FAILED)

def test_should_consider_optimization_complete(self) -> None:
# Tests non-GSS parts of the completion criterion.
Expand Down

0 comments on commit abc2381

Please sign in to comment.