Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] Add new customize operators with cupy #653

Merged
merged 9 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions brainpy/_src/dependency_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
'raise_taichi_not_found',
'import_numba',
'raise_numba_not_found',
'import_cupy',
'import_cupy_jit',
'raise_cupy_not_found',
'import_brainpylib_cpu_ops',
'import_brainpylib_gpu_ops',
]
Expand All @@ -17,6 +20,8 @@

numba = None
taichi = None
cupy = None
cupy_jit = None
brainpylib_cpu_ops = None
brainpylib_gpu_ops = None

Expand All @@ -25,6 +30,9 @@
'> pip install taichi==1.7.0')
numba_install_info = ('We need numba. Please install numba by pip . \n'
'> pip install numba')
cupy_install_info = ('We need cupy. Please install cupy by pip . \n'
'For CUDA v11.2 ~ 11.8 > pip install cupy-cuda11x\n'
'For CUDA v12.x > pip install cupy-cuda12x\n')
os.environ["TI_LOG_LEVEL"] = "error"


Expand Down Expand Up @@ -81,6 +89,48 @@ def raise_numba_not_found():
raise ModuleNotFoundError(numba_install_info)


def import_cupy(error_if_not_found=True):
"""
Internal API to import cupy.

If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
otherwise it will return None.
"""
global cupy
if cupy is None:
try:
import cupy as cupy
except ModuleNotFoundError:
if error_if_not_found:
raise_cupy_not_found()
else:
return None
return cupy


def import_cupy_jit(error_if_not_found=True):
"""
Internal API to import cupy.

If cupy is not found, it will raise a ModuleNotFoundError if error_if_not_found is True,
otherwise it will return None.
"""
global cupy_jit
if cupy_jit is None:
try:
from cupyx import jit as cupy_jit
except ModuleNotFoundError:
if error_if_not_found:
raise_cupy_not_found()
else:
return None
return cupy_jit


def raise_cupy_not_found():
raise ModuleNotFoundError(cupy_install_info)


def is_brainpylib_gpu_installed():
return False if brainpylib_gpu_ops is None else True

Expand Down
53 changes: 20 additions & 33 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
from functools import partial
from typing import Callable, Sequence, Tuple, Protocol, Optional
from typing import Callable, Sequence, Tuple, Protocol, Optional, Union

import jax
import numpy as np
from jax.interpreters import xla, batching, ad, mlir

from brainpy._src.dependency_check import import_numba
from brainpy._src.dependency_check import import_numba, import_cupy_jit
from brainpy._src.math.ndarray import Array
from brainpy._src.math.object_transform.base import BrainPyObject

if jax.__version__ >= '0.4.16':
from .numba_based import register_numba_mlir_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_mlir_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_mlir_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import (register_cupy_raw_module_mlir_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_mlir_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
else:
from .numba_based import register_numba_xla_cpu_translation_rule as register_numba_cpu_translation_rule
from .taichi_aot_based import (register_taichi_aot_xla_cpu_translation_rule as register_taichi_cpu_translation_rule,
register_taichi_aot_xla_gpu_translation_rule as register_taichi_gpu_translation_rule)
from .cupy_based import (register_cupy_raw_module_xla_gpu_translation_rule as register_cupy_raw_module_gpu_translation_rule,
register_cupy_jit_kernel_xla_gpu_translation_rule as register_cupy_jit_kernel_gpu_translation_rule)
from .utils import register_general_batching
from brainpy._src.math.op_register.ad_support import defjvp

numba = import_numba(error_if_not_found=False)
cp_jit = import_cupy_jit(error_if_not_found=False)

__all__ = [
'XLACustomOp',
Expand All @@ -41,34 +46,10 @@ def dtype(self) -> np.dtype:
class XLACustomOp(BrainPyObject):
"""Creating a XLA custom call operator.

>>> import numba as nb
>>> import taichi as ti
>>> import numpy as np
>>> import jax
>>>
>>> @nb.njit
>>> def numba_cpu_fun(a, b, out_a, out_b):
>>> out_a[:] = a
>>> out_b[:] = b
>>>
>>> @ti.kernel
>>> def taichi_gpu_fun(a, b, out_a, out_b):
>>> for i in range(a.size):
>>> out_a[i] = a[i]
>>> for i in range(b.size):
>>> out_b[i] = b[i]
>>>
>>> # option 1
>>> prim = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun)
>>> a2, b2 = prim(np.random.random(1000), np.random.random(1000),
>>> outs=[jax.ShapeDtypeStruct(1000, dtype=np.float32),
>>> jax.ShapeDtypeStruct(1000, dtype=np.float32)])
>>>
>>> # option 2
>>> prim2 = XLACustomOp(cpu_kernel=numba_cpu_fun, gpu_kernel=taichi_gpu_fun,
>>> outs=lambda a, b, **kwargs: [jax.ShapeDtypeStruct(a.shape, dtype=a.dtype),
>>> jax.ShapeDtypeStruct(b.shape, dtype=b.dtype)])
>>> a3, b3 = prim2(np.random.random(1000), np.random.random(1000))
For more information, please refer to the tutorials above:
Numba Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_numba.html
Taichi Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_taichi.html
CuPy Custom Op: https://brainpy.tech/docs/tutorial_advanced/operator_custom_with_cupy.html

Args:
cpu_kernel: Callable. The function defines the computation on CPU backend.
Expand All @@ -83,7 +64,7 @@ class XLACustomOp(BrainPyObject):
def __init__(
self,
cpu_kernel: Callable = None,
gpu_kernel: Callable = None,
gpu_kernel: Union[Callable, str] = None,
batching_translation: Callable = None,
jvp_translation: Callable = None,
transpose_translation: Callable = None,
Expand Down Expand Up @@ -125,11 +106,17 @@ def __init__(
gpu_checked = False
if gpu_kernel is None:
gpu_checked = True
if hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi
elif hasattr(gpu_kernel, 'kernel'): # cupy RawModule
register_cupy_raw_module_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
elif hasattr(gpu_kernel, '_mode'): # cupy JIT Kernel
register_cupy_jit_kernel_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
elif hasattr(gpu_kernel, '_is_wrapped_kernel') and gpu_kernel._is_wrapped_kernel: # taichi
register_taichi_gpu_translation_rule(self.primitive, gpu_kernel)
gpu_checked = True
if not gpu_checked:
raise ValueError(f'"cpu_kernel" must be a taichi kernel function. But we got {gpu_kernel}')
raise ValueError(f'"gpu_kernel" must be a taichi kernel function, cupy raw module or cupy jit kernel. But we got {gpu_kernel}')

# batching rule
if batching_translation is None:
Expand Down
Loading
Loading