From 4b0d0f3c53814da09e45e61fe0f236ad6f2582fb Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 23 May 2023 20:27:25 +0800 Subject: [PATCH 1/2] [bug] fix `brainpy.connect.FixedProb` bug --- brainpy/_src/connect/random_conn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainpy/_src/connect/random_conn.py b/brainpy/_src/connect/random_conn.py index 239bb41cf..cbcdeabaf 100644 --- a/brainpy/_src/connect/random_conn.py +++ b/brainpy/_src/connect/random_conn.py @@ -131,7 +131,7 @@ def build_mat(self): mat = (self._jaxrand.uniform(size=(self.pre_num, self.post_num)) < self.prob) * pre_state mat = bm.asarray(mat) if not self.include_self: - mat = bm.fill_diagonal(mat, False) + bm.fill_diagonal(mat, False) return mat.astype(MAT_DTYPE) From 20819c4886f3ef5a9def878b44694545738ac77c Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 23 May 2023 20:42:11 +0800 Subject: [PATCH 2/2] [bug] fix --- brainpy/_src/layers/linear.py | 18 + .../{test_csrmv.py => test_event_csrmv.py} | 0 ...t_csrmv_gpu.py => test_event_csrmv_gpu.py} | 4 +- brainpy/_src/math/sparse/_bsr_mm.py | 801 +++++++++--------- brainpy/_src/math/sparse/_bsr_mv.py | 80 +- 5 files changed, 450 insertions(+), 453 deletions(-) rename brainpy/_src/math/event/tests/{test_csrmv.py => test_event_csrmv.py} (100%) rename brainpy/_src/math/event/tests/{test_csrmv_gpu.py => test_event_csrmv_gpu.py} (71%) diff --git a/brainpy/_src/layers/linear.py b/brainpy/_src/layers/linear.py index 3d0c2025d..492e8bd44 100644 --- a/brainpy/_src/layers/linear.py +++ b/brainpy/_src/layers/linear.py @@ -207,3 +207,21 @@ def __init__(self, *args, **kwargs) -> None: def update(self, x): return x + + +class CSRLinear(Layer): + pass + + +class CSCLinear(Layer): + pass + + +class BSRLinear(Layer): + pass + + +class MatLinear(Layer): + pass + + diff --git a/brainpy/_src/math/event/tests/test_csrmv.py b/brainpy/_src/math/event/tests/test_event_csrmv.py similarity index 100% rename from brainpy/_src/math/event/tests/test_csrmv.py rename to brainpy/_src/math/event/tests/test_event_csrmv.py diff --git a/brainpy/_src/math/event/tests/test_csrmv_gpu.py b/brainpy/_src/math/event/tests/test_event_csrmv_gpu.py similarity index 71% rename from brainpy/_src/math/event/tests/test_csrmv_gpu.py rename to brainpy/_src/math/event/tests/test_event_csrmv_gpu.py index 5e2876968..a5b8df152 100644 --- a/brainpy/_src/math/event/tests/test_csrmv_gpu.py +++ b/brainpy/_src/math/event/tests/test_event_csrmv_gpu.py @@ -4,12 +4,12 @@ import jax import pytest -import test_csrmv +import test_event_csrmv if jax.default_backend() != 'gpu': pytest.skip("No gpu available.", allow_module_level=True) -class Test_event_csr_matvec_GPU(test_csrmv.Test_event_csr_matvec): +class Test_event_csr_matvec_GPU(test_event_csrmv.Test_event_csr_matvec): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs, platform='gpu') diff --git a/brainpy/_src/math/sparse/_bsr_mm.py b/brainpy/_src/math/sparse/_bsr_mm.py index 408991cdc..caba80a9b 100644 --- a/brainpy/_src/math/sparse/_bsr_mm.py +++ b/brainpy/_src/math/sparse/_bsr_mm.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- from functools import partial +from typing import Union, Tuple import jax.lax import numba @@ -16,117 +17,114 @@ from brainpy.errors import GPUOperatorNotFound try: - from brainpylib import gpu_ops + from brainpylib import gpu_ops except ImportError: - gpu_ops = None + gpu_ops = None __all__ = [ - 'blocksparse_matmat', - 'blocksparse_matmat_back' + 'bcsrmm', ] def get_mask(dense_b, blockshape, blockcount): - mask = jnp.zeros(blockcount[0] * blockcount[1], dtype=jnp.bool_) + mask = jnp.zeros(blockcount[0] * blockcount[1], dtype=jnp.bool_) - for i in range(blockcount[1]): - for j in range(blockcount[0]): - if jnp.abs(dense_b[i * blockshape[1]: (i + 1) * blockshape[1], - j * blockshape[0]: (j + 1) * blockshape[0]]).sum() != 0: - mask = mask.at[i * blockcount[0] + j].set(True) - mask = mask.reshape(blockcount[1], blockcount[0]) - return mask + for i in range(blockcount[1]): + for j in range(blockcount[0]): + if jnp.abs(dense_b[i * blockshape[1]: (i + 1) * blockshape[1], + j * blockshape[0]: (j + 1) * blockshape[0]]).sum() != 0: + mask = mask.at[i * blockcount[0] + j].set(True) + mask = mask.reshape(blockcount[1], blockcount[0]) + return mask def get_mask_from_ptr_indices(ptr, indices, blockcount): - - mask = jnp.zeros((blockcount[1], blockcount[0]), dtype=jnp.bool_) - for idx, indice in enumerate(indices): - row_index = 0 - for ptr_ in ptr[1:]: - if idx < ptr_: - break - row_index += 1 - mask = mask.at[row_index, indice].set(True) - return mask + mask = jnp.zeros((blockcount[1], blockcount[0]), dtype=jnp.bool_) + for idx, indice in enumerate(indices): + row_index = 0 + for ptr_ in ptr[1:]: + if idx < ptr_: + break + row_index += 1 + mask = mask.at[row_index, indice].set(True) + return mask def get_data(dense_b, mask, blockshape, blockcount, n_blocks): - data = jnp.zeros( - shape=(n_blocks * blockshape[1], blockshape[0]), - dtype=jnp.float32 - ) - - assignment_count = 0 - for i in range(blockcount[1]): - for j in range(blockcount[0]): - if mask[i, j]: - data = data.at[assignment_count * blockshape[1]: (assignment_count + 1) * blockshape[1], - :].set(dense_b[i * blockshape[1]: (i + 1) * blockshape[1], - j * blockshape[0]: (j + 1) * blockshape[0]]) - assignment_count += 1 - return data + data = jnp.zeros( + shape=(n_blocks * blockshape[1], blockshape[0]), + dtype=jnp.float32 + ) + + assignment_count = 0 + for i in range(blockcount[1]): + for j in range(blockcount[0]): + if mask[i, j]: + data = data.at[assignment_count * blockshape[1]: (assignment_count + 1) * blockshape[1], + :].set(dense_b[i * blockshape[1]: (i + 1) * blockshape[1], + j * blockshape[0]: (j + 1) * blockshape[0]]) + assignment_count += 1 + return data def get_ptr_indices(mask, blockcount, n_blocks, block_ptr=None): - nnz = jnp.nonzero(mask) + nnz = jnp.nonzero(mask) - if block_ptr is None: - block_ptr = jnp.arange(0, len(nnz[0])) + if block_ptr is None: + block_ptr = jnp.arange(0, len(nnz[0])) - indices = jnp.argsort(block_ptr) - _ = jnp.take(block_ptr, indices) + indices = jnp.argsort(block_ptr) + _ = jnp.take(block_ptr, indices) - blocks = nnz[0][jnp.array(indices)], nnz[1][jnp.array(indices)] - blocks = jnp.stack([nnz[0][jnp.array(indices)], nnz[1][jnp.array(indices)]], axis=-1).astype( - dtype=jnp.int32 - ) - blocks = jnp.flip(blocks, axis=-1).flatten() + blocks = nnz[0][jnp.array(indices)], nnz[1][jnp.array(indices)] + blocks = jnp.stack([nnz[0][jnp.array(indices)], nnz[1][jnp.array(indices)]], axis=-1).astype( + dtype=jnp.int32 + ) + blocks = jnp.flip(blocks, axis=-1).flatten() - X = blockcount[1] - Y = blockcount[0] + X = blockcount[1] + Y = blockcount[0] - rows = nnz[0][:] - cols = nnz[1][:] + rows = nnz[0][:] + cols = nnz[1][:] - block_indices = jnp.zeros(X * Y, dtype=jnp.int32) - positions = rows * Y + cols - for position in positions: - block_indices = block_indices.at[positions].set(block_ptr + 1) - block_indices = block_indices.reshape(X, Y).transpose().reshape(X * Y) + block_indices = jnp.zeros(X * Y, dtype=jnp.int32) + positions = rows * Y + cols + block_indices = block_indices.at[positions].set(block_ptr + 1) + block_indices = block_indices.reshape(X, Y).transpose().reshape(X * Y) - block_ptr = block_indices[jnp.nonzero(block_indices)[0]] - 1 + block_ptr = block_indices[jnp.nonzero(block_indices)[0]] - 1 - X, Y = Y, X - rows = cols - nnztt = jnp.nonzero(mask.transpose()) - cols = nnztt[:][1] + X, Y = Y, X + rows = cols + nnztt = jnp.nonzero(mask.transpose()) + cols = nnztt[:][1] - rows.astype(jnp.int32) + rows.astype(jnp.int32) - ptr_b = jnp.zeros((X + 1,), dtype=jnp.int32) - for row in rows: - ptr_b = ptr_b.at[row + 1].set(ptr_b[row + 1] + 1) - ptr_b = ptr_b.cumsum(0).astype(dtype=jnp.int32) + ptr_b = jnp.zeros((X + 1,), dtype=jnp.int32) + for row in rows: + ptr_b = ptr_b.at[row + 1].set(ptr_b[row + 1] + 1) + ptr_b = ptr_b.cumsum(0).astype(dtype=jnp.int32) - indices_b = jnp.stack([cols, block_ptr], axis=1).astype(dtype=jnp.int32) + indices_b = jnp.stack([cols, block_ptr], axis=1).astype(dtype=jnp.int32) - return ptr_b, indices_b + return ptr_b, indices_b -def get_dense(ptr, indices, data, shape, blockshape): - mask = get_mask_from_ptr_indices(ptr, indices, blockshape) - dense_data = jnp.zeros(shape, dtype=jnp.float32) - mask_count = 0 - for i in range(mask.shape[1]): - for j in range(mask.shape[0]): - if mask[i, j]: - dense_data = dense_data.at[ - i * blockshape[0]: (i + 1) * blockshape[0], - j * blockshape[1]: (j + 1) * blockshape[1], - ].set(data[mask_count * blockshape[0]: (mask_count + 1) * blockshape[0], :]) - mask_count += 1 - return dense_data +def get_dense(ptr, indices, data, shape, blockshape): + mask = get_mask_from_ptr_indices(ptr, indices, blockshape) + dense_data = jnp.zeros(shape, dtype=jnp.float32) + mask_count = 0 + for i in range(mask.shape[1]): + for j in range(mask.shape[0]): + if mask[i, j]: + dense_data = dense_data.at[ + i * blockshape[0]: (i + 1) * blockshape[0], + j * blockshape[1]: (j + 1) * blockshape[1], + ].set(data[mask_count * blockshape[0]: (mask_count + 1) * blockshape[0], :]) + mask_count += 1 + return dense_data def blocksparse_matmat_multiply(dense_a, @@ -137,350 +135,329 @@ def blocksparse_matmat_multiply(dense_a, dense_b=None, blockshape=(32, 32), device='cpu'): - if dense_b is not None: - # m, n, k - m = dense_a.shape[0] - n = dense_b.shape[1] - k = dense_a.shape[1] - - # blockcount - blockcount = (n // blockshape[0], k // blockshape[1]) - - # mask - mask = get_mask(dense_b, blockshape, blockcount) - - # n_blocks - n_blocks = mask.sum() - - # data_b - data_b = get_data(dense_b, mask, blockshape, blockcount, n_blocks) - - # ptr_b, indices_b - ptr_b, indices_b = get_ptr_indices(mask, blockcount, n_blocks) - else: - # m, n, k - m = dense_a.shape[0] - n = shape_b[1] - k = dense_a.shape[1] - - # blockcount - blockcount = (n // blockshape[0], k // blockshape[1]) - - mask = get_mask_from_ptr_indices(ptr_b, indices_b, blockcount) - - n_blocks = mask.sum() - - ptr_b, indices_b = get_ptr_indices(mask, blockcount, n_blocks) - - # out - # out = jnp.zeros((n, m)) - - # verbose - print('data_b: ', data_b) - print('ptr:', ptr_b) - print('indices:', indices_b) - - '''out = blocksparse_matmat_cpu_test(dense_a, - ptr_b, - indices_b, - data_b, - out, - m=m, - n=n, - k=k, - block_size_rows_b=blockshape[0], - block_size_cols_b=blockshape[1]) - return out''' - - if device == 'cpu': - out = blocksparse_matmat( - dense_a, - ptr_b, - indices_b, - data_b, - m=m, - n=n, - k=k, - block_size_rows_b=blockshape[0], - block_size_cols_b=blockshape[1], - ) - return out - elif device == 'gpu': - out = blocksparse_matmat( - dense_a, - ptr_b, - indices_b, - data_b, - m=m, - n=n, - k=k, - block_size_rows_b=blockshape[0], - block_size_cols_b=blockshape[1], - ) - return out.transpose() - else: - raise Exception('Invalid device: ', device) - -def blocksparse_matmat( - A_data: jnp.ndarray, - B_ptr: jnp.ndarray, - B_indices: jnp.ndarray, - B_data: jnp.ndarray, - *, - m: int, - n: int, - k: int, - block_size_rows_b: int, - block_size_cols_b: int, -) -> jax.Array: - A_data = as_jax(A_data) - B_ptr = as_jax(B_ptr) - B_indices = as_jax(B_indices) - B_data = as_jax(B_data) - return blocksparse_matmat_p.bind(A_data, - B_ptr, - B_indices, - B_data, - m=m, - n=n, - k=k, - block_size_rows_b=block_size_rows_b, - block_size_cols_b=block_size_cols_b)[0] - -'''def blocksparse_matmat_cpu_test( - A_data, - B_ptr, - B_indices, - B_data, - m, - n, - k, - block_size_rows_b, - block_size_cols_b): - res_val = np.zeros((m, n)) - # index[0]为index, index[1]为该index对应的block的index - for idx, index in enumerate(B_indices): - # find the column - row_index = 0 - for ptr in B_ptr[1:]: - if ptr > idx: - break - row_index += 1 - row_start = row_index * block_size_cols_b - # find the row - col_start = index[0] * block_size_rows_b - # calculate the value and add to the res_val - for i in range(block_size_rows_b): - for j in range(block_size_cols_b): - if B_data[index[1] * block_size_rows_b + i, j] == 0: - continue - row_now = row_start + j - for c_m in range(m): - print('c{c_col}{c_row} = a{a_col}{a_row} * b{b_col}{b_row}'.format(c_col=c_m + 1,c_row=row_now + 1, - a_col=c_m + 1, a_row=col_start + i + 1, - b_col=col_start + i + 1, b_row=row_start + j + 1)) - res_val[c_m, row_now] += A_data[c_m, col_start + i] * B_data[index[1] * block_size_rows_b + i, j] - # res_val[:, row_now + j] += A_data[:, row_now + j] * B_data[index[1] * block_size_rows_b + i, j] - - return res_val''' - -# CPU implement -@numba.njit(fastmath=True, parallel=True, nogil=True) -def _blocksparse_matmat_numba_imp(outs, ins): - res_val = outs[0] - res_val.transpose() - res_val.fill(0) - A_data, B_ptr, B_indices, B_data, m, n, k, block_size_rows_b, block_size_cols_b = ins - m = np.int32(m) - n = np.int32(n) - k = np.int32(k) - block_size_rows_b = np.int32(block_size_rows_b) - block_size_cols_b = np.int32(block_size_cols_b) - - # index[0]为index, index[1]为该index对应的block的index - for idx, index in enumerate(B_indices): - # find the column - row_index = 0 - for ptr in B_ptr[1:]: - if ptr > idx: - break - row_index += 1 - row_start = row_index * block_size_cols_b - # find the row - col_start = index[0] * block_size_rows_b - # calculate the value and add to the res_val - for i in range(block_size_rows_b): - for j in range(block_size_cols_b): - if B_data[index[1] * block_size_rows_b + i, j] == 0: - continue - row_now = row_start + j - col_now = col_start + i - res_val[:, row_now] += A_data[:, col_now] * B_data[index[1] * block_size_rows_b + i, j] - '''for c_m in range(m): - print('c{c_col}{c_row} = a{a_col}{a_row} * b{b_col}{b_row}'.format(c_col=c_m + 1,c_row=row_now + 1, - a_col=c_m + 1, a_row=col_start + i + 1, - b_col=col_start + i + 1, b_row=row_start + j + 1)) - res_val[c_m, row_now] += A_data[c_m, col_start + i] * B_data[index[1] * block_size_rows_b + i, j]''' - # res_val[:, row_now + j] += A_data[:, row_now + j] * B_data[index[1] * block_size_rows_b + i, j] - - return res_val - - -def _blocksparse_matmat_cpu_translation(c, A_data, B_ptr, B_indices, B_data, *, m, n, k, block_size_rows_b, - block_size_cols_b): - inputs = (A_data, B_ptr, B_indices, B_data) - description = dict(m=m, - n=n, - k=k, - block_size_rows_b=block_size_rows_b, - block_size_cols_b=block_size_cols_b) - name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( - c, - _blocksparse_matmat_numba_imp, - abs_eval_fn=_blocksparse_matmat_abstract, - multiple_results=True, - inputs=inputs, - description=description + if dense_b is not None: + # m, n, k + m = dense_a.shape[0] + k = dense_a.shape[1] + n = dense_b.shape[1] + + # blockcount + blockcount = (n // blockshape[0], k // blockshape[1]) + + # mask + mask = get_mask(dense_b, blockshape, blockcount) + + # n_blocks + n_blocks = mask.sum() + + # data_b + data_b = get_data(dense_b, mask, blockshape, blockcount, n_blocks) + + # ptr_b, indices_b + ptr_b, indices_b = get_ptr_indices(mask, blockcount, n_blocks) + else: + # m, n, k + m = dense_a.shape[0] + n = shape_b[1] + k = dense_a.shape[1] + + # blockcount + blockcount = (n // blockshape[0], k // blockshape[1]) + + mask = get_mask_from_ptr_indices(ptr_b, indices_b, blockcount) + + n_blocks = mask.sum() + + ptr_b, indices_b = get_ptr_indices(mask, blockcount, n_blocks) + + # out + # out = jnp.zeros((n, m)) + + # verbose + print('data_b: ', data_b) + print('ptr:', ptr_b) + print('indices:', indices_b) + + '''out = blocksparse_matmat_cpu_test(dense_a, + ptr_b, + indices_b, + data_b, + out, + m=m, + n=n, + k=k, + block_size_k=blockshape[0], + block_size_n=blockshape[1]) + return out''' + + if device == 'cpu': + out = bcsrmm( + dense_a, + ptr_b, + indices_b, + data_b, + m=m, + n=n, + k=k, + block_size_k=blockshape[0], + block_size_n=blockshape[1], ) - return xla_client.ops.CustomCallWithLayout( - c, name, - operands=inputs, - operand_shapes_with_layout=in_layouts, - shape_with_layout=out_layouts, + return out + elif device == 'gpu': + out = bcsrmm( + dense_a, + ptr_b, + indices_b, + data_b, + m=m, + n=n, + k=k, + block_size_k=blockshape[0], + block_size_n=blockshape[1], ) + return out.transpose() + else: + raise Exception('Invalid device: ', device) + + +def bcsrmm( + A_data: jax.Array, + B_data: jax.Array, + B_indices: jax.Array, + B_ptr: jax.Array, + *, + shape: Tuple[int, int], + block_size: Tuple[int, int], + transpose: bool = False, + method: str = 'cutlass' +) -> jax.Array: + """Perform the matrix multiplication :math:`C = A @ B` with BSR data structure. + + Args: + A_data: The dense matrix :math:`A`. + B_data: The data at each block of :math:`B`. + B_indices: The sparse indices of :math:`B`. + B_ptr: The connection pointer of :math:`B`. + shape: a tuple of int, indicating the array shape of :math:`B`. + block_size: a tuple of int, indicating the block size for portioning :math:`B`. + transpose: boolean. If True, perform :math:`A @ B^T`; otherwise, perform :math:`A @ B`. + method: a sting for denoting the BSR sparse computing method. + + Returns: + The dense array :math:`C`. + """ + A_data = as_jax(A_data) + B_data = as_jax(B_data) + B_indices = as_jax(B_indices) + B_ptr = as_jax(B_ptr) + assert A_data.shape[1] == shape[0] + + if method == 'cutlass': + C = _bcsrmm_cutlass_p.bind(A_data, + B_data, + B_indices, + B_ptr, + m=A_data.shape[0], + k=shape[0], + n=shape[1], + transpose=transpose, + block_size_k=block_size[0], + block_size_n=block_size[1])[0] + return C.T + else: + raise ValueError -def _blocksparse_matmat_abstract( - A_data, B_ptr, B_indices, B_data, *, m, n, k, block_size_rows_b, block_size_cols_b -): - shape = (n, m) - dtype = A_data.dtype - out = ShapedArray(dtype=dtype, shape=shape) - return [out] +@numba.njit(fastmath=True, parallel=True, nogil=True) +def _bcsrmm_cutlass_imp_transpose(outs, ins): # dense(m, k) @ bcsr(n, k) -> dense(n, m) + res_val = outs[0] + # B_data: (num_block, block_size_k, block_size_n) + A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins + block_size_k = block_size_k[()] + block_size_n = block_size_n[()] + n_block = n // block_size_n + + for ni in numba.prange(n_block): + C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) + start, end = B_inptr[ni], B_inptr[ni + 1] + ns = ni * block_size_n + ne = ns + block_size_n + for i in range(start, end): + ki = B_indices[i, 0] + ks = ki * block_size_k + ke = ki + block_size_k + bi = B_indices[i, 1] + C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) + res_val[ns: ne] = C_tmp + return res_val -def _blocksparse_matmat_gpu_translation( - c, A_data, B_ptr, B_indices, B_data, *, m, n, k, block_size_rows_b, block_size_cols_b +@numba.njit(fastmath=True, parallel=True, nogil=True) +def _bcsrmm_cutlass_imp2(outs, ins): # dense(m, k) @ bcsr(k, n) -> dense(n, m) + res_val = outs[0] + # B_data: (num_block, block_size_n, block_size_k) + A_data, B_data, B_indices, B_inptr, m, k, n, block_size_k, block_size_n = ins + block_size_k = block_size_k[()] + block_size_n = block_size_n[()] + n_block = n // block_size_n + + for ni in numba.prange(n_block): + C_tmp = np.zeros((block_size_n, m), dtype=A_data.dtype) + start, end = B_inptr[ni], B_inptr[ni + 1] + ns = ni * block_size_n + ne = ns + block_size_n + for i in range(start, end): + ki = B_indices[i, 0] + ks = ki * block_size_k + ke = ki + block_size_k + bi = B_indices[i, 1] + C_tmp += np.matmul(B_data[bi], A_data[:, ks: ke].T) + res_val[ns: ne] = C_tmp + return res_val + + +def _bcsrmm_abstract( + A_data, B_data, B_indices, B_ptr, *, m, k, n, block_size_k, block_size_n ): - if gpu_ops is None: - raise GPUOperatorNotFound(blocksparse_matmat_p.name) - - matrix_info = c.get_shape(A_data) - dtype = matrix_info.element_type() - - opaque = gpu_ops.build_blocksparse_format_descriptor(m, - n, - k, - block_size_rows_b, - block_size_cols_b) - - fn = b'gpu_blocksparse_matmat' - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(A_data, B_ptr, B_indices, B_data,), - operand_shapes_with_layout=(c.get_shape(A_data), - c.get_shape(B_ptr), - c.get_shape(B_indices), - c.get_shape(B_data), ), - shape_with_layout=xla_client.Shape.tuple_shape( - (xla_client.Shape.array_shape(dtype, (m, n), (1, 0)),) - ), - opaque=opaque - ) + assert B_indices.shape[0] * block_size_n == B_data.shape[0] + assert block_size_k == B_data.shape[1] + assert A_data.shape[0] == m + assert A_data.shape[1] == k + assert A_data.dtype == B_data.dtype + assert n // block_size_n + 1 == B_ptr.shape[0] + return [ShapedArray(dtype=A_data.dtype, shape=(n, m))] + + +def _bcsrmm_cpu_translation( + c, A_data, B_data, B_indices, B_ptr, *, + m, k, n, block_size_k, block_size_n +): + inputs = (A_data, B_ptr, B_indices, B_data) + description = dict(m=m, + n=n, + k=k, + block_size_k=block_size_k, + block_size_n=block_size_n) + name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( + c, + _bcsrmm_cutlass_imp2, + abs_eval_fn=_bcsrmm_abstract, + multiple_results=True, + inputs=inputs, + description=description + ) + return xla_client.ops.CustomCallWithLayout( + c, name, + operands=inputs, + operand_shapes_with_layout=in_layouts, + shape_with_layout=out_layouts, + ) + + +def _bcsrmm_gpu_translation( + c, A_data, B_data, B_indices, B_ptr, *, + m, k, n, block_size_k, block_size_n +): + if gpu_ops is None: + raise GPUOperatorNotFound(_bcsrmm_cutlass_p.name) -def _blocksparse_matmat_jvp_dense_a(dense_a_dot, A_data, B_ptr, B_indices, B_data, *, m, n, k, block_size_rows_b, block_size_cols_b): - return blocksparse_matmat(dense_a_dot, B_ptr, B_indices, B_data, m=m, n=n, k=k, block_size_rows_b=block_size_rows_b, block_size_cols_b=block_size_cols_b) - -def _blocksparse_matmat_jvp_data_b(data_b_dot, A_data, B_ptr, B_indices, B_data, *, m, n, k, block_size_rows_b, block_size_cols_b): - return blocksparse_matmat(A_data, B_ptr, B_indices, data_b_dot, m=m, n=n, k=k, block_size_rows_b=block_size_rows_b, block_size_cols_b=block_size_cols_b) - -def _blocksparse_matmat_jvp_transpose(): - # TODO: implement - pass - -blocksparse_matmat_p = Primitive('gpu_blocksparse_matmat') -blocksparse_matmat_p.multiple_results = True -blocksparse_matmat_p.def_abstract_eval(_blocksparse_matmat_abstract) -blocksparse_matmat_p.def_impl(partial(xla.apply_primitive, blocksparse_matmat_p)) -xla.backend_specific_translations['cpu'][blocksparse_matmat_p] = _blocksparse_matmat_cpu_translation -xla.backend_specific_translations['gpu'][blocksparse_matmat_p] = _blocksparse_matmat_gpu_translation -# ad.defjvp(blocksparse_matmat_p, _blocksparse_matmat_jvp_dense_a, None, None, _blocksparse_matmat_jvp_data_b) -ad.primitive_jvps[blocksparse_matmat_p] = _blocksparse_matmat_jvp_transpose -register_general_batching(blocksparse_matmat) - - -def blocksparse_matmat_back( - A_data: jnp.ndarray, - B_data: jnp.ndarray, - blocks: jnp.ndarray, - *, - m: int, - n: int, - k: int, - block_size_rows_b: int, - block_size_cols_b: int, - blocks_len: int, -) -> jax.Array: - A_data = as_jax(A_data) - B_data = as_jax(B_data) - blocks = as_jax(blocks) - return blocksparse_matmat_back_p.bind(A_data, - B_data, - blocks, - m = m, - n = n, - k = k, - block_size_rows_b = block_size_rows_b, - block_size_cols_b = block_size_cols_b, - blocks_len = blocks_len)[0] + matrix_info = c.get_shape(A_data) + dtype = matrix_info.element_type() -def _blocksparse_matmat_back_abstract( - A_data, B_data, blocks, *, m, n, k, block_size_rows_b, block_size_cols_b,blocks_len -): - shape = (n, k) - dtype = A_data.dtype - out = ShapedArray(dtype=dtype, shape=shape) - return [out] + opaque = gpu_ops.build_blocksparse_format_descriptor(m, + n, + k, + block_size_k, + block_size_n) + fn = b'gpu_blocksparse_matmat' -def _blocksparse_matmat_back_gpu_translation( - c, A_data, B_data, blocks, *, m, n, k, block_size_rows_b, block_size_cols_b,blocks_len -): - if gpu_ops is None: - raise GPUOperatorNotFound(blocksparse_matmat_back_p.name) - matrix_info = c.get_shape(A_data) - dtype = matrix_info.element_type() - - opaque = gpu_ops.build_blocksparse_back_format_descriptor(m, - n, - k, - block_size_rows_b, - block_size_cols_b, - blocks_len) - - fn = b'gpu_blocksparse_matmat_back' - - return xla_client.ops.CustomCallWithLayout( - c, - fn, - operands=(A_data, B_data, blocks,), - operand_shape_with_layout=(c.get_shape(A_data), - c.get_shape(B_data), - c.get_shape(blocks),), - shape_with_layout=xla_client.Shape.tuple_shape( - (xla_client.Shape.array_shape(dtype, (k, n), (1, 0)),) - ), - opaque=opaque - ) + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(A_data, B_ptr, B_indices, B_data,), + operand_shapes_with_layout=(c.get_shape(A_data), + c.get_shape(B_ptr), + c.get_shape(B_indices), + c.get_shape(B_data),), + shape_with_layout=xla_client.Shape.tuple_shape( + (xla_client.Shape.array_shape(dtype, (n, m), (1, 0)),) + ), + opaque=opaque + ) + + +def _bcsrmm_jvp_dense_a(dense_a_dot, A_data, B_ptr, B_indices, B_data, *, m, n, k, block_size_k, + block_size_n): + return bcsrmm(dense_a_dot, B_ptr, B_indices, B_data, m=m, n=n, k=k, block_size_k=block_size_k, + block_size_n=block_size_n) + + +def _bcsrmm_jvp_data_b(data_b_dot, A_data, B_ptr, B_indices, B_data, *, m, n, k, block_size_k, + block_size_n): + return bcsrmm(A_data, B_ptr, B_indices, data_b_dot, m=m, n=n, k=k, block_size_k=block_size_k, + block_size_n=block_size_n) -blocksparse_matmat_back_p = Primitive('gpu_blocksparse_matmat_back') -blocksparse_matmat_back_p.multiple_results = True -blocksparse_matmat_back_p.def_abstract_eval(_blocksparse_matmat_back_abstract) -blocksparse_matmat_back_p.def_impl(partial(xla.apply_primitive, blocksparse_matmat_back_p)) -xla.backend_specific_translations['gpu'][blocksparse_matmat_back_p] = _blocksparse_matmat_back_gpu_translation +def _bcsrmm_jvp_transpose(): + # TODO: implement + pass -register_general_batching(blocksparse_matmat_back) \ No newline at end of file +_bcsrmm_cutlass_p = Primitive('bcsrmm_cutlass_pim') +_bcsrmm_cutlass_p.multiple_results = True +_bcsrmm_cutlass_p.def_abstract_eval(_bcsrmm_abstract) +_bcsrmm_cutlass_p.def_impl(partial(xla.apply_primitive, _bcsrmm_cutlass_p)) +xla.backend_specific_translations['cpu'][_bcsrmm_cutlass_p] = _bcsrmm_cpu_translation +xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_p] = _bcsrmm_gpu_translation +ad.primitive_jvps[_bcsrmm_cutlass_p] = _bcsrmm_jvp_transpose +ad.primitive_transposes[_bcsrmm_cutlass_p] = _bcsrmm_jvp_transpose +register_general_batching(bcsrmm) + + +def _blocksparse_matmat_back_abstract( + A_data, B_data, blocks, *, m, n, k, transpose, block_size_k, block_size_n, blocks_len +): + shape = (n, k) + dtype = A_data.dtype + out = ShapedArray(dtype=dtype, shape=shape) + return [out] + + +def _blocksparse_matmat_back_gpu_translation( + c, A_data, B_data, blocks, *, m, n, k, transpose, block_size_k, block_size_n, blocks_len +): + if gpu_ops is None: + raise GPUOperatorNotFound(_bcsrmm_cutlass_back_p.name) + matrix_info = c.get_shape(A_data) + dtype = matrix_info.element_type() + + opaque = gpu_ops.build_blocksparse_back_format_descriptor(m, + n, + k, + block_size_k, + block_size_n, + blocks_len) + + fn = b'gpu_blocksparse_matmat_back' + + return xla_client.ops.CustomCallWithLayout( + c, + fn, + operands=(A_data, B_data, blocks,), + operand_shape_with_layout=(c.get_shape(A_data), + c.get_shape(B_data), + c.get_shape(blocks),), + shape_with_layout=xla_client.Shape.tuple_shape( + (xla_client.Shape.array_shape(dtype, (k, n), (1, 0)),) + ), + opaque=opaque + ) + + +_bcsrmm_cutlass_back_p = Primitive('bcsrmm_cutlass_back_prim') +_bcsrmm_cutlass_back_p.multiple_results = True +_bcsrmm_cutlass_back_p.def_abstract_eval(_blocksparse_matmat_back_abstract) +_bcsrmm_cutlass_back_p.def_impl(partial(xla.apply_primitive, _bcsrmm_cutlass_back_p)) +xla.backend_specific_translations['gpu'][_bcsrmm_cutlass_back_p] = _blocksparse_matmat_back_gpu_translation +register_general_batching(_bcsrmm_cutlass_back_p) diff --git a/brainpy/_src/math/sparse/_bsr_mv.py b/brainpy/_src/math/sparse/_bsr_mv.py index a7a38c4de..331858c3b 100644 --- a/brainpy/_src/math/sparse/_bsr_mv.py +++ b/brainpy/_src/math/sparse/_bsr_mv.py @@ -1,4 +1,3 @@ - from functools import partial from typing import Union, Tuple @@ -20,25 +19,23 @@ except ImportError: gpu_ops = None - __all__ = [ 'cusparse_bcsr_matvec' ] + @numba.njit(fastmath=True, parallel=True, nogil=True) def _cusparse_bcsr_matvec_bsr_matvec_numba_imp(outs, ins): - data, indices, indptr, vector, blocksize , shape, nnzb, transpose = ins + data, indices, indptr, vector, blocksize, shape, nnzb, transpose = ins blocksize = blocksize[()] outs.fill(0) - for i in range(shape[0]): - tmp= np.zeros(blocksize, dtype=data.dtype) - - for j in range(indptr[i], indptr[i + 1]): - start = indices[j] * blocksize - end = start + blocksize - tmp += data[start: end] @ vector[start: end] - outs[i * blocksize: (i + 1) * blocksize] = tmp + tmp = np.zeros(blocksize, dtype=data.dtype) + for j in range(indptr[i], indptr[i + 1]): + start = indices[j] * blocksize + end = start + blocksize + tmp += data[start: end] @ vector[start: end] + outs[i * blocksize: (i + 1) * blocksize] = tmp # @numba.njit(fastmath=True, parallel=True, nogil=True) @@ -60,9 +57,9 @@ def _cusparse_bcsr_matvec_bsr_matvec_numba_imp(outs, ins): # cnt+=1 +def _cusprase_bcsr_matvec_values(values, indices, indptr, vector, *, blocksize, nnzb, shape, transpose): + return cusparse_bcsr_matvec(values, indices, indptr, vector, blocksize, nnzb=nnzb, shape=shape, transpose=transpose) -def _cusprase_bcsr_matvec_values( values, indices, indptr, vector, *,blocksize ,nnzb, shape, transpose): - return cusparse_bcsr_matvec(values, indices, indptr, vector,blocksize,nnzb=nnzb ,shape=shape,transpose=transpose) def cusparse_bcsr_matvec( data: Union[float, jnp.ndarray], @@ -76,7 +73,6 @@ def cusparse_bcsr_matvec( method: str = 'vector', transpose: bool = False ) -> jnp.ndarray: - data = as_jax(data) indices = as_jax(indices) indptr = as_jax(indptr) @@ -95,15 +91,17 @@ def cusparse_bcsr_matvec( f'But we got {data.dtype} != {vector.dtype}.') # assert data.ndim == indices.ndim == indptr.ndim == vector.ndim == 1 - return cusparse_bcsr_matvec_vector_p.bind(data, indices, indptr, vector,blocksize = blocksize, shape=shape,nnzb=nnzb,transpose=transpose) + return cusparse_bcsr_matvec_vector_p.bind(data, indices, indptr, vector, blocksize=blocksize, shape=shape, nnzb=nnzb, + transpose=transpose) -def _cusparse_bcsr_matvec_vector_cpu_translation(c, data, indices, indptr, vector, *, blocksize , shape, nnzb, transpose): +def _cusparse_bcsr_matvec_vector_cpu_translation(c, data, indices, indptr, vector, *, blocksize, shape, nnzb, + transpose): inputs = (data, indices, indptr, vector) print(c.get_shape(data)) - description = dict(blocksize=blocksize,shape=shape,nnzb=nnzb, transpose=transpose,) + description = dict(blocksize=blocksize, shape=shape, nnzb=nnzb, transpose=transpose, ) if transpose: - skip=1 + skip = 1 else: name, inputs, in_layouts, out_layouts = compile_cpu_signature_with_numba( c, @@ -120,20 +118,21 @@ def _cusparse_bcsr_matvec_vector_cpu_translation(c, data, indices, indptr, vecto shape_with_layout=out_layouts, ) -def _cusparse_bcsr_matvec_vector_gpu_translation(c, data, indices, indptr, vector, *,blocksize , shape,nnzb): + +def _cusparse_bcsr_matvec_vector_gpu_translation(c, data, indices, indptr, vector, *, blocksize, shape, nnzb): if gpu_ops is None: raise GPUOperatorNotFound(cusparse_bcsr_matvec_vector_p.name) - + data_shape = c.get_shape(data) if data_shape.element_type() == np.float32: - type_name = b'float' - elif data_shape.element_type() == np.double: + type_name = b'float' + elif data_shape.element_type() == np.double: type_name = b'double' else: raise ValueError('data_type not support(except float/double)') - # 有可能不是这个 + # 有可能不是这个 - opaque = gpu_ops.build_bcsrcusparsespmv_descriptor(shape[0],shape[1],blocksize,nnzb) + opaque = gpu_ops.build_bcsrcusparsespmv_descriptor(shape[0], shape[1], blocksize, nnzb) return xla_client.ops.CustomCallWithLayout( c, b'gpu_bcsr_cusparse_spmv_' + type_name, @@ -143,10 +142,11 @@ def _cusparse_bcsr_matvec_vector_gpu_translation(c, data, indices, indptr, vecto c.get_shape(indptr), c.get_shape(vector), ), - shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0]*blocksize,), (0,)), + shape_with_layout=xla_client.Shape.array_shape(data_shape.element_type(), (shape[0] * blocksize,), (0,)), opaque=opaque, ) + # def _bcsr_matvec_abstract(*args, **kwargs): # data = args[0] # assert len(kwargs) == 1 @@ -166,12 +166,13 @@ def _cusparse_bcsr_matvec_vector_gpu_translation(c, data, indices, indptr, vecto # ): # return ShapedArray(dtype=values.dtype, shape=(batch_size, shape[1] if transpose else shape[0])) -def _cusparse_bcsr_matvec_abstract(data, indices, indptr, vector,*,blocksize, shape,nnzb,transpose=False): - return ShapedArray(dtype=data.dtype, shape=(shape[0]*blocksize,)) +def _cusparse_bcsr_matvec_abstract(data, indices, indptr, vector, *, blocksize, shape, nnzb, transpose=False): + return ShapedArray(dtype=data.dtype, shape=(shape[0] * blocksize,)) -def _cusparse_bcsr_matvec_jvp_values(data_dot, data, indices, indptr, vector, *,blocksize, shape,nnzb, transpose): - return cusparse_bcsr_matvec(data_dot, indices, indptr, vector,blocksize=blocksize,nnzb=nnzb, shape=shape, transpose=transpose) +def _cusparse_bcsr_matvec_jvp_values(data_dot, data, indices, indptr, vector, *, blocksize, shape, nnzb, transpose): + return cusparse_bcsr_matvec(data_dot, indices, indptr, vector, blocksize=blocksize, nnzb=nnzb, shape=shape, + transpose=transpose) def _cusparse_bcsr_transpose(ct, data, indices, indptr, vector, *, blocksize, shape, transpose): @@ -185,21 +186,22 @@ def _cusparse_bcsr_transpose(ct, data, indices, indptr, vector, *, blocksize, sh ct_values = ad.Zero(data) else: row, col = csr_to_coo(indices, indptr) - cnt=0 - ct_values=[] + cnt = 0 + ct_values = [] for i in row: for j in col: - for p in range(0,blocksize): - cntq=0 - for q in range(0,blocksize): + for p in range(0, blocksize): + cntq = 0 + for q in range(0, blocksize): if transpose: - ct_values[cnt][cntq] = vector[i*blocksize+p]*ct[j*blocksize+q] + ct_values[cnt][cntq] = vector[i * blocksize + p] * ct[j * blocksize + q] else: - ct_values[cnt][cntq] = vector[j*blocksize+q]*ct[i*blocksize+p] - cntq+=1 - cnt+=1 + ct_values[cnt][cntq] = vector[j * blocksize + q] * ct[i * blocksize + p] + cntq += 1 + cnt += 1 return ct_values, indices, indptr, vector - + + cusparse_bcsr_matvec_vector_p = Primitive('cusparse_block_spmv') cusparse_bcsr_matvec_vector_p.def_abstract_eval(_cusparse_bcsr_matvec_abstract) cusparse_bcsr_matvec_vector_p.def_impl(partial(xla.apply_primitive, cusparse_bcsr_matvec_vector_p))