Skip to content

Commit

Permalink
Always predict for status quo features trial index (#2810)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2810

Previously we had plots that looked like {F1904605463}
After the same plot looks like
{F1904608806}

Differential Revision: D63715134
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Oct 2, 2024
1 parent 3942bff commit f936ff5
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 12 deletions.
9 changes: 7 additions & 2 deletions ax/analysis/plotly/predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,18 @@ def _get_predictions(
metric_name: str,
outcome_constraints: list[OutcomeConstraint],
gr: Optional[GeneratorRun] = None,
trial_index: Optional[int] = None,
) -> list[dict[str, Any]]:
trial_index = (
_get_max_observed_trial_index(model)
if model.status_quo is None
else model.status_quo.features.trial_index
)
if gr is None:
observations = model.get_training_data()
features = [o.features for o in observations]
arm_names = [o.arm_name for o in observations]
for feature in features:
feature.trial_index = trial_index
else:
features = [
ObservationFeatures(parameters=arm.parameters, trial_index=trial_index)
Expand Down Expand Up @@ -244,7 +250,6 @@ def _prepare_data(
metric_name=metric_name,
outcome_constraints=outcome_constraints,
gr=gr,
trial_index=trial_index,
)
for gr in candidate_trial.generator_runs
]
Expand Down
50 changes: 40 additions & 10 deletions ax/analysis/plotly/tests/test_predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.prediction_utils import predict_at_point
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import MaxTrials
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -123,18 +124,25 @@ def test_compute(self) -> None:
]
experiment.add_tracking_metric(get_branin_metric(name="tracking_branin"))
generation_strategy = self.generation_strategy
generation_strategy.experiment = experiment
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
generator_runs=generation_strategy.gen_with_multiple_nodes(
experiment=experiment, n=10
)
).set_status_quo_with_weight(
status_quo=experiment.status_quo, weight=1.0
).mark_completed(
unsafe=True
)
experiment.fetch_data()
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
generator_runs=generation_strategy.gen_with_multiple_nodes(
experiment=experiment, n=10
)
).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0)
experiment.fetch_data()
# Ensure the current model is Botorch
self.assertEqual(none_throws(generation_strategy.model)._model_key, "BoTorch")
for metric in experiment.metrics:
with self.subTest(metric=metric):
# WHEN we compute the analysis for a metric
Expand Down Expand Up @@ -186,15 +194,23 @@ def test_compute(self) -> None:
@fast_botorch_optimize
def test_compute_multitask(self) -> None:
# GIVEN an experiment with candidates generated with a multitask model
experiment = get_branin_experiment()
experiment = get_branin_experiment(with_status_quo=True)
generation_strategy = self.generation_strategy
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
).mark_completed(unsafe=True)
).set_status_quo_with_weight(
status_quo=experiment.status_quo, weight=1
).mark_completed(
unsafe=True
)
experiment.fetch_data()
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
).mark_completed(unsafe=True)
).set_status_quo_with_weight(
status_quo=experiment.status_quo, weight=1
).mark_completed(
unsafe=True
)
experiment.fetch_data()
# leave as a candidate
experiment.new_batch_trial(
Expand All @@ -203,20 +219,24 @@ def test_compute_multitask(self) -> None:
n=10,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
)
)
).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1)
experiment.new_batch_trial(
generator_run=generation_strategy.gen(
experiment=experiment,
n=10,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
)
)
).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1)
self.assertEqual(none_throws(generation_strategy.model)._model_key, "ST_MTGP")
# WHEN we compute the analysis
analysis = PredictedEffectsPlot(metric_name="branin")
card = analysis.compute(
experiment=experiment, generation_strategy=generation_strategy
)
with patch(
f"{PredictedEffectsPlot.__module__}.predict_at_point",
wraps=predict_at_point,
) as predict_at_point_spy:
card = analysis.compute(
experiment=experiment, generation_strategy=generation_strategy
)
# THEN it has the right rows for arms with data, as well as the latest trial
arms_with_data = set(experiment.lookup_data().df["arm_name"].unique())
max_trial_index = max(experiment.trials.keys())
Expand All @@ -235,6 +255,16 @@ def test_compute_multitask(self) -> None:
or arm.name in experiment.trials[max_trial_index].arms_by_name,
arm.name,
)
# AND THEN it always predicts for the target trial
self.assertEqual(
len(
{
call[1]["obsf"].trial_index
for call in predict_at_point_spy.call_args_list
}
),
1,
)

@fast_botorch_optimize
def test_it_does_not_plot_abandoned_trials(self) -> None:
Expand Down

0 comments on commit f936ff5

Please sign in to comment.