diff --git a/ax/analysis/plotly/predicted_effects.py b/ax/analysis/plotly/predicted_effects.py index d78c3b0f7f9..10ef4a3f7bb 100644 --- a/ax/analysis/plotly/predicted_effects.py +++ b/ax/analysis/plotly/predicted_effects.py @@ -118,12 +118,18 @@ def _get_predictions( metric_name: str, outcome_constraints: list[OutcomeConstraint], gr: GeneratorRun | None = None, - trial_index: int | None = None, ) -> list[dict[str, Any]]: + trial_index = ( + _get_max_observed_trial_index(model) + if model.status_quo is None + else model.status_quo.features.trial_index + ) if gr is None: observations = model.get_training_data() features = [o.features for o in observations] arm_names = [o.arm_name for o in observations] + for feature in features: + feature.trial_index = trial_index else: features = [ ObservationFeatures(parameters=arm.parameters, trial_index=trial_index) @@ -226,7 +232,6 @@ def _prepare_data( metric_name: Name of metric to plot candidate_trial: Trial to plot candidates for by generator run """ - trial_index = _get_max_observed_trial_index(model) df = pd.DataFrame.from_records( list( chain( @@ -245,7 +250,6 @@ def _prepare_data( metric_name=metric_name, outcome_constraints=outcome_constraints, gr=gr, - trial_index=trial_index, ) for gr in candidate_trial.generator_runs ] diff --git a/ax/analysis/plotly/tests/test_predicted_effects.py b/ax/analysis/plotly/tests/test_predicted_effects.py index c6e6829a329..8374750cb3f 100644 --- a/ax/analysis/plotly/tests/test_predicted_effects.py +++ b/ax/analysis/plotly/tests/test_predicted_effects.py @@ -17,6 +17,7 @@ from ax.modelbridge.generation_node import GenerationNode from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.model_spec import ModelSpec +from ax.modelbridge.prediction_utils import predict_at_point from ax.modelbridge.registry import Models from ax.modelbridge.transition_criterion import MaxTrials from ax.utils.common.testutils import TestCase @@ -124,7 +125,9 @@ def test_compute(self) -> None: experiment.add_tracking_metric(get_branin_metric(name="tracking_branin")) generation_strategy = self.generation_strategy experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10) + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) ).set_status_quo_with_weight( status_quo=experiment.status_quo, weight=1.0 ).mark_completed( @@ -132,9 +135,13 @@ def test_compute(self) -> None: ) experiment.fetch_data() experiment.new_batch_trial( - generator_run=generation_strategy.gen(experiment=experiment, n=10) + generator_runs=generation_strategy.gen_with_multiple_nodes( + experiment=experiment, n=10 + ) ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0) experiment.fetch_data() + # Ensure the current model is Botorch + self.assertEqual(none_throws(generation_strategy.model)._model_key, "BoTorch") for metric in experiment.metrics: with self.subTest(metric=metric): # WHEN we compute the analysis for a metric @@ -186,15 +193,23 @@ def test_compute(self) -> None: @fast_botorch_optimize def test_compute_multitask(self) -> None: # GIVEN an experiment with candidates generated with a multitask model - experiment = get_branin_experiment() + experiment = get_branin_experiment(with_status_quo=True) generation_strategy = self.generation_strategy experiment.new_batch_trial( generator_run=generation_strategy.gen(experiment=experiment, n=10) - ).mark_completed(unsafe=True) + ).set_status_quo_with_weight( + status_quo=experiment.status_quo, weight=1 + ).mark_completed( + unsafe=True + ) experiment.fetch_data() experiment.new_batch_trial( generator_run=generation_strategy.gen(experiment=experiment, n=10) - ).mark_completed(unsafe=True) + ).set_status_quo_with_weight( + status_quo=experiment.status_quo, weight=1 + ).mark_completed( + unsafe=True + ) experiment.fetch_data() # leave as a candidate experiment.new_batch_trial( @@ -203,20 +218,24 @@ def test_compute_multitask(self) -> None: n=10, fixed_features=ObservationFeatures(parameters={}, trial_index=1), ) - ) + ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1) experiment.new_batch_trial( generator_run=generation_strategy.gen( experiment=experiment, n=10, fixed_features=ObservationFeatures(parameters={}, trial_index=1), ) - ) + ).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1) self.assertEqual(none_throws(generation_strategy.model)._model_key, "ST_MTGP") # WHEN we compute the analysis analysis = PredictedEffectsPlot(metric_name="branin") - card = analysis.compute( - experiment=experiment, generation_strategy=generation_strategy - ) + with patch( + f"{PredictedEffectsPlot.__module__}.predict_at_point", + wraps=predict_at_point, + ) as predict_at_point_spy: + card = analysis.compute( + experiment=experiment, generation_strategy=generation_strategy + ) # THEN it has the right rows for arms with data, as well as the latest trial arms_with_data = set(experiment.lookup_data().df["arm_name"].unique()) max_trial_index = max(experiment.trials.keys()) @@ -235,6 +254,16 @@ def test_compute_multitask(self) -> None: or arm.name in experiment.trials[max_trial_index].arms_by_name, arm.name, ) + # AND THEN it always predicts for the target trial + self.assertEqual( + len( + { + call[1]["obsf"].trial_index + for call in predict_at_point_spy.call_args_list + } + ), + 1, + ) @fast_botorch_optimize def test_it_does_not_plot_abandoned_trials(self) -> None: diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 4d67de1c132..079ad9f0e57 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -424,6 +424,7 @@ def gen_with_multiple_nodes( Returns: A list of ``GeneratorRuns`` for a single trial. """ + self.experiment = experiment # TODO: @mgarrard merge into gen method, just starting here to derisk # Validate `arms_per_node` if specified, otherwise construct the default # behavior with keys being node names and values being 1 to represent