Skip to content

Commit

Permalink
simplify a few things
Browse files Browse the repository at this point in the history
  • Loading branch information
smnbl committed Feb 18, 2022
1 parent 102f7b1 commit 113357e
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 84 deletions.
10 changes: 6 additions & 4 deletions src/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -169,17 +169,19 @@ function get_config(; gemm_shape, operator, global_a_layout, global_c_layout, kw
block_shape = get(params, :block_shape,
heuristic_block_shape(shared_a_layout, shared_b_layout, shared_c_layout, shared_d_layout))

# make sure block shape fits grid
block_shape = (M = min(block_shape.M, gemm_shape.M)
, N = min(block_shape.N, gemm_shape.N)
, K = min(block_shape.K, gemm_shape.K))

# 8 warps in a 4 x 2 arrangement usually works well
# TODO: shouldn't this be determined based on compute_warp ??
# TODO: base this on the compute shape?
warps_per_block = get(params, :warps_per_block, 8)
op_shape = Operator.shape(operator)
# this is incorrect if someone only changes warps per block!
# TODO: calculate warps_per_block based on compute_warp
compute_warp = get(params, :compute_warp,
(M = block_shape.M ÷ 4, N = block_shape.N ÷ 2, K = op_shape.K))

# Is the layout col-major or not? This is needed to find good values for mem_a_warp, mem_b_warp, etc.
# TODO: Let the layouts handle this?
is_a_col_major = get(params, :is_a_col_major, true)
is_b_col_major = get(params, :is_b_col_major, true)
is_cd_col_major = get(params, :is_cd_col_major, true)
Expand Down
1 change: 1 addition & 0 deletions src/layout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ end
abstract type AlignedColMajor{T} <: LayoutBase{T} end

@inline physical_size(::Type{<:Padded{AlignedColMajor{T}, P}}, logical_size::NamedTuple) where {T, P} = (logical_size[1] + P, logical_size[2])
@inline physical_size(::Type{<:AlignedColMajor{T}}, logical_size::NamedTuple) where {T} = (logical_size[1] , logical_size[2])

@inline fragtype(::Type{<:AlignedColMajor{T}}, tile_size::NamedTuple) where {T} = NTuple{16 ÷ sizeof(T), VecElement{T}}

Expand Down
89 changes: 25 additions & 64 deletions src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@ end
# ----
# SIMT Op
# ----
struct SIMTOp{M, N, K} end
struct SIMTOp end

@inline shape(::Type{SIMTOp{M, N, K}}) where {M, N, K} = (M = M*8, N = N*4, K = K*1)

@inline tuple_len(::Type{NTuple{N, T}}) where {N, T} = N
@inline shape(::Type{SIMTOp}) where {M, N, K} = (M = 8, N = 4, K = 1)

# convert_index_func: function used to transpose the index in case of a row-major layout
# 2 types of optimizations:
Expand All @@ -34,88 +32,51 @@ for (layout_type, convert_index_func) in [

@eval begin
# fragtype: type of per thread input/result of mma operations
@inline fragtype_a(::Type{SIMTOp{M, N, K}}, ::Type{$layout_type{T}}) where {M, N, K, T} = NTuple{M, T}
@inline fragtype_a(::Type{SIMTOp}, ::Type{$layout_type{T}}) where {T} = T

@inline fragtype_b(::Type{SIMTOp{M, N, K}}, ::Type{$layout_type{T}}) where {M, N, K, T} = NTuple{N, T}
@inline fragtype_accum(::Type{SIMTOp{M, N, K}}, ::Type{$layout_type{T}}) where {M, N, K, T} = NTuple{M * N, T}
@inline fragtype_b(::Type{SIMTOp}, ::Type{$layout_type{T}}) where {T} = T
@inline fragtype_accum(::Type{SIMTOp}, ::Type{$layout_type{T}}) where {T} = T

# define load / stores based on layout types
@inline function load_a(::Type{SIMTOp{M, N, K}}, ::Type{$layout_type{T}}, workspace, tile::Tile) where {M, N, K, T}
@inline function load_a(::Type{SIMTOp}, ::Type{$layout_type{T}}, workspace, tile::Tile) where {T}
laneId = (threadIdx().x - 1) % 32 + 1
ret = ntuple(i -> zero(T), Val(M))

@unroll for tile_row = 1:M
row = ((laneId - 1) % 8)*M + tile_row - 1

y, x = $convert_index_func((tile.base.M + tile.offset.M + 1 + row, tile.base.K + tile.offset.K + 1))
@inbounds val = workspace[y, x]
ret = Base.setindex(ret, val, tile_row)
end
row = ((laneId - 1) % 8)

return ret
y, x = $convert_index_func((tile.base.M + tile.offset.M + 1 + row, tile.base.K + tile.offset.K + 1))
@inbounds return workspace[y, x]
end

@inline function load_b(::Type{SIMTOp{M, N, K}}, ::Type{$layout_type{T}}, workspace, tile::Tile) where {M, N, K, T}
@inline function load_b(::Type{SIMTOp}, ::Type{$layout_type{T}}, workspace, tile::Tile) where {T}
laneId = (threadIdx().x - 1) % 32 + 1
ret = ntuple(i -> zero(T), Val(N))
col = ((laneId - 1) ÷ 8)

@unroll for tile_col = 1:N
col = ((laneId - 1) ÷ 8)*N + tile_col - 1

y, x = $convert_index_func((tile.base.K + tile.offset.K + 1, tile.base.N + tile.offset.N + 1 + col))
@inbounds val = workspace[y, x]
ret = Base.setindex(ret, val, tile_col)
end

return ret
y, x = $convert_index_func((tile.base.K + tile.offset.K + 1, tile.base.N + tile.offset.N + 1 + col))
@inbounds return workspace[y, x]
end

@inline function load_c(op::Type{SIMTOp{M, N, K}}, layout::Type{$layout_type{T}}, workspace, tile::Tile) where {M, N, K, T}
@inline function load_c(op::Type{SIMTOp}, layout::Type{$layout_type{T}}, workspace, tile::Tile) where {T}
laneId = (threadIdx().x - 1) % 32 + 1
ret = Base.ntuple(i -> zero(T), Val(M * N))

@unroll for tile_row = 1:M
@unroll for tile_col = 1:N
row = ((laneId - 1) % 8)*M + tile_row - 1
col = ((laneId - 1) ÷ 8)*N + tile_col - 1
row = ((laneId - 1) % 8)
col = ((laneId - 1) ÷ 8)

y, x = $convert_index_func((tile.base.M + tile.offset.M + 1 + row, tile.base.N + tile.offset.N + 1 + col))
@inbounds val = workspace[y, x]
ret = Base.setindex(ret, val, (tile_col - 1)*M + tile_row)
end

end

return ret
y, x = $convert_index_func((tile.base.M + tile.offset.M + 1 + row, tile.base.N + tile.offset.N + 1 + col))
@inbounds return workspace[y, x]
end

@inline function store_d(::Type{SIMTOp{M, N, K}}, ::Type{$layout_type{T}}, workspace, frag, tile::Tile) where {M, N, K, T}
@inline function store_d(::Type{SIMTOp}, ::Type{$layout_type{T}}, workspace, frag, tile::Tile) where {T}
laneId = (threadIdx().x - 1) % 32 + 1

@unroll for tile_row = 1:M
@unroll for tile_col = 1:N
row = ((laneId - 1) % 8)*M + tile_row - 1
col = ((laneId - 1) ÷ 8)*N + tile_col - 1
row = ((laneId - 1) % 8)
col = ((laneId - 1) ÷ 8)

y, x = $convert_index_func((tile.base.M + tile.offset.M + 1 + row, tile.base.N + tile.offset.N + 1 + col))
@inbounds workspace[y, x] = frag[(tile_col - 1)*M + tile_row]
end
end
y, x = $convert_index_func((tile.base.M + tile.offset.M + 1 + row, tile.base.N + tile.offset.N + 1 + col))
@inbounds workspace[y, x] = frag
end
end
end

@inline function mma(op::Type{SIMTOp{M, N, K}}, a_frag, b_frag, c_frag) where {M, N, K}
ret = Base.ntuple(i -> zero(eltype(c_frag)), Val(M * N))

@unroll for tile_row = 1:M
@unroll for tile_col = 1:N
@inbounds val = a_frag[tile_row]*b_frag[tile_col] + c_frag[(tile_col - 1)*M + tile_row]
ret = Base.setindex(ret, val, (tile_col - 1)*M + tile_row)
end
end

return ret
@inline function mma(op::Type{SIMTOp}, a_frag, b_frag, c_frag)
@inbounds return a_frag * b_frag + c_frag
end

# ----
Expand Down
37 changes: 21 additions & 16 deletions test/matmul.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
using CUDA
using CUDA: unsafe_free!
using ForwardDiff
using GemmKernels
using LinearAlgebra

################################################################################

@testset "Matmul API" begin
@test_if "simt" @testset "SIMT GEMM" begin
for dtype = [Int, Float32, Complex], transpose_a = [false, true], transpose_b = [false, true],
(M, N, K) in [(128, 128, 128), (256, 256, 128), (128, 128, 256), (256, 256, 256), (2048, 2048, 2048)]
@testset "$( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ); M = $M, N = $N, K = $K" begin
alpha = 2
beta = 3

a_h = rand(dtype, (M, K)) / sqrt(dtype(K))
b_h = rand(dtype, (K, N)) / sqrt(dtype(K))
@test_if "simt" @testset "SIMT GEMM $(dtype)x$(dtype)+$(dtype)=$(dtype) - $( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ); M = $M, N = $N, K = $K" for
dtype = [Int16, Int32, Int64, Float16, Float32, Float64, ComplexF16, ComplexF32],
transpose_a = [false, true], transpose_b = [false, true],
(M, N, K) in [(128, 128, 128), (256, 256, 128), (128, 128, 256), (256, 256, 256), (1024, 1024, 1024)]

if real(dtype) <: AbstractFloat
# floating point types & derivatives
a_h = rand(dtype, (M, K)) / sqrt(dtype(K))
b_h = rand(dtype, (K, N)) / sqrt(dtype(K))
else
# integer types & derivatives
a_h = floor.(dtype, rand(dtype, (M, K)) / sqrt(dtype(K)))
b_h = floor.(dtype, rand(dtype, (K, N)) / sqrt(dtype(K)))
end

c_h = rand(dtype, (M, N))

# Transpose input if necessary
Expand All @@ -28,30 +35,28 @@ using LinearAlgebra

conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
operator = Operator.SIMTOp{8, 4, 1},
operator = Operator.SIMTOp,
global_a_layout = transpose_a ? Layout.AlignedRowMajor{eltype(a)} : Layout.AlignedColMajor{eltype(a)},
global_b_layout = transpose_b ? Layout.AlignedRowMajor{eltype(b)} : Layout.AlignedColMajor{eltype(b)},

global_c_layout = Layout.AlignedColMajor{eltype(c)},
global_d_layout = Layout.AlignedColMajor{eltype(d)},

is_a_col_major = !transpose_a,

is_b_col_major = !transpose_b,
)

GemmKernels.matmul(a, b, c, d, conf;
transform_shared_to_regs_a = Transform.Elementwise(x -> x * alpha),
transform_shared_to_regs_c = Transform.Elementwise(x -> x * beta),
kernel = Kernel.matmul_pipelined
)

# Transpose outputs, if necessary
new_a_h = transpose_a ? transpose(a_h) : a_h
new_b_h = transpose_b ? transpose(b_h) : b_h

@test all(isapprox.(alpha * new_a_h * new_b_h + beta * c_h, Array(d); rtol = sqrt(eps(dtype))))
end
end
rtol = (real(dtype) <: AbstractFloat) ? 1.0 : 0
@test all(isapprox.(new_a_h * new_b_h + c_h, Array(d); rtol = rtol))
end

@test_if "wmma" @testset "WMMA GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
Expand Down Expand Up @@ -153,7 +158,7 @@ using LinearAlgebra
end

@test_if "diagonal" @testset "WMMA GEMM (A = diagonal, B = $( !transpose_b ? 'N' : 'T' ))" for transpose_b = [false, true]
@testset "(M = $M, N = $N, K = $K)" for (M, N, K) in [(128, 128, 128), (256, 256, 256), (4096, 4096, 4096)]
@testset "(M = $M, N = $N, K = $K)" for (M, N, K) in [(128, 128, 128), (256, 256, 256), (2048, 2048, 2048)]
@assert M == K "Diagonal only supports square A matrix (M == K)"

transpose_a = false
Expand Down

0 comments on commit 113357e

Please sign in to comment.