From 648e4f6578eed7ca6784c3c6fd787070db02187e Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Wed, 22 Jan 2025 13:16:42 -0800 Subject: [PATCH] Don't discard abandoned arms on reduced state Summary: Reduced state does still have use for leaving out gen metadata, but I don't know of any use cases with such a huge amount of abandoned arms this will make a difference (considering it can never make more than a 50% difference). It also affects the contents/functionality of the experiment by making it look like arms are not abandoned and should be used on trials. And loading and saving with reduced state loses records that arms were abandoned, which creates problems in experimentation. See N6450800 Differential Revision: D68514688 --- ax/storage/sqa_store/decoder.py | 13 +++--- ax/storage/sqa_store/load.py | 1 - ax/storage/sqa_store/tests/test_sqa_store.py | 47 +++++++++----------- 3 files changed, 27 insertions(+), 34 deletions(-) diff --git a/ax/storage/sqa_store/decoder.py b/ax/storage/sqa_store/decoder.py index 4200d1015b0..c83c83636d4 100644 --- a/ax/storage/sqa_store/decoder.py +++ b/ax/storage/sqa_store/decoder.py @@ -956,13 +956,12 @@ def trial_from_sqa( new_generator_run_structs.append(struct) generator_run_structs = new_generator_run_structs trial._generator_run_structs = generator_run_structs - if not reduced_state: - trial._abandoned_arms_metadata = { - abandoned_arm_sqa.name: self.abandoned_arm_from_sqa( - abandoned_arm_sqa=abandoned_arm_sqa - ) - for abandoned_arm_sqa in trial_sqa.abandoned_arms - } + trial._abandoned_arms_metadata = { + abandoned_arm_sqa.name: self.abandoned_arm_from_sqa( + abandoned_arm_sqa=abandoned_arm_sqa + ) + for abandoned_arm_sqa in trial_sqa.abandoned_arms + } trial._refresh_arms_by_name() # Trigger cache build else: trial = Trial( diff --git a/ax/storage/sqa_store/load.py b/ax/storage/sqa_store/load.py index f418a7d6e62..58ebdcce10a 100644 --- a/ax/storage/sqa_store/load.py +++ b/ax/storage/sqa_store/load.py @@ -279,7 +279,6 @@ def _get_experiment_sqa_reduced_state( large experiments, in cases where model state history is not required. """ options = get_query_options_to_defer_immutable_duplicates() - options.append(lazyload("abandoned_arms")) options.extend(get_query_options_to_defer_large_model_cols()) return _get_experiment_sqa( diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 45f14c3f2cd..96e8745163e 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -480,21 +480,6 @@ def test_ExperimentSaveAndLoadReducedState( for skip_runners_and_metrics in [False, True]: # 1. No abandoned arms + no trials case, reduced state should be the # same as non-reduced state. - exp = get_experiment_with_multi_objective() - save_experiment(exp) - loaded_experiment = load_experiment( - exp.name, - reduced_state=True, - skip_runners_and_metrics=skip_runners_and_metrics, - ) - loaded_experiment.runner = exp.runner - self.assertEqual(loaded_experiment, exp) - # Make sure decoder function was called with `reduced_state=True`. - self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state")) - _mock_exp_from_sqa.reset_mock() - delete_experiment(exp_name=exp.name) - - # 2. Try case with abandoned arms. exp = get_experiment_with_batch_trial(constrain_search_space=False) save_experiment(exp) loaded_experiment = load_experiment( @@ -502,23 +487,15 @@ def test_ExperimentSaveAndLoadReducedState( reduced_state=True, skip_runners_and_metrics=skip_runners_and_metrics, ) - # Experiments are not the same, because one has abandoned arms info. - self.assertNotEqual(loaded_experiment, exp) - # Remove all abandoned arms and check that all else is equal as expected. - t = assert_is_instance(exp.trials[0], BatchTrial) - t._abandoned_arms_metadata = {} loaded_experiment.runner = exp.runner - loaded_experiment._trials[0]._runner = exp._trials[0]._runner + for i in loaded_experiment.trials.keys(): + loaded_experiment.trials[i]._runner = exp.trials[i].runner self.assertEqual(loaded_experiment, exp) - # Make sure that all relevant decoding functions were called with - # `reduced_state=True` and correct number of times. + # Make sure decoder function was called with `reduced_state=True`. self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state")) self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state")) - # 2 generator runs + regular and status quo. self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state")) _mock_exp_from_sqa.reset_mock() - _mock_trial_from_sqa.reset_mock() - _mock_gr_from_sqa.reset_mock() # 3. Try case with model state and search space + opt.config on a # generator run in the experiment. @@ -567,6 +544,24 @@ def test_ExperimentSaveAndLoadReducedState( self.assertEqual(loaded_experiment, exp) delete_experiment(exp_name=exp.name) + def test_load_and_save_reduced_state_does_not_lose_abandoned_arms(self) -> None: + exp = get_experiment_with_batch_trial(constrain_search_space=False) + exp.trials[0].mark_arm_abandoned(arm_name="0_0", reason="for this test") + save_experiment(exp) + loaded_experiment = load_experiment( + exp.name, reduced_state=True, skip_runners_and_metrics=True + ) + save_experiment(loaded_experiment) + reloaded_experiment = load_experiment(exp.name) + self.assertEqual( + reloaded_experiment.trials[0].abandoned_arms, + exp.trials[0].abandoned_arms, + ) + self.assertEqual( + len(reloaded_experiment.trials[0].abandoned_arms), + 1, + ) + def test_ExperimentSaveAndLoadGRWithOptConfig(self) -> None: exp = get_experiment_with_batch_trial(constrain_search_space=False) gr = Models.SOBOL(experiment=exp).gen(