diff --git a/configs/configs.jl b/configs/configs.jl index 1a0f2e81..eeb19b9f 100644 --- a/configs/configs.jl +++ b/configs/configs.jl @@ -91,21 +91,26 @@ end # Verify results. function verify(cf::Configuration, c_ref, d) - cf.verify(c_ref, d) + cf.verify(c_ref, d, cf.a_type) end -function verify_default(c_ref, d) - isapprox(c_ref, d) +compare(x, y, T) = error("Unimplemented compare(x, y, T) function for type $T") +compare(x, y, T::Type{<:AbstractFloat}) = isapprox(x, y; rtol=sqrt(eps(T))) +compare(x, y, T::Type{<:Integer}) = (x == y) +compare(x, y, T::Type{Complex{U}}) where {U} = compare(x, y, U) + +function verify_default(c_ref, d, T) + all(compare.(c_ref, d, T)) end -function verify_bias(c_ref, d, bias) - c_ref .+ bias ≈ d +function verify_bias(c_ref, d, bias, T) + all(compare.(c_ref .+ bias, d, T)) end -function verify_dual(c_ref, d) +function verify_dual(c_ref, d, T) c_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, c_ref) d_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, d) - isapprox(c_dual, d_dual) + all(compare.(c_dual, d_dual, T)) end function fpu_baseline(a, b, c, d, alpha, beta, transpose_a, transpose_b) @@ -282,7 +287,7 @@ macro get_wmma_bias_config() transpose_b, mul!, Epilogue.Bias(pointer(bias)), - (c_h, d) -> verify_bias(c_h, d, bias), + (c_h, d, T) -> verify_bias(c_h, d, bias, T), Kernel.matmul_pipelined, nothing) end end)