Skip to content

Commit

Permalink
update fmm interface for sumpy
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Sep 18, 2022
1 parent 8c04036 commit 31f92d3
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 73 deletions.
35 changes: 26 additions & 9 deletions boxtree/constant_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
"""

import numpy as np

from boxtree.array_context import PyOpenCLArrayContext
from boxtree.fmm import TreeIndependentDataForWrangler, ExpansionWranglerInterface
from boxtree.timing import DummyTimingFuture

Expand Down Expand Up @@ -83,7 +85,9 @@ def local_expansions_view(self, local_exps, level):
def timing_future(ops):
return DummyTimingFuture.from_op_count(ops)

def form_multipoles(self, level_start_source_box_nrs, source_boxes,
def form_multipoles(self, actx: PyOpenCLArrayContext,
level_start_source_box_nrs,
source_boxes,
src_weight_vecs):
src_weights, = src_weight_vecs
mpoles = self.multipole_expansion_zeros()
Expand All @@ -96,8 +100,10 @@ def form_multipoles(self, level_start_source_box_nrs, source_boxes,

return mpoles, self.timing_future(ops)

def coarsen_multipoles(self, level_start_source_parent_box_nrs,
source_parent_boxes, mpoles):
def coarsen_multipoles(self, actx: PyOpenCLArrayContext,
level_start_source_parent_box_nrs,
source_parent_boxes,
mpoles):
tree = self.tree
ops = 0

Expand All @@ -119,7 +125,8 @@ def coarsen_multipoles(self, level_start_source_parent_box_nrs,

return mpoles, self.timing_future(ops)

def eval_direct(self, target_boxes, neighbor_sources_starts,
def eval_direct(self, actx: PyOpenCLArrayContext,
target_boxes, neighbor_sources_starts,
neighbor_sources_lists, src_weight_vecs):
src_weights, = src_weight_vecs
pot = self.output_zeros()
Expand All @@ -144,6 +151,7 @@ def eval_direct(self, target_boxes, neighbor_sources_starts,
return pot, self.timing_future(ops)

def multipole_to_local(self,
actx: PyOpenCLArrayContext,
level_start_target_or_target_parent_box_nrs,
target_or_target_parent_boxes,
starts, lists, mpole_exps):
Expand All @@ -164,7 +172,9 @@ def multipole_to_local(self,
return local_exps, self.timing_future(ops)

def eval_multipoles(self,
target_boxes_by_source_level, from_sep_smaller_nonsiblings_by_level,
actx: PyOpenCLArrayContext,
target_boxes_by_source_level,
from_sep_smaller_nonsiblings_by_level,
mpole_exps):
pot = self.output_zeros()
ops = 0
Expand All @@ -186,8 +196,10 @@ def eval_multipoles(self,
return pot, self.timing_future(ops)

def form_locals(self,
actx: PyOpenCLArrayContext,
level_start_target_or_target_parent_box_nrs,
target_or_target_parent_boxes, starts, lists, src_weight_vecs):
target_or_target_parent_boxes,
starts, lists, src_weight_vecs):
src_weights, = src_weight_vecs
local_exps = self.local_expansion_zeros()
ops = 0
Expand All @@ -209,7 +221,9 @@ def form_locals(self,

return local_exps, self.timing_future(ops)

def refine_locals(self, level_start_target_or_target_parent_box_nrs,
def refine_locals(self,
actx: PyOpenCLArrayContext,
level_start_target_or_target_parent_box_nrs,
target_or_target_parent_boxes, local_exps):
ops = 0

Expand All @@ -222,7 +236,10 @@ def refine_locals(self, level_start_target_or_target_parent_box_nrs,

return local_exps, self.timing_future(ops)

def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps):
def eval_locals(self,
actx: PyOpenCLArrayContext,
level_start_target_box_nrs,
target_boxes, local_exps):
pot = self.output_zeros()
ops = 0

Expand All @@ -233,7 +250,7 @@ def eval_locals(self, level_start_target_box_nrs, target_boxes, local_exps):

return pot, self.timing_future(ops)

def finalize_potentials(self, potentials, template_ary):
def finalize_potentials(self, actx: PyOpenCLArrayContext, potentials):
return potentials

# }}}
Expand Down
9 changes: 6 additions & 3 deletions boxtree/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,14 @@ def __init__(self, array_context: PyOpenCLArrayContext, global_tree,
array_context, global_tree, traversal_builder, wrangler_factory,
calibration_params, comm)

def drive_dfmm(self, source_weights, timing_data=None):
"""Calculate potentials at target points.
"""
def drive_dfmm(self,
actx: PyOpenCLArrayContext,
source_weights,
timing_data=None):
"""Calculate potentials at target points."""
from boxtree.fmm import drive_fmm
return drive_fmm(
actx,
self.wrangler, source_weights,
timing_data=timing_data,
global_src_idx_all_ranks=self.src_idx_all_ranks,
Expand Down
9 changes: 6 additions & 3 deletions boxtree/distributed/calculation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def mpi_size(self):
def is_mpi_root(self):
return self.mpi_rank == 0

def distribute_source_weights(self, src_weight_vecs, src_idx_all_ranks):
def distribute_source_weights(self,
actx: PyOpenCLArrayContext, src_weight_vecs, src_idx_all_ranks):
if self.is_mpi_root:
distribute_weight_req = []
local_src_weight_vecs = np.empty((self.mpi_size,), dtype=object)
Expand All @@ -98,7 +99,8 @@ def distribute_source_weights(self, src_weight_vecs, src_idx_all_ranks):

return local_src_weight_vecs

def gather_potential_results(self, potentials, tgt_idx_all_ranks):
def gather_potential_results(self,
actx: PyOpenCLArrayContext, potentials, tgt_idx_all_ranks):
from boxtree.distributed import dtype_to_mpi
potentials_mpi_type = dtype_to_mpi(potentials.dtype)
gathered_potentials = None
Expand Down Expand Up @@ -254,7 +256,8 @@ def find_boxes_used_by_subrange(

return box_in_subrange

def communicate_mpoles(self, mpole_exps, return_stats=False):
def communicate_mpoles(self,
actx: PyOpenCLArrayContext, mpole_exps, return_stats=False):
"""Based on Algorithm 3: Reduce and Scatter in Lashuk et al. [1]_.
The main idea is to mimic an allreduce as done on a hypercube network, but to
Expand Down
54 changes: 42 additions & 12 deletions boxtree/fmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from boxtree.tree import Tree
from boxtree.traversal import FMMTraversalInfo
from boxtree.array_context import PyOpenCLArrayContext

import logging
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -155,6 +156,7 @@ def local_expansions_view(self, local_exps, level):

@abstractmethod
def form_multipoles(self,
actx: PyOpenCLArrayContext,
level_start_source_box_nrs, source_boxes,
src_weight_vecs):
"""Return an expansions array
Expand All @@ -167,6 +169,7 @@ def form_multipoles(self,

@abstractmethod
def coarsen_multipoles(self,
actx: PyOpenCLArrayContext,
level_start_source_parent_box_nrs,
source_parent_boxes, mpoles):
"""For each box in *source_parent_boxes*,
Expand All @@ -179,6 +182,7 @@ def coarsen_multipoles(self,

@abstractmethod
def eval_direct(self,
actx: PyOpenCLArrayContext,
target_boxes, neighbor_sources_starts,
neighbor_sources_lists, src_weight_vecs):
"""For each box in *target_boxes*, evaluate the influence of the
Expand All @@ -191,6 +195,7 @@ def eval_direct(self,

@abstractmethod
def multipole_to_local(self,
actx: PyOpenCLArrayContext,
level_start_target_or_target_parent_box_nrs,
target_or_target_parent_boxes,
starts, lists, mpole_exps):
Expand All @@ -205,6 +210,7 @@ def multipole_to_local(self,

@abstractmethod
def eval_multipoles(self,
actx: PyOpenCLArrayContext,
target_boxes_by_source_level, from_sep_smaller_by_level, mpole_exps):
"""For a level *i*, each box in *target_boxes_by_source_level[i]*, evaluate
the multipole expansion in *mpole_exps* in the nearby boxes given in
Expand All @@ -218,6 +224,7 @@ def eval_multipoles(self,

@abstractmethod
def form_locals(self,
actx: PyOpenCLArrayContext,
level_start_target_or_target_parent_box_nrs,
target_or_target_parent_boxes, starts, lists, src_weight_vecs):
"""For each box in *target_or_target_parent_boxes*, form local
Expand All @@ -232,6 +239,7 @@ def form_locals(self,

@abstractmethod
def refine_locals(self,
actx: PyOpenCLArrayContext,
level_start_target_or_target_parent_box_nrs,
target_or_target_parent_boxes, local_exps):
"""For each box in *child_boxes*,
Expand All @@ -243,6 +251,7 @@ def refine_locals(self,

@abstractmethod
def eval_locals(self,
actx: PyOpenCLArrayContext,
level_start_target_box_nrs, target_boxes, local_exps):
"""For each box in *target_boxes*, evaluate the local expansion in
*local_exps* and return a new potential array.
Expand All @@ -254,7 +263,7 @@ def eval_locals(self,
# }}}

@abstractmethod
def finalize_potentials(self, potentials, template_ary):
def finalize_potentials(self, actx: PyOpenCLArrayContext, potentials):
"""
Postprocess the reordered potentials. This is where global scaling
factors could be applied. This is distinct from :meth:`reorder_potentials`
Expand All @@ -268,7 +277,9 @@ def finalize_potentials(self, potentials, template_ary):
type.
"""

def distribute_source_weights(self, src_weight_vecs, src_idx_all_ranks):
def distribute_source_weights(self,
actx: PyOpenCLArrayContext,
src_weight_vecs, src_idx_all_ranks):
"""Used by the distributed implementation for transferring needed source
weights from root rank to each worker rank in the communicator.
Expand All @@ -288,7 +299,9 @@ def distribute_source_weights(self, src_weight_vecs, src_idx_all_ranks):
"""
return src_weight_vecs

def gather_potential_results(self, potentials, tgt_idx_all_ranks):
def gather_potential_results(self,
actx: PyOpenCLArrayContext,
potentials, tgt_idx_all_ranks):
"""Used by the distributed implementation for gathering calculated potentials
from all worker ranks in the communicator to the root rank.
Expand All @@ -305,7 +318,9 @@ def gather_potential_results(self, potentials, tgt_idx_all_ranks):
"""
return potentials

def communicate_mpoles(self, mpole_exps, return_stats=False):
def communicate_mpoles(self,
actx: PyOpenCLArrayContext,
mpole_exps, return_stats=False):
"""Used by the distributed implementation for forming the complete multipole
expansions from the partial multipole expansions.
Expand All @@ -324,9 +339,12 @@ def communicate_mpoles(self, mpole_exps, return_stats=False):
# }}}


def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
def drive_fmm(actx: PyOpenCLArrayContext,
wrangler: ExpansionWranglerInterface,
src_weight_vecs, *,
timing_data=None,
global_src_idx_all_ranks=None, global_tgt_idx_all_ranks=None):
global_src_idx_all_ranks=None,
global_tgt_idx_all_ranks=None):
"""Top-level driver routine for a fast multipole calculation.
In part, this is intended as a template for custom FMMs, in the sense that
Expand Down Expand Up @@ -373,15 +391,17 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
from boxtree.timing import TimingRecorder
recorder = TimingRecorder()

src_weight_vecs = [wrangler.reorder_sources(weight) for
weight in src_weight_vecs]
src_weight_vecs = [
wrangler.reorder_sources(weight) for weight in src_weight_vecs]

src_weight_vecs = wrangler.distribute_source_weights(
src_weight_vecs, global_src_idx_all_ranks)
actx,
src_weight_vecs, global_src_idx_all_ranks)

# {{{ "Step 2.1:" Construct local multipoles

mpole_exps, timing_future = wrangler.form_multipoles(
actx,
traversal.level_start_source_box_nrs,
traversal.source_boxes,
src_weight_vecs)
Expand All @@ -393,6 +413,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# {{{ "Step 2.2:" Propagate multipoles upward

mpole_exps, timing_future = wrangler.coarsen_multipoles(
actx,
traversal.level_start_source_parent_box_nrs,
traversal.source_parent_boxes,
mpole_exps)
Expand All @@ -403,11 +424,12 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,

# }}}

wrangler.communicate_mpoles(mpole_exps)
wrangler.communicate_mpoles(actx, mpole_exps)

# {{{ "Stage 3:" Direct evaluation from neighbor source boxes ("list 1")

potentials, timing_future = wrangler.eval_direct(
actx,
traversal.target_boxes,
traversal.neighbor_source_boxes_starts,
traversal.neighbor_source_boxes_lists,
Expand All @@ -422,6 +444,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# {{{ "Stage 4:" translate separated siblings' ("list 2") mpoles to local

local_exps, timing_future = wrangler.multipole_to_local(
actx,
traversal.level_start_target_or_target_parent_box_nrs,
traversal.target_or_target_parent_boxes,
traversal.from_sep_siblings_starts,
Expand All @@ -440,6 +463,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# contribution *out* of the downward-propagating local expansions)

mpole_result, timing_future = wrangler.eval_multipoles(
actx,
traversal.target_boxes_sep_smaller_by_source_level,
traversal.from_sep_smaller_by_level,
mpole_exps)
Expand All @@ -455,6 +479,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
"('list 3 close')")

direct_result, timing_future = wrangler.eval_direct(
actx,
traversal.target_boxes,
traversal.from_sep_close_smaller_starts,
traversal.from_sep_close_smaller_lists,
Expand All @@ -469,6 +494,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# {{{ "Stage 6:" form locals for separated bigger source boxes ("list 4")

local_result, timing_future = wrangler.form_locals(
actx,
traversal.level_start_target_or_target_parent_box_nrs,
traversal.target_or_target_parent_boxes,
traversal.from_sep_bigger_starts,
Expand All @@ -481,6 +507,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,

if traversal.from_sep_close_bigger_starts is not None:
direct_result, timing_future = wrangler.eval_direct(
actx,
traversal.target_boxes,
traversal.from_sep_close_bigger_starts,
traversal.from_sep_close_bigger_lists,
Expand All @@ -495,6 +522,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# {{{ "Stage 7:" propagate local_exps downward

local_exps, timing_future = wrangler.refine_locals(
actx,
traversal.level_start_target_or_target_parent_box_nrs,
traversal.target_or_target_parent_boxes,
local_exps)
Expand All @@ -506,6 +534,7 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# {{{ "Stage 8:" evaluate locals

local_result, timing_future = wrangler.eval_locals(
actx,
traversal.level_start_target_box_nrs,
traversal.target_boxes,
local_exps)
Expand All @@ -517,11 +546,12 @@ def drive_fmm(wrangler: ExpansionWranglerInterface, src_weight_vecs,
# }}}

potentials = wrangler.gather_potential_results(
potentials, global_tgt_idx_all_ranks)
actx,
potentials, global_tgt_idx_all_ranks)

result = wrangler.reorder_potentials(potentials)

result = wrangler.finalize_potentials(result, template_ary=src_weight_vecs[0])
result = wrangler.finalize_potentials(actx, result)

fmm_proc.done()

Expand Down
Loading

0 comments on commit 31f92d3

Please sign in to comment.