Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Put the BLAS interface directly in the GemmKernels.jl module. #132

Merged
merged 1 commit into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ function blas_benchmark(group, a_type, b_type, cd_type, N, M=N, K=N; alpha=true,
# influence from the Julia scheduler
group[name] = @benchmarkable(
begin
GemmKernels.BLAS.matmatmul!(c, $a_layout, $b_layout, a, b, $alpha, $beta; $(kwargs)...)
GemmKernels.matmatmul!(c, $a_layout, $b_layout, a, b, $alpha, $beta; $(kwargs)...)
CUDA.cuStreamSynchronize(stream())
end,
setup=(a=CuArray($a_h); b=CuArray($b_h); c=CuArray($c_h);
Expand Down
12 changes: 8 additions & 4 deletions src/GemmKernels.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
module GemmKernels

using CUDA
using LinearAlgebra

# utilities
include("tiling.jl")
include("array.jl")

# framework
include("config.jl")
include("epilogue.jl")
include("array.jl")
include("kernel.jl")
include("layout.jl")
include("operator.jl")
include("transform.jl")

include("launch.jl")

include("blas.jl")
# instantiations
include("matmul.jl")

end
30 changes: 0 additions & 30 deletions src/launch.jl

This file was deleted.

73 changes: 57 additions & 16 deletions src/blas.jl → src/matmul.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,40 @@
module BLAS
#
# low-level
#

function matmul(a, b, c, d, conf;
transform_global_to_shared_a = Transform.Elementwise(),
transform_global_to_shared_b = Transform.Elementwise(),
transform_global_to_shared_c = Transform.Elementwise(),
transform_shared_to_global_d = Transform.Elementwise(),
transform_shared_to_regs_a = Transform.Elementwise(),
transform_shared_to_regs_b = Transform.Elementwise(),
transform_shared_to_regs_c = Transform.Elementwise(),
transform_regs_to_shared_d = Transform.Elementwise(),
epilogue = Epilogue.Default(),
kernel = Kernel.matmul_singlestage)

args = [a, b, c, d,
transform_global_to_shared_a, transform_global_to_shared_b, transform_global_to_shared_c, transform_shared_to_global_d,
transform_shared_to_regs_a, transform_shared_to_regs_b, transform_shared_to_regs_c, transform_regs_to_shared_d,
epilogue,
conf]

shmem = Kernel.shmem_size(conf, kernel)
max_shmem = attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN)
if shmem > max_shmem
error("Requested too much shared memory: The current GPU can use at most $(Base.format_bytes(max_shmem)), while this configuration required $(Base.format_bytes(shmem))")
end

hostkernel = @cuda launch=false kernel(args...)
attributes(hostkernel.fun)[CUDA.FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES] = shmem
hostkernel(args...; shmem, conf.launch_args...)
end


using CUDA
using GemmKernels
using LinearAlgebra
#
# BLAS-like
#

# Select the best kernel
kernel(layout_a, layout_b) = Kernel.matmul_singlestage
Expand Down Expand Up @@ -44,11 +76,11 @@ end
# TODO: more, and device-capability dependent
]
compute_type = promote_type(eltype(A), eltype(B))
supports_wmma = something(wmma, (compute_type, compute_type, eltype(C)) in wmma_types)
use_wmma = something(wmma, (compute_type, compute_type, eltype(C)) in wmma_types)

# determine shared memory layouts
## padded to avoid bank conflicts
if supports_wmma
if use_wmma
# in the case of WMMA, the shared memory needs to have the correct type already,
# as we'll use WMMA intrinsics to load from it.
shared_a_layout = Layout.Padded{a_aligned_layout_base{compute_type}, 8}
Expand All @@ -62,8 +94,8 @@ end

# determine block shape
# XXX: heuristic should take much more into account (GEMM size, at least)
block_shape = if supports_wmma
GemmKernels.heuristic_block_shape(shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout)
block_shape = if use_wmma
heuristic_block_shape(shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout)
else
# XXX: heuristic for FPU
(M = 128, N = 128, K = 32)
Expand Down Expand Up @@ -106,8 +138,8 @@ end
Layout.ColMajor{eltype(C)}
end

conf = if supports_wmma
GemmKernels.get_config(;
conf = if use_wmma
get_config(;
gemm_shape = (M = m, N = n, K = k), block_shape,
operator = Operator.WMMAOp{16, 16, 16, compute_type, eltype(C)},

Expand All @@ -118,7 +150,7 @@ end
is_b_col_major = !transB
)
else
GemmKernels.get_config(;
get_config(;
gemm_shape = (M = m, N = n, K = k), block_shape,
operator = Operator.FPUOp{8, 8, 1, compute_type, eltype(C)},

Expand Down Expand Up @@ -146,12 +178,21 @@ function matmatmul!(C::CuArray, transA::Char, transB::Char, A::CuArray, B::CuArr

alpha = convert(compute_type, alpha)
beta = convert(eltype(C), beta)
GemmKernels.matmul(parent(A), parent(B), parent(C), parent(C), conf;
transform_shared_to_regs_a = Transform.Elementwise(x -> x * alpha),
transform_shared_to_regs_c = Transform.Elementwise(x -> x * beta),
kernel
)
matmul(A, B, C, C, conf;
transform_shared_to_regs_a = Transform.Elementwise(x -> x * alpha),
transform_shared_to_regs_c = Transform.Elementwise(x -> x * beta),
kernel
)
C
end

# convenience function
function mul!(C::CuArray,
A::Union{CuArray, Adjoint{<:Any,<:CuArray}, Transpose{<:Any,<:CuArray}},
B::Union{CuArray, Adjoint{<:Any,<:CuArray}, Transpose{<:Any,<:CuArray}},
alpha=true, beta=false)
transA = A isa Adjoint || A isa Transpose
transB = B isa Adjoint || B isa Transpose
matmatmul!(C, transA ? 'T' : 'N', transB ? 'T' : 'N',
parent(A), parent(B), alpha, beta)
end
8 changes: 4 additions & 4 deletions test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ end
b = CuArray(parent(b_h))
c = CuArray(c_h)

GemmKernels.BLAS.matmatmul!(c, transpose_a ? 'T' : 'N', transpose_b ? 'T' : 'N',
a, b, alpha, beta; wmma=true)
GemmKernels.matmatmul!(c, transpose_a ? 'T' : 'N', transpose_b ? 'T' : 'N',
a, b, alpha, beta; wmma=true)
mul!(c_h, a_h, b_h, alpha, beta)

@test c_h ≈ Array(c) rtol=sqrt(eps(AB_type))
Expand Down Expand Up @@ -58,8 +58,8 @@ end
b = CuArray(parent(b_h))
c = CuArray(c_h)

GemmKernels.BLAS.matmatmul!(c, transpose_a ? 'T' : 'N', transpose_b ? 'T' : 'N',
a, b, alpha, beta; wmma=false)
GemmKernels.matmatmul!(c, transpose_a ? 'T' : 'N', transpose_b ? 'T' : 'N',
a, b, alpha, beta; wmma=false)
mul!(c_h, a_h, b_h, alpha, beta)

@test c_h ≈ Array(c) rtol=sqrt(eps(compute_type))
Expand Down