Skip to content

Commit

Permalink
Allow additional convergence criteria to be defined
Browse files Browse the repository at this point in the history
  • Loading branch information
timothy-nunn committed May 10, 2024
1 parent ba0ab1f commit 26c45dc
Showing 1 changed file with 31 additions and 2 deletions.
33 changes: 31 additions & 2 deletions src/pyvmcon/vmcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ def solve(
qsp_options: Optional[Dict[str, Any]] = None,
initial_B: Optional[np.ndarray] = None,
callback: Optional[Callable[[int, Result, np.ndarray, float], None]] = None,
additional_convergence: Optional[
Callable[[Result, np.ndarray, np.ndarray, np.ndarray, np.ndarray], None]
] = None,
overwrite_convergence_criteria: bool = False,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Result]:
"""The main solving loop of the VMCON non-linear constrained optimiser.
Expand Down Expand Up @@ -58,12 +62,26 @@ def solve(
Initial estimate of the Hessian matrix `B`. If `None`, `B` is the
identity matrix of shape `(n, n)`.
callback : Optional[Callable[[int, np.ndarray, Result], None]]
callback : Optional[Callable[[int, ndarray, Result], None]]
A callable which takes the current iteration, the `Result` of the
current design point, current design point, and the convergence parameter
as arguments and returns `None`. This callable is called each iteration
after the QSP is solved but before the convergence test.
additional_convergence : Optional[Callable[[Result, ndarray, ndarray, ndarray, ndarray], None]]
A callabale which takes: the `Result` of the current design point,
the current design point, the proposed search direction for the
next design point, the equality Lagrange multipliers, and the
inequality Lagrange multipliers. The callable returns a boolean
indicating whether VMCON should be allowed to converge. Note that
the original VMCON convergence criteria being `False` will stop
convergence even if this callable returns `True` unless we
`overwrite_convergence_criteria`.
overwrite_convergence_criteria : bool
Ignore original VMCON convergence criteria and only
evaluate convergence using `additional_convergence`.
Returns
-------
x : ndarray
Expand Down Expand Up @@ -95,6 +113,12 @@ def solve(
)
raise ValueError(msg)

if overwrite_convergence_criteria and additional_convergence is None:
raise ValueError(
"Cannot overwrite convergence criteria without "
"providing an 'additional_convergence' callable."
)

# n is denoted in the VMCON paper
# as the number of inputs the function
# and the constraints take
Expand All @@ -105,6 +129,9 @@ def solve(
B = np.identity(n) if initial_B is None else initial_B

callback = callback or (lambda _i, _result, _x, _con: None)
additional_convergence = additional_convergence or (
lambda _result, _x, _delta, _lambda_eq, _lambda_in: True
)

# These two values being None allows the line
# search to realise that it is the first iteration
Expand Down Expand Up @@ -142,7 +169,9 @@ def solve(

callback(i, result, x, convergence_info)

if convergence_info < epsilon:
if additional_convergence and (
overwrite_convergence_criteria or convergence_info < epsilon
):
break

# perform a linesearch along the search direction
Expand Down

0 comments on commit 26c45dc

Please sign in to comment.