Skip to content

Commit

Permalink
Don't discard abandoned arms on reduced state (#3261)
Browse files Browse the repository at this point in the history
Summary:

Reduced state does still have use for leaving out gen metadata, but I don't know of any use cases with such a huge amount of abandoned arms this will make a difference (considering it can never make more than a 50% difference).

It also affects the contents/functionality of the experiment by making it look like arms are not abandoned and should be used on trials.  And loading and saving with reduced state loses records that arms were abandoned, which creates problems in experimentation.  See N6450800

Differential Revision: D68514688
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Jan 23, 2025
1 parent 369b487 commit aafd890
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 35 deletions.
13 changes: 6 additions & 7 deletions ax/storage/sqa_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,13 +956,12 @@ def trial_from_sqa(
new_generator_run_structs.append(struct)
generator_run_structs = new_generator_run_structs
trial._generator_run_structs = generator_run_structs
if not reduced_state:
trial._abandoned_arms_metadata = {
abandoned_arm_sqa.name: self.abandoned_arm_from_sqa(
abandoned_arm_sqa=abandoned_arm_sqa
)
for abandoned_arm_sqa in trial_sqa.abandoned_arms
}
trial._abandoned_arms_metadata = {
abandoned_arm_sqa.name: self.abandoned_arm_from_sqa(
abandoned_arm_sqa=abandoned_arm_sqa
)
for abandoned_arm_sqa in trial_sqa.abandoned_arms
}
trial._refresh_arms_by_name() # Trigger cache build
else:
trial = Trial(
Expand Down
1 change: 0 additions & 1 deletion ax/storage/sqa_store/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,6 @@ def _get_experiment_sqa_reduced_state(
large experiments, in cases where model state history is not required.
"""
options = get_query_options_to_defer_immutable_duplicates()
options.append(lazyload("abandoned_arms"))
options.extend(get_query_options_to_defer_large_model_cols())

return _get_experiment_sqa(
Expand Down
49 changes: 22 additions & 27 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import BatchTrial, LifecycleStage
from ax.core.batch_trial import LifecycleStage
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
Expand Down Expand Up @@ -480,45 +480,22 @@ def test_ExperimentSaveAndLoadReducedState(
for skip_runners_and_metrics in [False, True]:
# 1. No abandoned arms + no trials case, reduced state should be the
# same as non-reduced state.
exp = get_experiment_with_multi_objective()
save_experiment(exp)
loaded_experiment = load_experiment(
exp.name,
reduced_state=True,
skip_runners_and_metrics=skip_runners_and_metrics,
)
loaded_experiment.runner = exp.runner
self.assertEqual(loaded_experiment, exp)
# Make sure decoder function was called with `reduced_state=True`.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()
delete_experiment(exp_name=exp.name)

# 2. Try case with abandoned arms.
exp = get_experiment_with_batch_trial(constrain_search_space=False)
save_experiment(exp)
loaded_experiment = load_experiment(
exp.name,
reduced_state=True,
skip_runners_and_metrics=skip_runners_and_metrics,
)
# Experiments are not the same, because one has abandoned arms info.
self.assertNotEqual(loaded_experiment, exp)
# Remove all abandoned arms and check that all else is equal as expected.
t = assert_is_instance(exp.trials[0], BatchTrial)
t._abandoned_arms_metadata = {}
loaded_experiment.runner = exp.runner
loaded_experiment._trials[0]._runner = exp._trials[0]._runner
for i in loaded_experiment.trials.keys():
loaded_experiment.trials[i]._runner = exp.trials[i].runner
self.assertEqual(loaded_experiment, exp)
# Make sure that all relevant decoding functions were called with
# `reduced_state=True` and correct number of times.
# Make sure decoder function was called with `reduced_state=True`.
self.assertTrue(_mock_exp_from_sqa.call_args[1].get("reduced_state"))
self.assertTrue(_mock_trial_from_sqa.call_args[1].get("reduced_state"))
# 2 generator runs + regular and status quo.
self.assertTrue(_mock_gr_from_sqa.call_args[1].get("reduced_state"))
_mock_exp_from_sqa.reset_mock()
_mock_trial_from_sqa.reset_mock()
_mock_gr_from_sqa.reset_mock()

# 3. Try case with model state and search space + opt.config on a
# generator run in the experiment.
Expand Down Expand Up @@ -567,6 +544,24 @@ def test_ExperimentSaveAndLoadReducedState(
self.assertEqual(loaded_experiment, exp)
delete_experiment(exp_name=exp.name)

def test_load_and_save_reduced_state_does_not_lose_abandoned_arms(self) -> None:
exp = get_experiment_with_batch_trial(constrain_search_space=False)
exp.trials[0].mark_arm_abandoned(arm_name="0_0", reason="for this test")
save_experiment(exp)
loaded_experiment = load_experiment(
exp.name, reduced_state=True, skip_runners_and_metrics=True
)
save_experiment(loaded_experiment)
reloaded_experiment = load_experiment(exp.name)
self.assertEqual(
reloaded_experiment.trials[0].abandoned_arms,
exp.trials[0].abandoned_arms,
)
self.assertEqual(
len(reloaded_experiment.trials[0].abandoned_arms),
1,
)

def test_ExperimentSaveAndLoadGRWithOptConfig(self) -> None:
exp = get_experiment_with_batch_trial(constrain_search_space=False)
gr = Models.SOBOL(experiment=exp).gen(
Expand Down

0 comments on commit aafd890

Please sign in to comment.