From 1868046efec571f0970d97bca29882f705c7cb1c Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Wed, 9 Jun 2021 22:37:30 +0000 Subject: [PATCH] [CuBLAS] Support implicit broadcast in batch_matmul --- src/runtime/contrib/cblas/gemm_common.h | 38 +++++++++++++---- src/runtime/contrib/cublas/cublas.cc | 30 ++++++++++---- tests/python/contrib/test_cublas.py | 55 +++++++++++++------------ 3 files changed, 79 insertions(+), 44 deletions(-) diff --git a/src/runtime/contrib/cblas/gemm_common.h b/src/runtime/contrib/cblas/gemm_common.h index 9ccfa5183cd6..4724b14bffa1 100644 --- a/src/runtime/contrib/cblas/gemm_common.h +++ b/src/runtime/contrib/cblas/gemm_common.h @@ -181,28 +181,48 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue* ret, TBatchGemmOp op) { DLTensor* C = args[2]; bool transa = args[3]; bool transb = args[4]; + int bit_depth = sizeof(DType) * 8; + ICHECK_EQ(A->ndim, 3); ICHECK_EQ(B->ndim, 3); ICHECK_EQ(C->ndim, 3); - int batch_size = BatchCount3D(A); - ICHECK_EQ(BatchCount3D(B), batch_size); - ICHECK_EQ(BatchCount3D(C), batch_size); + + int batch_size = BatchCount3D(C); ICHECK_EQ(ElementStride(A), 1); ICHECK_EQ(ElementStride(B), 1); ICHECK_EQ(ElementStride(C), 1); + // C can never be transposed. ICHECK(!IsInPlaceTransposed3D(C)); // Reversed strides indicates an in-place transpose operation. transa = IsInPlaceTransposed3D(A) ? !transa : transa; transb = IsInPlaceTransposed3D(B) ? !transb : transb; + ICHECK(TypeMatch(B->dtype, kDLFloat, bit_depth)); ICHECK(TypeMatch(C->dtype, kDLFloat, bit_depth)); + double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; - const int A_size = A->shape[1] * A->shape[2]; - const int B_size = B->shape[1] * B->shape[2]; - const int C_size = C->shape[1] * C->shape[2]; + + int A_stride = A->shape[1] * A->shape[2]; + int B_stride = B->shape[1] * B->shape[2]; + int C_stride = C->shape[1] * C->shape[2]; + + // Broadcast A or B by changing its stride. + int batch_size_a = BatchCount3D(A); + int batch_size_b = BatchCount3D(B); + if (batch_size_a != batch_size_b) { + if (batch_size_a == 1) { + A_stride = 0; + } else if (batch_size_b == 1) { + B_stride = 0; + } + } else { + ICHECK_EQ(batch_size_a, batch_size); + ICHECK_EQ(batch_size_b, batch_size); + } + DType* A_data = reinterpret_cast(static_cast(A->data) + A->byte_offset); DType* B_data = reinterpret_cast(static_cast(B->data) + @@ -210,9 +230,9 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue* ret, TBatchGemmOp op) { DType* C_data = reinterpret_cast(static_cast(C->data) + C->byte_offset); op(batch_size, transb, transa, ColumnCount3D(B, transb), RowCount3D(A, transa), - ColumnCount3D(A, transa), static_cast(alpha), B_data, B_size, - ColumnStride3D(B), A_data, A_size, ColumnStride3D(A), - static_cast(beta), C_data, C_size, ColumnStride3D(C)); + ColumnCount3D(A, transa), static_cast(alpha), B_data, + B_stride, ColumnStride3D(B), A_data, A_stride, ColumnStride3D(A), + static_cast(beta), C_data, C_stride, ColumnStride3D(C)); } } // namespace contrib diff --git a/src/runtime/contrib/cublas/cublas.cc b/src/runtime/contrib/cublas/cublas.cc index 1216a63703bb..015d68aec819 100644 --- a/src/runtime/contrib/cublas/cublas.cc +++ b/src/runtime/contrib/cublas/cublas.cc @@ -275,9 +275,8 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) ICHECK_EQ(A->ndim, 3); ICHECK_EQ(B->ndim, 3); ICHECK_EQ(C->ndim, 3); - int batch_size = BatchCount3D(A); - ICHECK_EQ(BatchCount3D(B), batch_size); - ICHECK_EQ(BatchCount3D(C), batch_size); + + int batch_size = BatchCount3D(C); ICHECK_EQ(ElementStride(A), 1); ICHECK_EQ(ElementStride(B), 1); ICHECK_EQ(ElementStride(C), 1); @@ -299,9 +298,23 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) double alpha = args.size() > 5 ? args[5] : 1.0; double beta = args.size() > 6 ? args[6] : 0.0; - const int A_size = A->shape[1] * A->shape[2]; - const int B_size = B->shape[1] * B->shape[2]; - const int C_size = C->shape[1] * C->shape[2]; + int A_stride = A->shape[1] * A->shape[2]; + int B_stride = B->shape[1] * B->shape[2]; + int C_stride = C->shape[1] * C->shape[2]; + + // Broadcast A or B by changing its stride. + int batch_size_a = BatchCount3D(A); + int batch_size_b = BatchCount3D(B); + if (batch_size_a != batch_size_b) { + if (batch_size_a == 1) { + A_stride = 0; + } else if (batch_size_b == 1) { + B_stride = 0; + } + } else { + ICHECK_EQ(batch_size_a, batch_size); + ICHECK_EQ(batch_size_b, batch_size); + } cudaDataType_t cuda_in_type = GetCudaDataType(A->dtype); cudaDataType_t cuda_out_type = GetCudaDataType(C->dtype); @@ -325,8 +338,9 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl) CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx( hdl, CUBLASBooleanToTranspose(transb), CUBLASBooleanToTranspose(transa), ColumnCount3D(B, transb), RowCount3D(A, transa), ColumnCount3D(A, transa), alpha_ptr, B_data, - cuda_in_type, ColumnStride3D(B), B_size, A_data, cuda_in_type, ColumnStride3D(A), A_size, - beta_ptr, C_data, cuda_out_type, ColumnStride3D(C), C_size, batch_size, cuda_out_type, algo)); + cuda_in_type, ColumnStride3D(B), B_stride, A_data, cuda_in_type, ColumnStride3D(A), A_stride, + beta_ptr, C_data, cuda_out_type, ColumnStride3D(C), C_stride, batch_size, cuda_out_type, + algo)); } // matrix multiplication for row major diff --git a/tests/python/contrib/test_cublas.py b/tests/python/contrib/test_cublas.py index a0f51ca7c9fc..648100a569d7 100644 --- a/tests/python/contrib/test_cublas.py +++ b/tests/python/contrib/test_cublas.py @@ -112,33 +112,23 @@ def verify(target="cuda"): verify() -def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5): - j = 16 - n = 1024 - l = 128 - m = 236 - A = te.placeholder((j, n, l), name="A", dtype=in_dtype) - B = te.placeholder((j, l, m), name="B", dtype=in_dtype) +def verify_batch_matmul(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5): + A = te.placeholder(Ashape, name="A", dtype=in_dtype) + B = te.placeholder(Bshape, name="B", dtype=in_dtype) C = cublas.batch_matmul(A, B, dtype=out_dtype) s = te.create_schedule(C.op) - def verify(target="cuda"): - if not tvm.get_global_func("tvm.contrib.cublas.matmul", True): - print("skip because extern function is not available") - return - dev = tvm.cuda(0) - f = tvm.build(s, [A, B, C], target) - a = tvm.nd.array(np.random.uniform(size=(j, n, l)).astype(A.dtype), dev) - b = tvm.nd.array(np.random.uniform(size=(j, l, m)).astype(B.dtype), dev) - c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), dev) - f(a, b, c) - tvm.testing.assert_allclose( - c.numpy(), - np.matmul(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)).astype(C.dtype), - rtol=rtol, - ) - - verify() + dev = tvm.cuda(0) + f = tvm.build(s, [A, B, C], "cuda") + a = tvm.nd.array(np.random.uniform(size=Ashape).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=Bshape).astype(B.dtype), dev) + c = tvm.nd.array(np.zeros(Cshape, dtype=C.dtype), dev) + f(a, b, c) + tvm.testing.assert_allclose( + c.numpy(), + np.matmul(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)).astype(C.dtype), + rtol=rtol, + ) @tvm.testing.requires_cuda @@ -156,9 +146,20 @@ def test_matmul_add_igemm(): @tvm.testing.requires_cuda def test_batch_matmul(): - verify_batch_matmul("float", "float") - verify_batch_matmul("float16", "float") - verify_batch_matmul("float16", "float16", rtol=1e-2) + if not tvm.get_global_func("tvm.contrib.cublas.matmul", True): + print("skip because extern function is not available") + return + + verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float", "float") + verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float", "float") + verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float") + verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float") + verify_batch_matmul( + (16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 + ) + verify_batch_matmul( + (16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2 + ) if __name__ == "__main__":