Skip to content

Commit

Permalink
[math] Resolve encoding of source kernel when ti.func is nested in ti…
Browse files Browse the repository at this point in the history
….kernel
  • Loading branch information
Routhleck committed Nov 3, 2023
1 parent 1064116 commit a74accd
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 14 deletions.
5 changes: 3 additions & 2 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from .taichi_aot_based import (register_taichi_cpu_translation_rule,
register_taichi_gpu_translation_rule,
encode_md5,
preprocess_kernel_call_cpu, )
preprocess_kernel_call_cpu,
get_source_with_dependencies)
from .utils import register_general_batching


Expand Down Expand Up @@ -153,7 +154,7 @@ def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None):
self.outs = tuple([_transform_to_shapedarray(o) for o in outs])
cpu_kernel = getattr(self, "cpu_kernel", None)
if hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi
source_md5_encode = encode_md5('cpu' + inspect.getsource(cpu_kernel) + \
source_md5_encode = encode_md5('cpu' + get_source_with_dependencies(cpu_kernel) + \
str([(value.dtype, value.shape) for value in ins]) + \
str([(value.dtype, value.shape) for value in outs]))
new_ins = preprocess_kernel_call_cpu(source_md5_encode, ins, outs)
Expand Down
30 changes: 28 additions & 2 deletions brainpy/_src/math/op_register/taichi_aot_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import os
import pathlib
import re
import sqlite3
from functools import partial, reduce
from typing import Any
Expand Down Expand Up @@ -35,6 +36,31 @@ def encode_md5(source: str) -> str:

return md5.hexdigest()

# get source with dependencies
def get_source_with_dependencies(func, visited=None):
if visited is None:
visited = set()

source = inspect.getsource(func)

if func in visited:
return ''

visited.add(func)

module = inspect.getmodule(func)

dependent_funcs = re.findall(r'(\w+)\(', source)

# 递归地获取所有依赖的函数的源代码
for func_name in dependent_funcs:
# 使用 getattr 来从模块中获取函数
dependent_func = getattr(module, func_name, None)
if callable(dependent_func):
source += get_source_with_dependencies(dependent_func, visited)

return source


### VARIABLES ###
home_path = get_home_dir()
Expand Down Expand Up @@ -330,7 +356,7 @@ def _taichi_cpu_translation_rule(prim, kernel, c, *ins):
else:
outs_dict[name] = (output_dtypes[i - in_num], output_shapes[i - in_num])

source_md5_encode = encode_md5('cpu' + inspect.getsource(kernel) +
source_md5_encode = encode_md5('cpu' + get_source_with_dependencies(kernel) +
str([(value[0], value[1]) for value in ins_dict.values()]) +
str([(value[0], value[1]) for value in outs_dict.values()]))

Expand Down Expand Up @@ -373,7 +399,7 @@ def _taichi_gpu_translation_rule(prim, kernel, c, *ins):
out_names = names[in_num:]
ins_dict = {key: (dtype, shape) for key, shape, dtype in zip(in_names, input_shapes, input_dtypes)}
outs_dict = {key: (dtype, shape) for key, shape, dtype in zip(out_names, output_shapes, output_dtypes)}
source_md5_encode = encode_md5('gpu' + inspect.getsource(kernel) +
source_md5_encode = encode_md5('gpu' + get_source_with_dependencies(kernel) +
str([(value[0], value[1]) for value in ins_dict.values()]) +
str([(value[0], value[1]) for value in outs_dict.values()]))

Expand Down
45 changes: 35 additions & 10 deletions brainpy/_src/math/op_register/tests/test_taichi_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,41 @@

import brainpy.math as bm

bm.set_platform('cpu')

# @ti.kernel
# def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
# vector: ti.types.ndarray(ndim=1),
# weight: ti.types.ndarray(ndim=1),
# out: ti.types.ndarray(ndim=1)):
# weight_0 = weight[0]
# num_rows, num_cols = indices.shape
# ti.loop_config(serialize=True)
# for i in range(num_rows):
# if vector[i]:
# for j in range(num_cols):
# out[indices[i, j]] += weight_0

@ti.func
def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:
return weight[0]

@ti.func
def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):
out[index] += weight_val

@ti.kernel
def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
vector: ti.types.ndarray(ndim=1),
weight: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
weight_0 = weight[0]
num_rows, num_cols = indices.shape
ti.loop_config(serialize=True)
for i in range(num_rows):
if vector[i]:
for j in range(num_cols):
out[indices[i, j]] += weight_0
weight_val = get_weight(weight)
num_rows, num_cols = indices.shape
ti.loop_config(serialize=True)
for i in range(num_rows):
if vector[i]:
for j in range(num_cols):
update_output(out, indices[i, j], weight_val)


prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu)
Expand All @@ -27,10 +49,13 @@ def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
# indices = bm.random.randint(0, s, (s, 1000))
# vector = bm.random.rand(s) < 0.1
# weight = bm.array([1.0])
#

# out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

# out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

# print(out)
# bm.clear_buffer_memory()
#
#


# test_taichi_op_register()

0 comments on commit a74accd

Please sign in to comment.