Skip to content

Commit

Permalink
Always predict for status quo features trial index (facebook#2810)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#2810

Previously we had plots that looked like {F1904605463}
After the same plot looks like
{F1904608806}

##What's happening
Because of nonstationarity, the same feature gets predicted differently in different trials.  But the point of this plot is to show how well the parameterization does on the metric, so we want to standardize the trial index we're predicting on.

Differential Revision: D63715134

Reviewed By: mgarrard
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Oct 4, 2024
1 parent 331da3a commit 4064b66
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 13 deletions.
10 changes: 7 additions & 3 deletions ax/analysis/plotly/predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,18 @@ def _get_predictions(
metric_name: str,
outcome_constraints: list[OutcomeConstraint],
gr: GeneratorRun | None = None,
trial_index: int | None = None,
) -> list[dict[str, Any]]:
trial_index = (
_get_max_observed_trial_index(model)
if model.status_quo is None
else model.status_quo.features.trial_index
)
if gr is None:
observations = model.get_training_data()
features = [o.features for o in observations]
arm_names = [o.arm_name for o in observations]
for feature in features:
feature.trial_index = trial_index
else:
features = [
ObservationFeatures(parameters=arm.parameters, trial_index=trial_index)
Expand Down Expand Up @@ -226,7 +232,6 @@ def _prepare_data(
metric_name: Name of metric to plot
candidate_trial: Trial to plot candidates for by generator run
"""
trial_index = _get_max_observed_trial_index(model)
df = pd.DataFrame.from_records(
list(
chain(
Expand All @@ -245,7 +250,6 @@ def _prepare_data(
metric_name=metric_name,
outcome_constraints=outcome_constraints,
gr=gr,
trial_index=trial_index,
)
for gr in candidate_trial.generator_runs
]
Expand Down
49 changes: 39 additions & 10 deletions ax/analysis/plotly/tests/test_predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.prediction_utils import predict_at_point
from ax.modelbridge.registry import Models
from ax.modelbridge.transition_criterion import MaxTrials
from ax.utils.common.testutils import TestCase
Expand Down Expand Up @@ -124,17 +125,23 @@ def test_compute(self) -> None:
experiment.add_tracking_metric(get_branin_metric(name="tracking_branin"))
generation_strategy = self.generation_strategy
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
generator_runs=generation_strategy.gen_with_multiple_nodes(
experiment=experiment, n=10
)
).set_status_quo_with_weight(
status_quo=experiment.status_quo, weight=1.0
).mark_completed(
unsafe=True
)
experiment.fetch_data()
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
generator_runs=generation_strategy.gen_with_multiple_nodes(
experiment=experiment, n=10
)
).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1.0)
experiment.fetch_data()
# Ensure the current model is Botorch
self.assertEqual(none_throws(generation_strategy.model)._model_key, "BoTorch")
for metric in experiment.metrics:
with self.subTest(metric=metric):
# WHEN we compute the analysis for a metric
Expand Down Expand Up @@ -186,15 +193,23 @@ def test_compute(self) -> None:
@fast_botorch_optimize
def test_compute_multitask(self) -> None:
# GIVEN an experiment with candidates generated with a multitask model
experiment = get_branin_experiment()
experiment = get_branin_experiment(with_status_quo=True)
generation_strategy = self.generation_strategy
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
).mark_completed(unsafe=True)
).set_status_quo_with_weight(
status_quo=experiment.status_quo, weight=1
).mark_completed(
unsafe=True
)
experiment.fetch_data()
experiment.new_batch_trial(
generator_run=generation_strategy.gen(experiment=experiment, n=10)
).mark_completed(unsafe=True)
).set_status_quo_with_weight(
status_quo=experiment.status_quo, weight=1
).mark_completed(
unsafe=True
)
experiment.fetch_data()
# leave as a candidate
experiment.new_batch_trial(
Expand All @@ -203,20 +218,24 @@ def test_compute_multitask(self) -> None:
n=10,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
)
)
).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1)
experiment.new_batch_trial(
generator_run=generation_strategy.gen(
experiment=experiment,
n=10,
fixed_features=ObservationFeatures(parameters={}, trial_index=1),
)
)
).set_status_quo_with_weight(status_quo=experiment.status_quo, weight=1)
self.assertEqual(none_throws(generation_strategy.model)._model_key, "ST_MTGP")
# WHEN we compute the analysis
analysis = PredictedEffectsPlot(metric_name="branin")
card = analysis.compute(
experiment=experiment, generation_strategy=generation_strategy
)
with patch(
f"{PredictedEffectsPlot.__module__}.predict_at_point",
wraps=predict_at_point,
) as predict_at_point_spy:
card = analysis.compute(
experiment=experiment, generation_strategy=generation_strategy
)
# THEN it has the right rows for arms with data, as well as the latest trial
arms_with_data = set(experiment.lookup_data().df["arm_name"].unique())
max_trial_index = max(experiment.trials.keys())
Expand All @@ -235,6 +254,16 @@ def test_compute_multitask(self) -> None:
or arm.name in experiment.trials[max_trial_index].arms_by_name,
arm.name,
)
# AND THEN it always predicts for the target trial
self.assertEqual(
len(
{
call[1]["obsf"].trial_index
for call in predict_at_point_spy.call_args_list
}
),
1,
)

@fast_botorch_optimize
def test_it_does_not_plot_abandoned_trials(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def gen_with_multiple_nodes(
Returns:
A list of ``GeneratorRuns`` for a single trial.
"""
self.experiment = experiment
# TODO: @mgarrard merge into gen method, just starting here to derisk
# Validate `arms_per_node` if specified, otherwise construct the default
# behavior with keys being node names and values being 1 to represent
Expand Down

0 comments on commit 4064b66

Please sign in to comment.