From dd854d45b2d675ecd66d6d737f32e0bb7e5d3fb9 Mon Sep 17 00:00:00 2001 From: Daniel Cohen Date: Wed, 2 Oct 2024 15:25:24 -0700 Subject: [PATCH] Extract constraint violation logic for common use (#2811) Summary: Move `get_constraint_violated_probabilities()` to a utils file and give it its own tests. No changes were made to `get_constraint_violated_probabilities()` Reviewed By: Cesar-Cardoso Differential Revision: D63656156 --- ax/analysis/plotly/predicted_effects.py | 117 +---------------------- ax/analysis/plotly/tests/test_utils.py | 106 +++++++++++++++++++++ ax/analysis/plotly/utils.py | 120 ++++++++++++++++++++++++ sphinx/source/analysis.rst | 8 ++ 4 files changed, 239 insertions(+), 112 deletions(-) create mode 100644 ax/analysis/plotly/tests/test_utils.py create mode 100644 ax/analysis/plotly/utils.py diff --git a/ax/analysis/plotly/predicted_effects.py b/ax/analysis/plotly/predicted_effects.py index 1ae82d2cea8..24c5557e285 100644 --- a/ax/analysis/plotly/predicted_effects.py +++ b/ax/analysis/plotly/predicted_effects.py @@ -6,31 +6,29 @@ from itertools import chain from typing import Any, Optional -import numpy as np import pandas as pd -import torch from ax.analysis.analysis import AnalysisCardLevel from ax.analysis.plotly.plotly_analysis import PlotlyAnalysis, PlotlyAnalysisCard +from ax.analysis.plotly.utils import ( + format_constraint_violated_probabilities, + get_constraint_violated_probabilities, +) from ax.core import OutcomeConstraint from ax.core.base_trial import BaseTrial, TrialStatus from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.generator_run import GeneratorRun from ax.core.observation import ObservationFeatures -from ax.core.types import ComparisonOp from ax.exceptions.core import UserInputError from ax.modelbridge.base import ModelBridge from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.prediction_utils import predict_at_point from ax.modelbridge.transforms.derelativize import Derelativize from ax.utils.common.typeutils import checked_cast -from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds from plotly import express as px, graph_objects as go, io as pio from pyre_extensions import none_throws -MINIMUM_CONTRAINT_VIOLATION_THRESHOLD = 0.01 - class PredictedEffectsPlot(PlotlyAnalysis): def __init__(self, metric_name: str) -> None: @@ -114,111 +112,6 @@ def compute( ) -def get_constraint_violated_probabilities( - predictions: list[tuple[dict[str, float], dict[str, float]]], - outcome_constraints: list[OutcomeConstraint], -) -> dict[str, list[float]]: - """Get the probability that each arm violates the outcome constraints. - - Args: - predictions: List of predictions for each observation feature - generated by predict_at_point. It should include predictions - for all outcome constraint metrics. - outcome_constraints: List of outcome constraints to check. - - Returns: - A dict of probabilities that each arm violates the outcome - constraint provided, and for "any_constraint_violated" the probability that - the arm violates *any* outcome constraint provided. - """ - if len(outcome_constraints) == 0: - return {"any_constraint_violated": [0.0] * len(predictions)} - if any(constraint.relative for constraint in outcome_constraints): - raise UserInputError( - "`get_constraint_violated_probabilities()` does not support relative " - "outcome constraints. Use `Derelativize().transform_optimization_config()` " - "before passing constraints to this method." - ) - - metrics = [constraint.metric.name for constraint in outcome_constraints] - means = torch.as_tensor( - [ - [prediction[0][metric_name] for metric_name in metrics] - for prediction in predictions - ] - ) - sigmas = torch.as_tensor( - [ - [prediction[1][metric_name] for metric_name in metrics] - for prediction in predictions - ] - ) - feasibility_probabilities: dict[str, np.ndarray] = {} - for constraint in outcome_constraints: - if constraint.op == ComparisonOp.GEQ: - con_lower_inds = torch.tensor([metrics.index(constraint.metric.name)]) - con_lower = torch.tensor([constraint.bound]) - con_upper_inds = torch.as_tensor([]) - con_upper = torch.as_tensor([]) - else: - con_lower_inds = torch.as_tensor([]) - con_lower = torch.as_tensor([]) - con_upper_inds = torch.tensor([metrics.index(constraint.metric.name)]) - con_upper = torch.tensor([constraint.bound]) - - feasibility_probabilities[constraint.metric.name] = ( - compute_log_prob_feas_from_bounds( - means=means, - sigmas=sigmas, - con_lower_inds=con_lower_inds, - con_upper_inds=con_upper_inds, - con_lower=con_lower, - con_upper=con_upper, - # "both" can also be expressed by 2 separate constraints... - con_both_inds=torch.as_tensor([]), - con_both=torch.as_tensor([]), - ) - .exp() - .numpy() - ) - - feasibility_probabilities["any_constraint_violated"] = np.prod( - list(feasibility_probabilities.values()), axis=0 - ) - - return { - metric_name: 1 - feasibility_probabilities[metric_name] - for metric_name in feasibility_probabilities - } - - -def _format_constraint_violated_probabilities( - constraints_violated: dict[str, float] -) -> str: - """Format the constraints violated for the tooltip.""" - max_metric_length = 70 - constraints_violated = { - k: v - for k, v in constraints_violated.items() - if v > MINIMUM_CONTRAINT_VIOLATION_THRESHOLD - } - constraints_violated_str = "
".join( - [ - ( - f"{k[:max_metric_length]}{'...' if len(k) > max_metric_length else ''}" - f": {v * 100:.1f}% chance violated" - ) - for k, v in constraints_violated.items() - ] - ) - if len(constraints_violated_str) == 0: - return "No constraints violated" - else: - constraints_violated_str = "
" + constraints_violated_str - - return constraints_violated_str - - def _get_predictions( model: ModelBridge, metric_name: str, @@ -283,7 +176,7 @@ def _get_predictions( "mean": predictions[i][0][metric_name], "sem": predictions[i][1][metric_name], "error_margin": 1.96 * predictions[i][1][metric_name], - "constraints_violated": _format_constraint_violated_probabilities( + "constraints_violated": format_constraint_violated_probabilities( constraints_violated[i] ), "size_column": 100 - probabilities_not_feasible[i] * 100, diff --git a/ax/analysis/plotly/tests/test_utils.py b/ax/analysis/plotly/tests/test_utils.py new file mode 100644 index 00000000000..c8626dd9af7 --- /dev/null +++ b/ax/analysis/plotly/tests/test_utils.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from ax.analysis.plotly.utils import get_constraint_violated_probabilities +from ax.core.metric import Metric +from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint +from ax.exceptions.core import UserInputError +from ax.utils.common.testutils import TestCase + + +class TestUtils(TestCase): + def test_no_constraints_violates_none(self) -> None: + constraint_violated_probabilities = get_constraint_violated_probabilities( + # predictions for 2 observations on metrics a and b + predictions=[ + ( + {"a": 1.0, "b": 2.0}, + {"a": 0.1, "b": 0.2}, + ), + ( + {"a": 1.1, "b": 2.1}, + {"a": 0.1, "b": 0.2}, + ), + ], + outcome_constraints=[], + ) + self.assertEqual( + constraint_violated_probabilities, {"any_constraint_violated": [0.0, 0.0]} + ) + + def test_relative_constraints_are_not_accepted(self) -> None: + with self.assertRaisesRegex( + UserInputError, + "does not support relative outcome constraints", + ): + get_constraint_violated_probabilities( + predictions=[], + outcome_constraints=[ + OutcomeConstraint( + metric=Metric("a"), + op=ComparisonOp.GEQ, + bound=0.0, + relative=True, + ) + ], + ) + + def test_it_gives_a_result_per_constraint_plus_overall(self) -> None: + constraint_violated_probabilities = get_constraint_violated_probabilities( + # predictions for 2 observations on metrics a and b + predictions=[ + ( + {"a": 1.0, "b": 2.0}, + {"a": 0.1, "b": 0.2}, + ), + ( + {"a": 1.1, "b": 2.1}, + {"a": 0.1, "b": 0.2}, + ), + ], + outcome_constraints=[ + OutcomeConstraint( + metric=Metric("a"), + op=ComparisonOp.GEQ, + bound=0.9, + relative=False, + ), + OutcomeConstraint( + metric=Metric("b"), + op=ComparisonOp.LEQ, + bound=2.2, + relative=False, + ), + ], + ) + self.assertEqual( + len(constraint_violated_probabilities.keys()), + 3, + ) + self.assertIn("any_constraint_violated", constraint_violated_probabilities) + self.assertIn("a", constraint_violated_probabilities) + self.assertIn("b", constraint_violated_probabilities) + self.assertAlmostEqual( + constraint_violated_probabilities["any_constraint_violated"][0], + 0.292, + places=2, + ) + self.assertAlmostEqual( + constraint_violated_probabilities["any_constraint_violated"][1], + 0.324, + places=2, + ) + self.assertAlmostEqual( + constraint_violated_probabilities["a"][0], 0.158, places=2 + ) + self.assertAlmostEqual( + constraint_violated_probabilities["a"][1], 0.022, places=2 + ) + self.assertAlmostEqual( + constraint_violated_probabilities["b"][0], 0.158, places=2 + ) + self.assertAlmostEqual( + constraint_violated_probabilities["b"][1], 0.308, places=2 + ) diff --git a/ax/analysis/plotly/utils.py b/ax/analysis/plotly/utils.py new file mode 100644 index 00000000000..85cdac7309c --- /dev/null +++ b/ax/analysis/plotly/utils.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint +from ax.exceptions.core import UserInputError +from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds + +# Because normal distributions have long tails, every arm has a non-zero +# probability of violating the constraint. But below a certain threshold, we +# consider probability of violation to be negligible. +MINIMUM_CONTRAINT_VIOLATION_THRESHOLD = 0.01 + + +def get_constraint_violated_probabilities( + predictions: list[tuple[dict[str, float], dict[str, float]]], + outcome_constraints: list[OutcomeConstraint], +) -> dict[str, list[float]]: + """Get the probability that each arm violates the outcome constraints. + + Args: + predictions: List of predictions for each observation feature + generated by predict_at_point. It should include predictions + for all outcome constraint metrics. + outcome_constraints: List of outcome constraints to check. + + Returns: + A dict of probabilities that each arm violates the outcome + constraint provided, and for "any_constraint_violated" the probability that + the arm violates *any* outcome constraint provided. + """ + if len(outcome_constraints) == 0: + return {"any_constraint_violated": [0.0] * len(predictions)} + if any(constraint.relative for constraint in outcome_constraints): + raise UserInputError( + "`get_constraint_violated_probabilities()` does not support relative " + "outcome constraints. Use `Derelativize().transform_optimization_config()` " + "before passing constraints to this method." + ) + + metrics = [constraint.metric.name for constraint in outcome_constraints] + means = torch.as_tensor( + [ + [prediction[0][metric_name] for metric_name in metrics] + for prediction in predictions + ] + ) + sigmas = torch.as_tensor( + [ + [prediction[1][metric_name] for metric_name in metrics] + for prediction in predictions + ] + ) + feasibility_probabilities: dict[str, np.ndarray] = {} + for constraint in outcome_constraints: + if constraint.op == ComparisonOp.GEQ: + con_lower_inds = torch.tensor([metrics.index(constraint.metric.name)]) + con_lower = torch.tensor([constraint.bound]) + con_upper_inds = torch.as_tensor([]) + con_upper = torch.as_tensor([]) + else: + con_lower_inds = torch.as_tensor([]) + con_lower = torch.as_tensor([]) + con_upper_inds = torch.tensor([metrics.index(constraint.metric.name)]) + con_upper = torch.tensor([constraint.bound]) + + feasibility_probabilities[constraint.metric.name] = ( + compute_log_prob_feas_from_bounds( + means=means, + sigmas=sigmas, + con_lower_inds=con_lower_inds, + con_upper_inds=con_upper_inds, + con_lower=con_lower, + con_upper=con_upper, + # "both" can also be expressed by 2 separate constraints... + con_both_inds=torch.as_tensor([]), + con_both=torch.as_tensor([]), + ) + .exp() + .numpy() + ) + + feasibility_probabilities["any_constraint_violated"] = np.prod( + list(feasibility_probabilities.values()), axis=0 + ) + + return { + metric_name: 1 - feasibility_probabilities[metric_name] + for metric_name in feasibility_probabilities + } + + +def format_constraint_violated_probabilities( + constraints_violated: dict[str, float] +) -> str: + """Format the constraints violated for the tooltip.""" + max_metric_length = 70 + constraints_violated = { + k: v + for k, v in constraints_violated.items() + if v > MINIMUM_CONTRAINT_VIOLATION_THRESHOLD + } + constraints_violated_str = "
".join( + [ + ( + f"{k[:max_metric_length]}{'...' if len(k) > max_metric_length else ''}" + f": {v * 100:.1f}% chance violated" + ) + for k, v in constraints_violated.items() + ] + ) + if len(constraints_violated_str) == 0: + return "No constraints violated" + else: + constraints_violated_str = "
" + constraints_violated_str + + return constraints_violated_str diff --git a/sphinx/source/analysis.rst b/sphinx/source/analysis.rst index d7f02774ae2..f52e6681bbb 100644 --- a/sphinx/source/analysis.rst +++ b/sphinx/source/analysis.rst @@ -54,3 +54,11 @@ Predicted Effects Analysis :members: :undoc-members: :show-inheritance: + +Plotly Anaylsis Utils +~~~~~~~~~~~~~~~ + +.. automodule:: ax.analysis.plotly.utils + :members: + :undoc-members: + :show-inheritance: