Skip to content

Commit

Permalink
remove encoding for newer transition criterion until finalized (#1957)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1957

i introduced some backwards compatability issues because i added transition criterion to the encoder/decoder that were still in flux. I forgot that those would be added to newly created exp and then cause backwards compatability issues. This diff removes those

Reviewed By: lena-kashtelyan

Differential Revision: D50890889

fbshipit-source-id: c20eff8b38519a09315452af979116de4599ac9c
  • Loading branch information
mgarrard authored and facebook-github-bot committed Nov 3, 2023
1 parent 52c9018 commit cd564e4
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 34 deletions.
2 changes: 2 additions & 0 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ def _unset_non_persistent_state_fields(self) -> None:
self._model = None
for s in self._steps:
s._model_spec_to_gen_from = None
# TODO: @mgarrard remove once re-enabled criterion storage
s._transition_criteria = []

def __repr__(self) -> str:
"""String representation of this generation strategy."""
Expand Down
1 change: 1 addition & 0 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ def test_save_and_load_generation_strategy(self) -> None:
)
second_client = AxClient(db_settings=db_settings)
second_client.load_experiment_from_database("unique_test_experiment")
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(second_client.generation_strategy, generation_strategy)

@patch(
Expand Down
1 change: 1 addition & 0 deletions ax/service/tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ def test_sqa_storage(self) -> None:
# Check that experiment and GS were saved.
exp, gs = scheduler._load_experiment_and_generation_strategy(experiment.name)
self.assertEqual(exp, experiment)
self.two_sobol_steps_GS._unset_non_persistent_state_fields()
self.assertEqual(gs, self.two_sobol_steps_GS)
scheduler.run_all_trials()
# Check that experiment and GS were saved and test reloading with reduced state.
Expand Down
36 changes: 2 additions & 34 deletions ax/storage/json_store/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,8 @@
from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy
from ax.modelbridge.registry import _decode_callables_from_references
from ax.modelbridge.transition_criterion import (
MaxTrials,
MinimumPreferenceOccurances,
MinimumTrialsInStatus,
MinTrials,
TransitionCriterion,
)
from ax.models.torch.botorch_modular.model import SurrogateSpec
Expand Down Expand Up @@ -353,37 +351,7 @@ def transition_criteria_from_json(
criterion_list = []
for criterion_json in transition_criteria_json:
criterion_type = criterion_json.pop("__type")
if criterion_type == "MinTrials":
criterion_list.append(
MinTrials(
statuses=object_from_json(criterion_json.pop("statuses")),
threshold=criterion_json.pop("threshold"),
transition_to=criterion_json.pop("transition_to")
if "transition_to" in criterion_json.keys()
else None,
)
)
elif criterion_type == "MaxTrials":
criterion_list.append(
MaxTrials(
only_in_statuses=object_from_json(
criterion_json.pop("only_in_statuses")
)
if "only_in_statuses" in criterion_json.keys()
else None,
threshold=criterion_json.pop("threshold"),
enforce=criterion_json.pop("enforce"),
not_in_statuses=object_from_json(
criterion_json.pop("not_in_statuses")
)
if "not_in_statuses" in criterion_json.keys()
else None,
transition_to=criterion_json.pop("transition_to")
if "transition_to" in criterion_json.keys()
else None,
)
)
elif criterion_type == "MinimumTrialsInStatus":
if criterion_type == "MinimumTrialsInStatus":
criterion_list.append(
MinimumTrialsInStatus(
status=object_from_json(criterion_json.pop("status")),
Expand All @@ -393,7 +361,7 @@ def transition_criteria_from_json(
else None,
)
)
else:
elif criterion_type == "MinimumPreferenceOccurances":
criterion_list.append(
MinimumPreferenceOccurances(
metric_name=criterion_json.pop("metric_name"),
Expand Down
4 changes: 4 additions & 0 deletions ax/storage/json_store/tests/test_json_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import numpy as np
import torch
from ax.benchmark.benchmark_method import BenchmarkMethod
from ax.core.metric import Metric
from ax.core.runner import Runner
from ax.exceptions.core import AxStorageWarning
Expand Down Expand Up @@ -327,6 +328,8 @@ def test_EncodeDecode(self) -> None:
converted_object = converted_object.state_dict()
if isinstance(original_object, GenerationStrategy):
original_object._unset_non_persistent_state_fields()
if isinstance(original_object, BenchmarkMethod):
original_object.generation_strategy._unset_non_persistent_state_fields()
try:
self.assertEqual(
original_object,
Expand Down Expand Up @@ -399,6 +402,7 @@ def test_DecodeGenerationStrategy(self) -> None:
decoder_registry=CORE_DECODER_REGISTRY,
class_decoder_registry=CORE_CLASS_DECODER_REGISTRY,
)
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(generation_strategy, new_generation_strategy)
self.assertGreater(len(new_generation_strategy._steps), 0)
self.assertIsInstance(new_generation_strategy._steps[0].model, Models)
Expand Down
1 change: 1 addition & 0 deletions ax/storage/sqa_store/tests/test_sqa_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,7 @@ def test_EncodeDecodeGenerationStrategy(self) -> None:
# pyre-fixme[6]: For 1st param expected `int` but got `Optional[int]`.
gs_id=generation_strategy._db_id
)
generation_strategy._unset_non_persistent_state_fields()
self.assertEqual(generation_strategy, new_generation_strategy)
self.assertIsNone(generation_strategy._experiment)

Expand Down

0 comments on commit cd564e4

Please sign in to comment.