Skip to content

Commit

Permalink
Extract constraint violation logic for common use (facebook#2811)
Browse files Browse the repository at this point in the history
Summary:

Move `get_constraint_violated_probabilities()` to a utils file and give it its own tests.  No changes were made to `get_constraint_violated_probabilities()`

Reviewed By: Cesar-Cardoso

Differential Revision: D63656156
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Oct 2, 2024
1 parent 78b0527 commit dd854d4
Show file tree
Hide file tree
Showing 4 changed files with 239 additions and 112 deletions.
117 changes: 5 additions & 112 deletions ax/analysis/plotly/predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,29 @@
from itertools import chain
from typing import Any, Optional

import numpy as np
import pandas as pd
import torch
from ax.analysis.analysis import AnalysisCardLevel

from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
from ax.analysis.plotly.utils import (
format_constraint_violated_probabilities,
get_constraint_violated_probabilities,
)
from ax.core import OutcomeConstraint
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
from ax.core.types import ComparisonOp
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.prediction_utils import predict_at_point
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.utils.common.typeutils import checked_cast
from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds
from plotly import express as px, graph_objects as go, io as pio
from pyre_extensions import none_throws

MINIMUM_CONTRAINT_VIOLATION_THRESHOLD = 0.01


class PredictedEffectsPlot(PlotlyAnalysis):
def __init__(self, metric_name: str) -> None:
Expand Down Expand Up @@ -114,111 +112,6 @@ def compute(
)


def get_constraint_violated_probabilities(
predictions: list[tuple[dict[str, float], dict[str, float]]],
outcome_constraints: list[OutcomeConstraint],
) -> dict[str, list[float]]:
"""Get the probability that each arm violates the outcome constraints.
Args:
predictions: List of predictions for each observation feature
generated by predict_at_point. It should include predictions
for all outcome constraint metrics.
outcome_constraints: List of outcome constraints to check.
Returns:
A dict of probabilities that each arm violates the outcome
constraint provided, and for "any_constraint_violated" the probability that
the arm violates *any* outcome constraint provided.
"""
if len(outcome_constraints) == 0:
return {"any_constraint_violated": [0.0] * len(predictions)}
if any(constraint.relative for constraint in outcome_constraints):
raise UserInputError(
"`get_constraint_violated_probabilities()` does not support relative "
"outcome constraints. Use `Derelativize().transform_optimization_config()` "
"before passing constraints to this method."
)

metrics = [constraint.metric.name for constraint in outcome_constraints]
means = torch.as_tensor(
[
[prediction[0][metric_name] for metric_name in metrics]
for prediction in predictions
]
)
sigmas = torch.as_tensor(
[
[prediction[1][metric_name] for metric_name in metrics]
for prediction in predictions
]
)
feasibility_probabilities: dict[str, np.ndarray] = {}
for constraint in outcome_constraints:
if constraint.op == ComparisonOp.GEQ:
con_lower_inds = torch.tensor([metrics.index(constraint.metric.name)])
con_lower = torch.tensor([constraint.bound])
con_upper_inds = torch.as_tensor([])
con_upper = torch.as_tensor([])
else:
con_lower_inds = torch.as_tensor([])
con_lower = torch.as_tensor([])
con_upper_inds = torch.tensor([metrics.index(constraint.metric.name)])
con_upper = torch.tensor([constraint.bound])

feasibility_probabilities[constraint.metric.name] = (
compute_log_prob_feas_from_bounds(
means=means,
sigmas=sigmas,
con_lower_inds=con_lower_inds,
con_upper_inds=con_upper_inds,
con_lower=con_lower,
con_upper=con_upper,
# "both" can also be expressed by 2 separate constraints...
con_both_inds=torch.as_tensor([]),
con_both=torch.as_tensor([]),
)
.exp()
.numpy()
)

feasibility_probabilities["any_constraint_violated"] = np.prod(
list(feasibility_probabilities.values()), axis=0
)

return {
metric_name: 1 - feasibility_probabilities[metric_name]
for metric_name in feasibility_probabilities
}


def _format_constraint_violated_probabilities(
constraints_violated: dict[str, float]
) -> str:
"""Format the constraints violated for the tooltip."""
max_metric_length = 70
constraints_violated = {
k: v
for k, v in constraints_violated.items()
if v > MINIMUM_CONTRAINT_VIOLATION_THRESHOLD
}
constraints_violated_str = "<br /> ".join(
[
(
f"{k[:max_metric_length]}{'...' if len(k) > max_metric_length else ''}"
f": {v * 100:.1f}% chance violated"
)
for k, v in constraints_violated.items()
]
)
if len(constraints_violated_str) == 0:
return "No constraints violated"
else:
constraints_violated_str = "<br /> " + constraints_violated_str

return constraints_violated_str


def _get_predictions(
model: ModelBridge,
metric_name: str,
Expand Down Expand Up @@ -283,7 +176,7 @@ def _get_predictions(
"mean": predictions[i][0][metric_name],
"sem": predictions[i][1][metric_name],
"error_margin": 1.96 * predictions[i][1][metric_name],
"constraints_violated": _format_constraint_violated_probabilities(
"constraints_violated": format_constraint_violated_probabilities(
constraints_violated[i]
),
"size_column": 100 - probabilities_not_feasible[i] * 100,
Expand Down
106 changes: 106 additions & 0 deletions ax/analysis/plotly/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from ax.analysis.plotly.utils import get_constraint_violated_probabilities
from ax.core.metric import Metric
from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint
from ax.exceptions.core import UserInputError
from ax.utils.common.testutils import TestCase


class TestUtils(TestCase):
def test_no_constraints_violates_none(self) -> None:
constraint_violated_probabilities = get_constraint_violated_probabilities(
# predictions for 2 observations on metrics a and b
predictions=[
(
{"a": 1.0, "b": 2.0},
{"a": 0.1, "b": 0.2},
),
(
{"a": 1.1, "b": 2.1},
{"a": 0.1, "b": 0.2},
),
],
outcome_constraints=[],
)
self.assertEqual(
constraint_violated_probabilities, {"any_constraint_violated": [0.0, 0.0]}
)

def test_relative_constraints_are_not_accepted(self) -> None:
with self.assertRaisesRegex(
UserInputError,
"does not support relative outcome constraints",
):
get_constraint_violated_probabilities(
predictions=[],
outcome_constraints=[
OutcomeConstraint(
metric=Metric("a"),
op=ComparisonOp.GEQ,
bound=0.0,
relative=True,
)
],
)

def test_it_gives_a_result_per_constraint_plus_overall(self) -> None:
constraint_violated_probabilities = get_constraint_violated_probabilities(
# predictions for 2 observations on metrics a and b
predictions=[
(
{"a": 1.0, "b": 2.0},
{"a": 0.1, "b": 0.2},
),
(
{"a": 1.1, "b": 2.1},
{"a": 0.1, "b": 0.2},
),
],
outcome_constraints=[
OutcomeConstraint(
metric=Metric("a"),
op=ComparisonOp.GEQ,
bound=0.9,
relative=False,
),
OutcomeConstraint(
metric=Metric("b"),
op=ComparisonOp.LEQ,
bound=2.2,
relative=False,
),
],
)
self.assertEqual(
len(constraint_violated_probabilities.keys()),
3,
)
self.assertIn("any_constraint_violated", constraint_violated_probabilities)
self.assertIn("a", constraint_violated_probabilities)
self.assertIn("b", constraint_violated_probabilities)
self.assertAlmostEqual(
constraint_violated_probabilities["any_constraint_violated"][0],
0.292,
places=2,
)
self.assertAlmostEqual(
constraint_violated_probabilities["any_constraint_violated"][1],
0.324,
places=2,
)
self.assertAlmostEqual(
constraint_violated_probabilities["a"][0], 0.158, places=2
)
self.assertAlmostEqual(
constraint_violated_probabilities["a"][1], 0.022, places=2
)
self.assertAlmostEqual(
constraint_violated_probabilities["b"][0], 0.158, places=2
)
self.assertAlmostEqual(
constraint_violated_probabilities["b"][1], 0.308, places=2
)
120 changes: 120 additions & 0 deletions ax/analysis/plotly/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import numpy as np
import torch
from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint
from ax.exceptions.core import UserInputError
from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds

# Because normal distributions have long tails, every arm has a non-zero
# probability of violating the constraint. But below a certain threshold, we
# consider probability of violation to be negligible.
MINIMUM_CONTRAINT_VIOLATION_THRESHOLD = 0.01


def get_constraint_violated_probabilities(
predictions: list[tuple[dict[str, float], dict[str, float]]],
outcome_constraints: list[OutcomeConstraint],
) -> dict[str, list[float]]:
"""Get the probability that each arm violates the outcome constraints.
Args:
predictions: List of predictions for each observation feature
generated by predict_at_point. It should include predictions
for all outcome constraint metrics.
outcome_constraints: List of outcome constraints to check.
Returns:
A dict of probabilities that each arm violates the outcome
constraint provided, and for "any_constraint_violated" the probability that
the arm violates *any* outcome constraint provided.
"""
if len(outcome_constraints) == 0:
return {"any_constraint_violated": [0.0] * len(predictions)}
if any(constraint.relative for constraint in outcome_constraints):
raise UserInputError(
"`get_constraint_violated_probabilities()` does not support relative "
"outcome constraints. Use `Derelativize().transform_optimization_config()` "
"before passing constraints to this method."
)

metrics = [constraint.metric.name for constraint in outcome_constraints]
means = torch.as_tensor(
[
[prediction[0][metric_name] for metric_name in metrics]
for prediction in predictions
]
)
sigmas = torch.as_tensor(
[
[prediction[1][metric_name] for metric_name in metrics]
for prediction in predictions
]
)
feasibility_probabilities: dict[str, np.ndarray] = {}
for constraint in outcome_constraints:
if constraint.op == ComparisonOp.GEQ:
con_lower_inds = torch.tensor([metrics.index(constraint.metric.name)])
con_lower = torch.tensor([constraint.bound])
con_upper_inds = torch.as_tensor([])
con_upper = torch.as_tensor([])
else:
con_lower_inds = torch.as_tensor([])
con_lower = torch.as_tensor([])
con_upper_inds = torch.tensor([metrics.index(constraint.metric.name)])
con_upper = torch.tensor([constraint.bound])

feasibility_probabilities[constraint.metric.name] = (
compute_log_prob_feas_from_bounds(
means=means,
sigmas=sigmas,
con_lower_inds=con_lower_inds,
con_upper_inds=con_upper_inds,
con_lower=con_lower,
con_upper=con_upper,
# "both" can also be expressed by 2 separate constraints...
con_both_inds=torch.as_tensor([]),
con_both=torch.as_tensor([]),
)
.exp()
.numpy()
)

feasibility_probabilities["any_constraint_violated"] = np.prod(
list(feasibility_probabilities.values()), axis=0
)

return {
metric_name: 1 - feasibility_probabilities[metric_name]
for metric_name in feasibility_probabilities
}


def format_constraint_violated_probabilities(
constraints_violated: dict[str, float]
) -> str:
"""Format the constraints violated for the tooltip."""
max_metric_length = 70
constraints_violated = {
k: v
for k, v in constraints_violated.items()
if v > MINIMUM_CONTRAINT_VIOLATION_THRESHOLD
}
constraints_violated_str = "<br /> ".join(
[
(
f"{k[:max_metric_length]}{'...' if len(k) > max_metric_length else ''}"
f": {v * 100:.1f}% chance violated"
)
for k, v in constraints_violated.items()
]
)
if len(constraints_violated_str) == 0:
return "No constraints violated"
else:
constraints_violated_str = "<br /> " + constraints_violated_str

return constraints_violated_str
8 changes: 8 additions & 0 deletions sphinx/source/analysis.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,11 @@ Predicted Effects Analysis
:members:
:undoc-members:
:show-inheritance:

Plotly Anaylsis Utils
~~~~~~~~~~~~~~~

.. automodule:: ax.analysis.plotly.utils
:members:
:undoc-members:
:show-inheritance:

0 comments on commit dd854d4

Please sign in to comment.