From 4fc4fe4c06bbd0e6b141225080234200b0729984 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Tue, 16 Nov 2021 20:21:01 +0100 Subject: [PATCH] Generalise WMMA Operator (#93) --- benchmarks/diagonal/benchmark.jl | 2 +- benchmarks/operator-fusion/benchmark.jl | 12 ++++----- src/blas.jl | 2 +- src/operator.jl | 34 ++++++++++++------------- test/blas.jl | 23 +++++++++-------- test/matmul.jl | 34 ++++++++++++------------- 6 files changed, 54 insertions(+), 53 deletions(-) diff --git a/benchmarks/diagonal/benchmark.jl b/benchmarks/diagonal/benchmark.jl index 66c73552..16464aaf 100644 --- a/benchmarks/diagonal/benchmark.jl +++ b/benchmarks/diagonal/benchmark.jl @@ -55,7 +55,7 @@ end function bench_gemmkernels(a, b, c, M, N, K, transpose_a, transpose_b, num_iterations = 100) conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMAOp{16, 16, 16}, + operator = Operator.WMMAOp{16, 16, 16, Float32}, global_a_layout = Layout.Diagonal{Float16}, global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, diff --git a/benchmarks/operator-fusion/benchmark.jl b/benchmarks/operator-fusion/benchmark.jl index bc433998..b89f8297 100644 --- a/benchmarks/operator-fusion/benchmark.jl +++ b/benchmarks/operator-fusion/benchmark.jl @@ -114,7 +114,7 @@ function bench_gemmkernels(a, b, c, M, N, K, transpose_a, transpose_b) CUDA.@sync begin conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMAOp{16, 16, 16}, + operator = Operator.WMMAOp{16, 16, 16, Float32}, global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, @@ -136,7 +136,7 @@ function bench_gemmkernels_relu(a, b, c, M, N, K, transpose_a, transpose_b) CUDA.@sync begin conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMAOp{16, 16, 16}, + operator = Operator.WMMAOp{16, 16, 16, Float32}, global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, @@ -159,7 +159,7 @@ function bench_gemmkernels_bias(a, b, c, bias, M, N, K, transpose_a, transpose_b CUDA.@sync begin conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMAOp{16, 16, 16}, + operator = Operator.WMMAOp{16, 16, 16, Float32}, global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, @@ -182,7 +182,7 @@ function bench_gemmkernels_biasrelu(a, b, c, bias, M, N, K, transpose_a, transpo CUDA.@sync begin conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMAOp{16, 16, 16}, + operator = Operator.WMMAOp{16, 16, 16, Float32}, global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, @@ -206,7 +206,7 @@ function bench_gemmkernels_biasrelutwice(a, b, c, bias, M, N, K, transpose_a, tr CUDA.@sync begin conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMAOp{16, 16, 16}, + operator = Operator.WMMAOp{16, 16, 16, Float32}, global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, @@ -231,7 +231,7 @@ function bench_gemmkernels_biasrelutwice_ab_elop(a, b, c, bias, M, N, K, transpo CUDA.@sync begin conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMAOp{16, 16, 16}, + operator = Operator.WMMAOp{16, 16, 16, Float32}, global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, diff --git a/src/blas.jl b/src/blas.jl index 77799239..55a5dca1 100644 --- a/src/blas.jl +++ b/src/blas.jl @@ -45,7 +45,7 @@ function gemmEx!(transA::Char, transB::Char, alpha::Number, A, B, beta::Number, conf = GemmKernels.get_config( gemm_shape = (M = m, N = n, K = k), - operator = Operator.WMMAOp{16, 16, 16}, + operator = Operator.WMMAOp{16, 16, 16, eltype(C)}, global_a_layout = a_layout, global_b_layout = b_layout, diff --git a/src/operator.jl b/src/operator.jl index 317b14a6..b78c5787 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -17,9 +17,9 @@ end # WMMA # ---- -struct WMMAOp{M, N, K} end +struct WMMAOp{M, N, K, T} end -@inline shape(::Type{WMMAOp{M, N, K}}) where {M, N, K} = (M = M, N = N, K = K) +@inline shape(::Type{WMMAOp{M, N, K, T}}) where {M, N, K, T} = (M = M, N = N, K = K) # convert_index_func: function used to transpose the index in case of a row-major layout for (layout_type, wmma_layout_type, convert_index_func) in [ @@ -27,12 +27,12 @@ for (layout_type, wmma_layout_type, convert_index_func) in [ (Layout.AlignedRowMajor, WMMA.RowMajor, x -> reverse(Tuple(x))) ] @eval begin - @inline fragtype_a(::Type{WMMAOp{16, 16, 16}}, ::Type{$layout_type{Float16}}) = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixA} - @inline fragtype_b(::Type{WMMAOp{16, 16, 16}}, ::Type{$layout_type{Float16}}) = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixB} - @inline fragtype_accum(::Type{WMMAOp{16, 16, 16}}, ::Type{$layout_type{Float32}}) = WMMA.Fragment{16, 16, 16, 8, Float32, WMMA.Unspecified, WMMA.Accumulator} + @inline fragtype_a(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{Float16}}) where {T} = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixA} + @inline fragtype_b(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{Float16}}) where {T} = WMMA.Fragment{16, 16, 16, 16, Float16, $wmma_layout_type, WMMA.MatrixB} + @inline fragtype_accum(::Type{WMMAOp{16, 16, 16, T}}, ::Type{$layout_type{T}}) where {T} = WMMA.Fragment{16, 16, 16, 8, T, WMMA.Unspecified, WMMA.Accumulator} - @inline function load_a(::Type{WMMAOp{M, N, K}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{M, N, K, Float32} + @inline function load_a(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, T} + conf = WMMA.Config{M, N, K, T} linear_base = linearise($convert_index_func(tile.base), size(workspace)) linear_offset = linearise($convert_index_func(tile.offset), size(workspace)) @@ -41,8 +41,8 @@ for (layout_type, wmma_layout_type, convert_index_func) in [ return WMMA.load_a(ptr, size(workspace, 1), $wmma_layout_type, conf) end - @inline function load_b(::Type{WMMAOp{M, N, K}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{M, N, K, Float32} + @inline function load_b(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{Float16}}, workspace, tile::Tile) where {M, N, K, T} + conf = WMMA.Config{M, N, K, T} linear_base = linearise($convert_index_func(tile.base), size(workspace)) linear_offset = linearise($convert_index_func(tile.offset), size(workspace)) @@ -51,30 +51,30 @@ for (layout_type, wmma_layout_type, convert_index_func) in [ return WMMA.load_b(ptr, size(workspace, 1), $wmma_layout_type, conf) end - @inline function load_c(::Type{WMMAOp{M, N, K}}, ::Type{$layout_type{Float32}}, workspace, tile::Tile) where {M, N, K} - conf = WMMA.Config{M, N, K, Float32} + @inline function load_c(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{T}}, workspace, tile::Tile) where {M, N, K, T} + conf = WMMA.Config{M, N, K, T} linear_base = linearise($convert_index_func(tile.base), size(workspace)) linear_offset = linearise($convert_index_func(tile.offset), size(workspace)) - ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float32) + ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(T) return WMMA.load_c(ptr, size(workspace, 1), $wmma_layout_type, conf) end - @inline function store_d(::Type{WMMAOp{M, N, K}}, ::Type{$layout_type{Float32}}, workspace, frag, tile::Tile) where {M, N, K} - conf = WMMA.Config{M, N, K, Float32} + @inline function store_d(::Type{WMMAOp{M, N, K, T}}, ::Type{$layout_type{T}}, workspace, frag, tile::Tile) where {M, N, K, T} + conf = WMMA.Config{M, N, K, T} linear_base = linearise($convert_index_func(tile.base), size(workspace)) linear_offset = linearise($convert_index_func(tile.offset), size(workspace)) - ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(Float32) + ptr = pointer(workspace, linear_base) + (linear_offset - 1) * sizeof(T) WMMA.store_d(ptr, frag, size(workspace, 1), $wmma_layout_type, conf) end end end -function mma(::Type{WMMAOp{M, N, K}}, a_frag, b_frag, c_frag) where {M, N, K} - conf = WMMA.Config{M, N, K, Float32} +function mma(::Type{WMMAOp{M, N, K, T}}, a_frag, b_frag, c_frag) where {M, N, K, T} + conf = WMMA.Config{M, N, K, T} return WMMA.mma(a_frag, b_frag, c_frag, conf) end diff --git a/test/blas.jl b/test/blas.jl index 16a3ef09..006202aa 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -5,19 +5,20 @@ using LinearAlgebra CUDA.CUBLAS.cublasSetMathMode(CUBLAS.handle(), CUBLAS.CUBLAS_TENSOR_OP_MATH) @test_if "blas" @testset "BLAS API" begin - @testset "WMMA GEMM ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true], - transpose_b = [false, true] + @testset "WMMA GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true], + transpose_b = [false, true], + (A_type, B_type, CD_type, min_dimension) in [(Float16, Float16, Float16, 256), (Float16, Float16, Float32, 128)] - @testset "(M = $M, N = $N, K = $K)" for M in [128, 256], - N in [128, 256], - K in [128, 256] + @testset "(M = $M, N = $N, K = $K)" for M in min_dimension .* [1, 2], + N in min_dimension .* [1, 2], + K in min_dimension .* [1, 2] - alpha = rand(Float32) - beta = rand(Float32) + alpha = rand(A_type) + beta = rand(CD_type) - a_h = rand(Float16, (M, K)) / sqrt(Float16(K)) - b_h = rand(Float16, (K, N)) / sqrt(Float16(K)) - c_h = rand(Float32, (M, N)) + a_h = rand(A_type, (M, K)) / sqrt(A_type(K)) + b_h = rand(B_type, (K, N)) / sqrt(B_type(K)) + c_h = rand(CD_type, (M, N)) # Transpose input if necessary a_h = transpose_a ? transpose(a_h) : a_h @@ -32,7 +33,7 @@ CUDA.CUBLAS.cublasSetMathMode(CUBLAS.handle(), CUBLAS.CUBLAS_TENSOR_OP_MATH) c_cublas = CuArray(c_h) CUDA.CUBLAS.gemmEx!(!transpose_a ? 'N' : 'T', !transpose_b ? 'N' : 'T', alpha, a, b, beta, c_cublas) - @test all(isapprox.(Array(c_gemmkernels), Array(c_cublas); rtol=sqrt(eps(Float16)))); + @test all(isapprox.(Array(c_gemmkernels), Array(c_cublas); rtol=sqrt(eps(A_type)))); end end diff --git a/test/matmul.jl b/test/matmul.jl index c359db18..ba7f3466 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -6,16 +6,16 @@ using LinearAlgebra ################################################################################ @testset "Matmul API" begin - @test_if "wmma" @testset "WMMA GEMM ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true], - transpose_b = [false, true] - - @testset "(M = $M, N = $N, K = $K)" for (M, N, K) in [(128, 128, 128), (256, 256, 128), (128, 128, 256), (256, 256, 256), (2048, 2048, 2048)] - alpha = 2 - beta = 3 + @test_if "wmma" @testset "WMMA GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true], + transpose_b = [false, true], + (A_type, B_type, CD_type, min_dimension) in [(Float16, Float16, Float16, 256), (Float16, Float16, Float32, 128)] + @testset "(M = $M, N = $N, K = $K)" for (M, N, K) in vcat(min_dimension.*[[1,1,1], [2,2,1], [1,1,2], [2,2,2]], [[2048, 2048, 2048]]) + alpha = convert(A_type, 2) + beta = convert(CD_type, 3) - a_h = rand(Float16, (M, K)) / sqrt(Float16(K)) - b_h = rand(Float16, (K, N)) / sqrt(Float16(K)) - c_h = rand(Float32, (M, N)) + a_h = rand(A_type, (M, K)) / sqrt(A_type(K)) + b_h = rand(B_type, (K, N)) / sqrt(B_type(K)) + c_h = rand(CD_type, (M, N)) # Transpose input if necessary a_h = transpose_a ? transpose(a_h) : a_h @@ -28,12 +28,12 @@ using LinearAlgebra conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMAOp{16, 16, 16}, - global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, - global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, + operator = Operator.WMMAOp{16, 16, 16, CD_type}, + global_a_layout = transpose_a ? Layout.AlignedRowMajor{A_type} : Layout.AlignedColMajor{A_type}, + global_b_layout = transpose_b ? Layout.AlignedRowMajor{B_type} : Layout.AlignedColMajor{B_type}, - global_c_layout = Layout.AlignedColMajor{Float32}, - global_d_layout = Layout.AlignedColMajor{Float32}, + global_c_layout = Layout.AlignedColMajor{CD_type}, + global_d_layout = Layout.AlignedColMajor{CD_type}, is_a_col_major = !transpose_a, is_b_col_major = !transpose_b, @@ -49,7 +49,7 @@ using LinearAlgebra new_a_h = transpose_a ? transpose(a_h) : a_h new_b_h = transpose_b ? transpose(b_h) : b_h - @test all(isapprox.(alpha * Float32.(new_a_h) * Float32.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(Float16)))) + @test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(A_type)))) end end @@ -80,7 +80,7 @@ using LinearAlgebra conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMAOp{16, 16, 16}, + operator = Operator.WMMAOp{16, 16, 16, Float32}, global_a_layout = transpose_a ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16}, @@ -125,7 +125,7 @@ using LinearAlgebra conf = GemmKernels.get_config( gemm_shape = (M = M, N = N, K = K), - operator = Operator.WMMAOp{16, 16, 16}, + operator = Operator.WMMAOp{16, 16, 16, Float32}, global_a_layout = Layout.Diagonal{Float16}, global_b_layout = transpose_b ? Layout.AlignedRowMajor{Float16} : Layout.AlignedColMajor{Float16},