From cf5e9398fd74f50976d527702b2199987da80d61 Mon Sep 17 00:00:00 2001 From: james <81617086+je-cook@users.noreply.github.com> Date: Fri, 17 Nov 2023 11:52:03 +0000 Subject: [PATCH 1/7] =?UTF-8?q?=F0=9F=8F=B7=EF=B8=8F=20Typing?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pyvmcon/vmcon.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/pyvmcon/vmcon.py b/src/pyvmcon/vmcon.py index 3a3df85..1b4d049 100644 --- a/src/pyvmcon/vmcon.py +++ b/src/pyvmcon/vmcon.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, Dict, Any, Callable +from typing import List, Tuple, Union, Optional, Dict, Any, Callable import logging import numpy as np import cvxpy as cp @@ -9,7 +9,7 @@ _QspSolveException, QSPSolverException, ) -from .problem import AbstractProblem, Result +from .problem import AbstractProblem, Result, T logger = logging.getLogger(__name__) @@ -25,7 +25,7 @@ def solve( qsp_options: Optional[Dict[str, Any]] = None, initial_B: Optional[np.ndarray] = None, callback: Optional[Callable[[int, np.ndarray, Result], None]] = None, -): +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Result]: """The main solving loop of the VMCON non-linear constrained optimiser. Parameters @@ -200,7 +200,7 @@ def solve_qsp( lbs: Optional[np.ndarray], ubs: Optional[np.ndarray], options: Dict[str, Any], -): +) -> Tuple[np.ndarray, ...]: """Solves the quadratic programming problem detailed in equation 4 and 5 of the VMCON paper. @@ -318,7 +318,7 @@ def convergence_value( return abs_df_dot_delta + abs_equality_err + abs_inequality_err -def _calculate_mu_i(mu_im1: Union[np.ndarray, None], lamda: np.ndarray): +def _calculate_mu_i(mu_im1: Union[np.ndarray, None], lamda: np.ndarray) -> np.ndarray: if mu_im1 is None: return np.abs(lamda) @@ -335,7 +335,7 @@ def perform_linesearch( lamda_inequality: np.ndarray, delta: np.ndarray, x_jm1: np.ndarray, -): +) -> Tuple[float, np.ndarray, np.ndarray, Result]: """Performs the line search on equation 6 (to minimise phi). Parameters @@ -355,7 +355,7 @@ def perform_linesearch( mu_equality = _calculate_mu_i(mu_equality, lamda_equality) mu_inequality = _calculate_mu_i(mu_inequality, lamda_inequality) - def phi(result: Result): + def phi(result: Result) -> T: sum_equality = (mu_equality * np.abs(result.eq)).sum() sum_inequality = ( mu_inequality * np.abs(np.array([min(0, c) for c in result.ie])) @@ -413,7 +413,7 @@ def _derivative_lagrangian( return result.df - c_equality_prime - c_inequality_prime -def _powells_gamma(gamma: np.ndarray, ksi: np.ndarray, B: np.ndarray): +def _powells_gamma(gamma: np.ndarray, ksi: np.ndarray, B: np.ndarray) -> np.ndarray: ksiTBksi = ksi.T @ B @ ksi # used throughout eqn 10 ksiTgamma = ksi.T @ gamma # dito, to reduce amount of matmul @@ -432,7 +432,7 @@ def calculate_new_B( x_j: np.ndarray, lamda_equality: np.ndarray, lamda_inequality: np.ndarray, -): +) -> np.ndarray: # xi (the symbol name) would be a bit confusing in this context, # ksi is how its pronounced in modern greek # reshape ksi to be a matrix @@ -469,7 +469,7 @@ def calculate_new_B( return B -def _find_out_of_bounds_vars(higher: np.ndarray, lower: np.ndarray): +def _find_out_of_bounds_vars(higher: np.ndarray, lower: np.ndarray) -> List[str]: """Return the indices of the out of bounds variables""" out_of_bounds = [] From 83eab870548e311eedf03cb7042ab5de25f860cd Mon Sep 17 00:00:00 2001 From: james <81617086+je-cook@users.noreply.github.com> Date: Fri, 17 Nov 2023 11:55:27 +0000 Subject: [PATCH 2/7] =?UTF-8?q?=F0=9F=9A=A8=20ruff=20autofixes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pyvmcon/__init__.py | 4 ++-- src/pyvmcon/exceptions.py | 10 +++------- src/pyvmcon/problem.py | 5 ++--- src/pyvmcon/vmcon.py | 10 +++++----- tests/test_vmcon_paper.py | 4 ++-- 5 files changed, 14 insertions(+), 19 deletions(-) diff --git a/src/pyvmcon/__init__.py b/src/pyvmcon/__init__.py index 73b31ba..bc4b2d3 100644 --- a/src/pyvmcon/__init__.py +++ b/src/pyvmcon/__init__.py @@ -1,10 +1,10 @@ -from .vmcon import solve from .exceptions import ( - VMCONConvergenceException, LineSearchConvergenceException, QSPSolverException, + VMCONConvergenceException, ) from .problem import AbstractProblem, Problem, Result +from .vmcon import solve __all__ = [ "solve", diff --git a/src/pyvmcon/exceptions.py b/src/pyvmcon/exceptions.py index 116ab69..1980cc6 100644 --- a/src/pyvmcon/exceptions.py +++ b/src/pyvmcon/exceptions.py @@ -1,4 +1,5 @@ from typing import Optional + import numpy as np from .problem import Result @@ -7,7 +8,8 @@ class VMCONConvergenceException(Exception): """Base class for an exception that indicates VMCON has failed to converge. This exception allows certain diagnostics - to be passed and propogated with the exception.""" + to be passed and propogated with the exception. + """ def __init__( self, @@ -51,20 +53,14 @@ class _QspSolveException(Exception): to identify that the QSP has failed to solve. """ - pass - class QSPSolverException(VMCONConvergenceException): """Indicates VMCON failed to solve because the QSP Solver was unable to solve. """ - pass - class LineSearchConvergenceException(VMCONConvergenceException): """Indicates the line search portion of VMCON was unable to solve within a pre-defined number of iterations """ - - pass diff --git a/src/pyvmcon/problem.py b/src/pyvmcon/problem.py index 67023fa..e2243f7 100644 --- a/src/pyvmcon/problem.py +++ b/src/pyvmcon/problem.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod -from typing import NamedTuple, TypeVar, Callable, List +from typing import Callable, List, NamedTuple, TypeVar + import numpy as np T = TypeVar("T", np.ndarray, np.number, float) @@ -45,7 +46,6 @@ def __call__(self, x: np.ndarray) -> Result: @abstractmethod def num_equality(self) -> int: """Returns the number of equality constraints this problem has""" - pass @property def has_equality(self) -> bool: @@ -59,7 +59,6 @@ def has_inequality(self) -> bool: @abstractmethod def num_inequality(self) -> int: """Returns the number of inequality constraints this problem has""" - pass @property def total_constraints(self) -> int: diff --git a/src/pyvmcon/vmcon.py b/src/pyvmcon/vmcon.py index 1b4d049..8a84624 100644 --- a/src/pyvmcon/vmcon.py +++ b/src/pyvmcon/vmcon.py @@ -1,13 +1,14 @@ -from typing import List, Tuple, Union, Optional, Dict, Any, Callable import logging -import numpy as np +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + import cvxpy as cp +import numpy as np from .exceptions import ( - VMCONConvergenceException, LineSearchConvergenceException, - _QspSolveException, QSPSolverException, + VMCONConvergenceException, + _QspSolveException, ) from .problem import AbstractProblem, Result, T @@ -77,7 +78,6 @@ def solve( result : Result The result from running the solution vector through the problem. """ - if len(x.shape) != 1: raise ValueError("Input vector `x` is not a 1D array") diff --git a/tests/test_vmcon_paper.py b/tests/test_vmcon_paper.py index e3cf02d..ec4115b 100644 --- a/tests/test_vmcon_paper.py +++ b/tests/test_vmcon_paper.py @@ -1,7 +1,7 @@ -import pytest from typing import NamedTuple -import numpy as np +import numpy as np +import pytest from pyvmcon import solve from pyvmcon.exceptions import VMCONConvergenceException from pyvmcon.problem import Problem From e728ba46c656fc27c4f6a7318c9be395fb889e69 Mon Sep 17 00:00:00 2001 From: james <81617086+je-cook@users.noreply.github.com> Date: Fri, 17 Nov 2023 12:04:25 +0000 Subject: [PATCH 3/7] =?UTF-8?q?=E2=9A=A1=20Use=20numpy=20for=20loops?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pyvmcon/vmcon.py | 47 +++++++++++++++----------------------------- 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/src/pyvmcon/vmcon.py b/src/pyvmcon/vmcon.py index 8a84624..3e0a320 100644 --- a/src/pyvmcon/vmcon.py +++ b/src/pyvmcon/vmcon.py @@ -307,13 +307,11 @@ def convergence_value( The Lagrange multipliers for inequality constraints for the jth evaluation point. """ + ind_eq = min(lamda_equality.shape[0], result.eq.shape[0]) + ind_ieq = min(lamda_inequality.shape[0], result.ie.shape[0]) abs_df_dot_delta = abs(np.dot(result.df, delta_j)) - abs_equality_err = np.sum( - [abs(lamda * c) for lamda, c in zip(lamda_equality, result.eq)] - ) - abs_inequality_err = np.sum( - [abs(lamda * c) for lamda, c in zip(lamda_inequality, result.ie)] - ) + abs_equality_err = abs(np.sum(lamda_equality[:ind_eq] * result.eq[:ind_eq])) + abs_inequality_err = abs(np.sum(lamda_inequality[:ind_ieq] * result.ie[:ind_ieq])) return abs_df_dot_delta + abs_equality_err + abs_inequality_err @@ -357,9 +355,7 @@ def perform_linesearch( def phi(result: Result) -> T: sum_equality = (mu_equality * np.abs(result.eq)).sum() - sum_inequality = ( - mu_inequality * np.abs(np.array([min(0, c) for c in result.ie])) - ).sum() + sum_inequality = (mu_inequality * np.abs(np.minimum(0, result.ie))).sum() return result.f + sum_equality + sum_inequality @@ -402,13 +398,11 @@ def _derivative_lagrangian( result: Result, lamda_equality: np.ndarray, lamda_inequality: np.ndarray, -): - c_equality_prime = sum( - [lamda * dc for lamda, dc in zip(lamda_equality, result.deq)] - ) - c_inequality_prime = sum( - [lamda * dc for lamda, dc in zip(lamda_inequality, result.die)] - ) +) -> np.ndarray: + ind_eq = min(lamda_equality.shape[0], result.deq.shape[0]) + ind_ieq = min(lamda_inequality.shape[0], result.die.shape[0]) + c_equality_prime = (lamda_equality[:ind_eq] * result.deq[:ind_eq]).sum(axis=0) + c_inequality_prime = (lamda_inequality[:ind_ieq] * result.die[:ind_ieq]).sum(axis=0) return result.df - c_equality_prime - c_inequality_prime @@ -448,9 +442,7 @@ def calculate_new_B( lamda_equality, lamda_inequality, ) - gamma = (g1 - g2).reshape((x_j.shape[0], 1)) - - gamma = _powells_gamma(gamma, ksi, B) + gamma = _powells_gamma((g1 - g2).reshape((x_j.shape[0], 1)), ksi, B) if (gamma == 0).all(): logger.warning("All gamma components are 0") @@ -460,21 +452,14 @@ def calculate_new_B( logger.warning("All xi (ksi) components are 0") ksi[:] = 1e-10 - B = ( - B - - ((B @ ksi @ ksi.T @ B) / (ksi.T @ B @ ksi)) - + ((gamma @ gamma.T) / (ksi.T @ gamma)) - ) # eqn 8 + # eqn 8 + B += (gamma @ gamma.T) / (ksi.T @ gamma) - ( + (B @ ksi @ ksi.T @ B) / (ksi.T @ B @ ksi) + ) return B def _find_out_of_bounds_vars(higher: np.ndarray, lower: np.ndarray) -> List[str]: """Return the indices of the out of bounds variables""" - - out_of_bounds = [] - for i, boolean in enumerate((higher - lower) < 0): - if boolean: - out_of_bounds.append(str(i)) - - return out_of_bounds + return np.nonzero((higher - lower) < 0)[0].astype(str).tolist() From ef1b2616e24c62d5c4c037851dcba0e859947bbe Mon Sep 17 00:00:00 2001 From: james <81617086+je-cook@users.noreply.github.com> Date: Fri, 17 Nov 2023 12:07:27 +0000 Subject: [PATCH 4/7] =?UTF-8?q?=F0=9F=9A=A8=20ruff=20SIM108?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/pyvmcon/vmcon.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/pyvmcon/vmcon.py b/src/pyvmcon/vmcon.py index 3e0a320..e55f8df 100644 --- a/src/pyvmcon/vmcon.py +++ b/src/pyvmcon/vmcon.py @@ -102,10 +102,7 @@ def solve( # The paper uses the B matrix as the # running approximation of the Hessian - if initial_B is None: - B = np.identity(n) - else: - B = initial_B + B = np.identity(n) if initial_B is None else initial_B callback = callback or (lambda _i, _result, _x, _con: None) From ddd81b231ccce0113605dba1511a19bf788c7f5a Mon Sep 17 00:00:00 2001 From: james <81617086+je-cook@users.noreply.github.com> Date: Fri, 17 Nov 2023 12:51:24 +0000 Subject: [PATCH 5/7] reduce number of operations --- src/pyvmcon/vmcon.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/pyvmcon/vmcon.py b/src/pyvmcon/vmcon.py index e55f8df..8c9c5e9 100644 --- a/src/pyvmcon/vmcon.py +++ b/src/pyvmcon/vmcon.py @@ -450,9 +450,8 @@ def calculate_new_B( ksi[:] = 1e-10 # eqn 8 - B += (gamma @ gamma.T) / (ksi.T @ gamma) - ( - (B @ ksi @ ksi.T @ B) / (ksi.T @ B @ ksi) - ) + B_ksi = B @ ksi + B += (gamma @ gamma.T) / (ksi.T @ gamma) - ((B_ksi @ ksi.T @ B) / (ksi.T @ B_ksi)) return B From ece912828a87196463f2741ca748818fb8972763 Mon Sep 17 00:00:00 2001 From: james <81617086+je-cook@users.noreply.github.com> Date: Fri, 17 Nov 2023 13:02:55 +0000 Subject: [PATCH 6/7] broadcasting fix --- src/pyvmcon/vmcon.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pyvmcon/vmcon.py b/src/pyvmcon/vmcon.py index 8c9c5e9..1687a32 100644 --- a/src/pyvmcon/vmcon.py +++ b/src/pyvmcon/vmcon.py @@ -398,8 +398,12 @@ def _derivative_lagrangian( ) -> np.ndarray: ind_eq = min(lamda_equality.shape[0], result.deq.shape[0]) ind_ieq = min(lamda_inequality.shape[0], result.die.shape[0]) - c_equality_prime = (lamda_equality[:ind_eq] * result.deq[:ind_eq]).sum(axis=0) - c_inequality_prime = (lamda_inequality[:ind_ieq] * result.die[:ind_ieq]).sum(axis=0) + c_equality_prime = (lamda_equality[:ind_eq, None] * result.deq[:ind_eq]).sum( + axis=None if ind_eq == 0 else 0 + ) + c_inequality_prime = (lamda_inequality[:ind_ieq, None] * result.die[:ind_ieq]).sum( + axis=None if ind_ieq == 0 else 0 + ) return result.df - c_equality_prime - c_inequality_prime From 9e9ccb91226372f318de38804ec15994758417dc Mon Sep 17 00:00:00 2001 From: james <81617086+je-cook@users.noreply.github.com> Date: Fri, 17 Nov 2023 14:10:09 +0000 Subject: [PATCH 7/7] spelling --- src/pyvmcon/exceptions.py | 2 +- src/pyvmcon/vmcon.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pyvmcon/exceptions.py b/src/pyvmcon/exceptions.py index 1980cc6..6ff44e1 100644 --- a/src/pyvmcon/exceptions.py +++ b/src/pyvmcon/exceptions.py @@ -8,7 +8,7 @@ class VMCONConvergenceException(Exception): """Base class for an exception that indicates VMCON has failed to converge. This exception allows certain diagnostics - to be passed and propogated with the exception. + to be passed and propagated with the exception. """ def __init__( diff --git a/src/pyvmcon/vmcon.py b/src/pyvmcon/vmcon.py index 1687a32..99066f3 100644 --- a/src/pyvmcon/vmcon.py +++ b/src/pyvmcon/vmcon.py @@ -160,7 +160,7 @@ def solve( # use alpha found during the linesearch to find xj. # Notice that the revision of matrix B needs the x^(j-1) - # so our running x is not overriden yet! + # so our running x is not overridden yet! xj = x + alpha * delta # Revise matrix B @@ -226,7 +226,7 @@ def solve_qsp( options : Dict[str, Any] Dictionary of keyword arguments that are passed to the - CVXPY `Probelem.solve` method. + CVXPY `Problem.solve` method. Notes ----- @@ -381,7 +381,7 @@ def phi(result: Result) -> T: else: raise LineSearchConvergenceException( - "Line search did not converge on an approimate minima", + "Line search did not converge on an approximate minima", x=x_jm1, result=result, lamda_equality=lamda_equality,