-
Notifications
You must be signed in to change notification settings - Fork 315
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extract constraint violation logic for common use (#2811)
Summary: Move `get_constraint_violated_probabilities()` to a utils file and give it its own tests. No changes were made to `get_constraint_violated_probabilities()` Reviewed By: Cesar-Cardoso Differential Revision: D63656156
- Loading branch information
1 parent
78b0527
commit 265f678
Showing
3 changed files
with
231 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from ax.analysis.plotly.utils import get_constraint_violated_probabilities | ||
from ax.core.metric import Metric | ||
from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint | ||
from ax.exceptions.core import UserInputError | ||
from ax.utils.common.testutils import TestCase | ||
|
||
|
||
class TestUtils(TestCase): | ||
def test_no_constraints_violates_none(self) -> None: | ||
constraint_violated_probabilities = get_constraint_violated_probabilities( | ||
# predictions for 2 observations on metrics a and b | ||
predictions=[ | ||
( | ||
{"a": 1.0, "b": 2.0}, | ||
{"a": 0.1, "b": 0.2}, | ||
), | ||
( | ||
{"a": 1.1, "b": 2.1}, | ||
{"a": 0.1, "b": 0.2}, | ||
), | ||
], | ||
outcome_constraints=[], | ||
) | ||
self.assertEqual( | ||
constraint_violated_probabilities, {"any_constraint_violated": [0.0, 0.0]} | ||
) | ||
|
||
def test_relative_constraints_are_not_accepted(self) -> None: | ||
with self.assertRaisesRegex( | ||
UserInputError, | ||
"does not support relative outcome constraints", | ||
): | ||
get_constraint_violated_probabilities( | ||
predictions=[], | ||
outcome_constraints=[ | ||
OutcomeConstraint( | ||
metric=Metric("a"), | ||
op=ComparisonOp.GEQ, | ||
bound=0.0, | ||
relative=True, | ||
) | ||
], | ||
) | ||
|
||
def test_it_gives_a_result_per_constraint_plus_overall(self) -> None: | ||
constraint_violated_probabilities = get_constraint_violated_probabilities( | ||
# predictions for 2 observations on metrics a and b | ||
predictions=[ | ||
( | ||
{"a": 1.0, "b": 2.0}, | ||
{"a": 0.1, "b": 0.2}, | ||
), | ||
( | ||
{"a": 1.1, "b": 2.1}, | ||
{"a": 0.1, "b": 0.2}, | ||
), | ||
], | ||
outcome_constraints=[ | ||
OutcomeConstraint( | ||
metric=Metric("a"), | ||
op=ComparisonOp.GEQ, | ||
bound=0.9, | ||
relative=False, | ||
), | ||
OutcomeConstraint( | ||
metric=Metric("b"), | ||
op=ComparisonOp.LEQ, | ||
bound=2.2, | ||
relative=False, | ||
), | ||
], | ||
) | ||
self.assertEqual( | ||
len(constraint_violated_probabilities.keys()), | ||
3, | ||
) | ||
self.assertIn("any_constraint_violated", constraint_violated_probabilities) | ||
self.assertIn("a", constraint_violated_probabilities) | ||
self.assertIn("b", constraint_violated_probabilities) | ||
self.assertAlmostEqual( | ||
constraint_violated_probabilities["any_constraint_violated"][0], | ||
0.292, | ||
places=2, | ||
) | ||
self.assertAlmostEqual( | ||
constraint_violated_probabilities["any_constraint_violated"][1], | ||
0.324, | ||
places=2, | ||
) | ||
self.assertAlmostEqual( | ||
constraint_violated_probabilities["a"][0], 0.158, places=2 | ||
) | ||
self.assertAlmostEqual( | ||
constraint_violated_probabilities["a"][1], 0.022, places=2 | ||
) | ||
self.assertAlmostEqual( | ||
constraint_violated_probabilities["b"][0], 0.158, places=2 | ||
) | ||
self.assertAlmostEqual( | ||
constraint_violated_probabilities["b"][1], 0.308, places=2 | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import numpy as np | ||
import torch | ||
from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint | ||
from ax.exceptions.core import UserInputError | ||
from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds | ||
|
||
# Because normal distributions have long tails, every arm has a non-zero | ||
# probability of violating the constraint. But below a certain threshold, we | ||
# consider probability of violation to be negligible. | ||
MINIMUM_CONTRAINT_VIOLATION_THRESHOLD = 0.01 | ||
|
||
|
||
def get_constraint_violated_probabilities( | ||
predictions: list[tuple[dict[str, float], dict[str, float]]], | ||
outcome_constraints: list[OutcomeConstraint], | ||
) -> dict[str, list[float]]: | ||
"""Get the probability that each arm violates the outcome constraints. | ||
Args: | ||
predictions: List of predictions for each observation feature | ||
generated by predict_at_point. It should include predictions | ||
for all outcome constraint metrics. | ||
outcome_constraints: List of outcome constraints to check. | ||
Returns: | ||
A dict of probabilities that each arm violates the outcome | ||
constraint provided, and for "any_constraint_violated" the probability that | ||
the arm violates *any* outcome constraint provided. | ||
""" | ||
if len(outcome_constraints) == 0: | ||
return {"any_constraint_violated": [0.0] * len(predictions)} | ||
if any(constraint.relative for constraint in outcome_constraints): | ||
raise UserInputError( | ||
"`get_constraint_violated_probabilities()` does not support relative " | ||
"outcome constraints. Use `Derelativize().transform_optimization_config()` " | ||
"before passing constraints to this method." | ||
) | ||
|
||
metrics = [constraint.metric.name for constraint in outcome_constraints] | ||
means = torch.as_tensor( | ||
[ | ||
[prediction[0][metric_name] for metric_name in metrics] | ||
for prediction in predictions | ||
] | ||
) | ||
sigmas = torch.as_tensor( | ||
[ | ||
[prediction[1][metric_name] for metric_name in metrics] | ||
for prediction in predictions | ||
] | ||
) | ||
feasibility_probabilities: dict[str, np.ndarray] = {} | ||
for constraint in outcome_constraints: | ||
if constraint.op == ComparisonOp.GEQ: | ||
con_lower_inds = torch.tensor([metrics.index(constraint.metric.name)]) | ||
con_lower = torch.tensor([constraint.bound]) | ||
con_upper_inds = torch.as_tensor([]) | ||
con_upper = torch.as_tensor([]) | ||
else: | ||
con_lower_inds = torch.as_tensor([]) | ||
con_lower = torch.as_tensor([]) | ||
con_upper_inds = torch.tensor([metrics.index(constraint.metric.name)]) | ||
con_upper = torch.tensor([constraint.bound]) | ||
|
||
feasibility_probabilities[constraint.metric.name] = ( | ||
compute_log_prob_feas_from_bounds( | ||
means=means, | ||
sigmas=sigmas, | ||
con_lower_inds=con_lower_inds, | ||
con_upper_inds=con_upper_inds, | ||
con_lower=con_lower, | ||
con_upper=con_upper, | ||
# "both" can also be expressed by 2 separate constraints... | ||
con_both_inds=torch.as_tensor([]), | ||
con_both=torch.as_tensor([]), | ||
) | ||
.exp() | ||
.numpy() | ||
) | ||
|
||
feasibility_probabilities["any_constraint_violated"] = np.prod( | ||
list(feasibility_probabilities.values()), axis=0 | ||
) | ||
|
||
return { | ||
metric_name: 1 - feasibility_probabilities[metric_name] | ||
for metric_name in feasibility_probabilities | ||
} | ||
|
||
|
||
def format_constraint_violated_probabilities( | ||
constraints_violated: dict[str, float] | ||
) -> str: | ||
"""Format the constraints violated for the tooltip.""" | ||
max_metric_length = 70 | ||
constraints_violated = { | ||
k: v | ||
for k, v in constraints_violated.items() | ||
if v > MINIMUM_CONTRAINT_VIOLATION_THRESHOLD | ||
} | ||
constraints_violated_str = "<br /> ".join( | ||
[ | ||
( | ||
f"{k[:max_metric_length]}{'...' if len(k) > max_metric_length else ''}" | ||
f": {v * 100:.1f}% chance violated" | ||
) | ||
for k, v in constraints_violated.items() | ||
] | ||
) | ||
if len(constraints_violated_str) == 0: | ||
return "No constraints violated" | ||
else: | ||
constraints_violated_str = "<br /> " + constraints_violated_str | ||
|
||
return constraints_violated_str |