From 86e6eca9fbf3c72b2fe538e87e4a5fcd6ddc2193 Mon Sep 17 00:00:00 2001 From: routhleck <1310722434@qq.com> Date: Wed, 1 Nov 2023 17:02:55 +0800 Subject: [PATCH] Link libtaichi_c_api.so when import brainpylib --- brainpy/_src/math/brainpylib_check.py | 11 +++++++++++ brainpy/_src/math/tests/test_taichi_op_register.py | 14 +++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/brainpy/_src/math/brainpylib_check.py b/brainpy/_src/math/brainpylib_check.py index 95e029471..74036f8ec 100644 --- a/brainpy/_src/math/brainpylib_check.py +++ b/brainpy/_src/math/brainpylib_check.py @@ -1,4 +1,15 @@ from jax.lib import xla_client +import taichi as ti +import os + + +taichi_path = ti.__path__[0] +taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api') +import ctypes +try: + ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so') +except OSError: + print('taichi aot custom call, Only support linux now.') # Register the CPU XLA custom calls try: diff --git a/brainpy/_src/math/tests/test_taichi_op_register.py b/brainpy/_src/math/tests/test_taichi_op_register.py index 9990a4a62..1de56056e 100644 --- a/brainpy/_src/math/tests/test_taichi_op_register.py +++ b/brainpy/_src/math/tests/test_taichi_op_register.py @@ -1,12 +1,24 @@ import unittest import jax import jax.numpy as jnp -import brainpy.math as bm import taichi as ti import os +taichi_path = ti.__path__[0] +taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api') +taichi_lib_dir = os.path.join(taichi_path, '_lib', 'runtime') +os.environ.update({ +'TAICHI_C_API_INSTALL_DIR': taichi_c_api_install_dir, +'TI_LIB_DIR': taichi_lib_dir +}) + +import brainpy.math as bm + +# from brainpylib import cpu_ops +# print(cpu_ops.registrations().items()) bm.set_platform('gpu') + @ti.kernel def event_ell_cpu(indices: ti.types.ndarray(ndim=2), vector: ti.types.ndarray(ndim=1), weight: ti.types.ndarray(ndim=1), out: ti.types.ndarray(ndim=1)): weight_0 = weight[0]