Skip to content

Commit

Permalink
Link libtaichi_c_api.so when import brainpylib
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Nov 1, 2023
1 parent f45a454 commit 86e6eca
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
11 changes: 11 additions & 0 deletions brainpy/_src/math/brainpylib_check.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
from jax.lib import xla_client
import taichi as ti
import os


taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
import ctypes
try:
ctypes.CDLL(taichi_c_api_install_dir + '/lib/libtaichi_c_api.so')
except OSError:
print('taichi aot custom call, Only support linux now.')

# Register the CPU XLA custom calls
try:
Expand Down
14 changes: 13 additions & 1 deletion brainpy/_src/math/tests/test_taichi_op_register.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
import unittest
import jax
import jax.numpy as jnp
import brainpy.math as bm
import taichi as ti
import os
taichi_path = ti.__path__[0]
taichi_c_api_install_dir = os.path.join(taichi_path, '_lib', 'c_api')
taichi_lib_dir = os.path.join(taichi_path, '_lib', 'runtime')
os.environ.update({
'TAICHI_C_API_INSTALL_DIR': taichi_c_api_install_dir,
'TI_LIB_DIR': taichi_lib_dir
})

import brainpy.math as bm

# from brainpylib import cpu_ops
# print(cpu_ops.registrations().items())

bm.set_platform('gpu')


@ti.kernel
def event_ell_cpu(indices: ti.types.ndarray(ndim=2), vector: ti.types.ndarray(ndim=1), weight: ti.types.ndarray(ndim=1), out: ti.types.ndarray(ndim=1)):
weight_0 = weight[0]
Expand Down

0 comments on commit 86e6eca

Please sign in to comment.