From e42a182936e6eaf95ef8ecf50931f9b6153b355a Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Fri, 24 Jan 2025 09:53:24 -0800 Subject: [PATCH] add fully bayesian linear model to MBM registry (#3264) Summary: This also ensures that we use a model list for multi-outcome models Reviewed By: saitcakmak Differential Revision: D68570546 --- ax/models/torch/botorch_modular/utils.py | 4 ++-- ax/storage/botorch_modular_registry.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index c316679fa79..5b8f04acf1b 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -26,7 +26,7 @@ qLogNoisyExpectedHypervolumeImprovement, ) from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll -from botorch.models.fully_bayesian import SaasFullyBayesianSingleTaskGP +from botorch.models.fully_bayesian import FullyBayesianSingleTaskGP from botorch.models.gp_regression import SingleTaskGP from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP from botorch.models.gp_regression_mixed import MixedSingleTaskGP @@ -149,7 +149,7 @@ def use_model_list( botorch_model_class = ( model_configs[0].botorch_model_class or botorch_model_class ) - if issubclass(botorch_model_class, SaasFullyBayesianSingleTaskGP): + if issubclass(botorch_model_class, FullyBayesianSingleTaskGP): # SAAS models do not support multiple outcomes. # Use model list if there are multiple outcomes. return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1 diff --git a/ax/storage/botorch_modular_registry.py b/ax/storage/botorch_modular_registry.py index 17242a98c25..739f0b9b353 100644 --- a/ax/storage/botorch_modular_registry.py +++ b/ax/storage/botorch_modular_registry.py @@ -51,6 +51,7 @@ from botorch.acquisition.preference import AnalyticExpectedUtilityOfBestOption from botorch.models import SaasFullyBayesianSingleTaskGP from botorch.models.contextual import LCEAGP +from botorch.models.fully_bayesian import FullyBayesianLinearSingleTaskGP from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP # BoTorch `Model` imports @@ -113,6 +114,7 @@ MultiTaskGP: "MultiTaskGP", SingleTaskGP: "SingleTaskGP", SingleTaskMultiFidelityGP: "SingleTaskMultiFidelityGP", + FullyBayesianLinearSingleTaskGP: "FullyBayesianLinearSingleTaskGP", SaasFullyBayesianSingleTaskGP: "SaasFullyBayesianSingleTaskGP", SaasFullyBayesianMultiTaskGP: "SaasFullyBayesianMultiTaskGP", LCEAGP: "LCEAGP",