diff --git a/botorch/optim/optimize.py b/botorch/optim/optimize.py index 88879c1102..935720c7e2 100644 --- a/botorch/optim/optimize.py +++ b/botorch/optim/optimize.py @@ -1438,7 +1438,7 @@ def optimize_acqf_discrete_local_search( X_avoid = torch.zeros(0, dim, device=device, dtype=dtype) inequality_constraints = inequality_constraints or [] - for i in range(q): + for _ in range(q): # generate some starting points X0 = _gen_starting_points_local_search( discrete_choices=discrete_choices, diff --git a/botorch/optim/parameter_constraints.py b/botorch/optim/parameter_constraints.py index 069ad2f5e7..edb8709efb 100644 --- a/botorch/optim/parameter_constraints.py +++ b/botorch/optim/parameter_constraints.py @@ -11,7 +11,6 @@ from __future__ import annotations from collections.abc import Callable - from functools import partial from typing import Union @@ -26,7 +25,7 @@ ScipyConstraintDict = dict[ str, Union[str, Callable[[np.ndarray], float], Callable[[np.ndarray], np.ndarray]] ] -NLC_TOL = -1e-6 +CONST_TOL = 1e-6 def make_scipy_bounds( @@ -511,9 +510,12 @@ def f_grad(X): def nonlinear_constraint_is_feasible( - nonlinear_inequality_constraint: Callable, is_intrapoint: bool, x: Tensor + nonlinear_inequality_constraint: Callable, + is_intrapoint: bool, + x: Tensor, + tolerance: float = CONST_TOL, ) -> Tensor: - """Checks if a nonlinear inequality constraint is fulfilled. + """Checks if a nonlinear inequality constraint is fulfilled (within tolerance). Args: nonlinear_inequality_constraint: Callable to evaluate the @@ -523,6 +525,9 @@ def nonlinear_constraint_is_feasible( constraint has to evaluated over the whole q-batch and is a an inter-point constraint. x: Tensor of shape (batch x q x d). + tolerance: Rather than using the exact `const(x) >= 0` constraint, this helper + checks feasibility of `const(x) >= -tolerance`. This avoids marking the + candidates as infeasible due to tiny violations. Returns: A boolean tensor of shape (batch) indicating if the constraint is @@ -530,7 +535,7 @@ def nonlinear_constraint_is_feasible( """ def check_x(x: Tensor) -> bool: - return _arrayify(nonlinear_inequality_constraint(x)).item() >= NLC_TOL + return _arrayify(nonlinear_inequality_constraint(x)).item() >= -tolerance x_flat = x.view(-1, *x.shape[-2:]) is_feasible = torch.ones(x_flat.shape[0], dtype=torch.bool, device=x.device) @@ -603,3 +608,82 @@ def make_scipy_nonlinear_inequality_constraints( shapeX=shapeX, ) return scipy_nonlinear_inequality_constraints + + +def evaluate_feasibility( + X: Tensor, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None, + tolerance: float = CONST_TOL, +) -> Tensor: + r"""Evaluate feasibility of candidate points (within a tolerance). + + Args: + X: The candidate tensor of shape `batch x q x d`. + inequality_constraints: A list of tuples (indices, coefficients, rhs), + with each tuple encoding an inequality constraint of the form + `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and + `coefficients` should be torch tensors. See the docstring of + `make_scipy_linear_constraints` for an example. When q=1, or when + applying the same constraint to each candidate in the batch + (intra-point constraint), `indices` should be a 1-d tensor. + For inter-point constraints, in which the constraint is applied to the + whole batch of candidates, `indices` must be a 2-d tensor, where + in each row `indices[i] =(k_i, l_i)` the first index `k_i` corresponds + to the `k_i`-th element of the `q`-batch and the second index `l_i` + corresponds to the `l_i`-th feature of that element. + equality_constraints: A list of tuples (indices, coefficients, rhs), + with each tuple encoding an equality constraint of the form + `\sum_i (X[indices[i]] * coefficients[i]) = rhs`. See the docstring of + `make_scipy_linear_constraints` for an example. + nonlinear_inequality_constraints: A list of tuples representing the nonlinear + inequality constraints. The first element in the tuple is a callable + representing a constraint of the form `callable(x) >= 0`. In case of an + intra-point constraint, `callable()`takes in an one-dimensional tensor of + shape `d` and returns a scalar. In case of an inter-point constraint, + `callable()` takes a two dimensional tensor of shape `q x d` and again + returns a scalar. The second element is a boolean, indicating if it is an + intra-point or inter-point constraint (`True` for intra-point. `False` for + inter-point). For more information on intra-point vs inter-point + constraints, see the docstring of the `inequality_constraints` argument. + tolerance: The tolerance used to check the feasibility of equality constraints + and non-linear inequality constraints. For equality constraints, we check + if `abs(const(X) - rhs) < tolerance`. For non-linear inequality constraints, + we check if `const(X) >= -tolerance`. This avoids marking the candidates as + infeasible due to tiny violations. + + Returns: + A boolean tensor of shape `batch` indicating if the corresponding candidate of + shape `q x d` is feasible. + """ + is_feasible = torch.ones(X.shape[:-2], device=X.device, dtype=torch.bool) + if inequality_constraints is not None: + for idx, coef, rhs in inequality_constraints: + if idx.ndim == 1: + # Intra-point constraints. + is_feasible &= ((X[..., idx] * coef).sum(dim=-1) >= rhs).all(dim=-1) + else: + # Inter-point constraints. + is_feasible &= (X[..., idx[:, 0], idx[:, 1]] * coef).sum(dim=-1) >= rhs + if equality_constraints is not None: + for idx, coef, rhs in equality_constraints: + if idx.ndim == 1: + # Intra-point constraints. + is_feasible &= ( + ((X[..., idx] * coef).sum(dim=-1) - rhs).abs() < tolerance + ).all(dim=-1) + else: + # Inter-point constraints. + is_feasible &= ( + (X[..., idx[:, 0], idx[:, 1]] * coef).sum(dim=-1) - rhs + ).abs() < tolerance + if nonlinear_inequality_constraints is not None: + for const, intra in nonlinear_inequality_constraints: + is_feasible &= nonlinear_constraint_is_feasible( + nonlinear_inequality_constraint=const, + is_intrapoint=intra, + x=X, + tolerance=tolerance, + ) + return is_feasible diff --git a/test/optim/test_parameter_constraints.py b/test/optim/test_parameter_constraints.py index df0cad98b7..574d7c82e5 100644 --- a/test/optim/test_parameter_constraints.py +++ b/test/optim/test_parameter_constraints.py @@ -18,6 +18,7 @@ _make_linear_constraints, _make_nonlinear_constraints, eval_lin_constraint, + evaluate_feasibility, lin_constraint_jac, make_scipy_bounds, make_scipy_linear_constraints, @@ -529,6 +530,142 @@ def test_generate_unfixed_lin_constraints(self): eq=eq, ) + def test_evaluate_feasibility(self) -> None: + # Check that the feasibility is evaluated correctly. + X = torch.tensor( # 3 x 2 x 3 -> leads to output of shape 3. + [ + [[1.0, 1.0, 1.0], [1.0, 1.0, 3.0]], + [[2.0, 2.0, 1.0], [2.0, 2.0, 5.0]], + [[3.0, 3.0, 3.0], [3.0, 3.0, 3.0]], + ], + device=self.device, + ) + # X[..., 2] * 4 >= 5. + inequality_constraints = [ + ( + torch.tensor([2], device=self.device), + torch.tensor([4], device=self.device), + 5.0, + ) + ] + # X[..., 0] + X[..., 1] == 4. + equality_constraints = [ + ( + torch.tensor([0, 1], device=self.device), + torch.ones(2, device=self.device), + 4.0, + ) + ] + + # sum(X, dim=-1) < 5. + def nlc1(x): + return 5 - x.sum(dim=-1) + + # Only inequality. + self.assertAllClose( + evaluate_feasibility( + X=X, + inequality_constraints=inequality_constraints, + ), + torch.tensor([False, False, True], device=self.device), + ) + # Only equality. + self.assertAllClose( + evaluate_feasibility( + X=X, + equality_constraints=equality_constraints, + ), + torch.tensor([False, True, False], device=self.device), + ) + # Both inequality and equality. + self.assertAllClose( + evaluate_feasibility( + X=X, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ), + torch.tensor([False, False, False], device=self.device), + ) + # Nonlinear inequality. + self.assertAllClose( + evaluate_feasibility( + X=X, + nonlinear_inequality_constraints=[(nlc1, True)], + ), + torch.tensor([True, False, False], device=self.device), + ) + # No constraints. + self.assertAllClose( + evaluate_feasibility( + X=X, + ), + torch.ones(3, device=self.device, dtype=torch.bool), + ) + + def test_evaluate_feasibility_inter_point(self) -> None: + # Check that inter-point constraints evaluate correctly. + X = torch.tensor( # 3 x 2 x 3 -> leads to output of shape 3. + [ + [[1.0, 1.0, 1.0], [0.0, 1.0, 3.0]], + [[1.0, 1.0, 1.0], [2.0, 1.0, 3.0]], + [[2.0, 2.0, 1.0], [2.0, 2.0, 5.0]], + ], + dtype=torch.double, + device=self.device, + ) + linear_inter_cons = ( # X[..., 0, 0] - X[..., 1, 0] >= / == 0. + torch.tensor([[0, 0], [1, 0]], device=self.device), + torch.tensor([1.0, -1.0], device=self.device), + 0, + ) + # Linear inequality. + self.assertAllClose( + evaluate_feasibility( + X=X, + inequality_constraints=[linear_inter_cons], + ), + torch.tensor([True, False, True], device=self.device), + ) + # Linear equality. + self.assertAllClose( + evaluate_feasibility( + X=X, + equality_constraints=[linear_inter_cons], + ), + torch.tensor([False, False, True], device=self.device), + ) + # Linear equality with too high of a tolerance. + self.assertAllClose( + evaluate_feasibility( + X=X, + equality_constraints=[linear_inter_cons], + tolerance=100, + ), + torch.tensor([True, True, True], device=self.device), + ) + + # Nonlinear inequality. + def nlc1(x): # X.sum(over q & d) >= 10.0 + return x.sum() - 10.0 + + self.assertEqual( + evaluate_feasibility( + X=X, + nonlinear_inequality_constraints=[(nlc1, False)], + ).tolist(), + [False, False, True], + ) + # All together. + self.assertEqual( + evaluate_feasibility( + X=X, + inequality_constraints=[linear_inter_cons], + equality_constraints=[linear_inter_cons], + nonlinear_inequality_constraints=[(nlc1, False)], + ).tolist(), + [False, False, True], + ) + class TestMakeScipyBounds(BotorchTestCase): def test_make_scipy_bounds(self):