-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement Priors from "Vanilla Bayesian Optimization Performs Great i…
…n High Dimensions" (#402)
- Loading branch information
Showing
15 changed files
with
297 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,39 +1,37 @@ | ||
from typing import Dict, Optional | ||
from typing import Dict, List, Optional | ||
|
||
import numpy as np | ||
import plotly.express as px | ||
import plotly.graph_objects as go | ||
import torch | ||
|
||
import bofire.priors.api as priors | ||
from bofire.data_models.priors.api import AnyPrior | ||
from gpytorch.priors import Prior | ||
|
||
|
||
def plot_prior_pdf_plotly( | ||
prior: AnyPrior, | ||
priors: List[Prior], | ||
lower: float, | ||
upper: float, | ||
layout_options: Optional[Dict] = None, | ||
labels: Optional[List[str]] = None, | ||
): | ||
"""Plot the probability density function of the prior with plotly. | ||
"""Plot the probability density function of a gyptorch prior with plotly. | ||
Args: | ||
prior (AnyPrior): The prior that should be plotted. | ||
lower (float): lower bound for computing the prior pdf. | ||
upper (float): upper bound for computing the prior pdf. | ||
layout_options (Dict, optional): Layout options passed to plotly. Defaults to {}. | ||
prior: The prior that should be plotted. | ||
lower: lower bound for computing the prior pdf. | ||
upper: upper bound for computing the prior pdf. | ||
layout_options: Layout options passed to plotly. Defaults to None. | ||
labels: Labels for the priors, that are shown in the plot. Defaults to None. | ||
Returns: | ||
fig, ax objects of the plot. | ||
""" | ||
|
||
use_labels = labels is not None and len(labels) == len(priors) | ||
x = np.linspace(lower, upper, 1000) | ||
|
||
fig = px.line( | ||
x=x, | ||
y=np.exp(priors.map(prior).log_prob(torch.from_numpy(x)).numpy()), | ||
) | ||
|
||
fig = go.Figure() | ||
for i, prior in enumerate(priors): | ||
y = np.exp(prior.log_prob(torch.from_numpy(x)).numpy()) | ||
label = labels[i] if use_labels else prior.__class__.__name__ # type: ignore | ||
fig.add_trace(go.Scatter(x=x, y=y, mode="lines", name=label)) | ||
if layout_options is not None: | ||
fig.update_layout(layout_options) | ||
|
||
return fig |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from bofire.priors.mapper import map # noqa: F401 | ||
from bofire.priors.mapper import map |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,32 +1,56 @@ | ||
import math | ||
|
||
import gpytorch | ||
|
||
import bofire.data_models.priors.api as data_models | ||
|
||
|
||
def map_NormalPrior(data_model: data_models.NormalPrior) -> gpytorch.priors.NormalPrior: | ||
def map_NormalPrior( | ||
data_model: data_models.NormalPrior, **kwargs | ||
) -> gpytorch.priors.NormalPrior: | ||
return gpytorch.priors.NormalPrior(loc=data_model.loc, scale=data_model.scale) | ||
|
||
|
||
def map_GammaPrior(data_model: data_models.GammaPrior) -> gpytorch.priors.GammaPrior: | ||
def map_GammaPrior( | ||
data_model: data_models.GammaPrior, **kwargs | ||
) -> gpytorch.priors.GammaPrior: | ||
return gpytorch.priors.GammaPrior( | ||
concentration=data_model.concentration, rate=data_model.rate | ||
) | ||
|
||
|
||
def map_LKJPrior(data_model: data_models.LKJPrior) -> gpytorch.priors.LKJPrior: | ||
def map_LKJPrior( | ||
data_model: data_models.LKJPrior, **kwargs | ||
) -> gpytorch.priors.LKJPrior: | ||
return gpytorch.priors.LKJCovariancePrior( | ||
n=data_model.n_tasks, eta=data_model.shape, sd_prior=map(data_model.sd_prior) | ||
) | ||
|
||
|
||
def map_LogNormalPrior( | ||
data_model: data_models.LogNormalPrior, | ||
**kwargs, | ||
) -> gpytorch.priors.LogNormalPrior: | ||
return gpytorch.priors.LogNormalPrior(loc=data_model.loc, scale=data_model.scale) | ||
|
||
|
||
def map_DimensionalityScaledLogNormalPrior( | ||
data_model: data_models.DimensionalityScaledLogNormalPrior, d: int | ||
) -> gpytorch.priors.LogNormalPrior: | ||
return gpytorch.priors.LogNormalPrior( | ||
loc=data_model.loc + math.log(d) * data_model.loc_scaling, | ||
scale=(data_model.scale**2 + math.log(d) * data_model.scale_scaling) ** 0.5, | ||
) | ||
|
||
|
||
PRIOR_MAP = { | ||
data_models.NormalPrior: map_NormalPrior, | ||
data_models.GammaPrior: map_GammaPrior, | ||
data_models.LKJPrior: map_LKJPrior, | ||
data_models.LogNormalPrior: map_LogNormalPrior, | ||
data_models.DimensionalityScaledLogNormalPrior: map_DimensionalityScaledLogNormalPrior, | ||
} | ||
|
||
|
||
def map( | ||
data_model: data_models.AnyPrior, | ||
) -> gpytorch.priors.Prior: | ||
return PRIOR_MAP[data_model.__class__](data_model) | ||
def map(data_model: data_models.AnyPrior, **kwargs) -> gpytorch.priors.Prior: | ||
return PRIOR_MAP[data_model.__class__](data_model, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
import bofire.priors.api as priors | ||
from bofire.data_models.priors.api import BOTORCH_LENGTHCALE_PRIOR | ||
from bofire.plot.api import plot_prior_pdf_plotly | ||
|
||
|
||
def test_plot_prior_pdf_plotly(): | ||
plot_prior_pdf_plotly(BOTORCH_LENGTHCALE_PRIOR(), lower=0, upper=10) | ||
plot_prior_pdf_plotly([priors.map(BOTORCH_LENGTHCALE_PRIOR())], lower=0, upper=10) |
Oops, something went wrong.