diff --git a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py index 8e290fa35..3ac1e0ee2 100644 --- a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py +++ b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv.py @@ -12,7 +12,7 @@ import pandas as pd import taichi as ti -bm.set_platform('gpu') +bm.set_platform('cpu') s = [1000, 5000, 10000, 20000, 25000, 30000] p = [0.1, 0.2, 0.3, 0.4, 0.5] @@ -42,11 +42,29 @@ False ] +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + print(bm.get_platform()) -def test_event_csrmv_cpu(shape, values_type, events_type, transpose): +@partial(jax.jit, static_argnums=(4, 5)) +def event_csrmv_taichi(weight, indices, indptr, vector, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)[0] + return r + +@partial(jax.jit, static_argnums=(4, 5)) +def event_csrmv(weight, indices, indptr, vector, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose) + return r + +def test_event_csrmv(shape, values_type, events_type, transpose): rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.3)(*shape).require('pre2post') + indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post') vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 weight = 1. @@ -57,477 +75,146 @@ def test_event_csrmv_cpu(shape, values_type, events_type, transpose): heter_data = bm.ones(indices.shape) * weight weight = heter_data - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - # time.sleep(2) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - # print(result1[0]) - # print(result2) - # print(groundtruth - result1[0]) - # print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_event_csrmv_gpu(shape, values_type, events_type, transpose): - rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.3)(*shape).require('pre2post') - vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 - weight = 1. + time10 = time.time() + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time11 = time.time() - - if events_type == 'float': - vector = vector.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.ones(indices.shape) * weight - weight = heter_data - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - - - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - # print(result1[0]) - # print(result2) - # print(groundtruth - result1[0]) - # print(groundtruth - result2) - - print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - time12 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - - -def test_event_csrmv_square_cpu(s, p, values_type, events_type, transpose): - print('s: ', s, 'p: ', p) - k = int(s * p) - bm.random.seed(1234) - rng = bm.random.RandomState(seed=1234) - # init - indices = bm.random.randint(0, s, (s, k)) - vector = bm.random.rand(s) < 0.5 - weight = jnp.array([1.0]) - csr_indices = indices.flatten() - csr_indptr = np.cumsum(np.insert(np.ones(s, dtype=int) * k, 0, 0)) - - pre_indices = np.repeat(np.arange(s), k) - dense = np.zeros((s, s)) - dense[pre_indices, csr_indices] = 1.0 - - if events_type == 'float': - vector = vector.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.as_jax(rng.random(csr_indices.shape)) - weight = heter_data - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - # print(result1[0]) - # print(result2) - # print(groundtruth - result1[0]) - # print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - time12 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time19 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 + time22 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time23 = time.time() - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - assert(jnp.allclose(result1[0], result2)) + time24 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time25 = time.time() - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time26 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time27 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_event_csrmv_square_gpu(s, p, values_type, events_type, transpose): - print('s: ', s, 'p: ', p) - k = int(s * p) - bm.random.seed(1234) - rng = bm.random.RandomState(seed=1234) - # init - indices = bm.random.randint(0, s, (s, k)) - vector = bm.random.rand(s) < 0.5 - weight = jnp.array([1.0]) - csr_indices = indices.flatten() - csr_indptr = np.cumsum(np.insert(np.ones(s, dtype=int) * k, 0, 0)) - pre_indices = np.repeat(np.arange(s), k) - dense = np.zeros((s, s)) - dense[pre_indices, csr_indices] = 1.0 - - if events_type == 'float': - vector = vector.astype(bm.float32) - if values_type == 'heter': - heter_data = bm.as_jax(rng.random(csr_indices.shape)) - weight = heter_data - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - - - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.event.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - # print('--------------------result1[0]------------------') - # print(result1[0]) - # print('--------------------result2------------------') - # print(result2) - # print('--------------------gt------------------') - # print(groundtruth) - # print('--------------------gt - result1[0]------------------') - # print(groundtruth - result1[0]) - # print('--------------------gt - result2------------------') - # print(groundtruth - result2) + time28 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.event.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time21 = time.time() + time30 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + jax.block_until_ready(event_csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - print('s: ', s, 'p: ', p, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - - assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + # assert(jnp.allclose(result1[0], result2)) return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 PATH = os.path.dirname(os.path.abspath(__file__)) # init dataframe df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'speedup']) - -### SQUARE MATRIX - -# if (bm.get_platform() == 'cpu'): -# for _s in s: -# for _p in p: -# for _values_type in values_type: -# for _events_type in events_type: -# for _transpose in transpose: -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_square_cpu(_s, _p, _values_type, _events_type, _transpose) -# # append to dataframe -# df.loc[df.shape[0]] = [_s, _p, _s, _s, 'cpu', _values_type, _events_type, _transpose, -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] -# df.to_csv(f'{PATH}/event_csrmv_square_cpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# for _values_type in values_type: -# for _events_type in events_type: -# for _transpose in transpose: -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_square_gpu(_s, _p, _values_type, _events_type, _transpose) -# # append to dataframe -# df.loc[df.shape[0]] = [_s, _p, _s, _s, 'gpu', _values_type, _events_type, _transpose, -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] -# df.to_csv(f'{PATH}/event_csrmv_square_gpu.csv', index=False) + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) ### RECTANGULAR MATRIX if (bm.get_platform() == 'cpu'): @@ -537,11 +224,15 @@ def test_event_csrmv_square_gpu(s, p, values_type, events_type, transpose): for _events_type in events_type: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_cpu((shape1, shape2), _values_type, _events_type, _transpose) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose) # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2,'cpu', _values_type, _events_type, _transpose, + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/event_csrmv_cpu.csv', index=False) if (bm.get_platform() == 'gpu'): @@ -551,25 +242,13 @@ def test_event_csrmv_square_gpu(s, p, values_type, events_type, transpose): for _events_type in events_type: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_event_csrmv_gpu((shape1, shape2), _values_type, _events_type, _transpose) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose) # append to dataframe df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/event_csrmv_gpu.csv', index=False) - - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) -# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] -# df.to_csv('event_ell_gpu.csv', index=False) - - # df = pd.read_csv('event_ell_gpu.csv') - # for _s in s: - # for _p in p: - # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) - # # 找到对应的行 - # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) diff --git a/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py new file mode 100644 index 000000000..98793e600 --- /dev/null +++ b/brainpy/_src/math/event/tests/event_csrmv_taichi_VS_event_csrmv_grad.py @@ -0,0 +1,271 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('cpu') + +s = [1000, 5000, 10000, 20000, 25000, 30000] +p = [0.1, 0.2, 0.3, 0.4, 0.5] + +shape = [ + 1000, + 2500, + 5000, + 10000, + 25000, + 37500, + 50000 +] + + + +values_type = [ + 'homo', + 'heter' + ] +events_type = [ + 'bool', + 'float', + ] +transpose = [ + True, + False + ] + +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + +print(bm.get_platform()) + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + +def sum_op2(op): + def func(*args, **kwargs): + r = op(*args, **kwargs)[0] + return r.sum() + + return func + +@partial(jax.jit, static_argnums=(4, 5)) +def event_csrmv_taichi_grad(weight, indices, indptr, vector, shape, transpose): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op2(bm.event.csrmv_taichi), argnums=3)( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + return r + +@partial(jax.jit, static_argnums=(4, 5)) +def event_csrmv_grad(weight, indices, indptr, vector, shape, transpose): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.event.csrmv), argnums=3)( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + return r + + +def test_event_csrmv(shape, values_type, events_type, transpose): + rng = bm.random.RandomState(seed=1234) + indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post') + vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 + weight = 1. + + + if events_type == 'float': + vector = vector.astype(bm.float32) + if values_type == 'heter': + heter_data = bm.ones(indices.shape) * weight + weight = heter_data + + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + + time0 = time.time() + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time1 = time.time() + + time2 = time.time() + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time3 = time.time() + + time4 = time.time() + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time5 = time.time() + + time6 = time.time() + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time9 = time.time() + + time10 = time.time() + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time11 = time.time() + + time12 = time.time() + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time13 = time.time() + + time14 = time.time() + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time15 = time.time() + + time16 = time.time() + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time17 = time.time() + + time18 = time.time() + result = jax.block_until_ready(event_csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time19 = time.time() + + + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + + time20 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time21 = time.time() + + time22 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time23 = time.time() + + time24 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time25 = time.time() + + time26 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time27 = time.time() + + time28 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time29 = time.time() + + time30 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(event_csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time39 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') + + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', + 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape1 in shape: + for shape2 in shape: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose) + # append to dataframe + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/event_csrmv_grad_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape1 in shape: + for shape2 in shape: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_event_csrmv((shape1, shape2), _values_type, _events_type, _transpose) + # append to dataframe + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/event_csrmv_grad_gpu.csv', index=False) diff --git a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py index 249438a48..21a246650 100644 --- a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py +++ b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec.py @@ -42,16 +42,63 @@ True, False ] -conn_prob = 0.1 +conn_prob = 0.05 homo_data = 1. w_low = 0. w_high = 1. w_mu = 0. w_sigma = 0.1 +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + print(bm.get_platform()) -def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel, bool_event): +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_event_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.event_mv_prob_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + return r + +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_event_matvec_homo(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.event_mv_prob_homo(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)[0] + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.event_mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_uniform(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.event_mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.event_mv_prob_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_normal(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.event_mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + + +def test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event): rng = bm.random.RandomState(seed=seed) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 if not bool_event: @@ -59,607 +106,432 @@ def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel, bool_event): # groundtruth = bm.as_jax(events, dtype=float) @ bm.as_jax(dense) - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 + time22 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time24 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + time26 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() -def test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) + time28 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel, bool_event): + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + +def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event): rng = bm.random.RandomState(seed=seed) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 if not bool_event: events = events.astype(float) + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) + time22 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) + time24 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() + time26 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() - time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo_taichi(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) + time28 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_homo(events, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel, bool_event): +def test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event): rng = bm.random.RandomState(seed=seed) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 if not bool_event: events = events.astype(float) # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time22 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel, bool_event): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 - if not bool_event: - events = events.astype(float) - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) + time24 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) + time26 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() - time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) + time28 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.event_mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_jitconn_matvec_cpu(shape, _type, transpose, outdim_parallel, bool_event): +def test_jitconn_matvec(shape, _type, transpose, outdim_parallel, bool_event): print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) if _type == 'homo': - return test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel, bool_event) + return test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event) elif _type == 'uniform': - return test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel, bool_event) + return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event) elif _type == 'normal': - return test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel, bool_event) + return test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event) else: raise ValueError -def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel, bool_event): - print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) - if _type == 'homo': - return test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel, bool_event) - elif _type == 'uniform': - return test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel, bool_event) - elif _type == 'normal': - return test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel, bool_event) - else: - raise ValueError - PATH = os.path.dirname(os.path.abspath(__file__)) # init dataframe df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'speedup']) + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) ### RECTANGULAR MATRIX if (bm.get_platform() == 'cpu'): @@ -670,11 +542,15 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel, bool_event for _transpose in transpose: for _bool_event in bool_event: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_cpu((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) # append to dataframe df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, _bool_event, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/jitconn_event_matvec_cpu.csv', index=False) if (bm.get_platform() == 'gpu'): @@ -685,24 +561,13 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel, bool_event for _transpose in transpose: for _bool_event in bool_event: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_gpu((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) # append to dataframe df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, _bool_event, - taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/jitconn_event_matvec_gpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) -# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] -# df.to_csv('event_ell_gpu.csv', index=False) - - # df = pd.read_csv('event_ell_gpu.csv') - # for _s in s: - # for _p in p: - # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) - # # 找到对应的行 - # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) diff --git a/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py new file mode 100644 index 000000000..ff4f01afc --- /dev/null +++ b/brainpy/_src/math/jitconn/tests/jitconn_event_matvec_taichi_VS_jitconn_event_matvec_grad.py @@ -0,0 +1,589 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('cpu') +# bm.disable_gpu_memory_preallocation() + +seed = 1234 + +shape = [ + 1000, + 2500, + 5000, + 10000, + 25000, + 37500, + 50000 + ] +types = [ + 'homo', + 'uniform', + 'normal' + ] +transpose = [ + True, + False + ] +outdim_parallel = [ + True, + False, + ] +bool_event = [ + True, + False + ] +conn_prob = 0.05 +homo_data = 1. +w_low = 0. +w_high = 1. +w_mu = 0. +w_sigma = 0.1 + +print(bm.get_platform()) + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs)[0] + return r.sum() + + return func + +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_event_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r +=jax.grad(sum_op(bm.jitconn.event_mv_prob_homo_taichi), argnums=0)( + vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r + +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_event_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.event_mv_prob_homo), argnums=0)( + vector.astype(float), homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_uniform_taichi_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform_taichi), argnums=0)( + vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_uniform_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.event_mv_prob_uniform), argnums=0)( + vector.astype(float), w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_normal_taichi_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.event_mv_prob_normal_taichi), argnums=0)( + vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_event_matvec_normal_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.event_mv_prob_normal), argnums=0)( + vector.astype(float), w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r + +def test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + + time0 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time1 = time.time() + + time2 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time3 = time.time() + + time4 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time5 = time.time() + + time6 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time9 = time.time() + + time10 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + + time12 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time13 = time.time() + + time14 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time15 = time.time() + + time16 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time17 = time.time() + + time18 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_taichi_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time19 = time.time() + + + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + + time20 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() + + time22 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() + + time24 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() + + time26 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() + + time28 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() + + time30 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_homo_grad(events, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') + + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + +def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + + time0 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time1 = time.time() + + time2 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time3 = time.time() + + time4 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time5 = time.time() + + time6 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time9 = time.time() + + time10 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + + time12 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time13 = time.time() + + time14 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time15 = time.time() + + time16 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time17 = time.time() + + time18 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time19 = time.time() + + + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + + time20 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() + + time22 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() + + time24 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() + + time26 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() + + time28 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() + + time30 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') + + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + +def test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) < 0.1 + if not bool_event: + events = events.astype(float) + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + + time0 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time1 = time.time() + + time2 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time3 = time.time() + + time4 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time5 = time.time() + + time6 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time9 = time.time() + + time10 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + + time12 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time13 = time.time() + + time14 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time15 = time.time() + + time16 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time17 = time.time() + + time18 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time19 = time.time() + + + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + + time20 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() + + time22 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() + + time24 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() + + time26 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() + + time28 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() + + time30 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_event_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') + + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + +def test_jitconn_matvec(shape, _type, transpose, outdim_parallel, bool_event): + print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) + if _type == 'homo': + return test_jitconn_matvec_homo(shape, transpose, outdim_parallel, bool_event) + elif _type == 'uniform': + return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel, bool_event) + elif _type == 'normal': + return test_jitconn_matvec_normal(shape, transpose, outdim_parallel, bool_event) + else: + raise ValueError + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', + 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) + + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape1 in shape: + for shape2 in shape: + for _type in types: + for _outdim_parallel in outdim_parallel: + for _transpose in transpose: + for _bool_event in bool_event: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) + # append to dataframe + df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, _bool_event, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/jitconn_event_matvec_grad_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape1 in shape: + for shape2 in shape: + for _type in types: + for _outdim_parallel in outdim_parallel: + for _transpose in transpose: + for _bool_event in bool_event: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel, _bool_event) + # append to dataframe + df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, _bool_event, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/jitconn_event_matvec_grad_gpu.csv', index=False) diff --git a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py index 92def9be6..14a19aefb 100644 --- a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py +++ b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec.py @@ -38,616 +38,489 @@ True, False, ] -conn_prob = 0.1 +bool_event = False +conn_prob = 0.05 homo_data = 1. w_low = 0. w_high = 1. w_mu = 0. w_sigma = 0.1 +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + print(bm.get_platform()) -def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel): +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.mv_prob_uniform_taichi(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_uniform(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.mv_prob_uniform(vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.mv_prob_normal_taichi(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_normal(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += bm.jitconn.mv_prob_normal(vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel) + return r + +def test_jitconn_matvec_homo(shape, transpose, outdim_parallel): rng = bm.random.RandomState(seed=seed) vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo_taichi(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 + time22 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + time24 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time26 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) + time28 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_matvec_homo(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel): +def test_jitconn_matvec_uniform(shape, transpose, outdim_parallel): rng = bm.random.RandomState(seed=seed) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform_taichi(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) + time22 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) + time24 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() + time26 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() - time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_homo_taichi(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) + time28 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_homo(vector, homo_data, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_matvec_uniform(events, w_low, w_high, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel): +def test_jitconn_matvec_normal(shape, transpose, outdim_parallel): rng = bm.random.RandomState(seed=seed) events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_uniform_taichi(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal_taichi(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time19 = time.time() + + + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_uniform(events, w_low=w_low, w_high=w_high, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + time22 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time23 = time.time() - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + time24 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time25 = time.time() - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + time26 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time27 = time.time() -def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel): - rng = bm.random.RandomState(seed=seed) - events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.jitconn.mv_prob_normal_taichi(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) + time28 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.jitconn.mv_prob_normal(events, w_mu=w_mu, w_sigma=w_sigma, conn_prob=conn_prob, shape=shape, seed=seed, outdim_parallel=outdim_parallel, transpose=transpose)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(jitconn_matvec_normal(events, w_mu, w_sigma, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - - -def test_jitconn_matvec_cpu(shape, _type, transpose, outdim_parallel): - print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) - if _type == 'homo': - return test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel) - elif _type == 'uniform': - return test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel) - elif _type == 'normal': - return test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel) - else: - raise ValueError - + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 -def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel): +def test_jitconn_matvec(shape, _type, transpose, outdim_parallel): print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) if _type == 'homo': - return test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel) + return test_jitconn_matvec_homo(shape, transpose, outdim_parallel) elif _type == 'uniform': - return test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel) + return test_jitconn_matvec_uniform(shape, transpose, outdim_parallel) elif _type == 'normal': - return test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel) + return test_jitconn_matvec_normal(shape, transpose, outdim_parallel) else: raise ValueError PATH = os.path.dirname(os.path.abspath(__file__)) # init dataframe -df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', - 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', +df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', 'bool_event', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'speedup']) + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) ### RECTANGULAR MATRIX if (bm.get_platform() == 'cpu'): @@ -657,11 +530,15 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel): for _outdim_parallel in outdim_parallel: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_cpu((shape1, shape2), _type, _transpose, _outdim_parallel) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel) # append to dataframe - df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, + df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, bool_event, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/jitconn_matvec_cpu.csv', index=False) if (bm.get_platform() == 'gpu'): @@ -671,24 +548,13 @@ def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel): for _outdim_parallel in outdim_parallel: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_gpu((shape1, shape2), _type, _transpose, _outdim_parallel) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_jitconn_matvec((shape1, shape2), _type, _transpose, _outdim_parallel) # append to dataframe - df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, + df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, bool_event, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/jitconn_matvec_gpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) -# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] -# df.to_csv('event_ell_gpu.csv', index=False) - - # df = pd.read_csv('event_ell_gpu.csv') - # for _s in s: - # for _p in p: - # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) - # # 找到对应的行 - # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) diff --git a/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py new file mode 100644 index 000000000..165c9b19b --- /dev/null +++ b/brainpy/_src/math/jitconn/tests/jitconn_matvec_taichi_VS_jitconn_matvec_grad.py @@ -0,0 +1,736 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('cpu') + +seed = 1234 + +shape = [ + 1000, + 2500, + 5000, + 10000, + 25000, + 37500, + 50000 + ] +bool_event = False +types = [ + 'homo', + 'uniform', + 'normal' + ] +transpose = [ + True, + False + ] +outdim_parallel = [ + True, + False, + ] +conn_prob = 0.05 +homo_data = 1. +w_low = 0. +w_high = 1. +w_mu = 0. +w_sigma = 0.1 + +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + +print(bm.get_platform()) + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs)[0] + return r.sum() + + return func + +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.mv_prob_homo_taichi), argnums=0)( + vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r + +@partial(jax.jit, static_argnums=(4, 5, 6)) +def jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.mv_prob_homo), argnums=0)( + vector, homo_data, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_uniform_taichi_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.mv_prob_uniform_taichi), argnums=0)( + vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_uniform_grad(vector, w_low, w_high, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.mv_prob_uniform), argnums=0)( + vector, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_normal_taichi_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.mv_prob_normal_taichi), argnums=0)( + vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r + +@partial(jax.jit, static_argnums=(5, 6, 7)) +def jitconn_matvec_normal_grad(vector, w_mu, w_sigma, conn_prob, seed, shape, transpose, outdim_parallel): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.jitconn.mv_prob_normal), argnums=0)( + vector, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel + ) + return r + +def test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_cpu_1: ', brainpy_time1, 'ms') + print('brainpylib_cpu_2: ', brainpy_time2, 'ms') + print('brainpylib_cpu_3: ', brainpy_time3, 'ms') + print('brainpylib_cpu_4: ', brainpy_time4, 'ms') + print('brainpylib_cpu_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + vector = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_homo_taichi_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_homo_grad(vector, homo_data, conn_prob, seed, shape=shape, outdim_parallel=outdim_parallel, transpose=transpose)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_2: ', brainpy_time2, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_4: ', brainpy_time4, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_uniform_taichi_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_uniform_grad(events, w_low, w_high, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_2: ', brainpy_time2, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_4: ', brainpy_time4, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + +def test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel): + rng = bm.random.RandomState(seed=seed) + events = bm.as_jax(rng.random(shape[0] if transpose else shape[1])) + + # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) + + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + # time.sleep(2) + + time0 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time1 = time.time() + # time.sleep(2) + + time2 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time3 = time.time() + # time.sleep(2) + + time4 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time5 = time.time() + # time.sleep(2) + + time6 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time7 = time.time() + + time8 = time.time() + result1 = jax.block_until_ready(jitconn_matvec_normal_taichi_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time9 = time.time() + + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) +# print(result1[0]) +# print(result2) +# print(groundtruth - result1[0]) +# print(groundtruth - result2) + + # print(result1[0] - result2) + # print(bm.allclose(groundtruth, result1[0])) + # print(bm.allclose(groundtruth, result2)) + # assert bm.allclose(result1[0], result2) + + time12 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time13 = time.time() + # time.sleep(2) + + time14 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time15 = time.time() + # time.sleep(2) + + time16 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time17 = time.time() + # time.sleep(2) + + time18 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time19 = time.time() + + time20 = time.time() + result2 = jax.block_until_ready(jitconn_matvec_normal_grad(events, w_mu, w_sigma, conn_prob, seed, shape=shape, transpose=transpose, outdim_parallel=outdim_parallel)) + time21 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + brainpy_time1 = (time13 - time12) * 1000 + brainpy_time2 = (time15 - time14) * 1000 + brainpy_time3 = (time17 - time16) * 1000 + brainpy_time4 = (time19 - time18) * 1000 + brainpy_time5 = (time21 - time20) * 1000 + + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_2: ', taichi_aot_time2, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_4: ', taichi_aot_time4, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_2: ', brainpy_time2, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_4: ', brainpy_time4, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + # assert(jnp.allclose(result1[0], result2)) + + speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ + (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + + +def test_jitconn_matvec_cpu(shape, _type, transpose, outdim_parallel): + print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) + if _type == 'homo': + return test_jitconn_matvec_homo_cpu(shape, transpose, outdim_parallel) + elif _type == 'uniform': + return test_jitconn_matvec_uniform_cpu(shape, transpose, outdim_parallel) + elif _type == 'normal': + return test_jitconn_matvec_normal_cpu(shape, transpose, outdim_parallel) + else: + raise ValueError + + +def test_jitconn_matvec_gpu(shape, _type, transpose, outdim_parallel): + print('shape: ', shape, ' type: ', _type, ' transpose: ', transpose, ' outdim_parallel: ', outdim_parallel) + if _type == 'homo': + return test_jitconn_matvec_homo_gpu(shape, transpose, outdim_parallel) + elif _type == 'uniform': + return test_jitconn_matvec_uniform_gpu(shape, transpose, outdim_parallel) + elif _type == 'normal': + return test_jitconn_matvec_normal_gpu(shape, transpose, outdim_parallel) + else: + raise ValueError + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame(columns=['shape[0]', 'shape[1]', 'backend', 'type', 'transpose', 'outdim_parallel', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', + 'speedup']) + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape1 in shape: + for shape2 in shape: + for _type in types: + for _outdim_parallel in outdim_parallel: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_cpu((shape1, shape2), _type, _transpose, _outdim_parallel) + # append to dataframe + df.loc[df.shape[0]] = [shape1, shape2, 'cpu', _type, _transpose, _outdim_parallel, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + df.to_csv(f'{PATH}/jitconn_matvec_grad_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape1 in shape: + for shape2 in shape: + for _type in types: + for _outdim_parallel in outdim_parallel: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_jitconn_matvec_gpu((shape1, shape2), _type, _transpose, _outdim_parallel) + # append to dataframe + df.loc[df.shape[0]] = [shape1, shape2, 'gpu', _type, _transpose, _outdim_parallel, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + df.to_csv(f'{PATH}/jitconn_matvec_grad_gpu.csv', index=False) diff --git a/brainpy/_src/math/op_register/taichi_aot_based.py b/brainpy/_src/math/op_register/taichi_aot_based.py index 06d0508a1..ab7b98011 100644 --- a/brainpy/_src/math/op_register/taichi_aot_based.py +++ b/brainpy/_src/math/op_register/taichi_aot_based.py @@ -2,6 +2,7 @@ import inspect import os import pathlib +import platform import re from functools import partial, reduce from typing import Any, Sequence @@ -11,8 +12,8 @@ from jax.interpreters import xla from jax.lib import xla_client -from .utils import _shape_to_layout from brainpy._src.dependency_check import import_taichi, import_brainpylib_cpu_ops, import_brainpylib_gpu_ops +from .utils import _shape_to_layout ### UTILS ### @@ -36,33 +37,42 @@ def encode_md5(source: str) -> str: return md5.hexdigest() +# TODO +# not a very good way # get source with dependencies def get_source_with_dependencies(func, visited=None): if visited is None: visited = set() source = inspect.getsource(func) - if func in visited: return '' visited.add(func) - module = inspect.getmodule(func) - dependent_funcs = re.findall(r'(\w+)\(', source) for func_name in dependent_funcs: dependent_func = getattr(module, func_name, None) if callable(dependent_func): source += get_source_with_dependencies(dependent_func, visited) - return source +# check if Metal is supported +def is_metal_supported(): + # first check if we are on macOS + if platform.system() != 'Darwin': + return False + if platform.processor() != 'arm': + return False + return True + + ### VARIABLES ### home_path = get_home_dir() kernels_aot_path = os.path.join(home_path, '.brainpy', 'kernels') +is_metal_device = is_metal_supported() # check if a kernel exists in the database @@ -107,7 +117,9 @@ def _array_to_field(dtype, shape) -> Any: elif dtype == np.float64: dtype = ti.float64 else: - raise TypeError + raise NotImplementedError(f'Currently we do not support dtype {dtype} in Taichi. ' + f'If you think it is necessary, please open an issue at ' + f'https://github.com/brainpy/BrainPy/issues/new') return ti.field(dtype=dtype, shape=shape) @@ -122,11 +134,16 @@ def _build_kernel( ti = import_taichi() # init arch - arch = None if device == 'cpu': - arch = ti.x64 + if is_metal_device: + arch = ti.arm64 + device = 'arm64' + else: + arch = ti.x64 elif device == 'gpu': arch = ti.cuda + else: + raise ValueError(f'Unknown device: {device}') ti.init(arch=arch) @@ -328,9 +345,14 @@ def _compile_kernel(kernel, c, platform, *ins, **kwargs): def _taichi_cpu_translation_rule(kernel, c, *ins, **kwargs): in_out_info = _compile_kernel(kernel, c, 'cpu', *ins, **kwargs) ins = [xla_client.ops.Constant(c, v) for v in in_out_info] + list(ins) + if is_metal_device: + fn = b'taichi_kernel_aot_call_cpu_arm64' + else: + fn = b'taichi_kernel_aot_call_cpu' + return xla_client.ops.CustomCallWithLayout( c, - b'taichi_kernel_aot_call_cpu', + fn, operands=ins, operand_shapes_with_layout=tuple(c.get_shape(value) for value in ins), shape_with_layout=xla_client.Shape.tuple_shape( diff --git a/brainpy/_src/math/sparse/_csr_mv_taichi.py b/brainpy/_src/math/sparse/_csr_mv_taichi.py index 73812d44b..cd09af08e 100644 --- a/brainpy/_src/math/sparse/_csr_mv_taichi.py +++ b/brainpy/_src/math/sparse/_csr_mv_taichi.py @@ -61,8 +61,8 @@ def _sparse_csr_matvec_homo_cpu(values: ti.types.ndarray(ndim=1), for row_i in range(row_ptr.shape[0] - 1): r = 0. for j in range(row_ptr[row_i], row_ptr[row_i + 1]): - r += value * vector[col_indices[j]] - out[row_i] = r + r += vector[col_indices[j]] + out[row_i] = r * value @ti.kernel @@ -115,9 +115,9 @@ def _sparse_csr_matvec_homo_gpu(values: ti.types.ndarray(ndim=1), j = row_ptr[row_i] + index end_index = row_ptr[row_i + 1] while j < end_index: - r += value * vector[col_indices[j]] + r += vector[col_indices[j]] j += 32 - out[row_i] += r # TODO: warp-level primitive + out[row_i] += value * r @ti.kernel @@ -285,4 +285,4 @@ def _define_op(cpu_kernel, gpu_kernel): # no transpose heter _csr_matvec_heter_p = _define_op(cpu_kernel=_sparse_csr_matvec_heter_cpu, - gpu_kernel=_sparse_csr_matvec_heter_gpu) + gpu_kernel=_sparse_csr_matvec_heter_gpu) \ No newline at end of file diff --git a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py index 8ff6e1481..1db246212 100644 --- a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py +++ b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv.py @@ -12,7 +12,7 @@ import pandas as pd import taichi as ti -bm.set_platform('gpu') +bm.set_platform('cpu') s = [1000, 5000, 10000, 15000, 20000, 25000, 30000] p = [0.1, 0.2, 0.3, 0.4, 0.5] @@ -38,520 +38,213 @@ ] method = 'cusparse' -print(bm.get_platform()) - -def test_sparse_csrmv_cpu(shape, values_type, events_type, transpose): - rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.3)(*shape).require('pre2post') - vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 - weight = 1. - - if values_type == 'heter': - heter_data = bm.ones(indices.shape) * weight - weight = heter_data - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time7 = time.time() +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 - time8 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time9 = time.time() +print(bm.get_platform()) - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) +@partial(jax.jit, static_argnums=(4, 5)) +def csrmv_taichi(weight, indices, indptr, vector, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)[0] + return r - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_sparse_csrmv_gpu(shape, values_type, events_type, transpose): +@partial(jax.jit, static_argnums=(4, 5)) +def csrmv(weight, indices, indptr, vector, shape, transpose): + r = 0 + for i in range(ITERATION): + r += bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose) + return r + +def test_sparse_csrmv(shape, values_type, events_type, transpose): rng = bm.random.RandomState(seed=1234) - indices, indptr = bp.conn.FixedProb(0.3)(*shape).require('pre2post') + indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post') vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 weight = 1. + + if events_type == 'float': + vector = vector.astype(bm.float32) if values_type == 'heter': heter_data = bm.ones(indices.shape) * weight weight = heter_data - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - - - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - # time.sleep(2) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time0 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time1 = time.time() - # time.sleep(2) time2 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time3 = time.time() - # time.sleep(2) time4 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time5 = time.time() - # time.sleep(2) time6 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time7 = time.time() time8 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time9 = time.time() - - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - + time10 = time.time() + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time11 = time.time() + time12 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time13 = time.time() - # time.sleep(2) - + time14 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time15 = time.time() - # time.sleep(2) - + time16 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time17 = time.time() - # time.sleep(2) - + time18 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) - time21 = time.time() - - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - - # assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - - -def test_sparse_csrmv_square_cpu(s, p, values_type, events_type, transpose): - print('s: ', s, 'p: ', p) - k = int(s * p) - rng = bm.random.RandomState(seed=1234) - # init - indices = bm.random.randint(0, s, (s, k)) - vector = rng.random(s) - weight = jnp.array([1.0]) - csr_indices = indices.flatten() - csr_indptr = np.cumsum(np.insert(np.ones(s, dtype=int) * k, 0, 0)) - - pre_indices = np.repeat(np.arange(s), k) - dense = np.zeros((s, s)) - dense[pre_indices, csr_indices] = 1.0 - - if values_type == 'heter': - heter_data = bm.as_jax(rng.random(csr_indices.shape)) - weight = heter_data - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - # time.sleep(2) - - time0 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time1 = time.time() - # time.sleep(2) - - time2 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time3 = time.time() - # time.sleep(2) - - time4 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) -# print(result1[0]) -# print(result2) -# print(groundtruth - result1[0]) -# print(groundtruth - result2) - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time13 = time.time() - # time.sleep(2) - time14 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time19 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time20 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) time21 = time.time() - taichi_aot_time1 = (time1 - time0) * 1000 - taichi_aot_time2 = (time3 - time2) * 1000 - taichi_aot_time3 = (time5 - time4) * 1000 - taichi_aot_time4 = (time7 - time6) * 1000 - taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - - print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') - print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') - print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_cpu_1: ', brainpy_time1, 'ms') - print('brainpylib_cpu_2: ', brainpy_time2, 'ms') - print('brainpylib_cpu_3: ', brainpy_time3, 'ms') - print('brainpylib_cpu_4: ', brainpy_time4, 'ms') - print('brainpylib_cpu_5: ', brainpy_time5, 'ms') - assert(jnp.allclose(result1[0], result2)) - - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 - - return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup - -def test_sparse_csrmv_square_gpu(s, p, values_type, events_type, transpose): - print('s: ', s, 'p: ', p) - k = int(s * p) - bm.random.seed(1234) - rng = bm.random.RandomState(seed=1234) - # init - indices = bm.random.randint(0, s, (s, k)) - vector = rng.random(s) - weight = jnp.array([1.0]) - csr_indices = indices.flatten() - csr_indptr = np.cumsum(np.insert(np.ones(s, dtype=int) * k, 0, 0)) - pre_indices = np.repeat(np.arange(s), k) - dense = np.zeros((s, s)) - dense[pre_indices, csr_indices] = 1.0 - - if values_type == 'heter': - heter_data = bm.as_jax(rng.random(csr_indices.shape)) - weight = heter_data - - # groundtruth = bm.as_jax(vector, dtype=float) @ bm.as_jax(dense) - - - - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - # time.sleep(2) + time22 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time23 = time.time() - time0 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time1 = time.time() - # time.sleep(2) + time24 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time25 = time.time() - time2 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time3 = time.time() - # time.sleep(2) + time26 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time27 = time.time() - time4 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time5 = time.time() - # time.sleep(2) - - time6 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time7 = time.time() - - time8 = time.time() - result1 = jax.block_until_ready(bm.sparse.csrmv_taichi(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose)) - time9 = time.time() - - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) - # print('--------------------result1[0]------------------') - # print(result1[0]) - # print('--------------------result2------------------') - # print(result2) - # print('--------------------gt - result1[0]------------------') - # print(groundtruth - result1[0]) - # print('--------------------gt - result2------------------') - # print(groundtruth - result2) + time28 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time29 = time.time() - # print(result1[0] - result2) - # print(bm.allclose(groundtruth, result1[0])) - # print(bm.allclose(groundtruth, result2)) - # assert bm.allclose(result1[0], result2) - - time12 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) - time13 = time.time() - # time.sleep(2) - - time14 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) - time15 = time.time() - # time.sleep(2) - - time16 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) - time17 = time.time() - # time.sleep(2) - - time18 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) - time19 = time.time() - - time20 = time.time() - result2 = jax.block_until_ready(bm.sparse.csrmv(weight, csr_indices, csr_indptr, vector, shape=(s, s), transpose=transpose, method=method)) - time21 = time.time() + time30 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(csrmv(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time39 = time.time() taichi_aot_time1 = (time1 - time0) * 1000 taichi_aot_time2 = (time3 - time2) * 1000 taichi_aot_time3 = (time5 - time4) * 1000 taichi_aot_time4 = (time7 - time6) * 1000 taichi_aot_time5 = (time9 - time8) * 1000 - brainpy_time1 = (time13 - time12) * 1000 - brainpy_time2 = (time15 - time14) * 1000 - brainpy_time3 = (time17 - time16) * 1000 - brainpy_time4 = (time19 - time18) * 1000 - brainpy_time5 = (time21 - time20) * 1000 - + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) print('taichi_aot_1: ', taichi_aot_time1, 'ms') - print('taichi_aot_2: ', taichi_aot_time2, 'ms') print('taichi_aot_3: ', taichi_aot_time3, 'ms') - print('taichi_aot_4: ', taichi_aot_time4, 'ms') print('taichi_aot_5: ', taichi_aot_time5, 'ms') - print('brainpylib_gpu_1: ', brainpy_time1, 'ms') - print('brainpylib_gpu_2: ', brainpy_time2, 'ms') - print('brainpylib_gpu_3: ', brainpy_time3, 'ms') - print('brainpylib_gpu_4: ', brainpy_time4, 'ms') - print('brainpylib_gpu_5: ', brainpy_time5, 'ms') - - # assert(jnp.allclose(result1[0], result2)) + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') - speedup = (brainpy_time1 + brainpy_time2 + brainpy_time3 + brainpy_time4 + brainpy_time5) / \ - (taichi_aot_time1 + taichi_aot_time2 + taichi_aot_time3 + taichi_aot_time4 + taichi_aot_time5) - 1 return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ - brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, speedup + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + PATH = os.path.dirname(os.path.abspath(__file__)) # init dataframe df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', - 'speedup']) - -### SQUARE MATRIX -# if (bm.get_platform() == 'cpu'): -# for _s in s: -# for _p in p: -# for _values_type in values_type: -# for _events_type in events_type: -# for _transpose in transpose: -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_sparse_csrmv_square_cpu(_s, _p, _values_type, _events_type, _transpose) -# # append to dataframe -# df.loc[df.shape[0]] = [_s, _p, 'cpu', _values_type, _events_type, _transpose, -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] -# df.to_csv(f'{PATH}/csrmv_square_cpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# for _values_type in values_type: -# for _events_type in events_type: -# for _transpose in transpose: -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_sparse_csrmv_square_gpu(_s, _p, _values_type, _events_type, _transpose) -# # append to dataframe -# df.loc[df.shape[0]] = [_s, _p, 'gpu', _values_type, _events_type, _transpose, -# taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, -# brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] -# df.to_csv(f'{PATH}/csrmv_square_gpu.csv', index=False) + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) ### RECTANGULAR MATRIX if (bm.get_platform() == 'cpu'): for shape1 in shape: for shape2 in shape: - for _values_type in values_type: + for _values_type in values_type: for _events_type in events_type: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_sparse_csrmv_cpu((shape1, shape2), _values_type, _events_type, _transpose) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose) # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.3 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/csrmv_cpu.csv', index=False) if (bm.get_platform() == 'gpu'): for shape1 in shape: for shape2 in shape: - for _values_type in values_type: + for _values_type in values_type: for _events_type in events_type: for _transpose in transpose: taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup = test_sparse_csrmv_gpu((shape1, shape2), _values_type, _events_type, _transpose) + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose) # append to dataframe - df.loc[df.shape[0]] = [(shape1, shape2), 0.3 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, - brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, speedup] + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] df.to_csv(f'{PATH}/csrmv_gpu.csv', index=False) - -# if (bm.get_platform() == 'gpu'): -# for _s in s: -# for _p in p: -# taichi_aot_avg_time = test_event_ell_gpu_taichi(_s, _p) -# df.loc[df.shape[0]] = [_s, _p, 'gpu', block_dim, taichi_aot_avg_time, 0] -# df.to_csv('event_ell_gpu.csv', index=False) - - # df = pd.read_csv('event_ell_gpu.csv') - # for _s in s: - # for _p in p: - # brainpy_avg_time = test_event_ell_gpu_brainpylib(_s, _p) - # # 找到对应的行 - # df.loc[(df['s'] == _s) & (df['p'] == _p) & (df['backend'] == 'gpu'), 'brainpy avg time(ms)'] = brainpy_avg_time - # df.to_csv('event_ell_gpu.csv', index=False) diff --git a/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py new file mode 100644 index 000000000..d902c9395 --- /dev/null +++ b/brainpy/_src/math/sparse/tests/csrmv_taichi_VS_csrmv_grad.py @@ -0,0 +1,273 @@ +# from jax_taichi import jax_taichi_call + +import time +from functools import partial +import os + +import brainpy as bp +import brainpy.math as bm +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +import taichi as ti + +bm.set_platform('cpu') + +s = [1000, + 5000, + 10000, + 15000, + 20000, + 25000, + 30000] +p = [0.1, 0.2, 0.3, 0.4, 0.5] + +shape = [ + 1000, + 2500, + 5000, + 10000, + 25000, + 37500, + 50000 +] + +values_type = [ + 'homo', + 'heter' + ] +events_type = ['float'] +transpose = [ + True, + False + ] +method = 'cusparse' + +ITERATION = 100 +if bm.get_platform() == 'cpu': + ITERATION = 10 + +print(bm.get_platform()) + +def sum_op(op): + def func(*args, **kwargs): + r = op(*args, **kwargs) + return r.sum() + + return func + + +def sum_op2(op): + def func(*args, **kwargs): + r = op(*args, **kwargs)[0] + return r.sum() + + return func + +@partial(jax.jit, static_argnums=(4, 5)) +def csrmv_taichi_grad(weight, indices, indptr, vector, shape, transpose): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op2(bm.sparse.csrmv_taichi), argnums=3)( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + return r + +@partial(jax.jit, static_argnums=(4, 5)) +def csrmv_grad(weight, indices, indptr, vector, shape, transpose): + r = 0 + for i in range(ITERATION): + r += jax.grad(sum_op(bm.sparse.csrmv), argnums=3)( + weight, indices, indptr, vector.astype(float), shape=shape, transpose=transpose) + return r + +def test_sparse_csrmv(shape, values_type, events_type, transpose): + rng = bm.random.RandomState(seed=1234) + indices, indptr = bp.conn.FixedProb(0.05, seed=1234, allow_multi_conn=True)(*shape).require('pre2post') + vector = rng.random(shape[0] if transpose else shape[1]) < 0.1 + weight = 1. + + + if events_type == 'float': + vector = vector.astype(bm.float32) + if values_type == 'heter': + heter_data = bm.ones(indices.shape) * weight + weight = heter_data + + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + + time0 = time.time() + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time1 = time.time() + + time2 = time.time() + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time3 = time.time() + + time4 = time.time() + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time5 = time.time() + + time6 = time.time() + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time7 = time.time() + + time8 = time.time() + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time9 = time.time() + + time10 = time.time() + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time11 = time.time() + + time12 = time.time() + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time13 = time.time() + + time14 = time.time() + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time15 = time.time() + + time16 = time.time() + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time17 = time.time() + + time18 = time.time() + result = jax.block_until_ready(csrmv_taichi_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time19 = time.time() + + + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + + time20 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time21 = time.time() + + time22 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time23 = time.time() + + time24 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time25 = time.time() + + time26 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time27 = time.time() + + time28 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time29 = time.time() + + time30 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time31 = time.time() + + time32 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time33 = time.time() + + time34 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time35 = time.time() + + time36 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time37 = time.time() + + time38 = time.time() + result = jax.block_until_ready(csrmv_grad(weight, indices, indptr, vector, shape=shape, transpose=transpose)) + time39 = time.time() + + taichi_aot_time1 = (time1 - time0) * 1000 + taichi_aot_time2 = (time3 - time2) * 1000 + taichi_aot_time3 = (time5 - time4) * 1000 + taichi_aot_time4 = (time7 - time6) * 1000 + taichi_aot_time5 = (time9 - time8) * 1000 + taichi_aot_time6 = (time11 - time10) * 1000 + taichi_aot_time7 = (time13 - time12) * 1000 + taichi_aot_time8 = (time15 - time14) * 1000 + taichi_aot_time9 = (time17 - time16) * 1000 + taichi_aot_time10 = (time19 - time18) * 1000 + brainpy_time1 = (time21 - time20) * 1000 + brainpy_time2 = (time23 - time22) * 1000 + brainpy_time3 = (time25 - time24) * 1000 + brainpy_time4 = (time27 - time26) * 1000 + brainpy_time5 = (time29 - time28) * 1000 + brainpy_time6 = (time31 - time30) * 1000 + brainpy_time7 = (time33 - time32) * 1000 + brainpy_time8 = (time35 - time34) * 1000 + brainpy_time9 = (time37 - time36) * 1000 + brainpy_time10 = (time39 - time38) * 1000 + print('shape: ', shape, 'values_type: ', values_type, 'events_type: ', events_type, 'transpose: ', transpose) + print('taichi_aot_1: ', taichi_aot_time1, 'ms') + print('taichi_aot_3: ', taichi_aot_time3, 'ms') + print('taichi_aot_5: ', taichi_aot_time5, 'ms') + print('taichi_aot_7: ', taichi_aot_time7, 'ms') + print('taichi_aot_9: ', taichi_aot_time9, 'ms') + print('brainpylib_1: ', brainpy_time1, 'ms') + print('brainpylib_3: ', brainpy_time3, 'ms') + print('brainpylib_5: ', brainpy_time5, 'ms') + print('brainpylib_7: ', brainpy_time7, 'ms') + print('brainpylib_9: ', brainpy_time9, 'ms') + + + return taichi_aot_time1, taichi_aot_time2, taichi_aot_time3, taichi_aot_time4, taichi_aot_time5,\ + taichi_aot_time6, taichi_aot_time7, taichi_aot_time8, taichi_aot_time9, taichi_aot_time10,\ + brainpy_time1, brainpy_time2, brainpy_time3, brainpy_time4, brainpy_time5, \ + brainpy_time6, brainpy_time7, brainpy_time8, brainpy_time9, brainpy_time10 + +PATH = os.path.dirname(os.path.abspath(__file__)) + +# init dataframe +df = pd.DataFrame(columns=['s', 'p', 'shape[0]', 'shape[1]', 'backend', 'values type', 'events type', 'transpose', + 'taichi aot time1(ms)', 'taichi aot time2(ms)', 'taichi aot time3(ms)', 'taichi aot time4(ms)', 'taichi aot time5(ms)', + 'taichi aot time6(ms)', 'taichi aot time7(ms)', 'taichi aot time8(ms)', 'taichi aot time9(ms)', 'taichi aot time10(ms)', + 'brainpy time1(ms)', 'brainpy time2(ms)', 'brainpy time3(ms)', 'brainpy time4(ms)', 'brainpy time5(ms)', + 'brainpy time6(ms)', 'brainpy time7(ms)', 'brainpy time8(ms)', 'brainpy time9(ms)', 'brainpy time10(ms)']) + + +### RECTANGULAR MATRIX +if (bm.get_platform() == 'cpu'): + for shape1 in shape: + for shape2 in shape: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose) + # append to dataframe + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'cpu', _values_type, _events_type, _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/csrmv_grad_cpu.csv', index=False) + +if (bm.get_platform() == 'gpu'): + for shape1 in shape: + for shape2 in shape: + for _values_type in values_type: + for _events_type in events_type: + for _transpose in transpose: + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5,\ + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10,\ + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, \ + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10 = test_sparse_csrmv((shape1, shape2), _values_type, _events_type, _transpose) + # append to dataframe + df.loc[df.shape[0]] = [(shape1, shape2), 0.5 , shape1, shape2, 'gpu', _values_type, _events_type, _transpose, + taichi_aot_time_1, taichi_aot_time_2, taichi_aot_time_3, taichi_aot_time_4, taichi_aot_time_5, + taichi_aot_time_6, taichi_aot_time_7, taichi_aot_time_8, taichi_aot_time_9, taichi_aot_time_10, + brainpy_time_1, brainpy_time_2, brainpy_time_3, brainpy_time_4, brainpy_time_5, + brainpy_time_6, brainpy_time_7, brainpy_time_8, brainpy_time_9, brainpy_time_10] + df.to_csv(f'{PATH}/csrmv_grad_gpu.csv', index=False)