diff --git a/driver/gemm_driver.hpp b/driver/gemm_driver.hpp index d022f0f1c2..9b5e358074 100644 --- a/driver/gemm_driver.hpp +++ b/driver/gemm_driver.hpp @@ -93,10 +93,12 @@ void callCpuGemmStridedBatched(bool isColMajor, : a_offset + strideA * bi + lda * mi + ki; int bindex = transB ? b_offset + strideB * bi + ldb * ni + ki : b_offset + strideB * bi + ldb * ki + ni; - y += a_ptr[aindex] * b_ptr[bindex]; + y += static_cast(a_ptr[aindex]) * static_cast(b_ptr[bindex]); } - int cindex = c_offset + strideC * bi + ldc * mi + ni; - c_ptr[cindex] = alpha * y + beta * c_ptr[cindex]; + int cindex = c_offset + strideC * bi + ldc * mi + ni; + c_ptr[cindex] = + static_cast(static_cast(alpha) * y + + static_cast(beta) * static_cast(c_ptr[cindex])); } } } @@ -180,6 +182,10 @@ int GemmDriver::ParseCmdLineArgs(int argc, char* argv[]) template int GemmDriver::GetandSetData() { + gemm_desc.dataType = data_type; + gemm_desc.a_cast_type = data_type; + gemm_desc.b_cast_type = data_type; + gemm_desc.isColMajor = inflags.GetValueInt("isColMajor"); gemm_desc.m = inflags.GetValueInt("a_h"); gemm_desc.k = inflags.GetValueInt("a_w"); @@ -230,27 +236,28 @@ int GemmDriver::AllocateBuffersAndCopy() a = std::vector(a_sz); b = std::vector(b_sz); #if GEMM_DRIVER_DEBUG - c = std::vector(c_sz, 1.); + c = std::vector(c_sz, static_cast(1)); #else - c = std::vector(c_sz, 0.); + + c = std::vector(c_sz, static_cast(0)); #endif chost = c; for(int i = 0; i < a_sz; i++) { #if GEMM_DRIVER_DEBUG - a[i] = static_cast(i); + a[i] = static_cast(i); #else - a[i] = prng::gen_canonical(); + a[i] = prng::gen_canonical(); #endif } for(int i = 0; i < b_sz; i++) { #if GEMM_DRIVER_DEBUG - b[i] = static_cast(i); + b[i] = static_cast(i); #else - b[i] = prng::gen_A_to_B(-0.5, 0.5) * 0.001; + b[i] = prng::gen_A_to_B(static_cast(-0.5), static_cast(0.5)); #endif } #if MIOPEN_BACKEND_OPENCL