From 95291655c32e5c673f45236b3fe7e68053a5b36b Mon Sep 17 00:00:00 2001 From: Timothy <75321887+timothy-nunn@users.noreply.github.com> Date: Wed, 16 Aug 2023 15:02:19 +0100 Subject: [PATCH] Add callback option (#10) * Add a callback function that runs each iteration of the solver * Add a callback function that runs each iteration of the solver --- src/pyvmcon/vmcon.py | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/pyvmcon/vmcon.py b/src/pyvmcon/vmcon.py index d16b1f6..ac0ea2a 100644 --- a/src/pyvmcon/vmcon.py +++ b/src/pyvmcon/vmcon.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, Dict, Any +from typing import Union, Optional, Dict, Any, Callable import logging import numpy as np import cvxpy as cp @@ -24,6 +24,7 @@ def solve( epsilon: float = 1e-8, qsp_options: Optional[Dict[str, Any]] = None, initial_B: Optional[np.ndarray] = None, + callback: Optional[Callable[[int, np.ndarray, Result], None]] = None, ): """The main solving loop of the VMCON non-linear constrained optimiser. @@ -56,6 +57,12 @@ def solve( Initial estimate of the Hessian matrix `B`. If `None`, `B` is the identity matrix of shape `(max(n,m), max(n,m))`. + callback : Optional[Callable[[int, np.ndarray, Result], None]] + A callable which takes the current iteration, current design point, + the `Result` of the current design point, and the convergence parameter + as arguments and returns `None`. This callable is called each iteration + after the QSP is solved but before the convergence test. + Returns ------- x : ndarray @@ -103,6 +110,8 @@ def solve( else: B = initial_B + callback = callback or (lambda _i, _result, _x, _con: None) + # These two values being None allows the line # search to realise that it is the first iteration mu_equality = None @@ -133,13 +142,13 @@ def solve( # Exit to optimisation loop if the convergence # criteria is met - if convergence_test( - result, - delta, - lamda_equality, - lamda_inequality, - epsilon, - ): + convergence_info = convergence_value( + result, delta, lamda_equality, lamda_inequality + ) + + callback(i, result, x, convergence_info) + + if convergence_info < epsilon: break # perform a linesearch along the search direction @@ -275,13 +284,12 @@ def solve_qsp( return delta.value, lamda_equality, lamda_inequality -def convergence_test( +def convergence_value( result: Result, delta_j: np.ndarray, lamda_equality: np.ndarray, lamda_inequality: np.ndarray, - epsilon: float, -) -> bool: +) -> float: """Test if the convergence criteria of VMCON have been met. Equation 11 of the VMCON paper. Note this tests convergence at the point (j-1)th evaluation point. @@ -301,9 +309,6 @@ def convergence_test( lambda_inequality : ndarray The Lagrange multipliers for inequality constraints for the jth evaluation point. - - epsilon : float - The user-supplied error tolerance. """ abs_df_dot_delta = abs(np.dot(result.df, delta_j)) abs_equality_err = np.sum( @@ -313,7 +318,7 @@ def convergence_test( [abs(lamda * c) for lamda, c in zip(lamda_inequality, result.ie)] ) - return (abs_df_dot_delta + abs_equality_err + abs_inequality_err) < epsilon + return abs_df_dot_delta + abs_equality_err + abs_inequality_err def _calculate_mu_i(mu_im1: Union[np.ndarray, None], lamda: np.ndarray):