Skip to content

Commit

Permalink
Update gen_with_multiple_nodes to handle pending points (facebook#2817)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebook#2817

In extending gen_multiple_trials_with_multiple models, we realized that gen_with_multiple_nodes doesn't currently handle pending points gracefully. This diff ensure that points generated from one node are considered pending for the next node in a single trial's gen loop (and so forth for any number of nodes).

Reviewed By: saitcakmak

Differential Revision: D63785539

fbshipit-source-id: ae9bc91a34edc86495eac8c9bd3ec578799192e4
  • Loading branch information
mgarrard authored and facebook-github-bot committed Oct 3, 2024
1 parent 85ae2b9 commit 331da3a
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 9 deletions.
31 changes: 24 additions & 7 deletions ax/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,14 +393,31 @@ def get_pending_observation_features_based_on_trial_status(
def extend_pending_observations(
experiment: Experiment,
pending_observations: dict[str, list[ObservationFeatures]],
generator_run: GeneratorRun,
) -> None:
generator_runs: list[GeneratorRun],
) -> dict[str, list[ObservationFeatures]]:
"""Extend given pending observations dict (from metric name to observations
that are pending for that metric), with arms in a given generator run.
Args:
experiment: Experiment, for which the generation strategy is producing
``GeneratorRun``s.
pending_observations: Dict from metric name to pending observations for
that metric, used to avoid resuggesting arms that will be explored soon.
generator_runs: List of ``GeneratorRun``s currently produced by the
``GenerationStrategy``.
Returns:
A new dictionary of pending observations to avoid in-place modification
"""
# TODO: T203665729 @mgarrard add arm signature to ObservationFeatures and then use
# that to compare to arm signature in GR to speed up this method
extended_obs = deepcopy(pending_observations)
for m in experiment.metrics:
if m not in pending_observations:
pending_observations[m] = []
pending_observations[m].extend(
ObservationFeatures.from_arm(a) for a in generator_run.arms
)
if m not in extended_obs:
extended_obs[m] = []
for generator_run in generator_runs:
for a in generator_run.arms:
ob_ft = ObservationFeatures.from_arm(a)
if ob_ft not in extended_obs[m]:
extended_obs[m].append(ob_ft)
return extended_obs
14 changes: 12 additions & 2 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,7 @@ def gen_with_multiple_nodes(
"grs_this_gen": grs,
"n": n,
}
pending_observations = deepcopy(pending_observations) or {}
while continue_gen_for_trial:
gen_kwargs["grs_this_gen"] = grs
should_transition, node_to_gen_from_name = (
Expand Down Expand Up @@ -477,6 +478,15 @@ def gen_with_multiple_nodes(
pending_observations=pending_observations,
)
)
# ensure that the points generated from each node are marked as pending
# points for future calls to gen
pending_observations = extend_pending_observations(
experiment=experiment,
pending_observations=pending_observations,
# only pass in the most recent generator run to avoid unnecessary
# deduplication in extend_pending_observations
generator_runs=[grs[-1]],
)
continue_gen_for_trial = self._should_continue_gen_for_trial()
return grs

Expand Down Expand Up @@ -856,10 +866,10 @@ def _gen_multiple(

# Extend the `pending_observation` with newly generated point(s)
# to avoid repeating them.
extend_pending_observations(
pending_observations = extend_pending_observations(
experiment=experiment,
pending_observations=pending_observations,
generator_run=generator_run,
generator_runs=[generator_run],
)
return generator_runs

Expand Down
102 changes: 102 additions & 0 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,108 @@ def test_gs_setup_with_nodes(self) -> None:
logger.output,
)

def test_gen_with_multiple_nodes_pending_points(self) -> None:
exp = get_experiment_with_multi_objective()
gs = GenerationStrategy(
nodes=[
GenerationNode(
node_name="sobol_1",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGen(
transition_to="sobol_2",
)
],
),
GenerationNode(
node_name="sobol_2",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGen(transition_to="sobol_3")
],
),
GenerationNode(
node_name="sobol_3",
model_specs=[self.sobol_model_spec],
transition_criteria=[
AutoTransitionAfterGen(
transition_to="sobol_1",
block_transition_if_unmet=True,
continue_trial_generation=False,
),
],
),
]
)
gs.experiment = exp
arms_per_node = {
"sobol_1": 2,
"sobol_2": 1,
"sobol_3": 3,
}
with mock_patch_method_original(
mock_path=f"{ModelSpec.__module__}.ModelSpec.gen",
original_method=ModelSpec.gen,
) as model_spec_gen_mock:
# Generate a trial that should be composed of arms from 3 nodes
grs = gs.gen_with_multiple_nodes(
experiment=exp, arms_per_node=arms_per_node
)

self.assertEqual(len(grs), 3) # len == 3 due to 3 nodes contributing
pending_in_each_gen = enumerate(
call_kwargs.get("pending_observations")
for _, call_kwargs in model_spec_gen_mock.call_args_list
)

# for each call to gen after the first call to gen, which should have no
# pending points the number of pending points should be equal to the sum of
# the number of arms we suspect from the previous nodes
expected_pending_per_call = [2, 3]
for idx, pending in pending_in_each_gen:
# the first pending call will be empty because we didn't pass in any
# additional points, start checking after the first position
# that the pending points we expect are present
if idx > 0:
self.assertEqual(
len(pending["m2"]), expected_pending_per_call[idx - 1]
)
prev_gr = grs[idx - 1]
for arm in prev_gr.arms:
for m in pending:
self.assertIn(ObservationFeatures.from_arm(arm), pending[m])

exp.new_batch_trial(generator_runs=grs).mark_running(
no_runner_required=True
)
model_spec_gen_mock.reset_mock()

# check that the pending points line up
original_pending = not_none(get_pending(experiment=exp))
first_3_trials_obs_feats = [
ObservationFeatures.from_arm(arm=a, trial_index=idx)
for idx, trial in exp.trials.items()
for a in trial.arms
]
for m in original_pending:
self.assertTrue(
same_elements(original_pending[m], first_3_trials_obs_feats)
)

# check that we can pass in pending points
grs = gs.gen_with_multiple_nodes(
experiment=exp,
arms_per_node=arms_per_node,
pending_observations=original_pending,
)
self.assertEqual(len(grs), 3) # len == 3 due to 3 nodes contributing
pending_in_each_gen = enumerate(
call_kwargs.get("pending_observations")
for _, call_kwargs in model_spec_gen_mock.call_args_list
)
# check first call is 6 (from the previous trial having 6 arms)
self.assertEqual(len(list(pending_in_each_gen)[0][1]["m1"]), 6)

def test_gs_initializes_all_previous_node_to_none(self) -> None:
"""Test that all previous nodes are initialized to None"""
node_1 = GenerationNode(
Expand Down

0 comments on commit 331da3a

Please sign in to comment.