diff --git a/ax/analysis/plotly/predicted_effects.py b/ax/analysis/plotly/predicted_effects.py
index 1ae82d2cea8..24c5557e285 100644
--- a/ax/analysis/plotly/predicted_effects.py
+++ b/ax/analysis/plotly/predicted_effects.py
@@ -6,31 +6,29 @@
from itertools import chain
from typing import Any, Optional
-import numpy as np
import pandas as pd
-import torch
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard
+from ax.analysis.plotly.utils import (
+ format_constraint_violated_probabilities,
+ get_constraint_violated_probabilities,
+)
from ax.core import OutcomeConstraint
from ax.core.base_trial import BaseTrial, TrialStatus
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.generator_run import GeneratorRun
from ax.core.observation import ObservationFeatures
-from ax.core.types import ComparisonOp
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.prediction_utils import predict_at_point
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.utils.common.typeutils import checked_cast
-from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds
from plotly import express as px, graph_objects as go, io as pio
from pyre_extensions import none_throws
-MINIMUM_CONTRAINT_VIOLATION_THRESHOLD = 0.01
-
class PredictedEffectsPlot(PlotlyAnalysis):
def __init__(self, metric_name: str) -> None:
@@ -114,111 +112,6 @@ def compute(
)
-def get_constraint_violated_probabilities(
- predictions: list[tuple[dict[str, float], dict[str, float]]],
- outcome_constraints: list[OutcomeConstraint],
-) -> dict[str, list[float]]:
- """Get the probability that each arm violates the outcome constraints.
-
- Args:
- predictions: List of predictions for each observation feature
- generated by predict_at_point. It should include predictions
- for all outcome constraint metrics.
- outcome_constraints: List of outcome constraints to check.
-
- Returns:
- A dict of probabilities that each arm violates the outcome
- constraint provided, and for "any_constraint_violated" the probability that
- the arm violates *any* outcome constraint provided.
- """
- if len(outcome_constraints) == 0:
- return {"any_constraint_violated": [0.0] * len(predictions)}
- if any(constraint.relative for constraint in outcome_constraints):
- raise UserInputError(
- "`get_constraint_violated_probabilities()` does not support relative "
- "outcome constraints. Use `Derelativize().transform_optimization_config()` "
- "before passing constraints to this method."
- )
-
- metrics = [constraint.metric.name for constraint in outcome_constraints]
- means = torch.as_tensor(
- [
- [prediction[0][metric_name] for metric_name in metrics]
- for prediction in predictions
- ]
- )
- sigmas = torch.as_tensor(
- [
- [prediction[1][metric_name] for metric_name in metrics]
- for prediction in predictions
- ]
- )
- feasibility_probabilities: dict[str, np.ndarray] = {}
- for constraint in outcome_constraints:
- if constraint.op == ComparisonOp.GEQ:
- con_lower_inds = torch.tensor([metrics.index(constraint.metric.name)])
- con_lower = torch.tensor([constraint.bound])
- con_upper_inds = torch.as_tensor([])
- con_upper = torch.as_tensor([])
- else:
- con_lower_inds = torch.as_tensor([])
- con_lower = torch.as_tensor([])
- con_upper_inds = torch.tensor([metrics.index(constraint.metric.name)])
- con_upper = torch.tensor([constraint.bound])
-
- feasibility_probabilities[constraint.metric.name] = (
- compute_log_prob_feas_from_bounds(
- means=means,
- sigmas=sigmas,
- con_lower_inds=con_lower_inds,
- con_upper_inds=con_upper_inds,
- con_lower=con_lower,
- con_upper=con_upper,
- # "both" can also be expressed by 2 separate constraints...
- con_both_inds=torch.as_tensor([]),
- con_both=torch.as_tensor([]),
- )
- .exp()
- .numpy()
- )
-
- feasibility_probabilities["any_constraint_violated"] = np.prod(
- list(feasibility_probabilities.values()), axis=0
- )
-
- return {
- metric_name: 1 - feasibility_probabilities[metric_name]
- for metric_name in feasibility_probabilities
- }
-
-
-def _format_constraint_violated_probabilities(
- constraints_violated: dict[str, float]
-) -> str:
- """Format the constraints violated for the tooltip."""
- max_metric_length = 70
- constraints_violated = {
- k: v
- for k, v in constraints_violated.items()
- if v > MINIMUM_CONTRAINT_VIOLATION_THRESHOLD
- }
- constraints_violated_str = "
".join(
- [
- (
- f"{k[:max_metric_length]}{'...' if len(k) > max_metric_length else ''}"
- f": {v * 100:.1f}% chance violated"
- )
- for k, v in constraints_violated.items()
- ]
- )
- if len(constraints_violated_str) == 0:
- return "No constraints violated"
- else:
- constraints_violated_str = "
" + constraints_violated_str
-
- return constraints_violated_str
-
-
def _get_predictions(
model: ModelBridge,
metric_name: str,
@@ -283,7 +176,7 @@ def _get_predictions(
"mean": predictions[i][0][metric_name],
"sem": predictions[i][1][metric_name],
"error_margin": 1.96 * predictions[i][1][metric_name],
- "constraints_violated": _format_constraint_violated_probabilities(
+ "constraints_violated": format_constraint_violated_probabilities(
constraints_violated[i]
),
"size_column": 100 - probabilities_not_feasible[i] * 100,
diff --git a/ax/analysis/plotly/tests/test_utils.py b/ax/analysis/plotly/tests/test_utils.py
new file mode 100644
index 00000000000..c8626dd9af7
--- /dev/null
+++ b/ax/analysis/plotly/tests/test_utils.py
@@ -0,0 +1,106 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from ax.analysis.plotly.utils import get_constraint_violated_probabilities
+from ax.core.metric import Metric
+from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint
+from ax.exceptions.core import UserInputError
+from ax.utils.common.testutils import TestCase
+
+
+class TestUtils(TestCase):
+ def test_no_constraints_violates_none(self) -> None:
+ constraint_violated_probabilities = get_constraint_violated_probabilities(
+ # predictions for 2 observations on metrics a and b
+ predictions=[
+ (
+ {"a": 1.0, "b": 2.0},
+ {"a": 0.1, "b": 0.2},
+ ),
+ (
+ {"a": 1.1, "b": 2.1},
+ {"a": 0.1, "b": 0.2},
+ ),
+ ],
+ outcome_constraints=[],
+ )
+ self.assertEqual(
+ constraint_violated_probabilities, {"any_constraint_violated": [0.0, 0.0]}
+ )
+
+ def test_relative_constraints_are_not_accepted(self) -> None:
+ with self.assertRaisesRegex(
+ UserInputError,
+ "does not support relative outcome constraints",
+ ):
+ get_constraint_violated_probabilities(
+ predictions=[],
+ outcome_constraints=[
+ OutcomeConstraint(
+ metric=Metric("a"),
+ op=ComparisonOp.GEQ,
+ bound=0.0,
+ relative=True,
+ )
+ ],
+ )
+
+ def test_it_gives_a_result_per_constraint_plus_overall(self) -> None:
+ constraint_violated_probabilities = get_constraint_violated_probabilities(
+ # predictions for 2 observations on metrics a and b
+ predictions=[
+ (
+ {"a": 1.0, "b": 2.0},
+ {"a": 0.1, "b": 0.2},
+ ),
+ (
+ {"a": 1.1, "b": 2.1},
+ {"a": 0.1, "b": 0.2},
+ ),
+ ],
+ outcome_constraints=[
+ OutcomeConstraint(
+ metric=Metric("a"),
+ op=ComparisonOp.GEQ,
+ bound=0.9,
+ relative=False,
+ ),
+ OutcomeConstraint(
+ metric=Metric("b"),
+ op=ComparisonOp.LEQ,
+ bound=2.2,
+ relative=False,
+ ),
+ ],
+ )
+ self.assertEqual(
+ len(constraint_violated_probabilities.keys()),
+ 3,
+ )
+ self.assertIn("any_constraint_violated", constraint_violated_probabilities)
+ self.assertIn("a", constraint_violated_probabilities)
+ self.assertIn("b", constraint_violated_probabilities)
+ self.assertAlmostEqual(
+ constraint_violated_probabilities["any_constraint_violated"][0],
+ 0.292,
+ places=2,
+ )
+ self.assertAlmostEqual(
+ constraint_violated_probabilities["any_constraint_violated"][1],
+ 0.324,
+ places=2,
+ )
+ self.assertAlmostEqual(
+ constraint_violated_probabilities["a"][0], 0.158, places=2
+ )
+ self.assertAlmostEqual(
+ constraint_violated_probabilities["a"][1], 0.022, places=2
+ )
+ self.assertAlmostEqual(
+ constraint_violated_probabilities["b"][0], 0.158, places=2
+ )
+ self.assertAlmostEqual(
+ constraint_violated_probabilities["b"][1], 0.308, places=2
+ )
diff --git a/ax/analysis/plotly/utils.py b/ax/analysis/plotly/utils.py
new file mode 100644
index 00000000000..85cdac7309c
--- /dev/null
+++ b/ax/analysis/plotly/utils.py
@@ -0,0 +1,120 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+import torch
+from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint
+from ax.exceptions.core import UserInputError
+from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds
+
+# Because normal distributions have long tails, every arm has a non-zero
+# probability of violating the constraint. But below a certain threshold, we
+# consider probability of violation to be negligible.
+MINIMUM_CONTRAINT_VIOLATION_THRESHOLD = 0.01
+
+
+def get_constraint_violated_probabilities(
+ predictions: list[tuple[dict[str, float], dict[str, float]]],
+ outcome_constraints: list[OutcomeConstraint],
+) -> dict[str, list[float]]:
+ """Get the probability that each arm violates the outcome constraints.
+
+ Args:
+ predictions: List of predictions for each observation feature
+ generated by predict_at_point. It should include predictions
+ for all outcome constraint metrics.
+ outcome_constraints: List of outcome constraints to check.
+
+ Returns:
+ A dict of probabilities that each arm violates the outcome
+ constraint provided, and for "any_constraint_violated" the probability that
+ the arm violates *any* outcome constraint provided.
+ """
+ if len(outcome_constraints) == 0:
+ return {"any_constraint_violated": [0.0] * len(predictions)}
+ if any(constraint.relative for constraint in outcome_constraints):
+ raise UserInputError(
+ "`get_constraint_violated_probabilities()` does not support relative "
+ "outcome constraints. Use `Derelativize().transform_optimization_config()` "
+ "before passing constraints to this method."
+ )
+
+ metrics = [constraint.metric.name for constraint in outcome_constraints]
+ means = torch.as_tensor(
+ [
+ [prediction[0][metric_name] for metric_name in metrics]
+ for prediction in predictions
+ ]
+ )
+ sigmas = torch.as_tensor(
+ [
+ [prediction[1][metric_name] for metric_name in metrics]
+ for prediction in predictions
+ ]
+ )
+ feasibility_probabilities: dict[str, np.ndarray] = {}
+ for constraint in outcome_constraints:
+ if constraint.op == ComparisonOp.GEQ:
+ con_lower_inds = torch.tensor([metrics.index(constraint.metric.name)])
+ con_lower = torch.tensor([constraint.bound])
+ con_upper_inds = torch.as_tensor([])
+ con_upper = torch.as_tensor([])
+ else:
+ con_lower_inds = torch.as_tensor([])
+ con_lower = torch.as_tensor([])
+ con_upper_inds = torch.tensor([metrics.index(constraint.metric.name)])
+ con_upper = torch.tensor([constraint.bound])
+
+ feasibility_probabilities[constraint.metric.name] = (
+ compute_log_prob_feas_from_bounds(
+ means=means,
+ sigmas=sigmas,
+ con_lower_inds=con_lower_inds,
+ con_upper_inds=con_upper_inds,
+ con_lower=con_lower,
+ con_upper=con_upper,
+ # "both" can also be expressed by 2 separate constraints...
+ con_both_inds=torch.as_tensor([]),
+ con_both=torch.as_tensor([]),
+ )
+ .exp()
+ .numpy()
+ )
+
+ feasibility_probabilities["any_constraint_violated"] = np.prod(
+ list(feasibility_probabilities.values()), axis=0
+ )
+
+ return {
+ metric_name: 1 - feasibility_probabilities[metric_name]
+ for metric_name in feasibility_probabilities
+ }
+
+
+def format_constraint_violated_probabilities(
+ constraints_violated: dict[str, float]
+) -> str:
+ """Format the constraints violated for the tooltip."""
+ max_metric_length = 70
+ constraints_violated = {
+ k: v
+ for k, v in constraints_violated.items()
+ if v > MINIMUM_CONTRAINT_VIOLATION_THRESHOLD
+ }
+ constraints_violated_str = "
".join(
+ [
+ (
+ f"{k[:max_metric_length]}{'...' if len(k) > max_metric_length else ''}"
+ f": {v * 100:.1f}% chance violated"
+ )
+ for k, v in constraints_violated.items()
+ ]
+ )
+ if len(constraints_violated_str) == 0:
+ return "No constraints violated"
+ else:
+ constraints_violated_str = "
" + constraints_violated_str
+
+ return constraints_violated_str
diff --git a/sphinx/source/analysis.rst b/sphinx/source/analysis.rst
index d7f02774ae2..f52e6681bbb 100644
--- a/sphinx/source/analysis.rst
+++ b/sphinx/source/analysis.rst
@@ -54,3 +54,11 @@ Predicted Effects Analysis
:members:
:undoc-members:
:show-inheritance:
+
+Plotly Anaylsis Utils
+~~~~~~~~~~~~~~~
+
+.. automodule:: ax.analysis.plotly.utils
+ :members:
+ :undoc-members:
+ :show-inheritance: