From 716162a5f221d7b4c1437b34a1dcc1d9f9117cdd Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 29 Jun 2023 11:06:06 +0200 Subject: [PATCH] FPUOp: Decouple shared memory type from register types. --- src/blas.jl | 25 ++++++++-------- src/operator.jl | 79 ++++++++++++++++++++++++++----------------------- 2 files changed, 55 insertions(+), 49 deletions(-) diff --git a/src/blas.jl b/src/blas.jl index d30ab307..ae6dd1d2 100644 --- a/src/blas.jl +++ b/src/blas.jl @@ -43,23 +43,25 @@ function gemmEx!(transA::Char, transB::Char, alpha::Number, A::CuMatrix, B::CuMa end global_d_layout = Layout.AlignedColMajor{eltype(C)} - # determine shared memory layouts (padded to avoid bank conflicts) + # determine shared memory layouts compute_type = promote_type(eltype(A), eltype(B)) - ## padded to avoid bank conflicts, and converting values to the compute type domain - shared_a_layout = Layout.Padded{a_layout_base{compute_type}, 8} - shared_b_layout = Layout.Padded{b_layout_base{compute_type}, 8} - ## outputs are never transposed - ## XXX: why not padded? - storage_type = eltype(C) - shared_c_layout = shared_d_layout = Layout.AlignedColMajor{storage_type} + ## padded to avoid bank conflicts + shared_a_layout = Layout.Padded{a_layout_base{eltype(A)}, 8} + shared_b_layout = Layout.Padded{b_layout_base{eltype(B)}, 8} + ## outputs are never transposed, and padding them doesn't seem worth it + shared_c_layout = shared_d_layout = Layout.AlignedColMajor{eltype(C)} wmma_types = [ (Float16, Float16, Float16), (Float16, Float16, Float32), # TODO: more, and device-capability dependent ] - conf = if something(wmma, (eltype(A), eltype(B), eltype(C)) in wmma_types) - println("WMMA") + conf = if something(wmma, (compute_type, compute_type, eltype(C)) in wmma_types) + # in the case of WMMA, the shared memory needs to have the correct type already, + # as we'll use WMMA intrinsics to load from it. + shared_a_layout = Layout.Padded{a_layout_base{compute_type}, 8} + shared_b_layout = Layout.Padded{b_layout_base{compute_type}, 8} + GemmKernels.get_config(; gemm_shape = (M = m, N = n, K = k), operator = Operator.WMMAOp{16, 16, 16, eltype(C)}, @@ -71,11 +73,10 @@ function gemmEx!(transA::Char, transB::Char, alpha::Number, A::CuMatrix, B::CuMa is_b_col_major = !transpose_b ) else - println("FPU") GemmKernels.get_config(; gemm_shape = (M = m, N = n, K = k), block_shape = (M = 128, N = 128, K = 32), - operator = Operator.FPUOp{8, 8, 1, storage_type, compute_type}, + operator = Operator.FPUOp{8, 8, 1, compute_type}, global_a_layout, global_b_layout, global_c_layout, global_d_layout, shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout, diff --git a/src/operator.jl b/src/operator.jl index 0c6785b9..e1523f04 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -21,23 +21,25 @@ end # FPU # --- -abstract type GeneralFPUOp{M, N, K, DT, CT} end +# CT is the compute type used to perform scalar operations in. +# only a single type is used, as GPUs don't expose native mixed-mode arithmetic. +abstract type GeneralFPUOp{M, N, K, CT} end -@inline shape(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}) where {M, N, K, DT, CT} = (M = M, N = N, K = K) +@inline shape(::Type{<:GeneralFPUOp{M, N, K, CT}}) where {M, N, K, CT} = (M = M, N = N, K = K) for (layout_type, convert_index_func) in [ (Layout.AlignedColMajor, identity), (Layout.AlignedRowMajor, x -> reverse(Tuple(x))) ] @eval begin - @inline fragtype_a(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{CT}}) where {M, N, K, DT, CT} = NTuple{M * K ÷ 4, CT} - @inline fragtype_b(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{CT}}) where {M, N, K, DT, CT} = NTuple{K * N ÷ 8, CT} + @inline fragtype_a(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, DT} = NTuple{M * K ÷ 4, CT} + @inline fragtype_b(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, DT} = NTuple{K * N ÷ 8, CT} - @inline function fragtype_accum(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{DT}}) where {M, N, K, DT, CT} - return NTuple{M * N ÷ 32, DT} + @inline function fragtype_accum(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, DT} + return NTuple{M * N ÷ 32, CT} end - @inline function load_a(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, DT, CT} + @inline function load_a(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, DT} laneId = (threadIdx().x - 1) % 32 + 1 op_y = (laneId - 1) % 4 + 1 @@ -54,7 +56,7 @@ for (layout_type, convert_index_func) in [ return NTuple{M * K ÷ 4, CT}(frag) end - @inline function load_b(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, DT, CT} + @inline function load_b(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, DT} laneId = (threadIdx().x - 1) % 32 + 1 op_x = (laneId - 1) ÷ 4 + 1 @@ -71,7 +73,7 @@ for (layout_type, convert_index_func) in [ return NTuple{K * N ÷ 8, CT}(frag) end - @inline function load_c(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, DT, CT} + @inline function load_c(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, DT} laneId = (threadIdx().x - 1) % 32 + 1 op_y = (laneId - 1) % 4 + 1 @@ -79,17 +81,17 @@ for (layout_type, convert_index_func) in [ y, x = (tile.base.M + tile.offset.M + op_y, tile.base.N + tile.offset.N + op_x) - frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, DT}(undef) + frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, CT}(undef) @loopinfo unroll for m = 1 : M ÷ 4 @loopinfo unroll for n = 1 : N ÷ 8 @inbounds frag = setindex(frag, workspace[y + 4 * (m - 1), x + 8 * (n - 1)], m, n) end end - return NTuple{M * N ÷ 32, DT}(frag) + return NTuple{M * N ÷ 32, CT}(frag) end - @inline function store_d(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{DT}}, workspace, frag, tile::Tile) where {M, N, K, DT, CT} + @inline function store_d(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}, workspace, frag, tile::Tile) where {M, N, K, CT, DT} laneId = (threadIdx().x - 1) % 32 + 1 op_y = (laneId - 1) % 4 + 1 @@ -97,7 +99,7 @@ for (layout_type, convert_index_func) in [ y, x = (tile.base.M + tile.offset.M + op_y, tile.base.N + tile.offset.N + op_x) - frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, DT}(frag) + frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, CT}(frag) @loopinfo unroll for m = 1 : M ÷ 4 @loopinfo unroll for n = 1 : N ÷ 8 @inbounds workspace[y + 4 * (m - 1), x + 8 * (n - 1)] = frag[m, n] @@ -107,20 +109,20 @@ for (layout_type, convert_index_func) in [ end end -abstract type FPUOp{M, N, K, DT, CT} <: GeneralFPUOp{M, N, K, DT, CT} end -function operator_fma(::Type{FPUOp{M, N, K, DT, CT}}, a::CT, b::CT, c::DT) where {M, N, K, DT, CT} +abstract type FPUOp{M, N, K, CT} <: GeneralFPUOp{M, N, K, CT} end +function operator_fma(::Type{FPUOp{M, N, K, CT}}, a::CT, b::CT, c::CT) where {M, N, K, CT} return fma(a, b, c) end -abstract type TropicalFPUOp{M, N, K, DT, CT} <: GeneralFPUOp{M, N, K, DT, CT} end -function operator_fma(::Type{TropicalFPUOp{M, N, K, DT, CT}}, a::CT, b::CT, c::DT) where {M, N, K, DT, CT} +abstract type TropicalFPUOp{M, N, K, CT} <: GeneralFPUOp{M, N, K, CT} end +function operator_fma(::Type{TropicalFPUOp{M, N, K, CT}}, a::CT, b::CT, c::CT) where {M, N, K, CT} return max(a + b, c) end -@inline function mma(operator_type::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, a_frag, b_frag, c_frag) where {M, N, K, DT, CT} +@inline function mma(operator_type::Type{<:GeneralFPUOp{M, N, K, CT}}, a_frag, b_frag, c_frag) where {M, N, K, CT} a_frag = LocalArray{Tuple{M ÷ 4, K}, CT}(a_frag) b_frag = LocalArray{Tuple{K, N ÷ 8}, CT}(b_frag) - c_frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, DT}(c_frag) + c_frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, CT}(c_frag) @loopinfo unroll for m = 1 : M ÷ 4 @loopinfo unroll for n = 1 : N ÷ 8 @@ -134,16 +136,19 @@ end end end - return NTuple{M * N ÷ 32, DT}(c_frag) + return NTuple{M * N ÷ 32, CT}(c_frag) end # ---- # WMMA # ---- -struct WMMAOp{M, N, K, T} end +# AT is the element type of the accumulator. the compute type is the same as the data type +# derived from the shared memory layout. this is because we use WMMA intrinsics to +# load/store the shared memory, and thus cannot perform a conversion like the FPU operator. +struct WMMAOp{M, N, K, AT} end -@inline shape(::Type{WMMAOp{M, N, K, T}}) where {M, N, K, T} = (M = M, N = N, K = K) +@inline shape(::Type{WMMAOp{M, N, K, AT}}) where {M, N, K, AT} = (M = M, N = N, K = K) # convert_index_func: function used to transpose the index in case of a row-major layout for (layout_type, wmma_layout_type, convert_index_func) in [ @@ -151,12 +156,12 @@ for (layout_type, wmma_layout_type, convert_index_func) in [ (Layout.AlignedRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x))) ] @eval begin - @inline fragtype_a(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{Float16}}) where {T} = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixA} - @inline fragtype_b(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{Float16}}) where {T} = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixB} - @inline fragtype_accum(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{T}}) where {T} = WMMA.Fragment{16, 16, 16, 8, T, WMMA.Unspecified, WMMA.Accumulator} + @inline fragtype_a(::Type{WMMAOp{16, 16, 16, AT}}, ::Type{$layout_type{DT}}) where {AT, DT} = WMMA.Fragment{16, 16, 16, 16, DT, $wmma_layout_type, WMMA.MatrixA} + @inline fragtype_b(::Type{WMMAOp{16, 16, 16, AT}}, ::Type{$layout_type{DT}}) where {AT, DT} = WMMA.Fragment{16, 16, 16, 16, DT, $wmma_layout_type, WMMA.MatrixB} + @inline fragtype_accum(::Type{WMMAOp{16, 16, 16, AT}}, ::Type{$layout_type{AT}}) where {AT} = WMMA.Fragment{16, 16, 16, 8, AT, WMMA.Unspecified, WMMA.Accumulator} - @inline function load_a(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, T} - conf = WMMA.Config{M, N, K, T} + @inline function load_a(::Type{WMMAOp{M, N, K, AT}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, AT} + conf = WMMA.Config{M, N, K, AT} linear_base = linearise($convert_index_func(tile.base), size(workspace)) linear_offset = linearise($convert_index_func(tile.offset), size(workspace)) @@ -165,8 +170,8 @@ for (layout_type, wmma_layout_type, convert_index_func) in [ return WMMA.load_a(ptr, size(workspace, 1), $wmma_layout_type, conf) end - @inline function load_b(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, T} - conf = WMMA.Config{M, N, K, T} + @inline function load_b(::Type{WMMAOp{M, N, K, AT}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, AT} + conf = WMMA.Config{M, N, K, AT} linear_base = linearise($convert_index_func(tile.base), size(workspace)) linear_offset = linearise($convert_index_func(tile.offset), size(workspace)) @@ -175,30 +180,30 @@ for (layout_type, wmma_layout_type, convert_index_func) in [ return WMMA.load_b(ptr, size(workspace, 1), $wmma_layout_type, conf) end - @inline function load_c(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{T}}, workspace, tile::Tile) where {M, N, K, T} - conf = WMMA.Config{M, N, K, T} + @inline function load_c(::Type{WMMAOp{M, N, K, AT}}, ::Type{$layout_type{AT}}, workspace, tile::Tile) where {M, N, K, AT} + conf = WMMA.Config{M, N, K, AT} linear_base = linearise($convert_index_func(tile.base), size(workspace)) linear_offset = linearise($convert_index_func(tile.offset), size(workspace)) - ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(T) + ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(AT) return WMMA.load_c(ptr, size(workspace, 1), $wmma_layout_type, conf) end - @inline function store_d(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{T}}, workspace, frag, tile::Tile) where {M, N, K, T} - conf = WMMA.Config{M, N, K, T} + @inline function store_d(::Type{WMMAOp{M, N, K, AT}}, ::Type{$layout_type{AT}}, workspace, frag, tile::Tile) where {M, N, K, AT} + conf = WMMA.Config{M, N, K, AT} linear_base = linearise($convert_index_func(tile.base), size(workspace)) linear_offset = linearise($convert_index_func(tile.offset), size(workspace)) - ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(T) + ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(AT) WMMA.store_d(ptr, frag, size(workspace, 1), $wmma_layout_type, conf) end end end -function mma(::Type{WMMAOp{M, N, K, T}}, a_frag, b_frag, c_frag) where {M, N, K, T} - conf = WMMA.Config{M, N, K, T} +function mma(::Type{WMMAOp{M, N, K, AT}}, a_frag, b_frag, c_frag) where {M, N, K, AT} + conf = WMMA.Config{M, N, K, AT} return WMMA.mma(a_frag, b_frag, c_frag, conf) end