From 6a98621673b125d090a26f71ed0d0beaba8c1aae Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Thu, 3 Oct 2024 12:37:08 -0700 Subject: [PATCH] Always predict for status quo features trial index (#2810) Summary: Previously we had plots that looked like {F1904605463} After the same plot looks like {F1904608806} ##What's happening Because of nonstationarity, the same feature gets predicted differently in different trials. But the point of this plot is to show how well the parameterization does on the metric, so we want to standardize the trial index we're predicting on. Reviewed By: mgarrard Differential Revision: D63715134 --- ax/analysis/plotly/predicted_effects.py | 14 ++++-- .../plotly/tests/test_predicted_effects.py | 50 +++++++++++++++---- 2 files changed, 49 insertions(+), 15 deletions(-) diff --git a/ax/analysis/plotly/predicted_effects.py b/ax/analysis/plotly/predicted_effects.py index d78c3b0f7f9..8e70ba4b269 100644 --- a/ax/analysis/plotly/predicted_effects.py +++ b/ax/analysis/plotly/predicted_effects.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. from itertools import chain -from typing import Any +from typing import Any, Optional import pandas as pd from ax.analysis.analysis import AnalysisCardLevel @@ -117,13 +117,19 @@ def _get_predictions( model: ModelBridge, metric_name: str, outcome_constraints: list[OutcomeConstraint], - gr: GeneratorRun | None = None, - trial_index: int | None = None, + gr: Optional[GeneratorRun] = None, ) -> list[dict[str, Any]]: + trial_index = ( + _get_max_observed_trial_index(model) + if model.status_quo is None + else model.status_quo.features.trial_index + ) if gr is None: observations = model.get_training_data() features = [o.features for o in observations] arm_names = [o.arm_name for o in observations] + for feature in features: + feature.trial_index = trial_index else: features = [ ObservationFeatures(parameters=arm.parameters, trial_index=trial_index) @@ -226,7 +232,6 @@ def _prepare_data( metric_name: Name of metric to plot candidate_trial: Trial to plot candidates for by generator run """ - trial_index = _get_max_observed_trial_index(model) df = pd.DataFrame.from_records( list( chain( @@ -245,7 +250,6 @@ def _prepare_data( metric_name=metric_name, outcome_constraints=outcome_constraints, gr=gr, - trial_index=trial_index, ) for gr in candidate_trial.generator_runs ] diff --git a/ax/analysis/plotly/tests/test_predicted_effects.py b/ax/analysis/plotly/tests/test_predicted_effects.py index c6e6829a329..38c0a7691eb 100644 --- a/ax/analysis/plotly/tests/test_predicted_effects.py +++ b/ax/analysis/plotly/tests/test_predicted_effects.py @@ -17,6 +17,7 @@ from ax.modelbridge.generation_node import GenerationNode from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.model_spec import ModelSpec +from ax.modelbridge.prediction_utils import predict_at_point from ax.modelbridge.registry import Models from ax.modelbridge.transition_criterion import MaxTrials from ax.utils.common.testutils import TestCase @@ -123,8 +124,11 @@ def test_compute(self) -> None: ] experiment.add_tracking_metric(get_branin_metric(name="tracking_branin")) generation_strategy = self.generation_strategy + generation_strategy.experiment = experiment experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10) + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) ).set_status_quo_with_weight( status_quo=experiment.status_quo, weight=1.0 ).mark_completed( @@ -132,9 +136,13 @@ def test_compute(self) -> None: ) experiment.fetch_data() experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10) + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0) experiment.fetch_data() + # Ensure the current model is Botorch + self.assertEqual(none_throws(generation_strategy.model)._model_key, "BoTorch") for metric in experiment.metrics: with self.subTest(metric=metric): # WHEN we compute the analysis for a metric @@ -186,15 +194,23 @@ def test_compute(self) -> None: @fast_botorch_optimize def test_compute_multitask(self) -> None: # GIVEN an experiment with candidates generated with a multitask model - experiment = get_branin_experiment() + experiment = get_branin_experiment(with_status_quo=True) generation_strategy = self.generation_strategy experiment.new_batch_trial( generator_run=generation_strategy.gen(experiment=experiment, n=10) - ).mark_completed(unsafe=True) + ).set_status_quo_with_weight( + status_quo=experiment.status_quo, weight=1 + ).mark_completed( + unsafe=True + ) experiment.fetch_data() experiment.new_batch_trial( generator_run=generation_strategy.gen(experiment=experiment, n=10) - ).mark_completed(unsafe=True) + ).set_status_quo_with_weight( + status_quo=experiment.status_quo, weight=1 + ).mark_completed( + unsafe=True + ) experiment.fetch_data() # leave as a candidate experiment.new_batch_trial( @@ -203,20 +219,24 @@ def test_compute_multitask(self) -> None: n=10, fixed_features=ObservationFeatures(parameters={}, trial_index=1), ) - ) + ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1) experiment.new_batch_trial( generator_run=generation_strategy.gen( experiment=experiment, n=10, fixed_features=ObservationFeatures(parameters={}, trial_index=1), ) - ) + ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1) self.assertEqual(none_throws(generation_strategy.model)._model_key, "ST_MTGP") # WHEN we compute the analysis analysis = PredictedEffectsPlot(metric_name="branin") - card = analysis.compute( - experiment=experiment, generation_strategy=generation_strategy - ) + with patch( + f"{PredictedEffectsPlot.__module__}.predict_at_point", + wraps=predict_at_point, + ) as predict_at_point_spy: + card = analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) # THEN it has the right rows for arms with data, as well as the latest trial arms_with_data = set(experiment.lookup_data().df["arm_name"].unique()) max_trial_index = max(experiment.trials.keys()) @@ -235,6 +255,16 @@ def test_compute_multitask(self) -> None: or arm.name in experiment.trials[max_trial_index].arms_by_name, arm.name, ) + # AND THEN it always predicts for the target trial + self.assertEqual( + len( + { + call[1]["obsf"].trial_index + for call in predict_at_point_spy.call_args_list + } + ), + 1, + ) @fast_botorch_optimize def test_it_does_not_plot_abandoned_trials(self) -> None: