Skip to content

Commit

Permalink
Add a callback function that runs each iteration of the solver
Browse files Browse the repository at this point in the history
  • Loading branch information
timothy-nunn committed Aug 16, 2023
1 parent 02e4eae commit f6bb573
Showing 1 changed file with 20 additions and 15 deletions.
35 changes: 20 additions & 15 deletions src/pyvmcon/vmcon.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Optional, Dict, Any
from typing import Union, Optional, Dict, Any, Callable
import logging
import numpy as np
import cvxpy as cp
Expand All @@ -24,6 +24,7 @@ def solve(
epsilon: float = 1e-8,
qsp_options: Optional[Dict[str, Any]] = None,
initial_B: Optional[np.ndarray] = None,
callback: Optional[Callable[[int, np.ndarray, Result], None]],
):
"""The main solving loop of the VMCON non-linear constrained optimiser.
Expand Down Expand Up @@ -56,6 +57,12 @@ def solve(
Initial estimate of the Hessian matrix `B`. If `None`, `B` is the
identity matrix of shape `(max(n,m), max(n,m))`.
callback : Optional[Callable[[int, np.ndarray, Result], None]]
A callable which takes the current iteration, current design point,
the `Result` of the current design point, and the convergence parameter
as arguments and returns `None`. This callable is called each iteration
after the QSP is solved but before the convergence test.
Returns
-------
x : ndarray
Expand Down Expand Up @@ -103,6 +110,8 @@ def solve(
else:
B = initial_B

callback = callback or (lambda _i, _result, _x, _con: None)

# These two values being None allows the line
# search to realise that it is the first iteration
mu_equality = None
Expand Down Expand Up @@ -133,13 +142,13 @@ def solve(

# Exit to optimisation loop if the convergence
# criteria is met
if convergence_test(
result,
delta,
lamda_equality,
lamda_inequality,
epsilon,
):
convergence_info = convergence_value(
result, delta, lamda_equality, lamda_inequality
)

callback(i, result, x, convergence_info)

if convergence_info < epsilon:
break

# perform a linesearch along the search direction
Expand Down Expand Up @@ -275,13 +284,12 @@ def solve_qsp(
return delta.value, lamda_equality, lamda_inequality


def convergence_test(
def convergence_value(
result: Result,
delta_j: np.ndarray,
lamda_equality: np.ndarray,
lamda_inequality: np.ndarray,
epsilon: float,
) -> bool:
) -> float:
"""Test if the convergence criteria of VMCON have been met.
Equation 11 of the VMCON paper. Note this tests convergence at the
point (j-1)th evaluation point.
Expand All @@ -301,9 +309,6 @@ def convergence_test(
lambda_inequality : ndarray
The Lagrange multipliers for inequality constraints for the jth
evaluation point.
epsilon : float
The user-supplied error tolerance.
"""
abs_df_dot_delta = abs(np.dot(result.df, delta_j))
abs_equality_err = np.sum(
Expand All @@ -313,7 +318,7 @@ def convergence_test(
[abs(lamda * c) for lamda, c in zip(lamda_inequality, result.ie)]
)

return (abs_df_dot_delta + abs_equality_err + abs_inequality_err) < epsilon
return abs_df_dot_delta + abs_equality_err + abs_inequality_err


def _calculate_mu_i(mu_im1: Union[np.ndarray, None], lamda: np.ndarray):
Expand Down

0 comments on commit f6bb573

Please sign in to comment.