From ba08467b9b78a2ceecc5a8a59f11cef143cb014c Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 5 Jul 2023 18:48:56 +0200 Subject: [PATCH] Put the BLAS interface directly in the GemmKernels.jl module. --- benchmarks/blas.jl | 2 +- src/GemmKernels.jl | 12 ++++--- src/launch.jl | 30 ---------------- src/{blas.jl => matmul.jl} | 73 +++++++++++++++++++++++++++++--------- test/blas.jl | 8 ++--- 5 files changed, 70 insertions(+), 55 deletions(-) delete mode 100644 src/launch.jl rename src/{blas.jl => matmul.jl} (68%) diff --git a/benchmarks/blas.jl b/benchmarks/blas.jl index e62f2899..d020bd54 100644 --- a/benchmarks/blas.jl +++ b/benchmarks/blas.jl @@ -29,7 +29,7 @@ function blas_benchmark(group, a_type, b_type, cd_type, N, M=N, K=N; alpha=true, # influence from the Julia scheduler group[name] = @benchmarkable( begin - GemmKernels.BLAS.matmatmul!(c, $a_layout, $b_layout, a, b, $alpha, $beta; $(kwargs)...) + GemmKernels.matmatmul!(c, $a_layout, $b_layout, a, b, $alpha, $beta; $(kwargs)...) CUDA.cuStreamSynchronize(stream()) end, setup=(a=CuArray($a_h); b=CuArray($b_h); c=CuArray($c_h); diff --git a/src/GemmKernels.jl b/src/GemmKernels.jl index 3449c233..3e5c5ad6 100644 --- a/src/GemmKernels.jl +++ b/src/GemmKernels.jl @@ -1,17 +1,21 @@ module GemmKernels +using CUDA +using LinearAlgebra + +# utilities include("tiling.jl") +include("array.jl") +# framework include("config.jl") include("epilogue.jl") -include("array.jl") include("kernel.jl") include("layout.jl") include("operator.jl") include("transform.jl") -include("launch.jl") - -include("blas.jl") +# instantiations +include("matmul.jl") end diff --git a/src/launch.jl b/src/launch.jl deleted file mode 100644 index 665ac1b6..00000000 --- a/src/launch.jl +++ /dev/null @@ -1,30 +0,0 @@ -using CUDA - -function matmul(a, b, c, d, conf; - transform_global_to_shared_a = Transform.Elementwise(), - transform_global_to_shared_b = Transform.Elementwise(), - transform_global_to_shared_c = Transform.Elementwise(), - transform_shared_to_global_d = Transform.Elementwise(), - transform_shared_to_regs_a = Transform.Elementwise(), - transform_shared_to_regs_b = Transform.Elementwise(), - transform_shared_to_regs_c = Transform.Elementwise(), - transform_regs_to_shared_d = Transform.Elementwise(), - epilogue = Epilogue.Default(), - kernel = Kernel.matmul_singlestage) - - args = [a, b, c, d, - transform_global_to_shared_a, transform_global_to_shared_b, transform_global_to_shared_c, transform_shared_to_global_d, - transform_shared_to_regs_a, transform_shared_to_regs_b, transform_shared_to_regs_c, transform_regs_to_shared_d, - epilogue, - conf] - - shmem = Kernel.shmem_size(conf, kernel) - max_shmem = attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN) - if shmem > max_shmem - error("Requested too much shared memory: The current GPU can use at most $(Base.format_bytes(max_shmem)), while this configuration required $(Base.format_bytes(shmem))") - end - - hostkernel = @cuda launch=false kernel(args...) - attributes(hostkernel.fun)[CUDA.FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES] = shmem - hostkernel(args...; shmem, conf.launch_args...) -end diff --git a/src/blas.jl b/src/matmul.jl similarity index 68% rename from src/blas.jl rename to src/matmul.jl index 2a454acb..25508d25 100644 --- a/src/blas.jl +++ b/src/matmul.jl @@ -1,8 +1,40 @@ -module BLAS +# +# low-level +# + +function matmul(a, b, c, d, conf; + transform_global_to_shared_a = Transform.Elementwise(), + transform_global_to_shared_b = Transform.Elementwise(), + transform_global_to_shared_c = Transform.Elementwise(), + transform_shared_to_global_d = Transform.Elementwise(), + transform_shared_to_regs_a = Transform.Elementwise(), + transform_shared_to_regs_b = Transform.Elementwise(), + transform_shared_to_regs_c = Transform.Elementwise(), + transform_regs_to_shared_d = Transform.Elementwise(), + epilogue = Epilogue.Default(), + kernel = Kernel.matmul_singlestage) + + args = [a, b, c, d, + transform_global_to_shared_a, transform_global_to_shared_b, transform_global_to_shared_c, transform_shared_to_global_d, + transform_shared_to_regs_a, transform_shared_to_regs_b, transform_shared_to_regs_c, transform_regs_to_shared_d, + epilogue, + conf] + + shmem = Kernel.shmem_size(conf, kernel) + max_shmem = attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN) + if shmem > max_shmem + error("Requested too much shared memory: The current GPU can use at most $(Base.format_bytes(max_shmem)), while this configuration required $(Base.format_bytes(shmem))") + end + + hostkernel = @cuda launch=false kernel(args...) + attributes(hostkernel.fun)[CUDA.FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES] = shmem + hostkernel(args...; shmem, conf.launch_args...) +end + -using CUDA -using GemmKernels -using LinearAlgebra +# +# BLAS-like +# # Select the best kernel kernel(layout_a, layout_b) = Kernel.matmul_singlestage @@ -44,11 +76,11 @@ end # TODO: more, and device-capability dependent ] compute_type = promote_type(eltype(A), eltype(B)) - supports_wmma = something(wmma, (compute_type, compute_type, eltype(C)) in wmma_types) + use_wmma = something(wmma, (compute_type, compute_type, eltype(C)) in wmma_types) # determine shared memory layouts ## padded to avoid bank conflicts - if supports_wmma + if use_wmma # in the case of WMMA, the shared memory needs to have the correct type already, # as we'll use WMMA intrinsics to load from it. shared_a_layout = Layout.Padded{a_aligned_layout_base{compute_type}, 8} @@ -62,8 +94,8 @@ end # determine block shape # XXX: heuristic should take much more into account (GEMM size, at least) - block_shape = if supports_wmma - GemmKernels.heuristic_block_shape(shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout) + block_shape = if use_wmma + heuristic_block_shape(shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout) else # XXX: heuristic for FPU (M = 128, N = 128, K = 32) @@ -106,8 +138,8 @@ end Layout.ColMajor{eltype(C)} end - conf = if supports_wmma - GemmKernels.get_config(; + conf = if use_wmma + get_config(; gemm_shape = (M = m, N = n, K = k), block_shape, operator = Operator.WMMAOp{16, 16, 16, compute_type, eltype(C)}, @@ -118,7 +150,7 @@ end is_b_col_major = !transB ) else - GemmKernels.get_config(; + get_config(; gemm_shape = (M = m, N = n, K = k), block_shape, operator = Operator.FPUOp{8, 8, 1, compute_type, eltype(C)}, @@ -146,12 +178,21 @@ function matmatmul!(C::CuArray, transA::Char, transB::Char, A::CuArray, B::CuArr alpha = convert(compute_type, alpha) beta = convert(eltype(C), beta) - GemmKernels.matmul(parent(A), parent(B), parent(C), parent(C), conf; - transform_shared_to_regs_a = Transform.Elementwise(x -> x * alpha), - transform_shared_to_regs_c = Transform.Elementwise(x -> x * beta), - kernel - ) + matmul(A, B, C, C, conf; + transform_shared_to_regs_a = Transform.Elementwise(x -> x * alpha), + transform_shared_to_regs_c = Transform.Elementwise(x -> x * beta), + kernel + ) C end +# convenience function +function mul!(C::CuArray, + A::Union{CuArray, Adjoint{<:Any,<:CuArray}, Transpose{<:Any,<:CuArray}}, + B::Union{CuArray, Adjoint{<:Any,<:CuArray}, Transpose{<:Any,<:CuArray}}, + alpha=true, beta=false) + transA = A isa Adjoint || A isa Transpose + transB = B isa Adjoint || B isa Transpose + matmatmul!(C, transA ? 'T' : 'N', transB ? 'T' : 'N', + parent(A), parent(B), alpha, beta) end diff --git a/test/blas.jl b/test/blas.jl index 2951660f..f2280aeb 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -29,8 +29,8 @@ end b = CuArray(parent(b_h)) c = CuArray(c_h) - GemmKernels.BLAS.matmatmul!(c, transpose_a ? 'T' : 'N', transpose_b ? 'T' : 'N', - a, b, alpha, beta; wmma=true) + GemmKernels.matmatmul!(c, transpose_a ? 'T' : 'N', transpose_b ? 'T' : 'N', + a, b, alpha, beta; wmma=true) mul!(c_h, a_h, b_h, alpha, beta) @test c_h ≈ Array(c) rtol=sqrt(eps(AB_type)) @@ -58,8 +58,8 @@ end b = CuArray(parent(b_h)) c = CuArray(c_h) - GemmKernels.BLAS.matmatmul!(c, transpose_a ? 'T' : 'N', transpose_b ? 'T' : 'N', - a, b, alpha, beta; wmma=false) + GemmKernels.matmatmul!(c, transpose_a ? 'T' : 'N', transpose_b ? 'T' : 'N', + a, b, alpha, beta; wmma=false) mul!(c_h, a_h, b_h, alpha, beta) @test c_h ≈ Array(c) rtol=sqrt(eps(compute_type))