Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't discard abandoned arms on reduced state #3261

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,13 +956,12 @@ def trial_from_sqa(
new_generator_run_structs.append(struct)
generator_run_structs = new_generator_run_structs
trial._generator_run_structs = generator_run_structs
if not reduced_state:
trial._abandoned_arms_metadata = {
abandoned_arm_sqa.name: self.abandoned_arm_from_sqa(
abandoned_arm_sqa=abandoned_arm_sqa
)
for abandoned_arm_sqa in trial_sqa.abandoned_arms
}
trial._abandoned_arms_metadata = {
abandoned_arm_sqa.name: self.abandoned_arm_from_sqa(
abandoned_arm_sqa=abandoned_arm_sqa
)
for abandoned_arm_sqa in trial_sqa.abandoned_arms
}
trial._refresh_arms_by_name() # Trigger cache build
else:
trial = Trial(
Expand Down
1 change: 0 additions & 1 deletion ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def _get_experiment_sqa_reduced_state(
large experiments, in cases where model state history is not required.
"""
options = get_query_options_to_defer_immutable_duplicates()
options.append(lazyload("abandoned_arms"))
options.extend(get_query_options_to_defer_large_model_cols())

return _get_experiment_sqa(
Expand Down
49 changes: 22 additions & 27 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import BatchTrial, LifecycleStage
from ax.core.batch_trial import LifecycleStage
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
Expand Down Expand Up @@ -480,45 +480,22 @@ def test_ExperimentSaveAndLoadReducedState(
for skip_runners_and_metrics in [False, True]:
# 1. No abandoned arms + no trials case, reduced state should be the
# same as non-reduced state.
exp = get_experiment_with_multi_objective()
save_experiment(exp)
loaded_experiment = load_experiment(
exp.name,
reduced_state=True,
skip_runners_and_metrics=skip_runners_and_metrics,
)
loaded_experiment.runner = exp.runner
self.assertEqual(loaded_experiment, exp)
# Make sure decoder function was called with `reduced_state=True`.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()
delete_experiment(exp_name=exp.name)

# 2. Try case with abandoned arms.
exp = get_experiment_with_batch_trial(constrain_search_space=False)
save_experiment(exp)
loaded_experiment = load_experiment(
exp.name,
reduced_state=True,
skip_runners_and_metrics=skip_runners_and_metrics,
)
# Experiments are not the same, because one has abandoned arms info.
self.assertNotEqual(loaded_experiment, exp)
# Remove all abandoned arms and check that all else is equal as expected.
t = assert_is_instance(exp.trials[0], BatchTrial)
t._abandoned_arms_metadata = {}
loaded_experiment.runner = exp.runner
loaded_experiment._trials[0]._runner = exp._trials[0]._runner
for i in loaded_experiment.trials.keys():
loaded_experiment.trials[i]._runner = exp.trials[i].runner
self.assertEqual(loaded_experiment, exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
# Make sure decoder function was called with `reduced_state=True`.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs + regular and status quo.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()
_mock_trial_from_sqa.reset_mock()
_mock_gr_from_sqa.reset_mock()

# 3. Try case with model state and search space + opt.config on a
# generator run in the experiment.
Expand Down Expand Up @@ -567,6 +544,24 @@ def test_ExperimentSaveAndLoadReducedState(
self.assertEqual(loaded_experiment, exp)
delete_experiment(exp_name=exp.name)

def test_load_and_save_reduced_state_does_not_lose_abandoned_arms(self) -> None:
exp = get_experiment_with_batch_trial(constrain_search_space=False)
exp.trials[0].mark_arm_abandoned(arm_name="0_0", reason="for this test")
save_experiment(exp)
loaded_experiment = load_experiment(
exp.name, reduced_state=True, skip_runners_and_metrics=True
)
save_experiment(loaded_experiment)
reloaded_experiment = load_experiment(exp.name)
self.assertEqual(
reloaded_experiment.trials[0].abandoned_arms,
exp.trials[0].abandoned_arms,
)
self.assertEqual(
len(reloaded_experiment.trials[0].abandoned_arms),
1,
)

def test_ExperimentSaveAndLoadGRWithOptConfig(self) -> None:
exp = get_experiment_with_batch_trial(constrain_search_space=False)
gr = Models.SOBOL(experiment=exp).gen(
Expand Down
Loading