diff --git a/ax/analysis/plotly/predicted_effects.py b/ax/analysis/plotly/predicted_effects.py index 24c5557e285..aa8e8a3e2a4 100644 --- a/ax/analysis/plotly/predicted_effects.py +++ b/ax/analysis/plotly/predicted_effects.py @@ -117,12 +117,18 @@ def _get_predictions( metric_name: str, outcome_constraints: list[OutcomeConstraint], gr: Optional[GeneratorRun] = None, - trial_index: Optional[int] = None, ) -> list[dict[str, Any]]: + trial_index = ( + _get_max_observed_trial_index(model) + if model.status_quo is None + else model.status_quo.features.trial_index + ) if gr is None: observations = model.get_training_data() features = [o.features for o in observations] arm_names = [o.arm_name for o in observations] + for feature in features: + feature.trial_index = trial_index else: features = [ ObservationFeatures(parameters=arm.parameters, trial_index=trial_index) @@ -244,7 +250,6 @@ def _prepare_data( metric_name=metric_name, outcome_constraints=outcome_constraints, gr=gr, - trial_index=trial_index, ) for gr in candidate_trial.generator_runs ] diff --git a/ax/analysis/plotly/tests/test_predicted_effects.py b/ax/analysis/plotly/tests/test_predicted_effects.py index c6e6829a329..e2b816b34ab 100644 --- a/ax/analysis/plotly/tests/test_predicted_effects.py +++ b/ax/analysis/plotly/tests/test_predicted_effects.py @@ -13,10 +13,12 @@ from ax.core.observation import ObservationFeatures from ax.core.trial import Trial from ax.exceptions.core import UserInputError +from ax.fb.modelbridge.generation_strategies import get_sequential_online_gs from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.generation_node import GenerationNode from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.model_spec import ModelSpec +from ax.modelbridge.prediction_utils import predict_at_point from ax.modelbridge.registry import Models from ax.modelbridge.transition_criterion import MaxTrials from ax.utils.common.testutils import TestCase @@ -122,9 +124,12 @@ def test_compute(self) -> None: get_branin_outcome_constraint(name="constraint_branin") ] experiment.add_tracking_metric(get_branin_metric(name="tracking_branin")) - generation_strategy = self.generation_strategy + generation_strategy = get_sequential_online_gs() + generation_strategy.experiment = experiment experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10) + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) ).set_status_quo_with_weight( status_quo=experiment.status_quo, weight=1.0 ).mark_completed( @@ -132,9 +137,13 @@ def test_compute(self) -> None: ) experiment.fetch_data() experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10) + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0) experiment.fetch_data() + # Ensure the current model is Botorch + self.assertEqual(none_throws(generation_strategy.model)._model_key, "BoTorch") for metric in experiment.metrics: with self.subTest(metric=metric): # WHEN we compute the analysis for a metric @@ -186,15 +195,23 @@ def test_compute(self) -> None: @fast_botorch_optimize def test_compute_multitask(self) -> None: # GIVEN an experiment with candidates generated with a multitask model - experiment = get_branin_experiment() + experiment = get_branin_experiment(with_status_quo=True) generation_strategy = self.generation_strategy experiment.new_batch_trial( generator_run=generation_strategy.gen(experiment=experiment, n=10) - ).mark_completed(unsafe=True) + ).set_status_quo_with_weight( + status_quo=experiment.status_quo, weight=1 + ).mark_completed( + unsafe=True + ) experiment.fetch_data() experiment.new_batch_trial( generator_run=generation_strategy.gen(experiment=experiment, n=10) - ).mark_completed(unsafe=True) + ).set_status_quo_with_weight( + status_quo=experiment.status_quo, weight=1 + ).mark_completed( + unsafe=True + ) experiment.fetch_data() # leave as a candidate experiment.new_batch_trial( @@ -203,20 +220,24 @@ def test_compute_multitask(self) -> None: n=10, fixed_features=ObservationFeatures(parameters={}, trial_index=1), ) - ) + ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1) experiment.new_batch_trial( generator_run=generation_strategy.gen( experiment=experiment, n=10, fixed_features=ObservationFeatures(parameters={}, trial_index=1), ) - ) + ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1) self.assertEqual(none_throws(generation_strategy.model)._model_key, "ST_MTGP") # WHEN we compute the analysis analysis = PredictedEffectsPlot(metric_name="branin") - card = analysis.compute( - experiment=experiment, generation_strategy=generation_strategy - ) + with patch( + f"{PredictedEffectsPlot.__module__}.predict_at_point", + wraps=predict_at_point, + ) as predict_at_point_spy: + card = analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) # THEN it has the right rows for arms with data, as well as the latest trial arms_with_data = set(experiment.lookup_data().df["arm_name"].unique()) max_trial_index = max(experiment.trials.keys()) @@ -235,6 +256,16 @@ def test_compute_multitask(self) -> None: or arm.name in experiment.trials[max_trial_index].arms_by_name, arm.name, ) + # AND THEN it always predicts for the target trial + self.assertEqual( + len( + { + call[1]["obsf"].trial_index + for call in predict_at_point_spy.call_args_list + } + ), + 1, + ) @fast_botorch_optimize def test_it_does_not_plot_abandoned_trials(self) -> None: