Skip to content

Commit

Permalink
Simplify tests. (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt authored Jul 2, 2023
1 parent fa335a1 commit 781f1de
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 36 deletions.
12 changes: 6 additions & 6 deletions test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@ using LinearAlgebra
CUDA.CUBLAS.cublasSetMathMode(CUBLAS.handle(), CUBLAS.CUBLAS_TENSOR_OP_MATH)

@testset "BLAS API" begin
@testset "WMMA GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
@testset "WMMA GEMM $(AB_type)*$(AB_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
transpose_b = [false, true],
(A_type, B_type, CD_type, min_dimension) in [(Float16, Float16, Float16, 256), (Float16, Float16, Float32, 128)]
(AB_type, CD_type, min_dimension) in [(Float16, Float16, 256), (Float16, Float32, 128)]

@testcase "(M = $M, N = $N, K = $K)" for M in min_dimension .* [1, 2],
N in min_dimension .* [1, 2],
K in min_dimension .* [1, 2]

alpha = rand(A_type)
alpha = rand(AB_type)
beta = rand(CD_type)

a_h = rand(A_type, (M, K)) / sqrt(A_type(K))
b_h = rand(B_type, (K, N)) / sqrt(B_type(K))
a_h = rand(AB_type, (M, K)) / sqrt(AB_type(K))
b_h = rand(AB_type, (K, N)) / sqrt(AB_type(K))
c_h = rand(CD_type, (M, N))

# Transpose input if necessary
Expand All @@ -33,7 +33,7 @@ CUDA.CUBLAS.cublasSetMathMode(CUBLAS.handle(), CUBLAS.CUBLAS_TENSOR_OP_MATH)
c_cublas = CuArray(c_h)
CUDA.CUBLAS.gemmEx!(!transpose_a ? 'N' : 'T', !transpose_b ? 'N' : 'T', alpha, a, b, beta, c_cublas)

@test all(isapprox.(Array(c_gemmkernels), Array(c_cublas); rtol=sqrt(eps(A_type))));
@test Array(c_gemmkernels) Array(c_cublas) rtol=sqrt(eps(AB_type))
end
end
end
63 changes: 34 additions & 29 deletions test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using LinearAlgebra
################################################################################

@testset "Matmul API" begin
@testset "FPU GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" for
@testset "FPU GEMM $(A_type)*$(B_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" for
(A_type, B_type, CD_type, min_dimension) in [
(Float16, Float16, Float32, 128), (Float32, Float32, Float32, 128), (Float32, Float32, Float64, 128), (Float64, Float64, Float64, 128),
(Int16, Int16, Int16, 128), (Int32, Int32, Int32, 128), (Int64, Int64, Int64, 128),
Expand Down Expand Up @@ -63,10 +63,11 @@ using LinearAlgebra
new_a_h = transpose_a ? transpose(a_h) : a_h
new_b_h = transpose_b ? transpose(b_h) : b_h

mul!(c_h, new_a_h, new_b_h, alpha, beta)
if A_type <: Integer
@test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d)))
@test c_h Array(d)
else
@test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(A_type))))
@test c_h Array(d) rtol=sqrt(eps(A_type))
end
end
end
Expand Down Expand Up @@ -120,11 +121,12 @@ using LinearAlgebra
new_a_h = transpose_a ? transpose(a_h) : a_h
new_b_h = transpose_b ? transpose(b_h) : b_h

@test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(A_type))))
mul!(c_h, new_a_h, new_b_h, alpha, beta)
@test c_h Array(d) rtol=sqrt(eps(A_type))
end
end

@testset "TROPICAL GEMM $(A_type)*$(B_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" for
@testset "TROPICAL GEMM $(A_type)*$(B_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' )) OP ($(OP_M), $(OP_N), $(OP_K))" for
(A_type, B_type, CD_type, min_dimension) in [(Float32, Float32, Float32, 128)],
transpose_a = [false, true],
transpose_b = [false, true],
Expand Down Expand Up @@ -172,12 +174,12 @@ using LinearAlgebra

GemmKernels.matmul(a, b, c, d, conf; kernel = Kernel.matmul_pipelined)

@test all(isapprox.(d_h, Array(d); rtol = sqrt(eps(A_type))))
@test d_h Array(d) rtol=sqrt(eps(A_type))
end
end


@testset "WMMA GEMM $(AB_type)*$(AB_type)+$(CD_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
@testset "WMMA GEMM $(AB_type)*$(AB_type)=$(CD_type) ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
transpose_b = [false, true],
(AB_type, CD_type, min_dimension) in [(Float16, Float16, 256), (Float16, Float32, 128)]
@testcase "(M = $M, N = $N, K = $K)" for (M, N, K) in vcat(min_dimension.*[[1,1,1], [2,2,1], [1,1,2], [2,2,2]], [[2048, 2048, 2048]])
Expand Down Expand Up @@ -220,7 +222,8 @@ using LinearAlgebra
new_a_h = transpose_a ? transpose(a_h) : a_h
new_b_h = transpose_b ? transpose(b_h) : b_h

@test all(isapprox.(alpha * CD_type.(new_a_h) * CD_type.(new_b_h) + beta * c_h, Array(d); rtol = sqrt(eps(AB_type))))
mul!(c_h, new_a_h, new_b_h, alpha, beta)
@test c_h Array(d) rtol=sqrt(eps(AB_type))
end
end

Expand Down Expand Up @@ -271,7 +274,8 @@ using LinearAlgebra
new_a_h = transpose_a ? transpose(a_h) : a_h
new_b_h = transpose_b ? transpose(b_h) : b_h

@test all(isapprox.(Float32.(new_a_h) * Float32.(new_b_h) + c_h .+ Array(bias), Array(d); rtol = sqrt(eps(Float16))))
mul!(c_h, new_a_h, new_b_h, true, true)
@test c_h .+ Array(bias) Array(d) rtol=sqrt(eps(Float16))
end
end

Expand All @@ -281,7 +285,7 @@ using LinearAlgebra

transpose_a = false

a_h = rand(Float16, M);
a_h = rand(Float16, M)
b_h = rand(Float16, (K, N)) / sqrt(Float16(K))
c_h = rand(Float32, (M, N))

Expand Down Expand Up @@ -315,26 +319,27 @@ using LinearAlgebra
new_a_h = transpose_a ? transpose(a_h) : a_h
new_b_h = transpose_b ? transpose(b_h) : b_h

@test all(isapprox.(Float32.(Diagonal(new_a_h)) * Float32.(new_b_h) + c_h, Array(d); rtol = sqrt(eps(Float16))))
mul!(c_h, Diagonal(new_a_h), new_b_h, true, true)
@test c_h Array(d) rtol=sqrt(eps(Float16))
end
end

@testset "WMMA Complex GEMM ($( !transpose_a ? 'N' : 'T' )$( !transpose_b ? 'N' : 'T' ))" for transpose_a = [false, true],
transpose_b = [false, true]

@testcase "(M = $M, N = $N, K = $K)" for (M, N, K) = [(128, 128, 128), (256, 256, 256), (2048, 2048, 2048)]
a_h = rand(Complex{Float16}, (M, K)) / sqrt(Float16(K));
b_h = rand(Complex{Float16}, (K, N)) / sqrt(Float16(K));
c_h = rand(Complex{Float32}, (M, N));
a_h = rand(Complex{Float16}, (M, K)) / sqrt(Float16(K))
b_h = rand(Complex{Float16}, (K, N)) / sqrt(Float16(K))
c_h = rand(Complex{Float32}, (M, N))

# Transpose input if necessary
a_h = transpose_a ? transpose(a_h) : a_h
b_h = transpose_b ? transpose(b_h) : b_h

a = CuArray(a_h);
b = CuArray(b_h);
c = CuArray(c_h);
d = similar(c);
a = CuArray(a_h)
b = CuArray(b_h)
c = CuArray(c_h)
d = similar(c)

conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
Expand Down Expand Up @@ -378,22 +383,21 @@ using LinearAlgebra
new_a_h = transpose_a ? transpose(new_a_h) : new_a_h
new_b_h = transpose_b ? transpose(new_b_h) : new_b_h

# TODO: Figure out why changing this to a * b + c = d instead of a * b = d - c
# makes tests fail for CC (see #19).
@test all(isapprox.(Complex{Float32}.(new_a_h) * Complex{Float32}.(new_b_h), Array(d) - c_h; rtol=sqrt(eps(Float16))));
mul!(c_h, new_a_h, new_b_h, true, true)
@test c_h Array(d) rtol=sqrt(eps(Float16))
end
end

@testset "WMMA Dual GEMM" begin
@testcase "(M = $M, N = $N, K = $K)" for (M, N, K) in [(128, 128, 128), (256, 256, 256), (2048, 2048, 2048)]
a_h = rand(Complex{Float16}, (M, K)) / sqrt(Float16(K));
b_h = rand(Complex{Float16}, (K, N)) / sqrt(Float16(K));
c_h = rand(Complex{Float32}, (M, N));
a_h = rand(Complex{Float16}, (M, K)) / sqrt(Float16(K))
b_h = rand(Complex{Float16}, (K, N)) / sqrt(Float16(K))
c_h = rand(Complex{Float32}, (M, N))

a = CuArray(a_h);
b = CuArray(b_h);
c = CuArray(c_h);
d = similar(c);
a = CuArray(a_h)
b = CuArray(b_h)
c = CuArray(c_h)
d = similar(c)

conf = GemmKernels.get_config(
gemm_shape = (M = M, N = N, K = K),
Expand Down Expand Up @@ -432,7 +436,8 @@ using LinearAlgebra
c_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, c_h)
d_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, Array(d))

@test all(isapprox.(a_dual * b_dual + c_dual, d_dual; rtol=sqrt(eps(Float16))));
mul!(c_dual, a_dual, b_dual, true, true)
@test c_dual d_dual rtol=sqrt(eps(Float16))
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@ withenv("JULIA_NUM_THREADS" => 1, "OPENBLAS_NUM_THREADS" => 1) do
end

@everywhere using XUnit
runtests("tests.jl")
runtests("tests.jl", ARGS...)

0 comments on commit 781f1de

Please sign in to comment.