Skip to content

Commit

Permalink
Always predict for status quo features trial index
Browse files Browse the repository at this point in the history
Summary:
Previously we had plots that looked like {F1904605463}
After the same plot looks like
{F1904608806}

Differential Revision: D63715134
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Oct 2, 2024
1 parent 3942bff commit 1e870e2
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 13 deletions.
9 changes: 7 additions & 2 deletions ax/analysis/plotly/predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,18 @@ def _get_predictions(
metric_name: str,
outcome_constraints: list[OutcomeConstraint],
gr: Optional[GeneratorRun] = None,
trial_index: Optional[int] = None,
) -> list[dict[str, Any]]:
trial_index = (
_get_max_observed_trial_index(model)
if model.status_quo is None
else model.status_quo.features.trial_index
)
if gr is None:
observations = model.get_training_data()
features = [o.features for o in observations]
arm_names = [o.arm_name for o in observations]
for feature in features:
feature.trial_index = trial_index
else:
features = [
ObservationFeatures(parameters=arm.parameters, trial_index=trial_index)
Expand Down Expand Up @@ -244,7 +250,6 @@ def _prepare_data(
metric_name=metric_name,
outcome_constraints=outcome_constraints,
gr=gr,
trial_index=trial_index,
)
for gr in candidate_trial.generator_runs
]
Expand Down
53 changes: 42 additions & 11 deletions ax/analysis/plotly/tests/test_predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from ax.core.observation import ObservationFeatures
from ax.core.trial import Trial
from ax.exceptions.core import UserInputError
from ax.fb.modelbridge.generation_strategies import get_sequential_online_gs
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.prediction_utils import predict_at_point
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import MaxTrials
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -122,19 +124,26 @@ def test_compute(self) -> None:
get_branin_outcome_constraint(name="constraint_branin")
]
experiment.add_tracking_metric(get_branin_metric(name="tracking_branin"))
generation_strategy = self.generation_strategy
generation_strategy = get_sequential_online_gs()
generation_strategy.experiment = experiment
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
generator_runs=generation_strategy.gen_with_multiple_nodes(
experiment=experiment, n=10
)
).set_status_quo_with_weight(
status_quo=experiment.status_quo, weight=1.0
).mark_completed(
unsafe=True
)
experiment.fetch_data()
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
generator_runs=generation_strategy.gen_with_multiple_nodes(
experiment=experiment, n=10
)
).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0)
experiment.fetch_data()
# Ensure the current model is Botorch
self.assertEqual(none_throws(generation_strategy.model)._model_key, "BoTorch")
for metric in experiment.metrics:
with self.subTest(metric=metric):
# WHEN we compute the analysis for a metric
Expand Down Expand Up @@ -186,15 +195,23 @@ def test_compute(self) -> None:
@fast_botorch_optimize
def test_compute_multitask(self) -> None:
# GIVEN an experiment with candidates generated with a multitask model
experiment = get_branin_experiment()
experiment = get_branin_experiment(with_status_quo=True)
generation_strategy = self.generation_strategy
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
).mark_completed(unsafe=True)
).set_status_quo_with_weight(
status_quo=experiment.status_quo, weight=1
).mark_completed(
unsafe=True
)
experiment.fetch_data()
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
).mark_completed(unsafe=True)
).set_status_quo_with_weight(
status_quo=experiment.status_quo, weight=1
).mark_completed(
unsafe=True
)
experiment.fetch_data()
# leave as a candidate
experiment.new_batch_trial(
Expand All @@ -203,20 +220,24 @@ def test_compute_multitask(self) -> None:
n=10,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
)
)
).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1)
experiment.new_batch_trial(
generator_run=generation_strategy.gen(
experiment=experiment,
n=10,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
)
)
).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1)
self.assertEqual(none_throws(generation_strategy.model)._model_key, "ST_MTGP")
# WHEN we compute the analysis
analysis = PredictedEffectsPlot(metric_name="branin")
card = analysis.compute(
experiment=experiment, generation_strategy=generation_strategy
)
with patch(
f"{PredictedEffectsPlot.__module__}.predict_at_point",
wraps=predict_at_point,
) as predict_at_point_spy:
card = analysis.compute(
experiment=experiment, generation_strategy=generation_strategy
)
# THEN it has the right rows for arms with data, as well as the latest trial
arms_with_data = set(experiment.lookup_data().df["arm_name"].unique())
max_trial_index = max(experiment.trials.keys())
Expand All @@ -235,6 +256,16 @@ def test_compute_multitask(self) -> None:
or arm.name in experiment.trials[max_trial_index].arms_by_name,
arm.name,
)
# AND THEN it always predicts for the target trial
self.assertEqual(
len(
{
call[1]["obsf"].trial_index
for call in predict_at_point_spy.call_args_list
}
),
1,
)

@fast_botorch_optimize
def test_it_does_not_plot_abandoned_trials(self) -> None:
Expand Down

0 comments on commit 1e870e2

Please sign in to comment.