From bd718f5908ba43848362d4fd7735b3c7981017bd Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 2 Jan 2024 14:01:42 +0100 Subject: [PATCH] Extend set of WMMA operator shapes (#183) * Extend set of WMMA operator shapes * Skip unsupported configurations --- configs/configs.jl | 83 ++++++++++++++++++++++++++++++++++-------- src/config.jl | 7 +++- src/operator.jl | 90 +++++++++++++++++++++++++++------------------- 3 files changed, 127 insertions(+), 53 deletions(-) diff --git a/configs/configs.jl b/configs/configs.jl index ca2f9a62..1097746f 100644 --- a/configs/configs.jl +++ b/configs/configs.jl @@ -337,7 +337,7 @@ macro get_wmma_complex_config() conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMAComplexOp{OP_M, OP_N, OP_K}, + operator = Operator.WMMAComplexOp{OP_M, OP_N, OP_K, AB_type, CD_type}, global_a_layout = transpose_a ? Layout.InterleavedRowMajor{Float16} : Layout.InterleavedColMajor{Float16}, global_b_layout = transpose_b ? Layout.InterleavedRowMajor{Float16} : Layout.InterleavedColMajor{Float16}, @@ -391,7 +391,7 @@ macro get_wmma_dual_config() conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMADualOp{OP_M, OP_N, OP_K}, + operator = Operator.WMMADualOp{OP_M, OP_N, OP_K, AB_type, CD_type}, global_a_layout = Layout.InterleavedColMajor{Float16}, global_b_layout = Layout.InterleavedColMajor{Float16}, @@ -459,7 +459,11 @@ function get_configs() # XXX: Should we do non-square matrices as well? M = K = N - push!(rv, @get_fpu_config) + try + push!(rv, @get_fpu_config) + catch err + isa(err, GemmKernels.ConfigError) || rethrow() + end end # FPU Op shapes @@ -488,7 +492,11 @@ function get_configs() # We'll only test square matrices. M = K = N - push!(rv, @get_fpu_config) + try + push!(rv, @get_fpu_config) + catch err + isa(err, GemmKernels.ConfigError) || rethrow() + end end # Tropical GEMM @@ -503,7 +511,11 @@ function get_configs() [1, 1, 2], [2, 2, 2]] - push!(rv, @get_tropical_config) + try + push!(rv, @get_tropical_config) + catch err + isa(err, GemmKernels.ConfigError) || rethrow() + end end # WMMA GEMM @@ -514,7 +526,11 @@ function get_configs() transpose_b = [false, true], (BLOCK_M, BLOCK_N, BLOCK_K) in [(128, 128, 64)], (WARPS_M, WARPS_N) in [(4, 2)], - (OP_M, OP_N, OP_K) in [(16, 16, 16)], + (OP_M, OP_N, OP_K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16), + ], (M, N, K) in vcat(min_dimension .* [ [1, 1, 1], [2, 2, 1], @@ -537,7 +553,11 @@ function get_configs() (OP_M, OP_N, OP_K) in [(16, 16, 16)], kernel in [Kernel.matmul_singlestage, Kernel.matmul_pipelined] - push!(rv, @get_wmma_config) + try + push!(rv, @get_wmma_config) + catch err + isa(err, GemmKernels.ConfigError) || rethrow() + end end # WMMA GEMM + bias @@ -545,49 +565,82 @@ function get_configs() (Float16, Float32, 128)], transpose_a = [false, true], transpose_b = [false, true], - (OP_M, OP_N, OP_K) in [(16, 16, 16)], + (OP_M, OP_N, OP_K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16), + ], (M, N, K) in vcat(min_dimension .* [ [1, 1, 1], [2, 2, 2]], [[4096, 4096, 4096]]) - push!(rv, @get_wmma_bias_config) + + try + push!(rv, @get_wmma_bias_config) + catch err + isa(err, GemmKernels.ConfigError) || rethrow() + end end # WMMA Diagonal GEMM for (AB_type, CD_type, min_dimension) in [ (Float16, Float32, 128)], transpose_b = [false, true], - (OP_M, OP_N, OP_K) in [(16, 16, 16)], + (OP_M, OP_N, OP_K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16), + ], (M, N, K) in vcat(min_dimension .* [ [1, 1, 1], [2, 2, 2]], [[4096, 4096, 4096]]) - push!(rv, @get_wmma_diagonal_config) + try + push!(rv, @get_wmma_diagonal_config) + catch err + isa(err, GemmKernels.ConfigError) || rethrow() + end end # WMMA Complex GEMM for (AB_type, CD_type) in [(Float16, Float32)], transpose_a = [false, true], transpose_b = [false, true], - (OP_M, OP_N, OP_K) in [(16, 16, 16)], + (OP_M, OP_N, OP_K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16), + ], (M, N, K) in [ (128, 128, 128), (256, 256, 256), (2048, 2048, 2048)] - push!(rv, @get_wmma_complex_config) + try + push!(rv, @get_wmma_complex_config) + catch err + isa(err, GemmKernels.ConfigError) || rethrow() + end end # WMMA Dual GEMM for (AB_type, CD_type) in [(Float16, Float32)], transpose_a = [false], transpose_b = [false], - (OP_M, OP_N, OP_K) in [(16, 16, 16)], + (OP_M, OP_N, OP_K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16), + ], (M, N, K) in [ (128, 128, 128), (256, 256, 256), (2048, 2048, 2048)] - push!(rv, @get_wmma_dual_config) + try + push!(rv, @get_wmma_dual_config) + catch err + isa(err, GemmKernels.ConfigError) || rethrow() + end end rv diff --git a/src/config.jl b/src/config.jl index 4fd99afc..85c1b690 100644 --- a/src/config.jl +++ b/src/config.jl @@ -152,7 +152,11 @@ end function check_wmma_shape(operator::Type) op_shape = Operator.shape(operator) - if op_shape ∉ [(M=16, N=16, K=16)] + if op_shape ∉ [ + (M=16, N=16, K=16), + (M=8, N=32, K=16), + (M=32, N=8, K=16), + ] throw(ConfigError("Unsupported WMMA Operator shape $(op_shape)!")) end end @@ -257,6 +261,7 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw check_tile_multiple(num, den, dims, msg) = all([num[dim] % den[dim] == 0 for dim in dims]) || throw(ConfigError(msg)) check_tile_multiple(block_shape, compute_warp, [:M, :N, :K], "block_shape must be a multiple of compute_warp!") + check_tile_multiple(compute_warp, op_shape, [:M, :N, :K], "compute_warp must be a multiple of op_shape!") require_tile_sized_global(global_a_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :K], "gemm_shape.MK must be a multiple of block_shape.MK!") require_tile_sized_global(global_b_layout) && check_tile_multiple(gemm_shape, block_shape, [:K, :N], "gemm_shape.KN must be a multiple of block_shape.KN!") require_tile_sized_global(global_c_layout) && check_tile_multiple(gemm_shape, block_shape, [:M, :N], "gemm_shape.MN must be a multiple of block_shape.MN!") diff --git a/src/operator.jl b/src/operator.jl index 861999fd..295dbe3d 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -156,6 +156,26 @@ struct WMMAOp{M, N, K, CT, AT} end @inline shape(::Type{WMMAOp{M, N, K, CT, AT}}) where {M, N, K, CT, AT} = (M = M, N = N, K = K) +for (M, N, K) in [ + (16, 16, 16), + (8, 32, 16), + (32, 8, 16) + ], + (layout_type, wmma_layout_type) in [ + (Layout.ColMajor, WMMA.ColMajor), + (Layout.UnsafeAlignedColMajor, WMMA.ColMajor), + (Layout.RowMajor, WMMA.RowMajor), + (Layout.UnsafeAlignedRowMajor, WMMA.RowMajor), + ] + @eval begin + # TODO: Have accessors in CUDA.jl to get the fragment sizes? + # FP16 (16, 16, 16), (8, 32, 16), and (32, 8, 16) + @inline fragtype_a(::Type{WMMAOp{$M, $N, $K, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{$M, $N, $K, 16, CT, $wmma_layout_type, WMMA.MatrixA} + @inline fragtype_b(::Type{WMMAOp{$M, $N, $K, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{$M, $N, $K, 16, CT, $wmma_layout_type, WMMA.MatrixB} + @inline fragtype_accum(::Type{WMMAOp{$M, $N, $K, CT, AT}}, ::Type{$layout_type{AT}}) where {CT, AT} = WMMA.Fragment{$M, $N, $K, 8, AT, WMMA.Unspecified, WMMA.Accumulator} + end +end + # convert_index_func: function used to transpose the index in case of a row-major layout for (layout_type, wmma_layout_type, convert_index_func) in [ (Layout.ColMajor, WMMA.ColMajor, identity), @@ -164,10 +184,6 @@ for (layout_type, wmma_layout_type, convert_index_func) in [ (Layout.UnsafeAlignedRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x))), ] @eval begin - @inline fragtype_a(::Type{WMMAOp{16, 16, 16, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{16, 16, 16, 16, CT, $wmma_layout_type, WMMA.MatrixA} - @inline fragtype_b(::Type{WMMAOp{16, 16, 16, CT, AT}}, ::Type{$layout_type{CT}}) where {CT, AT} = WMMA.Fragment{16, 16, 16, 16, CT, $wmma_layout_type, WMMA.MatrixB} - @inline fragtype_accum(::Type{WMMAOp{16, 16, 16, CT, AT}}, ::Type{$layout_type{AT}}) where {CT, AT} = WMMA.Fragment{16, 16, 16, 8, AT, WMMA.Unspecified, WMMA.Accumulator} - @inline function load_a(::Type{WMMAOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT} conf = WMMA.Config{M, N, K, AT} @@ -219,46 +235,46 @@ end # WMMAComplex # ----------- -struct WMMAComplexOp{M, N, K} end +struct WMMAComplexOp{M, N, K, CT, AT} end -@inline shape(::Type{WMMAComplexOp{M, N, K}}) where {M, N, K} = (M = M, N = N, K = K) +@inline shape(::Type{WMMAComplexOp{M, N, K, CT, AT}}) where {M, N, K, CT, AT} = (M = M, N = N, K = K) # convert_index_func: function used to transpose the index in case of a row-major layout -for (layout_type, wmma_layout_type, convert_index_func) in [ - (Layout.SplitColMajor, WMMA.ColMajor, identity), - (Layout.SplitRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x))), +for (layout_type, base_layout, wmma_layout_type, convert_index_func) in [ + (Layout.SplitColMajor, Layout.UnsafeAlignedColMajor, WMMA.ColMajor, identity), + (Layout.SplitRowMajor, Layout.UnsafeAlignedRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x))), ] @eval begin - @inline fragtype_a(::Type{WMMAComplexOp{16, 16, 16}}, ::Type{$layout_type{Float16}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixA}} - @inline fragtype_b(::Type{WMMAComplexOp{16, 16, 16}}, ::Type{$layout_type{Float16}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixB}} - @inline fragtype_accum(::Type{WMMAComplexOp{16, 16, 16}}, ::Type{$layout_type{Float32}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 8, Float32, WMMA.Unspecified, WMMA.Accumulator}} + @inline fragtype_a(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_a(WMMAOp{M, N, K, CT, AT}, $base_layout{CT})} + @inline fragtype_b(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_b(WMMAOp{M, N, K, CT, AT}, $base_layout{CT})} + @inline fragtype_accum(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{AT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_accum(WMMAOp{M, N, K, CT, AT}, $base_layout{AT})} - @inline function load_a(::Type{WMMAComplexOp{M, N, K}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{16, 16, 16, Float32} + @inline function load_a(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise($convert_index_func(tile.index), (size(workspace)[1], size(workspace)[2])) return (WMMA.load_a(pointer(workspace, ind), size(workspace)[1], $wmma_layout_type, conf), WMMA.load_a(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], $wmma_layout_type, conf)) end - @inline function load_b(::Type{WMMAComplexOp{M, N, K}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{16, 16, 16, Float32} + @inline function load_b(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise($convert_index_func(tile.index), (size(workspace)[1], size(workspace)[2])) return (WMMA.load_b(pointer(workspace, ind), size(workspace)[1], $wmma_layout_type, conf), WMMA.load_b(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], $wmma_layout_type, conf)) end - @inline function load_c(::Type{WMMAComplexOp{M, N, K}}, ::Type{$layout_type{Float32}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{M, N, K, Float32} + @inline function load_c(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{AT}}, workspace, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise($convert_index_func(tile.index), (size(workspace)[1], size(workspace)[2])) return (WMMA.load_c(pointer(workspace, ind), size(workspace)[1], $wmma_layout_type, conf), WMMA.load_c(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], $wmma_layout_type, conf)) end - @inline function store_d(::Type{WMMAComplexOp{M, N, K}}, ::Type{$layout_type{Float32}}, workspace, frag, tile::Tile) where {M, N, K} - conf = WMMA.Config{M, N, K, Float32} + @inline function store_d(::Type{WMMAComplexOp{M, N, K, CT, AT}}, ::Type{$layout_type{AT}}, workspace, frag, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise($convert_index_func(tile.index), (size(workspace)[1], size(workspace)[2])) WMMA.store_d(pointer(workspace, ind), frag[1], size(workspace)[1], $wmma_layout_type, conf) @@ -269,8 +285,8 @@ end using LLVM -@inline function mma(::Type{WMMAComplexOp{M, N, K}}, a_frag, b_frag, c_frag) where {M, N, K} - conf = WMMA.Config{16, 16, 16, Float32} +@inline function mma(::Type{WMMAComplexOp{M, N, K, CT, AT}}, a_frag, b_frag, c_frag) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} c_re = c_frag[1] c_im = c_frag[2] @@ -288,48 +304,48 @@ end # WMMADual # -------- -struct WMMADualOp{M, N, K} end +struct WMMADualOp{M, N, K, CT, AT} end -@inline shape(::Type{WMMADualOp{M, N, K}}) where {M, N, K} = (M = M, N = N, K = K) +@inline shape(::Type{WMMADualOp{M, N, K, CT, AT}}) where {M, N, K, CT, AT} = (M = M, N = N, K = K) -@inline fragtype_a(::Type{WMMADualOp{16, 16, 16}}, ::Type{Layout.SplitColMajor{Float16}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 16, Float16, WMMA.ColMajor, WMMA.MatrixA}} -@inline fragtype_b(::Type{WMMADualOp{16, 16, 16}}, ::Type{Layout.SplitColMajor{Float16}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 16, Float16, WMMA.ColMajor, WMMA.MatrixB}} -@inline fragtype_accum(::Type{WMMADualOp{16, 16, 16}}, ::Type{Layout.SplitColMajor{Float32}}) = NTuple{2, WMMA.Fragment{16, 16, 16, 8, Float32, WMMA.Unspecified, WMMA.Accumulator}} +@inline fragtype_a(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{CT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_a(WMMAOp{M, N, K, CT, AT}, Layout.UnsafeAlignedColMajor{CT})} +@inline fragtype_b(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{CT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_b(WMMAOp{M, N, K, CT, AT}, Layout.UnsafeAlignedColMajor{CT})} +@inline fragtype_accum(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{AT}}) where {M, N, K, CT, AT} = NTuple{2, fragtype_accum(WMMAOp{M, N, K, CT, AT}, Layout.UnsafeAlignedColMajor{AT})} -@inline function load_a(::Type{WMMADualOp{M, N, K}}, ::Type{Layout.SplitColMajor{Float16}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{16, 16, 16, Float32} +@inline function load_a(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise(tile.index, (size(workspace)[1], size(workspace)[2])) return (WMMA.load_a(pointer(workspace, ind), size(workspace)[1], WMMA.ColMajor, conf), WMMA.load_a(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], WMMA.ColMajor, conf)) end -@inline function load_b(::Type{WMMADualOp{M, N, K}}, ::Type{Layout.SplitColMajor{Float16}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{16, 16, 16, Float32} +@inline function load_b(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{CT}}, workspace, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise(tile.index, (size(workspace)[1], size(workspace)[2])) return (WMMA.load_b(pointer(workspace, ind), size(workspace)[1], WMMA.ColMajor, conf), WMMA.load_b(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], WMMA.ColMajor, conf)) end -@inline function load_c(::Type{WMMADualOp{M, N, K}}, ::Type{Layout.SplitColMajor{Float32}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{M, N, K, Float32} +@inline function load_c(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{AT}}, workspace, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise(tile.index, (size(workspace)[1], size(workspace)[2])) return (WMMA.load_c(pointer(workspace, ind), size(workspace)[1], WMMA.ColMajor, conf), WMMA.load_c(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), size(workspace)[1], WMMA.ColMajor, conf)) end -@inline function store_d(::Type{WMMADualOp{M, N, K}}, ::Type{Layout.SplitColMajor{Float32}}, workspace, frag, tile::Tile) where {M, N, K} - conf = WMMA.Config{M, N, K, Float32} +@inline function store_d(::Type{WMMADualOp{M, N, K, CT, AT}}, ::Type{Layout.SplitColMajor{AT}}, workspace, frag, tile::Tile) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} ind = linearise(tile.index, (size(workspace)[1], size(workspace)[2])) WMMA.store_d(pointer(workspace, ind), frag[1], size(workspace)[1], WMMA.ColMajor, conf) WMMA.store_d(pointer(workspace, ind + size(workspace)[1] * size(workspace)[2]), frag[2], size(workspace)[1], WMMA.ColMajor, conf) end -@inline function mma(::Type{WMMADualOp{M, N, K}}, a_frag, b_frag, c_frag) where {M, N, K} - conf = WMMA.Config{16, 16, 16, Float32} +@inline function mma(::Type{WMMADualOp{M, N, K, CT, AT}}, a_frag, b_frag, c_frag) where {M, N, K, CT, AT} + conf = WMMA.Config{M, N, K, AT} c_re = c_frag[1] c_du = c_frag[2]