Skip to content

Commit

Permalink
Add a checkpoint for the state variables
Browse files Browse the repository at this point in the history
  • Loading branch information
sblauth committed Mar 1, 2024
1 parent 99e7ac9 commit 2f1c577
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
1 change: 1 addition & 0 deletions cashocs/_optimization/line_search/armijo_line_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,5 +199,6 @@ def _compute_objective_at_new_iterate(self, current_function_value: float) -> fl
raise error
else:
objective_step = 2.0 * abs(current_function_value)
self.state_problem.revert_to_checkpoint()

return objective_step
24 changes: 24 additions & 0 deletions cashocs/_pde_problems/state_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(

self.bcs_list: List[List[fenics.DirichletBC]] = self.state_form_handler.bcs_list
self.states = self.db.function_db.states
self.states_checkpoint = [fun.copy(True) for fun in self.states]

self.picard_rtol = self.config.getfloat("StateSystem", "picard_rtol")
self.picard_atol = self.config.getfloat("StateSystem", "picard_atol")
Expand Down Expand Up @@ -127,6 +128,7 @@ def solve(self) -> List[fenics.Function]:
"""
if not self.has_solution:
self.db.callback.call_pre()
self._generate_checkpoint()
if (
not self.config.getboolean("StateSystem", "picard_iteration")
or self.db.parameter_db.state_dim == 1
Expand Down Expand Up @@ -198,3 +200,25 @@ def solve(self) -> List[fenics.Function]:
self._update_cost_functionals()

return self.states

def _generate_checkpoint(self) -> None:
"""Generates a checkpoint of the state variables."""
for i in range(len(self.states)):
self.states_checkpoint[i].vector().vec().aypx(
0.0, self.states[i].vector().vec()
)
self.states_checkpoint[i].vector().apply("")

def revert_to_checkpoint(self) -> None:
"""Reverts the state variables to a checkpointed value.
This is useful when the solution of the state problem fails and another attempt
is made to solve it. Then, the perturbed solution of Newton's method should not
be the initial guess.
"""
for i in range(len(self.states)):
self.states[i].vector().vec().aypx(
0.0, self.states_checkpoint[i].vector().vec()
)
self.states[i].vector().apply("")

0 comments on commit 2f1c577

Please sign in to comment.