diff --git a/grudge/array_context.py b/grudge/array_context.py index aca1edc08..3347ec70b 100644 --- a/grudge/array_context.py +++ b/grudge/array_context.py @@ -4,6 +4,7 @@ .. autoclass:: MPIBasedArrayContext .. autoclass:: MPIPyOpenCLArrayContext .. autoclass:: MPINumpyArrayContext +.. autoclass:: MPICupyArrayContext .. class:: MPIPytatoArrayContext .. autofunction:: get_reasonable_array_context_class """ @@ -104,6 +105,8 @@ ) +from arraycontext import CupyArrayContext + if TYPE_CHECKING: import pytato as pt from mpi4py import MPI @@ -426,6 +429,26 @@ def clone(self): # }}} +# {{{ + +class MPICupyArrayContext(CupyArrayContext, MPIBasedArrayContext): + """An array context for using distributed computation with :mod:`cupy` + eager evaluation. + + .. autofunction:: __init__ + """ + + def __init__(self, mpi_communicator): + super().__init__() + + self.mpi_communicator = mpi_communicator + + def clone(self): + return type(self)(self.mpi_communicator) + +# }}} + + # {{{ distributed + pyopencl class MPIPyOpenCLArrayContext(PyOpenCLArrayContext, MPIBasedArrayContext):