From 0dcbc1eda07f84a5204c82f1c34a33a4c62d1483 Mon Sep 17 00:00:00 2001 From: Timothy <75321887+timothy-nunn@users.noreply.github.com> Date: Tue, 15 Aug 2023 09:53:39 +0100 Subject: [PATCH] Remove QSP tolerance in favour of generic options dict (#7) --- src/pyvmcon/vmcon.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/src/pyvmcon/vmcon.py b/src/pyvmcon/vmcon.py index c1d3728..d16b1f6 100644 --- a/src/pyvmcon/vmcon.py +++ b/src/pyvmcon/vmcon.py @@ -1,4 +1,4 @@ -from typing import Union, Optional +from typing import Union, Optional, Dict, Any import logging import numpy as np import cvxpy as cp @@ -22,7 +22,7 @@ def solve( *, max_iter: int = 10, epsilon: float = 1e-8, - qsp_tolerence: float = 1e-4, + qsp_options: Optional[Dict[str, Any]] = None, initial_B: Optional[np.ndarray] = None, ): """The main solving loop of the VMCON non-linear constrained optimiser. @@ -47,6 +47,11 @@ def solve( epsilon : float The tolerance used to test if VMCON has converged + qsp_options : Optional[Dict[str, Any]] + Dictionary of keyword arguments that are passed to the + CVXPY `Probelem.solve` method. `None` will pass no + additional arguments to the solver. + initial_B : ndarray Initial estimate of the Hessian matrix `B`. If `None`, `B` is the identity matrix of shape `(max(n,m), max(n,m))`. @@ -115,7 +120,7 @@ def solve( # for our constraints try: delta, lamda_equality, lamda_inequality = solve_qsp( - problem, result, x, B, lbs, ubs, qsp_tolerence + problem, result, x, B, lbs, ubs, qsp_options or {} ) except _QspSolveException as e: raise QSPSolverException( @@ -188,7 +193,7 @@ def solve_qsp( B: np.ndarray, lbs: Optional[np.ndarray], ubs: Optional[np.ndarray], - tolerance: float, + options: Dict[str, Any], ): """Solves the quadratic programming problem detailed in equation 4 and 5 of the VMCON paper. @@ -216,10 +221,15 @@ def solve_qsp( ubs : ndarray The upper bounds of `x`. - tolerance : float - The relative tolerance of the QSP solver. - See https://www.cvxpy.org/tutorial/advanced/index.html#setting-solver-options - `eps_rel`. + options : Dict[str, Any] + Dictionary of keyword arguments that are passed to the + CVXPY `Probelem.solve` method. + + Notes + ----- + * By default, OSQP (https://osqp.org/) is the `solver` used in + the `solve` method however this can be changed by specifying a + different `solver` in the `options` dictionary. """ delta = cp.Variable(x.shape) problem_statement = cp.Minimize( @@ -244,7 +254,7 @@ def solve_qsp( constraints.append((result.deq @ delta) + result.eq == 0) qsp = cp.Problem(problem_statement, constraints or None) - qsp.solve(verbose=False, solver=cp.OSQP, eps_rel=tolerance) + qsp.solve(**{"solver": cp.OSQP, **options}) if delta.value is None: raise _QspSolveException(f"QSP failed to solve: {qsp.status}")