Skip to content

Commit

Permalink
Merge pull request #181 from sblauth/hotfix/mpi_communicators
Browse files Browse the repository at this point in the history
Correctly use MPI communicators
  • Loading branch information
sblauth authored Feb 1, 2023
2 parents ecbc56c + ea84dcf commit ed5c30e
Show file tree
Hide file tree
Showing 17 changed files with 74 additions and 22 deletions.
3 changes: 2 additions & 1 deletion cashocs/_constraints/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,9 @@ def _project_pointwise_multiplier(
lhs = trial * test * measure
rhs = project_term * test * measure

comm = self.cg_function_space.mesh().mpi_comm()
_utils.assemble_and_solve_linear(
lhs, rhs, A=A_tensor, b=b_tensor, fun=multiplier
lhs, rhs, A=A_tensor, b=b_tensor, fun=multiplier, comm=comm
)

def _update_cost_functional(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions cashocs/_database/geometry_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ def __init__(self, function_db: function_database.FunctionDatabase) -> None:
"""
self.mesh: fenics.Mesh = function_db.state_spaces[0].mesh()
self.dx: fenics.Measure = fenics.Measure("dx", self.mesh)
self.mpi_comm = self.mesh.mpi_comm()
1 change: 1 addition & 0 deletions cashocs/_forms/shape_form_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def compute(self) -> None:
A=self.A_mu_matrix,
b=self.b_mu,
ksp_options=self.options_mu,
comm=self.mesh.mpi_comm(),
)

if self.config.getboolean("ShapeGradient", "use_sqrt_mu"):
Expand Down
1 change: 1 addition & 0 deletions cashocs/_forms/shape_regularization.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ def _compute_curvature(self) -> None:
A=self.a_curvature_matrix.mat(),
b=self.b_curvature.vec(),
fun=self.kappa_curvature,
comm=self.db.geometry_db.mpi_comm,
)

def scale(self) -> None:
Expand Down
10 changes: 7 additions & 3 deletions cashocs/_pde_problems/adjoint_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,17 @@ def __init__(

# pylint: disable=invalid-name
self.A_tensors = [
fenics.PETScMatrix() for _ in range(self.db.parameter_db.state_dim)
fenics.PETScMatrix(db.geometry_db.mpi_comm)
for _ in range(self.db.parameter_db.state_dim)
]
self.b_tensors = [
fenics.PETScVector() for _ in range(self.db.parameter_db.state_dim)
fenics.PETScVector(db.geometry_db.mpi_comm)
for _ in range(self.db.parameter_db.state_dim)
]

self.res_j_tensors = [
fenics.PETScVector() for _ in range(self.db.parameter_db.state_dim)
fenics.PETScVector(db.geometry_db.mpi_comm)
for _ in range(self.db.parameter_db.state_dim)
]

self._number_of_solves = 0
Expand Down Expand Up @@ -119,6 +122,7 @@ def solve(self) -> List[fenics.Function]:
b=self.b_tensors[-1 - i],
fun=self.adjoints[-1 - i],
ksp_options=self.db.parameter_db.adjoint_ksp_options[-1 - i],
comm=self.db.geometry_db.mpi_comm,
)

else:
Expand Down
1 change: 1 addition & 0 deletions cashocs/_pde_problems/control_gradient_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def solve(self) -> List[fenics.Function]:
b=self.b_tensors[i].vec(),
fun=self.db.function_db.gradient[i],
ksp_options=self.riesz_ksp_options[i],
comm=self.db.geometry_db.mpi_comm,
)

self.has_solution = True
Expand Down
3 changes: 3 additions & 0 deletions cashocs/_pde_problems/hessian_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def hessian_application(
self.bcs_list_ad[i],
fun=self.db.function_db.states_prime[i],
ksp_options=self.db.parameter_db.state_ksp_options[i],
comm=self.db.geometry_db.mpi_comm,
)

for i in range(self.state_dim):
Expand All @@ -185,6 +186,7 @@ def hessian_application(
self.bcs_list_ad[-1 - i],
fun=self.db.function_db.adjoints_prime[-1 - i],
ksp_options=self.db.parameter_db.adjoint_ksp_options[-1 - i],
comm=self.db.geometry_db.mpi_comm,
)

else:
Expand Down Expand Up @@ -234,6 +236,7 @@ def hessian_application(
b=b,
fun=out[i],
ksp_options=self.riesz_ksp_options[i],
comm=self.db.geometry_db.mpi_comm,
)

self.no_sensitivity_solves += 2
Expand Down
1 change: 1 addition & 0 deletions cashocs/_pde_problems/shape_gradient_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def solve(self) -> List[fenics.Function]:
b=self.form_handler.fe_shape_derivative_vector.vec(),
fun=self.db.function_db.gradient[0],
ksp_options=self.ksp_options,
comm=self.db.geometry_db.mpi_comm,
)

self.has_solution = True
Expand Down
1 change: 1 addition & 0 deletions cashocs/_pde_problems/state_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def solve(self) -> List[fenics.Function]:
b=self.b_tensors[i],
fun=self.states[i],
ksp_options=self.db.parameter_db.state_ksp_options[i],
comm=self.db.geometry_db.mpi_comm,
)

else:
Expand Down
28 changes: 27 additions & 1 deletion cashocs/_utils/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from cashocs._utils import forms as forms_module

if TYPE_CHECKING:
from mpi4py import MPI

from cashocs import _typing

iterative_ksp_options: _typing.KspOption = {
Expand Down Expand Up @@ -218,13 +220,32 @@ def setup_fieldsplit_preconditioner(
pc.setFieldSplitIS(*idx_tuples)


def _initialize_comm(comm: Optional[MPI.Comm] = None) -> MPI.Comm:
"""Initializes the MPI communicator.
If the supplied communicator is `None`, return MPI.comm_world.
Args:
comm: The supplied communicator or `None`
Returns:
The resulting communicator.
"""
if comm is None:
comm = fenics.MPI.comm_world

return comm


def solve_linear_problem(
A: Optional[PETSc.Mat] = None, # pylint: disable=invalid-name
b: Optional[PETSc.Vec] = None,
fun: Optional[fenics.Function] = None,
ksp_options: Optional[_typing.KspOption] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
comm: Optional[MPI.Comm] = None,
) -> PETSc.Vec:
"""Solves a finite dimensional linear problem.
Expand All @@ -245,12 +266,14 @@ def solve_linear_problem(
atol: The absolute tolerance used in case an iterative solver is used for
solving the linear problem. Overrides the specification in the ksp object
and ksp_options.
comm: The MPI communicator for the problem.
Returns:
The solution vector.
"""
ksp = PETSc.KSP().create()
comm = _initialize_comm(comm)
ksp = PETSc.KSP().create(comm=comm)

if A is not None:
ksp.setOperators(A)
Expand Down Expand Up @@ -306,6 +329,7 @@ def assemble_and_solve_linear(
ksp_options: Optional[_typing.KspOption] = None,
rtol: Optional[float] = None,
atol: Optional[float] = None,
comm: Optional[MPI.Comm] = None,
) -> PETSc.Vec:
"""Assembles and solves a linear system.
Expand All @@ -325,6 +349,7 @@ def assemble_and_solve_linear(
atol: The absolute tolerance used in case an iterative solver is used for
solving the linear problem. Overrides the specification in the ksp object
and ksp_options.
comm: The MPI communicator for solving the problem.
Returns:
A PETSc vector containing the solution x.
Expand All @@ -341,6 +366,7 @@ def assemble_and_solve_linear(
ksp_options=ksp_options,
rtol=rtol,
atol=atol,
comm=comm,
)

return solution
Expand Down
8 changes: 6 additions & 2 deletions cashocs/geometry/boundary_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def compute_boundary_distance(
function_space = fenics.FunctionSpace(mesh, "CG", 1)
dx = measure.NamedMeasure("dx", mesh)

comm = mesh.mpi_comm()

ksp_options = copy.deepcopy(_utils.linalg.iterative_ksp_options)

u = fenics.TrialFunction(function_space)
Expand Down Expand Up @@ -95,7 +97,9 @@ def compute_boundary_distance(
lhs = fenics.dot(fenics.grad(u), fenics.grad(v)) * dx
rhs = fenics.Constant(1.0) * v * dx

_utils.assemble_and_solve_linear(lhs, rhs, bcs, fun=u_curr, ksp_options=ksp_options)
_utils.assemble_and_solve_linear(
lhs, rhs, bcs, fun=u_curr, ksp_options=ksp_options, comm=comm
)

rhs = fenics.dot(fenics.grad(u_prev) / norm_u_prev, fenics.grad(v)) * dx

Expand All @@ -114,7 +118,7 @@ def compute_boundary_distance(
u_prev.vector().vec().aypx(0.0, u_curr.vector().vec())
u_prev.vector().apply("")
_utils.assemble_and_solve_linear(
lhs, rhs, bcs, fun=u_curr, ksp_options=ksp_options
lhs, rhs, bcs, fun=u_curr, ksp_options=ksp_options, comm=comm
)
res = np.sqrt(fenics.assemble(residual_form))

Expand Down
1 change: 1 addition & 0 deletions cashocs/geometry/mesh_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def compute_decreases(
self.a_frobenius,
self.l_frobenius,
ksp_options=self.options_frobenius,
comm=self.db.geometry_db.mpi_comm,
)

frobenius_norm = x.max()[1]
Expand Down
5 changes: 2 additions & 3 deletions cashocs/geometry/mesh_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,13 @@ def test(self, transformation: fenics.Function, volume_change: float) -> bool:
A boolean that indicates whether the desired transformation is feasible.
"""
comm = self.transformation_container.function_space().mesh().mpi_comm()
self.transformation_container.vector().vec().aypx(
0.0, transformation.vector().vec()
)
self.transformation_container.vector().apply("")
x = _utils.assemble_and_solve_linear(
self.A_prior,
self.l_prior,
ksp_options=self.options_prior,
self.A_prior, self.l_prior, ksp_options=self.options_prior, comm=comm
)

min_det = float(x.min()[1])
Expand Down
16 changes: 9 additions & 7 deletions cashocs/geometry/quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray:
The element wise skewness of the mesh on process 0.
"""
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()
skewness_array = self._quality_object.skewness(mesh).array()
skewness_list: np.ndarray = comm.gather(skewness_array, root=0)
if comm.rank == 0:
Expand Down Expand Up @@ -329,7 +329,7 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray:
The maximum angle quality measure for each element on process 0.
"""
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()
maximum_angle_array = self._quality_object.maximum_angle(mesh).array()
maximum_angle_list: np.ndarray = comm.gather(maximum_angle_array, root=0)
if comm.rank == 0:
Expand Down Expand Up @@ -361,7 +361,7 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray:
The radius ratios of the mesh elements on process 0.
"""
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()
radius_ratios_array = fenics.MeshQuality.radius_ratios(mesh).array()
radius_ratios_list: np.ndarray = comm.gather(radius_ratios_array, root=0)
if comm.rank == 0:
Expand All @@ -388,7 +388,7 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray:
The condition numbers of the elements on process 0.
"""
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()
function_space_dg0 = fenics.FunctionSpace(mesh, "DG", 0)
jac = ufl.Jacobian(mesh)
inv = ufl.JacobianInverse(mesh)
Expand Down Expand Up @@ -417,7 +417,9 @@ def compute(self, mesh: fenics.Mesh) -> np.ndarray:

cond = fenics.Function(function_space_dg0)

_utils.assemble_and_solve_linear(lhs, rhs, fun=cond, ksp_options=options)
_utils.assemble_and_solve_linear(
lhs, rhs, fun=cond, ksp_options=options, comm=comm
)
cond.vector().vec().reciprocal()
cond.vector().apply("")
cond.vector().vec().scale(np.sqrt(mesh.geometric_dimension()))
Expand Down Expand Up @@ -454,7 +456,7 @@ def min(cls, calculator: MeshQualityCalculator, mesh: fenics.Mesh) -> float:
"""
quality_list = calculator.compute(mesh)
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()
if comm.rank == 0:
qual = float(np.min(quality_list))
else:
Expand All @@ -477,7 +479,7 @@ def avg(cls, calculator: MeshQualityCalculator, mesh: fenics.Mesh) -> float:
"""
quality_list = calculator.compute(mesh)
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()

if comm.rank == 0:
qual = float(np.average(quality_list))
Expand Down
2 changes: 1 addition & 1 deletion cashocs/io/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def gather_coordinates(mesh: fenics.Mesh) -> np.ndarray:
A numpy array which contains the vertex coordinates of the mesh
"""
comm = fenics.MPI.comm_world
comm = mesh.mpi_comm()
rank = comm.Get_rank()
top = mesh.topology()
global_vertex_indices = top.global_indices(0)
Expand Down
10 changes: 7 additions & 3 deletions cashocs/nonlinear_solvers/newton_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,11 @@ def __init__(
self.derivative, -self.nonlinear_form, self.bcs_hom
)
self.assembler.keep_diagonal = True

self.comm = self.u.function_space().mesh().mpi_comm()
# pylint: disable=invalid-name
self.A_fenics = self.A_tensor or fenics.PETScMatrix()
self.residual = self.b_tensor or fenics.PETScVector()
self.A_fenics = self.A_tensor or fenics.PETScMatrix(self.comm)
self.residual = self.b_tensor or fenics.PETScVector(self.comm)
self.b = fenics.as_backend_type(self.residual).vec()
self.A_matrix = fenics.as_backend_type(self.A_fenics).mat()

Expand All @@ -160,7 +162,7 @@ def __init__(
self.assembler_shift = fenics.SystemAssembler(
self.derivative, self.shift, self.bcs_hom
)
self.residual_shift = fenics.PETScVector()
self.residual_shift = fenics.PETScVector(self.comm)

self.breakdown = False
self.res = 1.0
Expand Down Expand Up @@ -266,6 +268,7 @@ def solve(self) -> fenics.Function:
ksp_options=self.ksp_options,
rtol=self.eta,
atol=self.atol / 10.0,
comm=self.comm,
)

if self.is_linear:
Expand Down Expand Up @@ -370,6 +373,7 @@ def _backtracking_line_search(self) -> None:
ksp_options=self.ksp_options,
rtol=self.eta,
atol=self.atol / 10.0,
comm=self.comm,
)

if (
Expand Down
4 changes: 3 additions & 1 deletion cashocs/nonlinear_solvers/picard_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,11 @@ def picard_iteration(
bcs_list = _utils.check_and_enlist_bcs(bcs_list)
bcs_list_hom = _create_homogenized_bcs(bcs_list)

comm = u_list[0].function_space().mesh().mpi_comm()

prefix = "Picard iteration: "

res_tensor = [fenics.PETScVector() for _ in range(len(u_list))]
res_tensor = [fenics.PETScVector(comm) for _ in u_list]
eta_max = 0.9
gamma = 0.9
res_0 = 1.0
Expand Down

0 comments on commit ed5c30e

Please sign in to comment.