Skip to content

Commit

Permalink
Configure and check shared memory automatically. (#112)
Browse files Browse the repository at this point in the history
* Configure and check shared memory automatically.

* Add @inbounds to avoid exception branches on shmem construction.

* Make shmem configuration depend on the kernel.
  • Loading branch information
maleadt authored Jun 28, 2023
1 parent 798e55b commit 2a6ad1d
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 13 deletions.
5 changes: 2 additions & 3 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ end

@inline function Base.getproperty(conf::Type{Config{MATMUL_SHAPE, BLOCK_SHAPE, WARPS_PER_BLOCK, MEM_A_WARP, MEM_A_THREAD, MEM_B_WARP, MEM_B_THREAD, MEM_CD_WARP, MEM_CD_THREAD, COMPUTE_WARP, COMPUTE_OP_SHAPE, GLOBAL_A_LAYOUT, GLOBAL_B_LAYOUT, GLOBAL_C_LAYOUT, GLOBAL_D_LAYOUT, SHARED_A_LAYOUT, SHARED_B_LAYOUT, SHARED_C_LAYOUT, SHARED_D_LAYOUT, OPERATOR, IS_A_COL_MAJOR, IS_B_COL_MAJOR}}, sym::Symbol) where {MATMUL_SHAPE, BLOCK_SHAPE, WARPS_PER_BLOCK, MEM_A_WARP, MEM_A_THREAD, MEM_B_WARP, MEM_B_THREAD, MEM_CD_WARP, MEM_CD_THREAD, COMPUTE_WARP, COMPUTE_OP_SHAPE, GLOBAL_A_LAYOUT, GLOBAL_B_LAYOUT, GLOBAL_C_LAYOUT, GLOBAL_D_LAYOUT, SHARED_A_LAYOUT, SHARED_B_LAYOUT, SHARED_C_LAYOUT, SHARED_D_LAYOUT, OPERATOR, IS_A_COL_MAJOR, IS_B_COL_MAJOR}
if sym == :launch_args
(threads = WARPS_PER_BLOCK * 32,
blocks = (MATMUL_SHAPE.M ÷ BLOCK_SHAPE.M, MATMUL_SHAPE.N ÷ BLOCK_SHAPE.N),
shmem = 64 * 1024)
(; threads = WARPS_PER_BLOCK * 32,
blocks = (MATMUL_SHAPE.M ÷ BLOCK_SHAPE.M, MATMUL_SHAPE.N ÷ BLOCK_SHAPE.N))

# convenience accessors for typevars
elseif sym == :matmul_shape
Expand Down
48 changes: 40 additions & 8 deletions src/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ function matmul_singlestage(a, b, c, d,
block_tile = Tile(conf.block_shape)

# (1) Cooperatively load a block_shape.M x block_shape.N tile of C from global to shared memory within one threadblock
shmem_c = CuDynamicSharedArray(Layout.eltype(conf.shared_c_layout), Layout.physical_size(conf.shared_c_layout, block_tile.MN.size))
shmem_c = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_c_layout), Layout.physical_size(conf.shared_c_layout, block_tile.MN.size))

@loopinfo unroll for warp_tile = parallellise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block)
@loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32)
Expand All @@ -55,8 +55,8 @@ function matmul_singlestage(a, b, c, d,
sync_threads()

# (3) Compute a block_shape.M x block_shape.N x block_shape.K matrix product within one threadblock
shmem_a = CuDynamicSharedArray(Layout.eltype(conf.shared_a_layout), Layout.physical_size(conf.shared_a_layout, block_tile.MK.size))
shmem_b = CuDynamicSharedArray(Layout.eltype(conf.shared_b_layout), Layout.physical_size(conf.shared_b_layout, block_tile.KN.size),
shmem_a = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_a_layout), Layout.physical_size(conf.shared_a_layout, block_tile.MK.size))
shmem_b = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_b_layout), Layout.physical_size(conf.shared_b_layout, block_tile.KN.size),
length(shmem_a) * sizeof(Layout.eltype(conf.shared_a_layout)))

@loopinfo unroll for block_k = 0 : block_tile.size.K : gemm_sz.size.K - 1
Expand Down Expand Up @@ -112,7 +112,7 @@ function matmul_singlestage(a, b, c, d,
end

# (4) Store the compute_warp.M x compute_warp.N tile of D from registers to shared memory
shmem_d = CuDynamicSharedArray(Layout.eltype(conf.shared_d_layout), Layout.physical_size(conf.shared_d_layout, block_tile.MN.size))
shmem_d = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_d_layout), Layout.physical_size(conf.shared_d_layout, block_tile.MN.size))

warp_tile = subdivide(block_tile.MN, Tile(conf.compute_warp).MN, warpId, conf.warps_per_block)

Expand All @@ -131,6 +131,22 @@ function matmul_singlestage(a, b, c, d,
return
end

function shmem_size(::Type{conf}, ::typeof(matmul_singlestage)) where {conf <: GemmKernels.Config}
size_a = sizeof(Layout.eltype(conf.shared_a_layout)) *
prod(Layout.physical_size(conf.shared_a_layout,
(; conf.block_shape.M, conf.block_shape.K)))
size_b = sizeof(Layout.eltype(conf.shared_b_layout)) *
prod(Layout.physical_size(conf.shared_b_layout,
(; conf.block_shape.K, conf.block_shape.N)))
size_c = sizeof(Layout.eltype(conf.shared_c_layout)) *
prod(Layout.physical_size(conf.shared_c_layout,
(; conf.block_shape.M, conf.block_shape.N)))
size_d = sizeof(Layout.eltype(conf.shared_d_layout)) *
prod(Layout.physical_size(conf.shared_d_layout,
(; conf.block_shape.M, conf.block_shape.N)))
max(size_c, size_a + size_b, size_d)
end

function matmul_pipelined(a, b, c, d,
transf_gl2sh_a, transf_gl2sh_b, transf_gl2sh_c, transf_sh2gl_d,
transf_sh2rf_a, transf_sh2rf_b, transf_sh2rf_c, transf_rf2sh_d,
Expand All @@ -151,7 +167,7 @@ function matmul_pipelined(a, b, c, d,
block_tile = Tile(conf.block_shape)

# (1) Cooperatively load a block_shape.M x block_shape.N tile of C from global to shared memory within one threadblock
shmem_c = CuDynamicSharedArray(Layout.eltype(conf.shared_c_layout), Layout.physical_size(conf.shared_c_layout, block_tile.MN.size))
shmem_c = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_c_layout), Layout.physical_size(conf.shared_c_layout, block_tile.MN.size))

@loopinfo unroll for warp_tile = parallellise(block_tile.MN, Tile(conf.mem_cd_warp), warpId, conf.warps_per_block)
@loopinfo unroll for thread_tile = parallellise(warp_tile, Tile(conf.mem_cd_thread), laneId, 32)
Expand All @@ -178,8 +194,8 @@ function matmul_pipelined(a, b, c, d,
sync_threads()

# (3) Compute a block_shape.M x block_shape.N x block_shape.K matrix product within one threadblock
shmem_a = CuDynamicSharedArray(Layout.eltype(conf.shared_a_layout), Layout.physical_size(conf.shared_a_layout, block_tile.MK.size))
shmem_b = CuDynamicSharedArray(Layout.eltype(conf.shared_b_layout), Layout.physical_size(conf.shared_b_layout, block_tile.KN.size),
shmem_a = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_a_layout), Layout.physical_size(conf.shared_a_layout, block_tile.MK.size))
shmem_b = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_b_layout), Layout.physical_size(conf.shared_b_layout, block_tile.KN.size),
length(shmem_a) * sizeof(Layout.eltype(conf.shared_a_layout)))

# Sizes of a_fragment and b_fragment
Expand Down Expand Up @@ -319,7 +335,7 @@ function matmul_pipelined(a, b, c, d,
end

# (4) Store the compute_warp.M x compute_warp.N tile of D from registers to shared memory
shmem_d = CuDynamicSharedArray(Layout.eltype(conf.shared_d_layout), Layout.physical_size(conf.shared_d_layout, block_tile.MN.size))
shmem_d = @inbounds CuDynamicSharedArray(Layout.eltype(conf.shared_d_layout), Layout.physical_size(conf.shared_d_layout, block_tile.MN.size))

warp_tile = subdivide(block_tile.MN, Tile(conf.compute_warp).MN, warpId, conf.warps_per_block)

Expand All @@ -338,4 +354,20 @@ function matmul_pipelined(a, b, c, d,
return
end

function shmem_size(::Type{conf}, ::typeof(matmul_pipelined)) where {conf <: GemmKernels.Config}
size_a = sizeof(Layout.eltype(conf.shared_a_layout)) *
prod(Layout.physical_size(conf.shared_a_layout,
(; conf.block_shape.M, conf.block_shape.K)))
size_b = sizeof(Layout.eltype(conf.shared_b_layout)) *
prod(Layout.physical_size(conf.shared_b_layout,
(; conf.block_shape.K, conf.block_shape.N)))
size_c = sizeof(Layout.eltype(conf.shared_c_layout)) *
prod(Layout.physical_size(conf.shared_c_layout,
(; conf.block_shape.M, conf.block_shape.N)))
size_d = sizeof(Layout.eltype(conf.shared_d_layout)) *
prod(Layout.physical_size(conf.shared_d_layout,
(; conf.block_shape.M, conf.block_shape.N)))
max(size_c, size_a + size_b, size_d)
end

end
10 changes: 8 additions & 2 deletions src/launch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ function matmul(a, b, c, d, conf;
epilogue,
conf]

shmem = Kernel.shmem_size(conf, kernel)
max_shmem = attribute(device(), CUDA.DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN)
if shmem > max_shmem
error("Requested too much shared memory: The current GPU can use at most $(Base.format_bytes(max_shmem)), while this configuration required $(Base.format_bytes(shmem))")
end

hostkernel = @cuda launch=false kernel(args...)
attributes(hostkernel.fun)[CUDA.FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES] = 64 * 1024
hostkernel(args...; conf.launch_args...)
attributes(hostkernel.fun)[CUDA.FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES] = shmem
hostkernel(args...; shmem, conf.launch_args...)
end

0 comments on commit 2a6ad1d

Please sign in to comment.