From cd564e42f250f5eb37ac3c91dcd5ceb8af9b2a1e Mon Sep 17 00:00:00 2001 From: Mia Garrard Date: Fri, 3 Nov 2023 12:17:27 -0700 Subject: [PATCH] remove encoding for newer transition criterion until finalized (#1957) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/1957 i introduced some backwards compatability issues because i added transition criterion to the encoder/decoder that were still in flux. I forgot that those would be added to newly created exp and then cause backwards compatability issues. This diff removes those Reviewed By: lena-kashtelyan Differential Revision: D50890889 fbshipit-source-id: c20eff8b38519a09315452af979116de4599ac9c --- ax/modelbridge/generation_strategy.py | 2 ++ ax/service/tests/test_ax_client.py | 1 + ax/service/tests/test_scheduler.py | 1 + ax/storage/json_store/decoder.py | 36 ++----------------- .../json_store/tests/test_json_store.py | 4 +++ ax/storage/sqa_store/tests/test_sqa_store.py | 1 + 6 files changed, 11 insertions(+), 34 deletions(-) diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 06e9b3c8a24..e2c946ccc97 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -326,6 +326,8 @@ def _unset_non_persistent_state_fields(self) -> None: self._model = None for s in self._steps: s._model_spec_to_gen_from = None + # TODO: @mgarrard remove once re-enabled criterion storage + s._transition_criteria = [] def __repr__(self) -> str: """String representation of this generation strategy.""" diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 1eb6ed4b093..1e465989802 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -614,6 +614,7 @@ def test_save_and_load_generation_strategy(self) -> None: ) second_client = AxClient(db_settings=db_settings) second_client.load_experiment_from_database("unique_test_experiment") + generation_strategy._unset_non_persistent_state_fields() self.assertEqual(second_client.generation_strategy, generation_strategy) @patch( diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index fc35fcc5d10..49ea8ccda46 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -777,6 +777,7 @@ def test_sqa_storage(self) -> None: # Check that experiment and GS were saved. exp, gs = scheduler._load_experiment_and_generation_strategy(experiment.name) self.assertEqual(exp, experiment) + self.two_sobol_steps_GS._unset_non_persistent_state_fields() self.assertEqual(gs, self.two_sobol_steps_GS) scheduler.run_all_trials() # Check that experiment and GS were saved and test reloading with reduced state. diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index 153180b7242..dc993ee2721 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -33,10 +33,8 @@ from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy from ax.modelbridge.registry import _decode_callables_from_references from ax.modelbridge.transition_criterion import ( - MaxTrials, MinimumPreferenceOccurances, MinimumTrialsInStatus, - MinTrials, TransitionCriterion, ) from ax.models.torch.botorch_modular.model import SurrogateSpec @@ -353,37 +351,7 @@ def transition_criteria_from_json( criterion_list = [] for criterion_json in transition_criteria_json: criterion_type = criterion_json.pop("__type") - if criterion_type == "MinTrials": - criterion_list.append( - MinTrials( - statuses=object_from_json(criterion_json.pop("statuses")), - threshold=criterion_json.pop("threshold"), - transition_to=criterion_json.pop("transition_to") - if "transition_to" in criterion_json.keys() - else None, - ) - ) - elif criterion_type == "MaxTrials": - criterion_list.append( - MaxTrials( - only_in_statuses=object_from_json( - criterion_json.pop("only_in_statuses") - ) - if "only_in_statuses" in criterion_json.keys() - else None, - threshold=criterion_json.pop("threshold"), - enforce=criterion_json.pop("enforce"), - not_in_statuses=object_from_json( - criterion_json.pop("not_in_statuses") - ) - if "not_in_statuses" in criterion_json.keys() - else None, - transition_to=criterion_json.pop("transition_to") - if "transition_to" in criterion_json.keys() - else None, - ) - ) - elif criterion_type == "MinimumTrialsInStatus": + if criterion_type == "MinimumTrialsInStatus": criterion_list.append( MinimumTrialsInStatus( status=object_from_json(criterion_json.pop("status")), @@ -393,7 +361,7 @@ def transition_criteria_from_json( else None, ) ) - else: + elif criterion_type == "MinimumPreferenceOccurances": criterion_list.append( MinimumPreferenceOccurances( metric_name=criterion_json.pop("metric_name"), diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index 751b5249120..b6c5a78f950 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -10,6 +10,7 @@ import numpy as np import torch +from ax.benchmark.benchmark_method import BenchmarkMethod from ax.core.metric import Metric from ax.core.runner import Runner from ax.exceptions.core import AxStorageWarning @@ -327,6 +328,8 @@ def test_EncodeDecode(self) -> None: converted_object = converted_object.state_dict() if isinstance(original_object, GenerationStrategy): original_object._unset_non_persistent_state_fields() + if isinstance(original_object, BenchmarkMethod): + original_object.generation_strategy._unset_non_persistent_state_fields() try: self.assertEqual( original_object, @@ -399,6 +402,7 @@ def test_DecodeGenerationStrategy(self) -> None: decoder_registry=CORE_DECODER_REGISTRY, class_decoder_registry=CORE_CLASS_DECODER_REGISTRY, ) + generation_strategy._unset_non_persistent_state_fields() self.assertEqual(generation_strategy, new_generation_strategy) self.assertGreater(len(new_generation_strategy._steps), 0) self.assertIsInstance(new_generation_strategy._steps[0].model, Models) diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 61e81e59ea0..17b4aa436ca 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -1226,6 +1226,7 @@ def test_EncodeDecodeGenerationStrategy(self) -> None: # pyre-fixme[6]: For 1st param expected `int` but got `Optional[int]`. gs_id=generation_strategy._db_id ) + generation_strategy._unset_non_persistent_state_fields() self.assertEqual(generation_strategy, new_generation_strategy) self.assertIsNone(generation_strategy._experiment)