From 121c7aaa07a24ecdf5da29b9207611b6b77f5e5c Mon Sep 17 00:00:00 2001 From: chaoming Date: Wed, 1 Nov 2023 22:44:34 +0800 Subject: [PATCH] [math] remove the hard requirement of `taichi` --- brainpy/_src/math/brainpylib_check.py | 58 +++++++++++++------ .../_src/math/op_register/taichi_aot_based.py | 51 +++++++--------- brainpy/_src/tools/package.py | 18 ------ requirements.txt | 1 - 4 files changed, 62 insertions(+), 66 deletions(-) diff --git a/brainpy/_src/math/brainpylib_check.py b/brainpy/_src/math/brainpylib_check.py index b84511784..4944027e3 100644 --- a/brainpy/_src/math/brainpylib_check.py +++ b/brainpy/_src/math/brainpylib_check.py @@ -2,25 +2,49 @@ import platform import ctypes -import taichi as ti from jax.lib import xla_client -taichi_path = ti.__path__[0] -taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api') -os.environ['TAICHI_C_API_INSTALL_DIR'] = taichi_c_api_install_dir -os.environ['TI_LIB_DIR'] = os.path.join(taichi_c_api_install_dir, 'runtime') - -# link DLL -if platform.system() == 'Windows': - try: - ctypes.CDLL(taichi_c_api_install_dir + '/bin/taichi_c_api.dll') - except OSError: - raise OSError(f'Does not found {taichi_c_api_install_dir + "/bin/taichi_c_api.dll"}') -elif platform.system() == 'Linux': - try: - ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so') - except OSError: - raise OSError(f'Does not found {taichi_c_api_install_dir + "/lib/taichi_c_api.dll"}') + +try: + import taichi as ti +except (ImportError, ModuleNotFoundError): + ti = None + + +def import_taichi(): + if ti is None: + raise ModuleNotFoundError( + 'Taichi is needed. Please install taichi through:\n\n' + '> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly' + ) + if ti.__version__ < (1, 7, 0): + raise RuntimeError( + 'We need taichi>=1.7.0. Currently you can install taichi>=1.7.0 through taichi-nightly:\n\n' + '> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly' + ) + return ti + + +if ti is None: + is_taichi_installed = False +else: + is_taichi_installed = True + taichi_path = ti.__path__[0] + taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api') + os.environ['TAICHI_C_API_INSTALL_DIR'] = taichi_c_api_install_dir + os.environ['TI_LIB_DIR'] = os.path.join(taichi_c_api_install_dir, 'runtime') + + # link DLL + if platform.system() == 'Windows': + try: + ctypes.CDLL(taichi_c_api_install_dir + '/bin/taichi_c_api.dll') + except OSError: + raise OSError(f'Can not find {taichi_c_api_install_dir + "/bin/taichi_c_api.dll"}') + elif platform.system() == 'Linux': + try: + ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so') + except OSError: + raise OSError(f'Can not find {taichi_c_api_install_dir + "/lib/taichi_c_api.dll"}') # Register the CPU XLA custom calls try: diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index bf6f6bf48..328252845 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -6,14 +6,14 @@ from functools import partial, reduce from typing import Any -import jax.numpy as jnp import numpy as np -import taichi as ti from jax.interpreters import xla from jax.lib import xla_client import brainpy.math as bm from .utils import _shape_to_layout +from ..brainpylib_check import import_taichi + ### UTILS ### @@ -122,25 +122,25 @@ def check_kernel_exist(source_md5_encode: str) -> bool: ### KERNEL AOT BUILD ### -# jnp dtype to taichi type -type_map4template = { - jnp.dtype("bool"): bool, - jnp.dtype("int8"): ti.int8, - jnp.dtype("int16"): ti.int16, - jnp.dtype("int32"): ti.int32, - jnp.dtype("int64"): ti.int64, - jnp.dtype("uint8"): ti.uint8, - jnp.dtype("uint16"): ti.uint16, - jnp.dtype("uint32"): ti.uint32, - jnp.dtype("uint64"): ti.uint64, - jnp.dtype("float16"): ti.float16, - jnp.dtype("float32"): ti.float32, - jnp.dtype("float64"): ti.float64, -} - def _array_to_field(dtype, shape) -> Any: - return ti.field(dtype=type_map4template[dtype], shape=shape) + ti = import_taichi() + if dtype == np.bool_: + dtype = bool + elif dtype == np.int8: dtype= ti.int8 + elif dtype == np.int16: dtype= ti.int16 + elif dtype == np.int32: dtype= ti.int32 + elif dtype == np.int64: dtype= ti.int64 + elif dtype == np.uint8: dtype= ti.uint8 + elif dtype == np.uint16: dtype= ti.uint16 + elif dtype == np.uint32: dtype= ti.uint32 + elif dtype == np.uint64: dtype= ti.uint64 + elif dtype == np.float16: dtype= ti.float16 + elif dtype == np.float32: dtype= ti.float32 + elif dtype == np.float64: dtype= ti.float64 + else: + raise TypeError + return ti.field(dtype=dtype, shape=shape) # build aot kernel @@ -151,6 +151,8 @@ def build_kernel( outs: dict, device: str ): + ti = import_taichi() + # init arch arch = None if device == 'cpu': @@ -191,17 +193,6 @@ def build_kernel( int: 0, float: 1, bool: 2, - ti.int32: 0, - ti.float32: 1, - ti.u8: 3, - ti.u16: 4, - ti.u32: 5, - ti.u64: 6, - ti.i8: 7, - ti.i16: 8, - ti.i64: 9, - ti.f16: 10, - ti.f64: 11, np.dtype('int32'): 0, np.dtype('float32'): 1, np.dtype('bool'): 2, diff --git a/brainpy/_src/tools/package.py b/brainpy/_src/tools/package.py index 89b384c70..7415a1cca 100644 --- a/brainpy/_src/tools/package.py +++ b/brainpy/_src/tools/package.py @@ -12,10 +12,6 @@ except (ImportError, ModuleNotFoundError): brainpylib = None -try: - import taichi as ti -except (ImportError, ModuleNotFoundError): - ti = None __all__ = [ 'import_numba', @@ -27,20 +23,6 @@ ] -def import_taichi(): - if ti is None: - raise ModuleNotFoundError( - 'Taichi is needed. Please install taichi through:\n\n' - '> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly' - ) - if ti.__version__ < (1, 7, 0): - raise RuntimeError( - 'We need taichi>=1.7.0. Currently you can install taichi>=1.7.0 through taichi-nightly:\n\n' - '> pip install -i https://pypi.taichi.graphics/simple/ taichi-nightly' - ) - return ti - - def import_numba(): if numba is None: raise ModuleNotFoundError('Numba is needed. Please install numba through:\n\n' diff --git a/requirements.txt b/requirements.txt index a329d8ca8..44025f5f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,3 @@ jax tqdm msgpack numba -taichi