Skip to content

Commit

Permalink
FPUOp: Decouple shared memory type from register types.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Jun 29, 2023
1 parent 33381f2 commit 716162a
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 49 deletions.
25 changes: 13 additions & 12 deletions src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,23 +43,25 @@ function gemmEx!(transA::Char, transB::Char, alpha::Number, A::CuMatrix, B::CuMa
end
global_d_layout = Layout.AlignedColMajor{eltype(C)}

# determine shared memory layouts (padded to avoid bank conflicts)
# determine shared memory layouts
compute_type = promote_type(eltype(A), eltype(B))
## padded to avoid bank conflicts, and converting values to the compute type domain
shared_a_layout = Layout.Padded{a_layout_base{compute_type}, 8}
shared_b_layout = Layout.Padded{b_layout_base{compute_type}, 8}
## outputs are never transposed
## XXX: why not padded?
storage_type = eltype(C)
shared_c_layout = shared_d_layout = Layout.AlignedColMajor{storage_type}
## padded to avoid bank conflicts
shared_a_layout = Layout.Padded{a_layout_base{eltype(A)}, 8}
shared_b_layout = Layout.Padded{b_layout_base{eltype(B)}, 8}
## outputs are never transposed, and padding them doesn't seem worth it
shared_c_layout = shared_d_layout = Layout.AlignedColMajor{eltype(C)}

wmma_types = [
(Float16, Float16, Float16),
(Float16, Float16, Float32),
# TODO: more, and device-capability dependent
]
conf = if something(wmma, (eltype(A), eltype(B), eltype(C)) in wmma_types)
println("WMMA")
conf = if something(wmma, (compute_type, compute_type, eltype(C)) in wmma_types)
# in the case of WMMA, the shared memory needs to have the correct type already,
# as we'll use WMMA intrinsics to load from it.
shared_a_layout = Layout.Padded{a_layout_base{compute_type}, 8}
shared_b_layout = Layout.Padded{b_layout_base{compute_type}, 8}

GemmKernels.get_config(;
gemm_shape = (M = m, N = n, K = k),
operator = Operator.WMMAOp{16, 16, 16, eltype(C)},
Expand All @@ -71,11 +73,10 @@ function gemmEx!(transA::Char, transB::Char, alpha::Number, A::CuMatrix, B::CuMa
is_b_col_major = !transpose_b
)
else
println("FPU")
GemmKernels.get_config(;
gemm_shape = (M = m, N = n, K = k),
block_shape = (M = 128, N = 128, K = 32),
operator = Operator.FPUOp{8, 8, 1, storage_type, compute_type},
operator = Operator.FPUOp{8, 8, 1, compute_type},

global_a_layout, global_b_layout, global_c_layout, global_d_layout,
shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout,
Expand Down
79 changes: 42 additions & 37 deletions src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,25 @@ end
# FPU
# ---

abstract type GeneralFPUOp{M, N, K, DT, CT} end
# CT is the compute type used to perform scalar operations in.
# only a single type is used, as GPUs don't expose native mixed-mode arithmetic.
abstract type GeneralFPUOp{M, N, K, CT} end

@inline shape(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}) where {M, N, K, DT, CT} = (M = M, N = N, K = K)
@inline shape(::Type{<:GeneralFPUOp{M, N, K, CT}}) where {M, N, K, CT} = (M = M, N = N, K = K)

for (layout_type, convert_index_func) in [
(Layout.AlignedColMajor, identity),
(Layout.AlignedRowMajor, x -> reverse(Tuple(x)))
]
@eval begin
@inline fragtype_a(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{CT}}) where {M, N, K, DT, CT} = NTuple{M * K ÷ 4, CT}
@inline fragtype_b(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{CT}}) where {M, N, K, DT, CT} = NTuple{K * N ÷ 8, CT}
@inline fragtype_a(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, DT} = NTuple{M * K ÷ 4, CT}
@inline fragtype_b(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, DT} = NTuple{K * N ÷ 8, CT}

@inline function fragtype_accum(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{DT}}) where {M, N, K, DT, CT}
return NTuple{M * N ÷ 32, DT}
@inline function fragtype_accum(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}) where {M, N, K, CT, DT}
return NTuple{M * N ÷ 32, CT}
end

@inline function load_a(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, DT, CT}
@inline function load_a(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, DT}
laneId = (threadIdx().x - 1) % 32 + 1

op_y = (laneId - 1) % 4 + 1
Expand All @@ -54,7 +56,7 @@ for (layout_type, convert_index_func) in [
return NTuple{M * K ÷ 4, CT}(frag)
end

@inline function load_b(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{CT}}, workspace, tile::Tile) where {M, N, K, DT, CT}
@inline function load_b(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, DT}
laneId = (threadIdx().x - 1) % 32 + 1

op_x = (laneId - 1) ÷ 4 + 1
Expand All @@ -71,33 +73,33 @@ for (layout_type, convert_index_func) in [
return NTuple{K * N ÷ 8, CT}(frag)
end

@inline function load_c(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, DT, CT}
@inline function load_c(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}, workspace, tile::Tile) where {M, N, K, CT, DT}
laneId = (threadIdx().x - 1) % 32 + 1

op_y = (laneId - 1) % 4 + 1
op_x = (laneId - 1) ÷ 4 + 1

y, x = (tile.base.M + tile.offset.M + op_y, tile.base.N + tile.offset.N + op_x)

frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, DT}(undef)
frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, CT}(undef)
@loopinfo unroll for m = 1 : M ÷ 4
@loopinfo unroll for n = 1 : N ÷ 8
@inbounds frag = setindex(frag, workspace[y + 4 * (m - 1), x + 8 * (n - 1)], m, n)
end
end

return NTuple{M * N ÷ 32, DT}(frag)
return NTuple{M * N ÷ 32, CT}(frag)
end

@inline function store_d(::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, ::Type{$layout_type{DT}}, workspace, frag, tile::Tile) where {M, N, K, DT, CT}
@inline function store_d(::Type{<:GeneralFPUOp{M, N, K, CT}}, ::Type{$layout_type{DT}}, workspace, frag, tile::Tile) where {M, N, K, CT, DT}
laneId = (threadIdx().x - 1) % 32 + 1

op_y = (laneId - 1) % 4 + 1
op_x = (laneId - 1) ÷ 4 + 1

y, x = (tile.base.M + tile.offset.M + op_y, tile.base.N + tile.offset.N + op_x)

frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, DT}(frag)
frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, CT}(frag)
@loopinfo unroll for m = 1 : M ÷ 4
@loopinfo unroll for n = 1 : N ÷ 8
@inbounds workspace[y + 4 * (m - 1), x + 8 * (n - 1)] = frag[m, n]
Expand All @@ -107,20 +109,20 @@ for (layout_type, convert_index_func) in [
end
end

abstract type FPUOp{M, N, K, DT, CT} <: GeneralFPUOp{M, N, K, DT, CT} end
function operator_fma(::Type{FPUOp{M, N, K, DT, CT}}, a::CT, b::CT, c::DT) where {M, N, K, DT, CT}
abstract type FPUOp{M, N, K, CT} <: GeneralFPUOp{M, N, K, CT} end
function operator_fma(::Type{FPUOp{M, N, K, CT}}, a::CT, b::CT, c::CT) where {M, N, K, CT}
return fma(a, b, c)
end

abstract type TropicalFPUOp{M, N, K, DT, CT} <: GeneralFPUOp{M, N, K, DT, CT} end
function operator_fma(::Type{TropicalFPUOp{M, N, K, DT, CT}}, a::CT, b::CT, c::DT) where {M, N, K, DT, CT}
abstract type TropicalFPUOp{M, N, K, CT} <: GeneralFPUOp{M, N, K, CT} end
function operator_fma(::Type{TropicalFPUOp{M, N, K, CT}}, a::CT, b::CT, c::CT) where {M, N, K, CT}
return max(a + b, c)
end

@inline function mma(operator_type::Type{<:GeneralFPUOp{M, N, K, DT, CT}}, a_frag, b_frag, c_frag) where {M, N, K, DT, CT}
@inline function mma(operator_type::Type{<:GeneralFPUOp{M, N, K, CT}}, a_frag, b_frag, c_frag) where {M, N, K, CT}
a_frag = LocalArray{Tuple{M ÷ 4, K}, CT}(a_frag)
b_frag = LocalArray{Tuple{K, N ÷ 8}, CT}(b_frag)
c_frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, DT}(c_frag)
c_frag = LocalArray{Tuple{M ÷ 4, N ÷ 8}, CT}(c_frag)

@loopinfo unroll for m = 1 : M ÷ 4
@loopinfo unroll for n = 1 : N ÷ 8
Expand All @@ -134,29 +136,32 @@ end
end
end

return NTuple{M * N ÷ 32, DT}(c_frag)
return NTuple{M * N ÷ 32, CT}(c_frag)
end

# ----
# WMMA
# ----

struct WMMAOp{M, N, K, T} end
# AT is the element type of the accumulator. the compute type is the same as the data type
# derived from the shared memory layout. this is because we use WMMA intrinsics to
# load/store the shared memory, and thus cannot perform a conversion like the FPU operator.
struct WMMAOp{M, N, K, AT} end

@inline shape(::Type{WMMAOp{M, N, K, T}}) where {M, N, K, T} = (M = M, N = N, K = K)
@inline shape(::Type{WMMAOp{M, N, K, AT}}) where {M, N, K, AT} = (M = M, N = N, K = K)

# convert_index_func: function used to transpose the index in case of a row-major layout
for (layout_type, wmma_layout_type, convert_index_func) in [
(Layout.AlignedColMajor, WMMA.ColMajor, identity),
(Layout.AlignedRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x)))
]
@eval begin
@inline fragtype_a(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{Float16}}) where {T} = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixA}
@inline fragtype_b(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{Float16}}) where {T} = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixB}
@inline fragtype_accum(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{T}}) where {T} = WMMA.Fragment{16, 16, 16, 8, T, WMMA.Unspecified, WMMA.Accumulator}
@inline fragtype_a(::Type{WMMAOp{16, 16, 16, AT}}, ::Type{$layout_type{DT}}) where {AT, DT} = WMMA.Fragment{16, 16, 16, 16, DT, $wmma_layout_type, WMMA.MatrixA}
@inline fragtype_b(::Type{WMMAOp{16, 16, 16, AT}}, ::Type{$layout_type{DT}}) where {AT, DT} = WMMA.Fragment{16, 16, 16, 16, DT, $wmma_layout_type, WMMA.MatrixB}
@inline fragtype_accum(::Type{WMMAOp{16, 16, 16, AT}}, ::Type{$layout_type{AT}}) where {AT} = WMMA.Fragment{16, 16, 16, 8, AT, WMMA.Unspecified, WMMA.Accumulator}

@inline function load_a(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, T}
conf = WMMA.Config{M, N, K, T}
@inline function load_a(::Type{WMMAOp{M, N, K, AT}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, AT}
conf = WMMA.Config{M, N, K, AT}

linear_base = linearise($convert_index_func(tile.base), size(workspace))
linear_offset = linearise($convert_index_func(tile.offset), size(workspace))
Expand All @@ -165,8 +170,8 @@ for (layout_type, wmma_layout_type, convert_index_func) in [
return WMMA.load_a(ptr, size(workspace, 1), $wmma_layout_type, conf)
end

@inline function load_b(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, T}
conf = WMMA.Config{M, N, K, T}
@inline function load_b(::Type{WMMAOp{M, N, K, AT}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, AT}
conf = WMMA.Config{M, N, K, AT}

linear_base = linearise($convert_index_func(tile.base), size(workspace))
linear_offset = linearise($convert_index_func(tile.offset), size(workspace))
Expand All @@ -175,30 +180,30 @@ for (layout_type, wmma_layout_type, convert_index_func) in [
return WMMA.load_b(ptr, size(workspace, 1), $wmma_layout_type, conf)
end

@inline function load_c(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{T}}, workspace, tile::Tile) where {M, N, K, T}
conf = WMMA.Config{M, N, K, T}
@inline function load_c(::Type{WMMAOp{M, N, K, AT}}, ::Type{$layout_type{AT}}, workspace, tile::Tile) where {M, N, K, AT}
conf = WMMA.Config{M, N, K, AT}

linear_base = linearise($convert_index_func(tile.base), size(workspace))
linear_offset = linearise($convert_index_func(tile.offset), size(workspace))

ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(T)
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(AT)
return WMMA.load_c(ptr, size(workspace, 1), $wmma_layout_type, conf)
end

@inline function store_d(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{T}}, workspace, frag, tile::Tile) where {M, N, K, T}
conf = WMMA.Config{M, N, K, T}
@inline function store_d(::Type{WMMAOp{M, N, K, AT}}, ::Type{$layout_type{AT}}, workspace, frag, tile::Tile) where {M, N, K, AT}
conf = WMMA.Config{M, N, K, AT}

linear_base = linearise($convert_index_func(tile.base), size(workspace))
linear_offset = linearise($convert_index_func(tile.offset), size(workspace))

ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(T)
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(AT)
WMMA.store_d(ptr, frag, size(workspace, 1), $wmma_layout_type, conf)
end
end
end

function mma(::Type{WMMAOp{M, N, K, T}}, a_frag, b_frag, c_frag) where {M, N, K, T}
conf = WMMA.Config{M, N, K, T}
function mma(::Type{WMMAOp{M, N, K, AT}}, a_frag, b_frag, c_frag) where {M, N, K, AT}
conf = WMMA.Config{M, N, K, AT}
return WMMA.mma(a_frag, b_frag, c_frag, conf)
end

Expand Down

0 comments on commit 716162a

Please sign in to comment.