From 0c6d962e33089d9e2b396536d29b222ca26c8626 Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Thu, 3 Oct 2024 13:41:16 -0700 Subject: [PATCH] Always predict for status quo features trial index (#2810) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2810 Previously we had plots that looked like {F1904605463} After the same plot looks like {F1904608806} ##What's happening Because of nonstationarity, the same feature gets predicted differently in different trials. But the point of this plot is to show how well the parameterization does on the metric, so we want to standardize the trial index we're predicting on. Differential Revision: D63715134 Reviewed By: mgarrard --- ax/analysis/plotly/predicted_effects.py | 10 ++-- .../plotly/tests/test_predicted_effects.py | 49 +++++++++++++++---- ax/modelbridge/generation_strategy.py | 1 + 3 files changed, 47 insertions(+), 13 deletions(-) diff --git a/ax/analysis/plotly/predicted_effects.py b/ax/analysis/plotly/predicted_effects.py index d78c3b0f7f9..10ef4a3f7bb 100644 --- a/ax/analysis/plotly/predicted_effects.py +++ b/ax/analysis/plotly/predicted_effects.py @@ -118,12 +118,18 @@ def _get_predictions( metric_name: str, outcome_constraints: list[OutcomeConstraint], gr: GeneratorRun | None = None, - trial_index: int | None = None, ) -> list[dict[str, Any]]: + trial_index = ( + _get_max_observed_trial_index(model) + if model.status_quo is None + else model.status_quo.features.trial_index + ) if gr is None: observations = model.get_training_data() features = [o.features for o in observations] arm_names = [o.arm_name for o in observations] + for feature in features: + feature.trial_index = trial_index else: features = [ ObservationFeatures(parameters=arm.parameters, trial_index=trial_index) @@ -226,7 +232,6 @@ def _prepare_data( metric_name: Name of metric to plot candidate_trial: Trial to plot candidates for by generator run """ - trial_index = _get_max_observed_trial_index(model) df = pd.DataFrame.from_records( list( chain( @@ -245,7 +250,6 @@ def _prepare_data( metric_name=metric_name, outcome_constraints=outcome_constraints, gr=gr, - trial_index=trial_index, ) for gr in candidate_trial.generator_runs ] diff --git a/ax/analysis/plotly/tests/test_predicted_effects.py b/ax/analysis/plotly/tests/test_predicted_effects.py index c6e6829a329..8374750cb3f 100644 --- a/ax/analysis/plotly/tests/test_predicted_effects.py +++ b/ax/analysis/plotly/tests/test_predicted_effects.py @@ -17,6 +17,7 @@ from ax.modelbridge.generation_node import GenerationNode from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.model_spec import ModelSpec +from ax.modelbridge.prediction_utils import predict_at_point from ax.modelbridge.registry import Models from ax.modelbridge.transition_criterion import MaxTrials from ax.utils.common.testutils import TestCase @@ -124,7 +125,9 @@ def test_compute(self) -> None: experiment.add_tracking_metric(get_branin_metric(name="tracking_branin")) generation_strategy = self.generation_strategy experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10) + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) ).set_status_quo_with_weight( status_quo=experiment.status_quo, weight=1.0 ).mark_completed( @@ -132,9 +135,13 @@ def test_compute(self) -> None: ) experiment.fetch_data() experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10) + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0) experiment.fetch_data() + # Ensure the current model is Botorch + self.assertEqual(none_throws(generation_strategy.model)._model_key, "BoTorch") for metric in experiment.metrics: with self.subTest(metric=metric): # WHEN we compute the analysis for a metric @@ -186,15 +193,23 @@ def test_compute(self) -> None: @fast_botorch_optimize def test_compute_multitask(self) -> None: # GIVEN an experiment with candidates generated with a multitask model - experiment = get_branin_experiment() + experiment = get_branin_experiment(with_status_quo=True) generation_strategy = self.generation_strategy experiment.new_batch_trial( generator_run=generation_strategy.gen(experiment=experiment, n=10) - ).mark_completed(unsafe=True) + ).set_status_quo_with_weight( + status_quo=experiment.status_quo, weight=1 + ).mark_completed( + unsafe=True + ) experiment.fetch_data() experiment.new_batch_trial( generator_run=generation_strategy.gen(experiment=experiment, n=10) - ).mark_completed(unsafe=True) + ).set_status_quo_with_weight( + status_quo=experiment.status_quo, weight=1 + ).mark_completed( + unsafe=True + ) experiment.fetch_data() # leave as a candidate experiment.new_batch_trial( @@ -203,20 +218,24 @@ def test_compute_multitask(self) -> None: n=10, fixed_features=ObservationFeatures(parameters={}, trial_index=1), ) - ) + ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1) experiment.new_batch_trial( generator_run=generation_strategy.gen( experiment=experiment, n=10, fixed_features=ObservationFeatures(parameters={}, trial_index=1), ) - ) + ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1) self.assertEqual(none_throws(generation_strategy.model)._model_key, "ST_MTGP") # WHEN we compute the analysis analysis = PredictedEffectsPlot(metric_name="branin") - card = analysis.compute( - experiment=experiment, generation_strategy=generation_strategy - ) + with patch( + f"{PredictedEffectsPlot.__module__}.predict_at_point", + wraps=predict_at_point, + ) as predict_at_point_spy: + card = analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) # THEN it has the right rows for arms with data, as well as the latest trial arms_with_data = set(experiment.lookup_data().df["arm_name"].unique()) max_trial_index = max(experiment.trials.keys()) @@ -235,6 +254,16 @@ def test_compute_multitask(self) -> None: or arm.name in experiment.trials[max_trial_index].arms_by_name, arm.name, ) + # AND THEN it always predicts for the target trial + self.assertEqual( + len( + { + call[1]["obsf"].trial_index + for call in predict_at_point_spy.call_args_list + } + ), + 1, + ) @fast_botorch_optimize def test_it_does_not_plot_abandoned_trials(self) -> None: diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 4d67de1c132..079ad9f0e57 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -424,6 +424,7 @@ def gen_with_multiple_nodes( Returns: A list of ``GeneratorRuns`` for a single trial. """ + self.experiment = experiment # TODO: @mgarrard merge into gen method, just starting here to derisk # Validate `arms_per_node` if specified, otherwise construct the default # behavior with keys being node names and values being 1 to represent