Skip to content

Commit

Permalink
[CuBLAS] Support implicit broadcast in batch_matmul (#8229)
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac authored Jun 10, 2021
1 parent d97d8d3 commit b93e56e
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 44 deletions.
38 changes: 29 additions & 9 deletions src/runtime/contrib/cblas/gemm_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,38 +181,58 @@ inline void CallBatchGemm(TVMArgs args, TVMRetValue* ret, TBatchGemmOp op) {
DLTensor* C = args[2];
bool transa = args[3];
bool transb = args[4];

int bit_depth = sizeof(DType) * 8;

ICHECK_EQ(A->ndim, 3);
ICHECK_EQ(B->ndim, 3);
ICHECK_EQ(C->ndim, 3);
int batch_size = BatchCount3D(A);
ICHECK_EQ(BatchCount3D(B), batch_size);
ICHECK_EQ(BatchCount3D(C), batch_size);

int batch_size = BatchCount3D(C);
ICHECK_EQ(ElementStride(A), 1);
ICHECK_EQ(ElementStride(B), 1);
ICHECK_EQ(ElementStride(C), 1);

// C can never be transposed.
ICHECK(!IsInPlaceTransposed3D(C));
// Reversed strides indicates an in-place transpose operation.
transa = IsInPlaceTransposed3D(A) ? !transa : transa;
transb = IsInPlaceTransposed3D(B) ? !transb : transb;

ICHECK(TypeMatch(B->dtype, kDLFloat, bit_depth));
ICHECK(TypeMatch(C->dtype, kDLFloat, bit_depth));

double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;
const int A_size = A->shape[1] * A->shape[2];
const int B_size = B->shape[1] * B->shape[2];
const int C_size = C->shape[1] * C->shape[2];

int A_stride = A->shape[1] * A->shape[2];
int B_stride = B->shape[1] * B->shape[2];
int C_stride = C->shape[1] * C->shape[2];

// Broadcast A or B by changing its stride.
int batch_size_a = BatchCount3D(A);
int batch_size_b = BatchCount3D(B);
if (batch_size_a != batch_size_b) {
if (batch_size_a == 1) {
A_stride = 0;
} else if (batch_size_b == 1) {
B_stride = 0;
}
} else {
ICHECK_EQ(batch_size_a, batch_size);
ICHECK_EQ(batch_size_b, batch_size);
}

DType* A_data = reinterpret_cast<typename TBatchGemmOp::TDatatype*>(static_cast<char*>(A->data) +
A->byte_offset);
DType* B_data = reinterpret_cast<typename TBatchGemmOp::TDatatype*>(static_cast<char*>(B->data) +
B->byte_offset);
DType* C_data = reinterpret_cast<typename TBatchGemmOp::TDatatype*>(static_cast<char*>(C->data) +
C->byte_offset);
op(batch_size, transb, transa, ColumnCount3D(B, transb), RowCount3D(A, transa),
ColumnCount3D(A, transa), static_cast<typename TBatchGemmOp::TDatatype>(alpha), B_data, B_size,
ColumnStride3D(B), A_data, A_size, ColumnStride3D(A),
static_cast<typename TBatchGemmOp::TDatatype>(beta), C_data, C_size, ColumnStride3D(C));
ColumnCount3D(A, transa), static_cast<typename TBatchGemmOp::TDatatype>(alpha), B_data,
B_stride, ColumnStride3D(B), A_data, A_stride, ColumnStride3D(A),
static_cast<typename TBatchGemmOp::TDatatype>(beta), C_data, C_stride, ColumnStride3D(C));
}

} // namespace contrib
Expand Down
30 changes: 22 additions & 8 deletions src/runtime/contrib/cublas/cublas.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,8 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl)
ICHECK_EQ(A->ndim, 3);
ICHECK_EQ(B->ndim, 3);
ICHECK_EQ(C->ndim, 3);
int batch_size = BatchCount3D(A);
ICHECK_EQ(BatchCount3D(B), batch_size);
ICHECK_EQ(BatchCount3D(C), batch_size);

int batch_size = BatchCount3D(C);
ICHECK_EQ(ElementStride(A), 1);
ICHECK_EQ(ElementStride(B), 1);
ICHECK_EQ(ElementStride(C), 1);
Expand All @@ -299,9 +298,23 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl)
double alpha = args.size() > 5 ? args[5] : 1.0;
double beta = args.size() > 6 ? args[6] : 0.0;

const int A_size = A->shape[1] * A->shape[2];
const int B_size = B->shape[1] * B->shape[2];
const int C_size = C->shape[1] * C->shape[2];
int A_stride = A->shape[1] * A->shape[2];
int B_stride = B->shape[1] * B->shape[2];
int C_stride = C->shape[1] * C->shape[2];

// Broadcast A or B by changing its stride.
int batch_size_a = BatchCount3D(A);
int batch_size_b = BatchCount3D(B);
if (batch_size_a != batch_size_b) {
if (batch_size_a == 1) {
A_stride = 0;
} else if (batch_size_b == 1) {
B_stride = 0;
}
} else {
ICHECK_EQ(batch_size_a, batch_size);
ICHECK_EQ(batch_size_b, batch_size);
}

cudaDataType_t cuda_in_type = GetCudaDataType(A->dtype);
cudaDataType_t cuda_out_type = GetCudaDataType(C->dtype);
Expand All @@ -325,8 +338,9 @@ inline void CallBatchGemmEx(TVMArgs args, TVMRetValue* ret, cublasHandle_t hdl)
CHECK_CUBLAS_ERROR(cublasGemmStridedBatchedEx(
hdl, CUBLASBooleanToTranspose(transb), CUBLASBooleanToTranspose(transa),
ColumnCount3D(B, transb), RowCount3D(A, transa), ColumnCount3D(A, transa), alpha_ptr, B_data,
cuda_in_type, ColumnStride3D(B), B_size, A_data, cuda_in_type, ColumnStride3D(A), A_size,
beta_ptr, C_data, cuda_out_type, ColumnStride3D(C), C_size, batch_size, cuda_out_type, algo));
cuda_in_type, ColumnStride3D(B), B_stride, A_data, cuda_in_type, ColumnStride3D(A), A_stride,
beta_ptr, C_data, cuda_out_type, ColumnStride3D(C), C_stride, batch_size, cuda_out_type,
algo));
}

// matrix multiplication for row major
Expand Down
55 changes: 28 additions & 27 deletions tests/python/contrib/test_cublas.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,33 +112,23 @@ def verify(target="cuda"):
verify()


def verify_batch_matmul(in_dtype, out_dtype, rtol=1e-5):
j = 16
n = 1024
l = 128
m = 236
A = te.placeholder((j, n, l), name="A", dtype=in_dtype)
B = te.placeholder((j, l, m), name="B", dtype=in_dtype)
def verify_batch_matmul(Ashape, Bshape, Cshape, in_dtype, out_dtype, rtol=1e-5):
A = te.placeholder(Ashape, name="A", dtype=in_dtype)
B = te.placeholder(Bshape, name="B", dtype=in_dtype)
C = cublas.batch_matmul(A, B, dtype=out_dtype)
s = te.create_schedule(C.op)

def verify(target="cuda"):
if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
print("skip because extern function is not available")
return
dev = tvm.cuda(0)
f = tvm.build(s, [A, B, C], target)
a = tvm.nd.array(np.random.uniform(size=(j, n, l)).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=(j, l, m)).astype(B.dtype), dev)
c = tvm.nd.array(np.zeros((j, n, m), dtype=C.dtype), dev)
f(a, b, c)
tvm.testing.assert_allclose(
c.numpy(),
np.matmul(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)).astype(C.dtype),
rtol=rtol,
)

verify()
dev = tvm.cuda(0)
f = tvm.build(s, [A, B, C], "cuda")
a = tvm.nd.array(np.random.uniform(size=Ashape).astype(A.dtype), dev)
b = tvm.nd.array(np.random.uniform(size=Bshape).astype(B.dtype), dev)
c = tvm.nd.array(np.zeros(Cshape, dtype=C.dtype), dev)
f(a, b, c)
tvm.testing.assert_allclose(
c.numpy(),
np.matmul(a.numpy().astype(C.dtype), b.numpy().astype(C.dtype)).astype(C.dtype),
rtol=rtol,
)


@tvm.testing.requires_cuda
Expand All @@ -156,9 +146,20 @@ def test_matmul_add_igemm():

@tvm.testing.requires_cuda
def test_batch_matmul():
verify_batch_matmul("float", "float")
verify_batch_matmul("float16", "float")
verify_batch_matmul("float16", "float16", rtol=1e-2)
if not tvm.get_global_func("tvm.contrib.cublas.matmul", True):
print("skip because extern function is not available")
return

verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float", "float")
verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float", "float")
verify_batch_matmul((16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float")
verify_batch_matmul((16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float")
verify_batch_matmul(
(16, 1024, 128), (16, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2
)
verify_batch_matmul(
(16, 1024, 128), (1, 128, 236), (16, 1024, 236), "float16", "float16", rtol=1e-2
)


if __name__ == "__main__":
Expand Down

0 comments on commit b93e56e

Please sign in to comment.