Skip to content

Commit

Permalink
Extend set of WMMA operator shapes (#183)
Browse files Browse the repository at this point in the history
* Extend set of WMMA operator shapes

* Skip unsupported configurations
  • Loading branch information
thomasfaingnaert authored Jan 2, 2024
1 parent 3c328d1 commit bd718f5
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 53 deletions.
83 changes: 68 additions & 15 deletions configs/configs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ macro get_wmma_complex_config()

conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
operator = Operator.WMMAComplexOp{OP_M, OP_N, OP_K},
operator = Operator.WMMAComplexOp{OP_M, OP_N, OP_K, AB_type, CD_type},

global_a_layout = transpose_a ? Layout.InterleavedRowMajor{Float16} : Layout.InterleavedColMajor{Float16},
global_b_layout = transpose_b ? Layout.InterleavedRowMajor{Float16} : Layout.InterleavedColMajor{Float16},
Expand Down Expand Up @@ -391,7 +391,7 @@ macro get_wmma_dual_config()

conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
operator = Operator.WMMADualOp{OP_M, OP_N, OP_K},
operator = Operator.WMMADualOp{OP_M, OP_N, OP_K, AB_type, CD_type},

global_a_layout = Layout.InterleavedColMajor{Float16},
global_b_layout = Layout.InterleavedColMajor{Float16},
Expand Down Expand Up @@ -459,7 +459,11 @@ function get_configs()
# XXX: Should we do non-square matrices as well?
M = K = N

push!(rv, @get_fpu_config)
try
push!(rv, @get_fpu_config)
catch err
isa(err, GemmKernels.ConfigError) || rethrow()
end
end

# FPU Op shapes
Expand Down Expand Up @@ -488,7 +492,11 @@ function get_configs()
# We'll only test square matrices.
M = K = N

push!(rv, @get_fpu_config)
try
push!(rv, @get_fpu_config)
catch err
isa(err, GemmKernels.ConfigError) || rethrow()
end
end

# Tropical GEMM
Expand All @@ -503,7 +511,11 @@ function get_configs()
[1, 1, 2],
[2, 2, 2]]

push!(rv, @get_tropical_config)
try
push!(rv, @get_tropical_config)
catch err
isa(err, GemmKernels.ConfigError) || rethrow()
end
end

# WMMA GEMM
Expand All @@ -514,7 +526,11 @@ function get_configs()
transpose_b = [false, true],
(BLOCK_M, BLOCK_N, BLOCK_K) in [(128, 128, 64)],
(WARPS_M, WARPS_N) in [(4, 2)],
(OP_M, OP_N, OP_K) in [(16, 16, 16)],
(OP_M, OP_N, OP_K) in [
(16, 16, 16),
(8, 32, 16),
(32, 8, 16),
],
(M, N, K) in vcat(min_dimension .* [
[1, 1, 1],
[2, 2, 1],
Expand All @@ -537,57 +553,94 @@ function get_configs()
(OP_M, OP_N, OP_K) in [(16, 16, 16)],
kernel in [Kernel.matmul_singlestage, Kernel.matmul_pipelined]

push!(rv, @get_wmma_config)
try
push!(rv, @get_wmma_config)
catch err
isa(err, GemmKernels.ConfigError) || rethrow()
end
end

# WMMA GEMM + bias
for (AB_type, CD_type, min_dimension) in [
(Float16, Float32, 128)],
transpose_a = [false, true],
transpose_b = [false, true],
(OP_M, OP_N, OP_K) in [(16, 16, 16)],
(OP_M, OP_N, OP_K) in [
(16, 16, 16),
(8, 32, 16),
(32, 8, 16),
],
(M, N, K) in vcat(min_dimension .* [
[1, 1, 1],
[2, 2, 2]], [[4096, 4096, 4096]])
push!(rv, @get_wmma_bias_config)

try
push!(rv, @get_wmma_bias_config)
catch err
isa(err, GemmKernels.ConfigError) || rethrow()
end
end

# WMMA Diagonal GEMM
for (AB_type, CD_type, min_dimension) in [
(Float16, Float32, 128)],
transpose_b = [false, true],
(OP_M, OP_N, OP_K) in [(16, 16, 16)],
(OP_M, OP_N, OP_K) in [
(16, 16, 16),
(8, 32, 16),
(32, 8, 16),
],
(M, N, K) in vcat(min_dimension .* [
[1, 1, 1],
[2, 2, 2]], [[4096, 4096, 4096]])

push!(rv, @get_wmma_diagonal_config)
try
push!(rv, @get_wmma_diagonal_config)
catch err
isa(err, GemmKernels.ConfigError) || rethrow()
end
end

# WMMA Complex GEMM
for (AB_type, CD_type) in [(Float16, Float32)],
transpose_a = [false, true],
transpose_b = [false, true],
(OP_M, OP_N, OP_K) in [(16, 16, 16)],
(OP_M, OP_N, OP_K) in [
(16, 16, 16),
(8, 32, 16),
(32, 8, 16),
],
(M, N, K) in [
(128, 128, 128),
(256, 256, 256),
(2048, 2048, 2048)]

push!(rv, @get_wmma_complex_config)
try
push!(rv, @get_wmma_complex_config)
catch err
isa(err, GemmKernels.ConfigError) || rethrow()
end
end

# WMMA Dual GEMM
for (AB_type, CD_type) in [(Float16, Float32)],
transpose_a = [false],
transpose_b = [false],
(OP_M, OP_N, OP_K) in [(16, 16, 16)],
(OP_M, OP_N, OP_K) in [
(16, 16, 16),
(8, 32, 16),
(32, 8, 16),
],
(M, N, K) in [
(128, 128, 128),
(256, 256, 256),
(2048, 2048, 2048)]

push!(rv, @get_wmma_dual_config)
try
push!(rv, @get_wmma_dual_config)
catch err
isa(err, GemmKernels.ConfigError) || rethrow()
end
end

rv
Expand Down
7 changes: 6 additions & 1 deletion src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ end
function check_wmma_shape(operator::Type)
op_shape = Operator.shape(operator)

if op_shape [(M=16, N=16, K=16)]
if op_shape [
(M=16, N=16, K=16),
(M=8, N=32, K=16),
(M=32, N=8, K=16),
]
throw(ConfigError("Unsupported WMMA Operator shape $(op_shape)!"))
end
end
Expand Down Expand Up @@ -257,6 +261,7 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw
check_tile_multiple(num, den, dims, msg) = all([num[dim] % den[dim] == 0 for dim in dims]) || throw(ConfigError(msg))

check_tile_multiple(block_shape, compute_warp, [:M, :N, :K], "block_shape must be a multiple of compute_warp!")
check_tile_multiple(compute_warp, op_shape, [:M, :N, :K], "compute_warp must be a multiple of op_shape!")
require_tile_sized_global(global_a_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :K], "gemm_shape.MK must be a multiple of block_shape.MK!")
require_tile_sized_global(global_b_layout) && check_tile_multiple(gemm_shape, block_shape, [:K, :N], "gemm_shape.KN must be a multiple of block_shape.KN!")
require_tile_sized_global(global_c_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :N], "gemm_shape.MN must be a multiple of block_shape.MN!")
Expand Down
90 changes: 53 additions & 37 deletions src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,26 @@ struct WMMAOp{M, N, K, CT, AT} end

@inline shape(::Type{WMMAOp{M, N, K, CT, AT}}) where {M, N, K, CT, AT} = (M = M, N = N, K = K)

for (M, N, K) in [
(16, 16, 16),
(8, 32, 16),
(32, 8, 16)
],
(layout_type, wmma_layout_type) in [
(Layout.ColMajor, WMMA.ColMajor),
(Layout.UnsafeAlignedColMajor, WMMA.ColMajor),
(Layout.RowMajor, WMMA.RowMajor),
(Layout.UnsafeAlignedRowMajor, WMMA.RowMajor),
]
@eval begin
# TODO: Have accessors in CUDA.jl to get the fragment sizes?
# FP16 (16, 16, 16), (8, 32, 16), and (32, 8, 16)
@inline fragtype_a(::Type{WMMAOp{$M, $N, $K, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{$M, $N, $K, 16, CT, $wmma_layout_type, WMMA.MatrixA}
@inline fragtype_b(::Type{WMMAOp{$M, $N, $K, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{$M, $N, $K, 16, CT, $wmma_layout_type, WMMA.MatrixB}
@inline fragtype_accum(::Type{WMMAOp{$M, $N, $K, CT, AT}}, ::Type{$layout_type{AT}}) where {CT, AT} = WMMA.Fragment{$M, $N, $K, 8, AT, WMMA.Unspecified, WMMA.Accumulator}
end
end

# convert_index_func: function used to transpose the index in case of a row-major layout
for (layout_type, wmma_layout_type, convert_index_func) in [
(Layout.ColMajor, WMMA.ColMajor, identity),
Expand All @@ -164,10 +184,6 @@ for (layout_type, wmma_layout_type, convert_index_func) in [
(Layout.UnsafeAlignedRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x))),
]
@eval begin
@inline fragtype_a(::Type{WMMAOp{16, 16, 16, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{16, 16, 16, 16, CT, $wmma_layout_type, WMMA.MatrixA}
@inline fragtype_b(::Type{WMMAOp{16, 16, 16, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{16, 16, 16, 16, CT, $wmma_layout_type, WMMA.MatrixB}
@inline fragtype_accum(::Type{WMMAOp{16, 16, 16, CT, AT}}, ::Type{$layout_type{AT}}) where {CT, AT} = WMMA.Fragment{16, 16, 16, 8, AT, WMMA.Unspecified, WMMA.Accumulator}

@inline function load_a(::Type{WMMAOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT}
conf = WMMA.Config{M, N, K, AT}

Expand Down Expand Up @@ -219,46 +235,46 @@ end
# WMMAComplex
# -----------

struct WMMAComplexOp{M, N, K} end
struct WMMAComplexOp{M, N, K, CT, AT} end

@inline shape(::Type{WMMAComplexOp{M, N, K}}) where {M, N, K} = (M = M, N = N, K = K)
@inline shape(::Type{WMMAComplexOp{M, N, K, CT, AT}}) where {M, N, K, CT, AT} = (M = M, N = N, K = K)

# convert_index_func: function used to transpose the index in case of a row-major layout
for (layout_type, wmma_layout_type, convert_index_func) in [
(Layout.SplitColMajor, WMMA.ColMajor, identity),
(Layout.SplitRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x))),
for (layout_type, base_layout, wmma_layout_type, convert_index_func) in [
(Layout.SplitColMajor, Layout.UnsafeAlignedColMajor, WMMA.ColMajor, identity),
(Layout.SplitRowMajor, Layout.UnsafeAlignedRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x))),
]
@eval begin
@inline fragtype_a(::Type{WMMAComplexOp{16, 16, 16}}, ::Type{$layout_type{Float16}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixA}}
@inline fragtype_b(::Type{WMMAComplexOp{16, 16, 16}}, ::Type{$layout_type{Float16}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixB}}
@inline fragtype_accum(::Type{WMMAComplexOp{16, 16, 16}}, ::Type{$layout_type{Float32}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 8, Float32, WMMA.Unspecified, WMMA.Accumulator}}
@inline fragtype_a(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_a(WMMAOp{M, N, K, CT, AT}, $base_layout{CT})}
@inline fragtype_b(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_b(WMMAOp{M, N, K, CT, AT}, $base_layout{CT})}
@inline fragtype_accum(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{AT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_accum(WMMAOp{M, N, K, CT, AT}, $base_layout{AT})}

@inline function load_a(::Type{WMMAComplexOp{M, N, K}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K}
conf = WMMA.Config{16, 16, 16, Float32}
@inline function load_a(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT}
conf = WMMA.Config{M, N, K, AT}
ind = linearise($convert_index_func(tile.index), (size(workspace)[1], size(workspace)[2]))

return (WMMA.load_a(pointer(workspace, ind), size(workspace)[1], $wmma_layout_type, conf),
WMMA.load_a(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], $wmma_layout_type, conf))
end

@inline function load_b(::Type{WMMAComplexOp{M, N, K}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K}
conf = WMMA.Config{16, 16, 16, Float32}
@inline function load_b(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT}
conf = WMMA.Config{M, N, K, AT}
ind = linearise($convert_index_func(tile.index), (size(workspace)[1], size(workspace)[2]))

return (WMMA.load_b(pointer(workspace, ind), size(workspace)[1], $wmma_layout_type, conf),
WMMA.load_b(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], $wmma_layout_type, conf))
end

@inline function load_c(::Type{WMMAComplexOp{M, N, K}}, ::Type{$layout_type{Float32}}, workspace, tile::Tile) where {M, N, K}
conf = WMMA.Config{M, N, K, Float32}
@inline function load_c(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{AT}}, workspace, tile::Tile) where {M, N, K, CT, AT}
conf = WMMA.Config{M, N, K, AT}
ind = linearise($convert_index_func(tile.index), (size(workspace)[1], size(workspace)[2]))

return (WMMA.load_c(pointer(workspace, ind), size(workspace)[1], $wmma_layout_type, conf),
WMMA.load_c(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], $wmma_layout_type, conf))
end

@inline function store_d(::Type{WMMAComplexOp{M, N, K}}, ::Type{$layout_type{Float32}}, workspace, frag, tile::Tile) where {M, N, K}
conf = WMMA.Config{M, N, K, Float32}
@inline function store_d(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{AT}}, workspace, frag, tile::Tile) where {M, N, K, CT, AT}
conf = WMMA.Config{M, N, K, AT}
ind = linearise($convert_index_func(tile.index), (size(workspace)[1], size(workspace)[2]))

WMMA.store_d(pointer(workspace, ind), frag[1], size(workspace)[1], $wmma_layout_type, conf)
Expand All @@ -269,8 +285,8 @@ end

using LLVM

@inline function mma(::Type{WMMAComplexOp{M, N, K}}, a_frag, b_frag, c_frag) where {M, N, K}
conf = WMMA.Config{16, 16, 16, Float32}
@inline function mma(::Type{WMMAComplexOp{M, N, K, CT, AT}}, a_frag, b_frag, c_frag) where {M, N, K, CT, AT}
conf = WMMA.Config{M, N, K, AT}

c_re = c_frag[1]
c_im = c_frag[2]
Expand All @@ -288,48 +304,48 @@ end
# WMMADual
# --------

struct WMMADualOp{M, N, K} end
struct WMMADualOp{M, N, K, CT, AT} end

@inline shape(::Type{WMMADualOp{M, N, K}}) where {M, N, K} = (M = M, N = N, K = K)
@inline shape(::Type{WMMADualOp{M, N, K, CT, AT}}) where {M, N, K, CT, AT} = (M = M, N = N, K = K)

@inline fragtype_a(::Type{WMMADualOp{16, 16, 16}}, ::Type{Layout.SplitColMajor{Float16}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 16, Float16, WMMA.ColMajor, WMMA.MatrixA}}
@inline fragtype_b(::Type{WMMADualOp{16, 16, 16}}, ::Type{Layout.SplitColMajor{Float16}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 16, Float16, WMMA.ColMajor, WMMA.MatrixB}}
@inline fragtype_accum(::Type{WMMADualOp{16, 16, 16}}, ::Type{Layout.SplitColMajor{Float32}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 8, Float32, WMMA.Unspecified, WMMA.Accumulator}}
@inline fragtype_a(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{CT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_a(WMMAOp{M, N, K, CT, AT}, Layout.UnsafeAlignedColMajor{CT})}
@inline fragtype_b(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{CT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_b(WMMAOp{M, N, K, CT, AT}, Layout.UnsafeAlignedColMajor{CT})}
@inline fragtype_accum(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{AT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_accum(WMMAOp{M, N, K, CT, AT}, Layout.UnsafeAlignedColMajor{AT})}

@inline function load_a(::Type{WMMADualOp{M, N, K}}, ::Type{Layout.SplitColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
conf = WMMA.Config{16, 16, 16, Float32}
@inline function load_a(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT}
conf = WMMA.Config{M, N, K, AT}
ind = linearise(tile.index, (size(workspace)[1], size(workspace)[2]))

return (WMMA.load_a(pointer(workspace, ind), size(workspace)[1], WMMA.ColMajor, conf),
WMMA.load_a(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], WMMA.ColMajor, conf))
end

@inline function load_b(::Type{WMMADualOp{M, N, K}}, ::Type{Layout.SplitColMajor{Float16}}, workspace, tile::Tile) where {M, N, K}
conf = WMMA.Config{16, 16, 16, Float32}
@inline function load_b(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT}
conf = WMMA.Config{M, N, K, AT}
ind = linearise(tile.index, (size(workspace)[1], size(workspace)[2]))

return (WMMA.load_b(pointer(workspace, ind), size(workspace)[1], WMMA.ColMajor, conf),
WMMA.load_b(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], WMMA.ColMajor, conf))
end

@inline function load_c(::Type{WMMADualOp{M, N, K}}, ::Type{Layout.SplitColMajor{Float32}}, workspace, tile::Tile) where {M, N, K}
conf = WMMA.Config{M, N, K, Float32}
@inline function load_c(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{AT}}, workspace, tile::Tile) where {M, N, K, CT, AT}
conf = WMMA.Config{M, N, K, AT}
ind = linearise(tile.index, (size(workspace)[1], size(workspace)[2]))

return (WMMA.load_c(pointer(workspace, ind), size(workspace)[1], WMMA.ColMajor, conf),
WMMA.load_c(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], WMMA.ColMajor, conf))
end

@inline function store_d(::Type{WMMADualOp{M, N, K}}, ::Type{Layout.SplitColMajor{Float32}}, workspace, frag, tile::Tile) where {M, N, K}
conf = WMMA.Config{M, N, K, Float32}
@inline function store_d(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{AT}}, workspace, frag, tile::Tile) where {M, N, K, CT, AT}
conf = WMMA.Config{M, N, K, AT}
ind = linearise(tile.index, (size(workspace)[1], size(workspace)[2]))

WMMA.store_d(pointer(workspace, ind), frag[1], size(workspace)[1], WMMA.ColMajor, conf)
WMMA.store_d(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), frag[2], size(workspace)[1], WMMA.ColMajor, conf)
end

@inline function mma(::Type{WMMADualOp{M, N, K}}, a_frag, b_frag, c_frag) where {M, N, K}
conf = WMMA.Config{16, 16, 16, Float32}
@inline function mma(::Type{WMMADualOp{M, N, K, CT, AT}}, a_frag, b_frag, c_frag) where {M, N, K, CT, AT}
conf = WMMA.Config{M, N, K, AT}

c_re = c_frag[1]
c_du = c_frag[2]
Expand Down

0 comments on commit bd718f5

Please sign in to comment.