Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[math] Resolve encoding of source kernel when ti.func is nested in ti… #532

Merged
merged 2 commits into from
Nov 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions brainpy/_src/math/op_register/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from .taichi_aot_based import (register_taichi_cpu_translation_rule,
register_taichi_gpu_translation_rule,
encode_md5,
preprocess_kernel_call_cpu, )
preprocess_kernel_call_cpu,
get_source_with_dependencies)
from .utils import register_general_batching


Expand Down Expand Up @@ -153,7 +154,7 @@ def __call__(self, *ins, outs: Optional[Sequence[ShapeDtype]] = None):
self.outs = tuple([_transform_to_shapedarray(o) for o in outs])
cpu_kernel = getattr(self, "cpu_kernel", None)
if hasattr(cpu_kernel, '_is_wrapped_kernel') and cpu_kernel._is_wrapped_kernel: # taichi
source_md5_encode = encode_md5('cpu' + inspect.getsource(cpu_kernel) + \
source_md5_encode = encode_md5('cpu' + get_source_with_dependencies(cpu_kernel) + \
str([(value.dtype, value.shape) for value in ins]) + \
str([(value.dtype, value.shape) for value in outs]))
new_ins = preprocess_kernel_call_cpu(source_md5_encode, ins, outs)
Expand Down
28 changes: 26 additions & 2 deletions brainpy/_src/math/op_register/taichi_aot_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
import os
import pathlib
import re
import sqlite3
from functools import partial, reduce
from typing import Any
Expand Down Expand Up @@ -35,6 +36,29 @@ def encode_md5(source: str) -> str:

return md5.hexdigest()

# get source with dependencies
def get_source_with_dependencies(func, visited=None):
if visited is None:
visited = set()

source = inspect.getsource(func)

if func in visited:
return ''

visited.add(func)

module = inspect.getmodule(func)

dependent_funcs = re.findall(r'(\w+)\(', source)

for func_name in dependent_funcs:
dependent_func = getattr(module, func_name, None)
if callable(dependent_func):
source += get_source_with_dependencies(dependent_func, visited)

return source


### VARIABLES ###
home_path = get_home_dir()
Expand Down Expand Up @@ -330,7 +354,7 @@ def _taichi_cpu_translation_rule(prim, kernel, c, *ins):
else:
outs_dict[name] = (output_dtypes[i - in_num], output_shapes[i - in_num])

source_md5_encode = encode_md5('cpu' + inspect.getsource(kernel) +
source_md5_encode = encode_md5('cpu' + get_source_with_dependencies(kernel) +
str([(value[0], value[1]) for value in ins_dict.values()]) +
str([(value[0], value[1]) for value in outs_dict.values()]))

Expand Down Expand Up @@ -373,7 +397,7 @@ def _taichi_gpu_translation_rule(prim, kernel, c, *ins):
out_names = names[in_num:]
ins_dict = {key: (dtype, shape) for key, shape, dtype in zip(in_names, input_shapes, input_dtypes)}
outs_dict = {key: (dtype, shape) for key, shape, dtype in zip(out_names, output_shapes, output_dtypes)}
source_md5_encode = encode_md5('gpu' + inspect.getsource(kernel) +
source_md5_encode = encode_md5('gpu' + get_source_with_dependencies(kernel) +
str([(value[0], value[1]) for value in ins_dict.values()]) +
str([(value[0], value[1]) for value in outs_dict.values()]))

Expand Down
45 changes: 35 additions & 10 deletions brainpy/_src/math/op_register/tests/test_taichi_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,41 @@

import brainpy.math as bm

bm.set_platform('cpu')

# @ti.kernel
# def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
# vector: ti.types.ndarray(ndim=1),
# weight: ti.types.ndarray(ndim=1),
# out: ti.types.ndarray(ndim=1)):
# weight_0 = weight[0]
# num_rows, num_cols = indices.shape
# ti.loop_config(serialize=True)
# for i in range(num_rows):
# if vector[i]:
# for j in range(num_cols):
# out[indices[i, j]] += weight_0

@ti.func
def get_weight(weight: ti.types.ndarray(ndim=1)) -> ti.f32:
return weight[0]

@ti.func
def update_output(out: ti.types.ndarray(ndim=1), index: ti.i32, weight_val: ti.f32):
out[index] += weight_val

@ti.kernel
def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
vector: ti.types.ndarray(ndim=1),
weight: ti.types.ndarray(ndim=1),
out: ti.types.ndarray(ndim=1)):
weight_0 = weight[0]
num_rows, num_cols = indices.shape
ti.loop_config(serialize=True)
for i in range(num_rows):
if vector[i]:
for j in range(num_cols):
out[indices[i, j]] += weight_0
weight_val = get_weight(weight)
num_rows, num_cols = indices.shape
ti.loop_config(serialize=True)
for i in range(num_rows):
if vector[i]:
for j in range(num_cols):
update_output(out, indices[i, j], weight_val)


prim = bm.XLACustomOp(cpu_kernel=event_ell_cpu)
Expand All @@ -27,10 +49,13 @@ def event_ell_cpu(indices: ti.types.ndarray(ndim=2),
# indices = bm.random.randint(0, s, (s, 1000))
# vector = bm.random.rand(s) < 0.1
# weight = bm.array([1.0])
#

# out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

# out = prim(indices, vector, weight, outs=[jax.ShapeDtypeStruct((s,), dtype=jnp.float32)])

# print(out)
# bm.clear_buffer_memory()
#
#


# test_taichi_op_register()