Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement CupyArrayContext #251

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7ab5211
Implement CupyArrayContext
matthiasdiener Feb 16, 2024
96b7a3d
print device name in test
matthiasdiener Feb 16, 2024
8dee38d
pylint
matthiasdiener Feb 17, 2024
2c025eb
Merge branch 'main' into cupyactx
matthiasdiener May 24, 2024
d6e3136
Merge branch 'main' into cupyactx
matthiasdiener Sep 6, 2024
bfa648a
update with current numpy actx
matthiasdiener Sep 6, 2024
27e5a19
restore some tests
matthiasdiener Sep 6, 2024
6d507e1
ruff
matthiasdiener Sep 6, 2024
8fb4e0b
make cupy import optional
matthiasdiener Sep 6, 2024
be70b67
CI fixes
matthiasdiener Sep 6, 2024
677419b
remove a few spurious changes
matthiasdiener Sep 6, 2024
6250211
change CI cupy integration
matthiasdiener Sep 6, 2024
d61f0cf
simplify CI install slightly
matthiasdiener Sep 6, 2024
5871ae7
Merge branch 'main' into cupyactx
matthiasdiener Nov 14, 2024
9c56443
Merge branch 'main' into cupyactx
matthiasdiener Dec 3, 2024
fd95813
lint
matthiasdiener Dec 3, 2024
5f4c4d9
update docs
matthiasdiener Dec 3, 2024
a8fe272
Merge branch 'main' into cupyactx
matthiasdiener Jan 31, 2025
2296c6d
improve array container support in {to,from}_numpy
matthiasdiener Jan 31, 2025
8fd5488
WS fix
matthiasdiener Jan 31, 2025
6f3cd94
fixes
matthiasdiener Feb 3, 2025
ab8266d
allow optional device selection
matthiasdiener Feb 7, 2025
79b0bc3
try running cupy via gitlab
matthiasdiener Feb 7, 2025
70aff99
debug more
matthiasdiener Feb 7, 2025
8b5c6cf
Revert "debug more"
matthiasdiener Feb 7, 2025
8561b2f
only test pocl-cpu in cupy test
matthiasdiener Feb 7, 2025
904e061
fix conda build
matthiasdiener Feb 7, 2025
340f9dc
add to coverage table
matthiasdiener Feb 8, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ jobs:
run: |
USE_CONDA_BUILD=1
curl -L -O https://gitlab.tiker.net/inducer/ci-support/raw/master/prepare-and-run-pylint.sh

CONDA_ENVIRONMENT=.test-conda-env-py3.yml
echo "- cupy" >> "$CONDA_ENVIRONMENT"
Comment on lines +44 to +45
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These probably also need to be added to .gitlab-ci.yml


. ./prepare-and-run-pylint.sh "$(basename $GITHUB_REPOSITORY)" examples/*.py test/test_*.py

mypy:
Expand All @@ -52,6 +56,9 @@ jobs:
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0

CONDA_ENVIRONMENT=.test-conda-env-py3.yml
echo "- cupy" >> "$CONDA_ENVIRONMENT"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that CI for this is not running on Github (and cannot run). Why install the package then? (Also above.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If cupy is not installed, these tests will fail, e.g. mypy with arraycontext/impl/cupy/__init__.py:57: error: Cannot find implementation or library stub for module named "cupy" [import-not-found]. We could type: ignore these, but at least in the case of pylint we would have to annotate every import cupy line with # pylint: disable=import-error. Not sure which way is better.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you adapt .gitlab-ci.yml and push a branch to Gitlab to show that CI for the actx succeeds? I've added you to the project there.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


build_py_project_in_conda_env
python -m pip install mypy pytest
./run-mypy.sh
Expand Down
19 changes: 19 additions & 0 deletions .gitlab-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,25 @@ Python 3 Nvidia Titan V:
reports:
junit: test/pytest.xml

Python 3 CuPy Nvidia Titan V:
script: |
curl -L -O https://tiker.net/ci-support-v0
. ./ci-support-v0
CONDA_ENVIRONMENT=.test-conda-env-py3.yml
echo "- cupy" >> "$CONDA_ENVIRONMENT"
export PYOPENCL_TEST=port:cpu
build_py_project_in_conda_env
test_py_project

tags:
- python3
- nvidia-titan-v
except:
- tags
artifacts:
reports:
junit: test/pytest.xml

Python 3 POCL Nvidia Titan V:
script: |
curl -L -O https://tiker.net/ci-support-v0
Expand Down
1 change: 1 addition & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ code to work with all of them? No problem! Comes with pre-made array context
implementations for:

- numpy
- cupy
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not make this into a link to cupy.dev or something?

- `PyOpenCL <https://documen.tician.de/pyopencl/array.html>`__
- `JAX <https://jax.readthedocs.io/en/latest/>`__
- `Pytato <https://documen.tician.de/pytato>`__ (for lazy/deferred evaluation)
Expand Down
2 changes: 2 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
ScalarLike,
tag_axes,
)
from .impl.cupy import CupyArrayContext
from .impl.jax import EagerJAXArrayContext
from .impl.numpy import NumpyArrayContext
from .impl.pyopencl import PyOpenCLArrayContext
Expand Down Expand Up @@ -116,6 +117,7 @@
"ArrayOrContainerT",
"ArrayT",
"CommonSubexpressionTag",
"CupyArrayContext",
"EagerJAXArrayContext",
"ElementwiseMapKernelTag",
"NotAnArrayContainerError",
Expand Down
207 changes: 207 additions & 0 deletions arraycontext/impl/cupy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
"""
.. currentmodule:: arraycontext

A :mod:`cupy`-based array context.

.. autoclass:: CupyArrayContext
"""

from __future__ import annotations


__copyright__ = """
Copyright (C) 2024 University of Illinois Board of Trustees
"""

__license__ = """
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in
all copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
"""

from typing import Any, overload

import numpy as np

import loopy as lp
from pytools.tag import ToTagSetConvertible

from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
ContainerOrScalarT,
NumpyOrContainerOrScalar,
UntransformedCodeWarning,
)


class CupyNonObjectArrayMetaclass(type):
def __instancecheck__(cls, instance: Any) -> bool:
import cupy as cp # type: ignore[import-untyped]
return isinstance(instance, cp.ndarray) and instance.dtype != object


class CupyNonObjectArray(metaclass=CupyNonObjectArrayMetaclass):
pass


class CupyArrayContext(ArrayContext):
"""
An :class:`ArrayContext` that uses :class:`cupy.ndarray` to represent arrays.

.. automethod:: __init__
"""

_loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase]

def __init__(self, device: int | None = None) -> None:
super().__init__()
self._loopy_transform_cache = {}

if device is not None:
import cupy as cp
cp.cuda.runtime.setDevice(device)
Comment on lines +78 to +80
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do devices work in cupy? Is there just one global device at a time and all array operations are done on that one?

jax seems to work in a vaguely similar fashion, and we don't mock around with the device selection there. Maybe this shouldn't either?


array_types = (CupyNonObjectArray,)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Put this before __init__?


def _get_fake_numpy_namespace(self):
from .fake_numpy import CupyFakeNumpyNamespace
return CupyFakeNumpyNamespace(self)

# {{{ ArrayContext interface

def clone(self):
return type(self)()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return type(self)()
return type(self)(self.device)

If the device setting sticks around.


@overload
def from_numpy(self, array: np.ndarray) -> Array:
Comment on lines +93 to +94
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary here as well? I don't think any of the other implementations have it and mypy isn't complaining?

...

@overload
def from_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...

def from_numpy(self,
array: NumpyOrContainerOrScalar
) -> ArrayOrContainerOrScalar:
import cupy as cp

def _from_numpy(ary):
return cp.array(ary)

return with_array_context(rec_map_array_container(_from_numpy, array),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return with_array_context(rec_map_array_container(_from_numpy, array),
return with_array_context(rec_map_array_container(cp.array, array),

? Doesn't seem useful to wrap it if there are no additional arguments.

actx=self)

@overload
def to_numpy(self, array: Array) -> np.ndarray:
...

@overload
def to_numpy(self, array: ContainerOrScalarT) -> ContainerOrScalarT:
...

def to_numpy(self,
array: ArrayOrContainerOrScalar
) -> NumpyOrContainerOrScalar:
import cupy as cp

def _to_numpy(ary):
return cp.asnumpy(ary)

return with_array_context(rec_map_array_container(_to_numpy, array),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return with_array_context(rec_map_array_container(_to_numpy, array),
return with_array_context(rec_map_array_container(cp.asnumpy, array),

actx=None)

def call_loopy(
self,
t_unit: lp.TranslationUnit, **kwargs: Any
) -> dict[str, Array]:
t_unit = t_unit.copy(target=lp.ExecutableCTarget())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to work? i.e. do the cupy arrays meaningfully translate to some C code? I would expect this to error out like the jax array context.

try:
executor = self._loopy_transform_cache[t_unit]
except KeyError:
executor = self.transform_loopy_program(t_unit).executor()
self._loopy_transform_cache[t_unit] = executor

_, result = executor(**kwargs)

return result

def freeze(self, array):
import cupy as cp

def _freeze(ary):
return cp.asnumpy(ary)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does cupy have something like jax.block_until_ready or pyopencl.finish? This transfers the memory back to the host, right?


return with_array_context(rec_map_array_container(_freeze, array), actx=None)

def thaw(self, array):
import cupy as cp

def _thaw(ary):
return cp.array(ary)

return with_array_context(rec_map_array_container(_thaw, array), actx=self)

# }}}

def transform_loopy_program(self, t_unit):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be deleted.

from warnings import warn
warn("Using the base "
f"{type(self).__name__}.transform_loopy_program "
"to transform a translation unit. "
"This is a no-op and will result in unoptimized C code for"
"the requested optimization, all in a single statement."
"This will work, but is unlikely to be performant."
f"Instead, subclass {type(self).__name__} and implement "
"the specific transform logic required to transform the program "
"for your package or application. Check higher-level packages "
"(e.g. meshmode), which may already have subclasses you may want "
"to build on.",
UntransformedCodeWarning, stacklevel=2)

return t_unit

def tag(self,
tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
# Cupy (like numpy) doesn't support tagging
return array

def tag_axis(self,
iaxis: int, tags: ToTagSetConvertible,
array: ArrayOrContainerOrScalarT) -> ArrayOrContainerOrScalarT:
# Cupy (like numpy) doesn't support tagging
return array

def einsum(self, spec, *args, arg_names=None, tagged=()):
import cupy as cp
return cp.einsum(spec, *args)

@property
def permits_inplace_modification(self):
return True

@property
def supports_nonscalar_broadcasting(self):
return True

@property
def permits_advanced_indexing(self):
return True
Loading
Loading