Skip to content

Commit

Permalink
[Test] fix gemm driver dataType initialization (#2558)
Browse files Browse the repository at this point in the history
fix gemm dataType initialization and make gemm driver more dataType friendly
  • Loading branch information
CAHEK7 authored Nov 30, 2023
1 parent 79b4470 commit 68966ea
Showing 1 changed file with 16 additions and 9 deletions.
25 changes: 16 additions & 9 deletions driver/gemm_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,12 @@ void callCpuGemmStridedBatched(bool isColMajor,
: a_offset + strideA * bi + lda * mi + ki;
int bindex = transB ? b_offset + strideB * bi + ldb * ni + ki
: b_offset + strideB * bi + ldb * ki + ni;
y += a_ptr[aindex] * b_ptr[bindex];
y += static_cast<double>(a_ptr[aindex]) * static_cast<double>(b_ptr[bindex]);
}
int cindex = c_offset + strideC * bi + ldc * mi + ni;
c_ptr[cindex] = alpha * y + beta * c_ptr[cindex];
int cindex = c_offset + strideC * bi + ldc * mi + ni;
c_ptr[cindex] =
static_cast<T>(static_cast<double>(alpha) * y +
static_cast<double>(beta) * static_cast<double>(c_ptr[cindex]));
}
}
}
Expand Down Expand Up @@ -180,6 +182,10 @@ int GemmDriver<T>::ParseCmdLineArgs(int argc, char* argv[])
template <typename T>
int GemmDriver<T>::GetandSetData()
{
gemm_desc.dataType = data_type;
gemm_desc.a_cast_type = data_type;
gemm_desc.b_cast_type = data_type;

gemm_desc.isColMajor = inflags.GetValueInt("isColMajor");
gemm_desc.m = inflags.GetValueInt("a_h");
gemm_desc.k = inflags.GetValueInt("a_w");
Expand Down Expand Up @@ -230,27 +236,28 @@ int GemmDriver<T>::AllocateBuffersAndCopy()
a = std::vector<T>(a_sz);
b = std::vector<T>(b_sz);
#if GEMM_DRIVER_DEBUG
c = std::vector<T>(c_sz, 1.);
c = std::vector<T>(c_sz, static_cast<T>(1));
#else
c = std::vector<T>(c_sz, 0.);

c = std::vector<T>(c_sz, static_cast<T>(0));
#endif
chost = c;

for(int i = 0; i < a_sz; i++)
{
#if GEMM_DRIVER_DEBUG
a[i] = static_cast<double>(i);
a[i] = static_cast<T>(i);
#else
a[i] = prng::gen_canonical<double>();
a[i] = prng::gen_canonical<T>();
#endif
}

for(int i = 0; i < b_sz; i++)
{
#if GEMM_DRIVER_DEBUG
b[i] = static_cast<double>(i);
b[i] = static_cast<T>(i);
#else
b[i] = prng::gen_A_to_B(-0.5, 0.5) * 0.001;
b[i] = prng::gen_A_to_B(static_cast<T>(-0.5), static_cast<T>(0.5));
#endif
}
#if MIOPEN_BACKEND_OPENCL
Expand Down

0 comments on commit 68966ea

Please sign in to comment.