diff --git a/cashocs/_optimization/line_search/armijo_line_search.py b/cashocs/_optimization/line_search/armijo_line_search.py index f4a03fcf..adf8b88a 100644 --- a/cashocs/_optimization/line_search/armijo_line_search.py +++ b/cashocs/_optimization/line_search/armijo_line_search.py @@ -199,5 +199,6 @@ def _compute_objective_at_new_iterate(self, current_function_value: float) -> fl raise error else: objective_step = 2.0 * abs(current_function_value) + self.state_problem.revert_to_checkpoint() return objective_step diff --git a/cashocs/_pde_problems/state_problem.py b/cashocs/_pde_problems/state_problem.py index 19c28e98..189a6f6b 100755 --- a/cashocs/_pde_problems/state_problem.py +++ b/cashocs/_pde_problems/state_problem.py @@ -73,6 +73,7 @@ def __init__( self.bcs_list: List[List[fenics.DirichletBC]] = self.state_form_handler.bcs_list self.states = self.db.function_db.states + self.states_checkpoint = [fun.copy(True) for fun in self.states] self.picard_rtol = self.config.getfloat("StateSystem", "picard_rtol") self.picard_atol = self.config.getfloat("StateSystem", "picard_atol") @@ -127,6 +128,7 @@ def solve(self) -> List[fenics.Function]: """ if not self.has_solution: self.db.callback.call_pre() + self._generate_checkpoint() if ( not self.config.getboolean("StateSystem", "picard_iteration") or self.db.parameter_db.state_dim == 1 @@ -198,3 +200,25 @@ def solve(self) -> List[fenics.Function]: self._update_cost_functionals() return self.states + + def _generate_checkpoint(self) -> None: + """Generates a checkpoint of the state variables.""" + for i in range(len(self.states)): + self.states_checkpoint[i].vector().vec().aypx( + 0.0, self.states[i].vector().vec() + ) + self.states_checkpoint[i].vector().apply("") + + def revert_to_checkpoint(self) -> None: + """Reverts the state variables to a checkpointed value. + + This is useful when the solution of the state problem fails and another attempt + is made to solve it. Then, the perturbed solution of Newton's method should not + be the initial guess. + + """ + for i in range(len(self.states)): + self.states[i].vector().vec().aypx( + 0.0, self.states_checkpoint[i].vector().vec() + ) + self.states[i].vector().apply("")