diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 4200d1015b0..c83c83636d4 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -956,13 +956,12 @@ def trial_from_sqa( new_generator_run_structs.append(struct) generator_run_structs = new_generator_run_structs trial._generator_run_structs = generator_run_structs - if not reduced_state: - trial._abandoned_arms_metadata = { - abandoned_arm_sqa.name: self.abandoned_arm_from_sqa( - abandoned_arm_sqa=abandoned_arm_sqa - ) - for abandoned_arm_sqa in trial_sqa.abandoned_arms - } + trial._abandoned_arms_metadata = { + abandoned_arm_sqa.name: self.abandoned_arm_from_sqa( + abandoned_arm_sqa=abandoned_arm_sqa + ) + for abandoned_arm_sqa in trial_sqa.abandoned_arms + } trial._refresh_arms_by_name() # Trigger cache build else: trial = Trial( diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index f418a7d6e62..58ebdcce10a 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -279,7 +279,6 @@ def _get_experiment_sqa_reduced_state( large experiments, in cases where model state history is not required. """ options = get_query_options_to_defer_immutable_duplicates() - options.append(lazyload("abandoned_arms")) options.extend(get_query_options_to_defer_large_model_cols()) return _get_experiment_sqa( diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 45f14c3f2cd..6dfafd31637 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -22,7 +22,7 @@ from ax.core.arm import Arm from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose from ax.core.base_trial import TrialStatus -from ax.core.batch_trial import BatchTrial, LifecycleStage +from ax.core.batch_trial import LifecycleStage from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric @@ -480,21 +480,6 @@ def test_ExperimentSaveAndLoadReducedState( for skip_runners_and_metrics in [False, True]: # 1. No abandoned arms + no trials case, reduced state should be the # same as non-reduced state. - exp = get_experiment_with_multi_objective() - save_experiment(exp) - loaded_experiment = load_experiment( - exp.name, - reduced_state=True, - skip_runners_and_metrics=skip_runners_and_metrics, - ) - loaded_experiment.runner = exp.runner - self.assertEqual(loaded_experiment, exp) - # Make sure decoder function was called with `reduced_state=True`. - self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state")) - _mock_exp_from_sqa.reset_mock() - delete_experiment(exp_name=exp.name) - - # 2. Try case with abandoned arms. exp = get_experiment_with_batch_trial(constrain_search_space=False) save_experiment(exp) loaded_experiment = load_experiment( @@ -502,23 +487,15 @@ def test_ExperimentSaveAndLoadReducedState( reduced_state=True, skip_runners_and_metrics=skip_runners_and_metrics, ) - # Experiments are not the same, because one has abandoned arms info. - self.assertNotEqual(loaded_experiment, exp) - # Remove all abandoned arms and check that all else is equal as expected. - t = assert_is_instance(exp.trials[0], BatchTrial) - t._abandoned_arms_metadata = {} loaded_experiment.runner = exp.runner - loaded_experiment._trials[0]._runner = exp._trials[0]._runner + for i in loaded_experiment.trials.keys(): + loaded_experiment.trials[i]._runner = exp.trials[i].runner self.assertEqual(loaded_experiment, exp) - # Make sure that all relevant decoding functions were called with - # `reduced_state=True` and correct number of times. + # Make sure decoder function was called with `reduced_state=True`. self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state")) self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state")) - # 2 generator runs + regular and status quo. self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state")) _mock_exp_from_sqa.reset_mock() - _mock_trial_from_sqa.reset_mock() - _mock_gr_from_sqa.reset_mock() # 3. Try case with model state and search space + opt.config on a # generator run in the experiment. @@ -567,6 +544,24 @@ def test_ExperimentSaveAndLoadReducedState( self.assertEqual(loaded_experiment, exp) delete_experiment(exp_name=exp.name) + def test_load_and_save_reduced_state_does_not_lose_abandoned_arms(self) -> None: + exp = get_experiment_with_batch_trial(constrain_search_space=False) + exp.trials[0].mark_arm_abandoned(arm_name="0_0", reason="for this test") + save_experiment(exp) + loaded_experiment = load_experiment( + exp.name, reduced_state=True, skip_runners_and_metrics=True + ) + save_experiment(loaded_experiment) + reloaded_experiment = load_experiment(exp.name) + self.assertEqual( + reloaded_experiment.trials[0].abandoned_arms, + exp.trials[0].abandoned_arms, + ) + self.assertEqual( + len(reloaded_experiment.trials[0].abandoned_arms), + 1, + ) + def test_ExperimentSaveAndLoadGRWithOptConfig(self) -> None: exp = get_experiment_with_batch_trial(constrain_search_space=False) gr = Models.SOBOL(experiment=exp).gen(