Skip to content

Commit

Permalink
Generalise WMMA Operator
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasfaingnaert committed Nov 16, 2021
1 parent 7ada93e commit 56e815d
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 53 deletions.
2 changes: 1 addition & 1 deletion benchmarks/diagonal/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end
function bench_gemmkernels(a, b, c, M, N, K, transpose_a, transpose_b, num_iterations = 100)
conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
operator = Operator.WMMAOp{16, 16, 16},
operator = Operator.WMMAOp{16, 16, 16, Float32},
global_a_layout = Layout.Diagonal{Float16},
global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},

Expand Down
12 changes: 6 additions & 6 deletions benchmarks/operator-fusion/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ function bench_gemmkernels(a, b, c, M, N, K, transpose_a, transpose_b)
CUDA.@sync begin
conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
operator = Operator.WMMAOp{16, 16, 16},
operator = Operator.WMMAOp{16, 16, 16, Float32},
global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},
global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},

Expand All @@ -136,7 +136,7 @@ function bench_gemmkernels_relu(a, b, c, M, N, K, transpose_a, transpose_b)
CUDA.@sync begin
conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
operator = Operator.WMMAOp{16, 16, 16},
operator = Operator.WMMAOp{16, 16, 16, Float32},
global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},
global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},

Expand All @@ -159,7 +159,7 @@ function bench_gemmkernels_bias(a, b, c, bias, M, N, K, transpose_a, transpose_b
CUDA.@sync begin
conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
operator = Operator.WMMAOp{16, 16, 16},
operator = Operator.WMMAOp{16, 16, 16, Float32},
global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},
global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},

Expand All @@ -182,7 +182,7 @@ function bench_gemmkernels_biasrelu(a, b, c, bias, M, N, K, transpose_a, transpo
CUDA.@sync begin
conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
operator = Operator.WMMAOp{16, 16, 16},
operator = Operator.WMMAOp{16, 16, 16, Float32},
global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},
global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},

Expand All @@ -206,7 +206,7 @@ function bench_gemmkernels_biasrelutwice(a, b, c, bias, M, N, K, transpose_a, tr
CUDA.@sync begin
conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
operator = Operator.WMMAOp{16, 16, 16},
operator = Operator.WMMAOp{16, 16, 16, Float32},
global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},
global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},

Expand All @@ -231,7 +231,7 @@ function bench_gemmkernels_biasrelutwice_ab_elop(a, b, c, bias, M, N, K, transpo
CUDA.@sync begin
conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
operator = Operator.WMMAOp{16, 16, 16},
operator = Operator.WMMAOp{16, 16, 16, Float32},
global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},
global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},

Expand Down
2 changes: 1 addition & 1 deletion src/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function gemmEx!(transA::Char, transB::Char, alpha::Number, A, B, beta::Number,

conf = GemmKernels.get_config(
gemm_shape = (M = m, N = n, K = k),
operator = Operator.WMMAOp{16, 16, 16},
operator = Operator.WMMAOp{16, 16, 16, eltype(C)},

global_a_layout = a_layout,
global_b_layout = b_layout,
Expand Down
34 changes: 17 additions & 17 deletions src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@ end
# WMMA
# ----

struct WMMAOp{M, N, K} end
struct WMMAOp{M, N, K, T} end

@inline shape(::Type{WMMAOp{M, N, K}}) where {M, N, K} = (M = M, N = N, K = K)
@inline shape(::Type{WMMAOp{M, N, K, T}}) where {M, N, K, T} = (M = M, N = N, K = K)

# convert_index_func: function used to transpose the index in case of a row-major layout
for (layout_type, wmma_layout_type, convert_index_func) in [
(Layout.AlignedColMajor, WMMA.ColMajor, identity),
(Layout.AlignedRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x)))
]
@eval begin
@inline fragtype_a(::Type{WMMAOp{16, 16, 16}}, ::Type{$layout_type{Float16}}) = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixA}
@inline fragtype_b(::Type{WMMAOp{16, 16, 16}}, ::Type{$layout_type{Float16}}) = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixB}
@inline fragtype_accum(::Type{WMMAOp{16, 16, 16}}, ::Type{$layout_type{Float32}}) = WMMA.Fragment{16, 16, 16, 8, Float32, WMMA.Unspecified, WMMA.Accumulator}
@inline fragtype_a(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{Float16}}) where {T} = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixA}
@inline fragtype_b(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{Float16}}) where {T} = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixB}
@inline fragtype_accum(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{T}}) where {T} = WMMA.Fragment{16, 16, 16, 8, T, WMMA.Unspecified, WMMA.Accumulator}

@inline function load_a(::Type{WMMAOp{M, N, K}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K}
conf = WMMA.Config{M, N, K, Float32}
@inline function load_a(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, T}
conf = WMMA.Config{M, N, K, T}

linear_base = linearise($convert_index_func(tile.base), size(workspace))
linear_offset = linearise($convert_index_func(tile.offset), size(workspace))
Expand All @@ -41,8 +41,8 @@ for (layout_type, wmma_layout_type, convert_index_func) in [
return WMMA.load_a(ptr, size(workspace, 1), $wmma_layout_type, conf)
end

@inline function load_b(::Type{WMMAOp{M, N, K}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K}
conf = WMMA.Config{M, N, K, Float32}
@inline function load_b(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, T}
conf = WMMA.Config{M, N, K, T}

linear_base = linearise($convert_index_func(tile.base), size(workspace))
linear_offset = linearise($convert_index_func(tile.offset), size(workspace))
Expand All @@ -51,30 +51,30 @@ for (layout_type, wmma_layout_type, convert_index_func) in [
return WMMA.load_b(ptr, size(workspace, 1), $wmma_layout_type, conf)
end

@inline function load_c(::Type{WMMAOp{M, N, K}}, ::Type{$layout_type{Float32}}, workspace, tile::Tile) where {M, N, K}
conf = WMMA.Config{M, N, K, Float32}
@inline function load_c(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{T}}, workspace, tile::Tile) where {M, N, K, T}
conf = WMMA.Config{M, N, K, T}

linear_base = linearise($convert_index_func(tile.base), size(workspace))
linear_offset = linearise($convert_index_func(tile.offset), size(workspace))

ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float32)
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(T)
return WMMA.load_c(ptr, size(workspace, 1), $wmma_layout_type, conf)
end

@inline function store_d(::Type{WMMAOp{M, N, K}}, ::Type{$layout_type{Float32}}, workspace, frag, tile::Tile) where {M, N, K}
conf = WMMA.Config{M, N, K, Float32}
@inline function store_d(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{T}}, workspace, frag, tile::Tile) where {M, N, K, T}
conf = WMMA.Config{M, N, K, T}

linear_base = linearise($convert_index_func(tile.base), size(workspace))
linear_offset = linearise($convert_index_func(tile.offset), size(workspace))

ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float32)
ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(T)
WMMA.store_d(ptr, frag, size(workspace, 1), $wmma_layout_type, conf)
end
end
end

function mma(::Type{WMMAOp{M, N, K}}, a_frag, b_frag, c_frag) where {M, N, K}
conf = WMMA.Config{M, N, K, Float32}
function mma(::Type{WMMAOp{M, N, K, T}}, a_frag, b_frag, c_frag) where {M, N, K, T}
conf = WMMA.Config{M, N, K, T}
return WMMA.mma(a_frag, b_frag, c_frag, conf)
end

Expand Down
23 changes: 12 additions & 11 deletions test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@ using LinearAlgebra
CUDA.CUBLAS.cublasSetMathMode(CUBLAS.handle(), CUBLAS.CUBLAS_TENSOR_OP_MATH)

@test_if "blas" @testset "BLAS API" begin
@testset "WMMA GEMM ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
transpose_b = [false, true]
@testset "WMMA GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
transpose_b = [false, true],
(A_type, B_type, CD_type, min_dimension) in [(Float16, Float16, Float16, 256), (Float16, Float16, Float32, 128)]

@testset "(M = $M, N = $N, K = $K)" for M in [128, 256],
N in [128, 256],
K in [128, 256]
@testset "(M = $M, N = $N, K = $K)" for M in min_dimension .* [1, 2],
N in min_dimension .* [1, 2],
K in min_dimension .* [1, 2]

alpha = rand(Float32)
beta = rand(Float32)
alpha = rand(A_type)
beta = rand(CD_type)

a_h = rand(Float16, (M, K)) / sqrt(Float16(K))
b_h = rand(Float16, (K, N)) / sqrt(Float16(K))
c_h = rand(Float32, (M, N))
a_h = rand(A_type, (M, K)) / sqrt(A_type(K))
b_h = rand(B_type, (K, N)) / sqrt(B_type(K))
c_h = rand(CD_type, (M, N))

# Transpose input if necessary
a_h = transpose_a ? transpose(a_h) : a_h
Expand All @@ -32,7 +33,7 @@ CUDA.CUBLAS.cublasSetMathMode(CUBLAS.handle(), CUBLAS.CUBLAS_TENSOR_OP_MATH)
c_cublas = CuArray(c_h)
CUDA.CUBLAS.gemmEx!(!transpose_a ? 'N' : 'T', !transpose_b ? 'N' : 'T', alpha, a, b, beta, c_cublas)

@test all(isapprox.(Array(c_gemmkernels), Array(c_cublas); rtol=sqrt(eps(Float16))));
@test all(isapprox.(Array(c_gemmkernels), Array(c_cublas); rtol=sqrt(eps(A_type))));
end
end

Expand Down
34 changes: 17 additions & 17 deletions test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@ using LinearAlgebra
################################################################################

@testset "Matmul API" begin
@test_if "wmma" @testset "WMMA GEMM ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
transpose_b = [false, true]

@testset "(M = $M, N = $N, K = $K)" for (M, N, K) in [(128, 128, 128), (256, 256, 128), (128, 128, 256), (256, 256, 256), (2048, 2048, 2048)]
alpha = 2
beta = 3
@test_if "wmma" @testset "WMMA GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
transpose_b = [false, true],
(A_type, B_type, CD_type, min_dimension) in [(Float16, Float16, Float16, 256), (Float16, Float16, Float32, 128)]
@testset "(M = $M, N = $N, K = $K)" for (M, N, K) in vcat(min_dimension.*[[1,1,1], [2,2,1], [1,1,2], [2,2,2]], [[2048, 2048, 2048]])
alpha = convert(A_type, 2)
beta = convert(CD_type, 3)

a_h = rand(Float16, (M, K)) / sqrt(Float16(K))
b_h = rand(Float16, (K, N)) / sqrt(Float16(K))
c_h = rand(Float32, (M, N))
a_h = rand(A_type, (M, K)) / sqrt(A_type(K))
b_h = rand(B_type, (K, N)) / sqrt(B_type(K))
c_h = rand(CD_type, (M, N))

# Transpose input if necessary
a_h = transpose_a ? transpose(a_h) : a_h
Expand All @@ -28,12 +28,12 @@ using LinearAlgebra

conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
operator = Operator.WMMAOp{16, 16, 16},
global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},
global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},
operator = Operator.WMMAOp{16, 16, 16, CD_type},
global_a_layout = transpose_a ? Layout.AlignedRowMajor{A_type} : Layout.AlignedColMajor{A_type},
global_b_layout = transpose_b ? Layout.AlignedRowMajor{B_type} : Layout.AlignedColMajor{B_type},

global_c_layout = Layout.AlignedColMajor{Float32},
global_d_layout = Layout.AlignedColMajor{Float32},
global_c_layout = Layout.AlignedColMajor{CD_type},
global_d_layout = Layout.AlignedColMajor{CD_type},

is_a_col_major = !transpose_a,
is_b_col_major = !transpose_b,
Expand All @@ -49,7 +49,7 @@ using LinearAlgebra
new_a_h = transpose_a ? transpose(a_h) : a_h
new_b_h = transpose_b ? transpose(b_h) : b_h

@test all(isapprox.(alpha * Float32.(new_a_h) * Float32.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(Float16))))
@test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(A_type))))
end
end

Expand Down Expand Up @@ -80,7 +80,7 @@ using LinearAlgebra

conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
operator = Operator.WMMAOp{16, 16, 16},
operator = Operator.WMMAOp{16, 16, 16, Float32},
global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},
global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},

Expand Down Expand Up @@ -125,7 +125,7 @@ using LinearAlgebra

conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
operator = Operator.WMMAOp{16, 16, 16},
operator = Operator.WMMAOp{16, 16, 16, Float32},
global_a_layout = Layout.Diagonal{Float16},
global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},

Expand Down

0 comments on commit 56e815d

Please sign in to comment.