Skip to content

Commit

Permalink
Cleaned up test case creation, skip tests for gfx908 and gfx90a, addr…
Browse files Browse the repository at this point in the history
…essed other comments
  • Loading branch information
xinlipn committed Jun 2, 2023
1 parent ae29c5d commit beb7fcb
Showing 1 changed file with 57 additions and 88 deletions.
145 changes: 57 additions & 88 deletions test/gtest/conv_embed_db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,12 @@

#include <miopen/miopen.h>
#include <gtest/gtest.h>
#include <miopen/miopen.h>
#include "conv_2d.hpp"
#include "get_handle.hpp"

using TestCase = std::tuple<std::vector<std::string>, std::string>;

enum class Precision
{
Float,
Half,
Int8,
BFloat16
};
#define MAKE_TEST_CASE std::make_tuple<std::vector<std::string>, std::string>

std::string GetFloatArg()
{
Expand All @@ -25,20 +19,6 @@ std::string GetFloatArg()
return tmp;
};

std::vector<std::string> GetEnvVars(const std::vector<std::string>& check_vars)
{
std::vector<std::string> vars = {};
for(const auto& cvar : check_vars)
{
static const auto tmp = std::getenv(cvar.c_str());
if(tmp != nullptr)
{
vars.push_back(cvar + "=0");
}
}
return vars;
};

void GetArgs(const TestCase& param, std::vector<std::string>& tokens)
{
auto env_vars = std::get<0>(param);
Expand Down Expand Up @@ -69,16 +49,23 @@ class Conv2dFloat : public testing::TestWithParam<std::vector<TestCase>>
{
};

void Run2dDriver(Precision prec)
void Run2dDriver(miopenDataType_t prec)
{

std::vector<TestCase> params;
switch(prec)
{
case Precision::Float: params = Conv2dFloat::GetParam(); break;
case Precision::Half: params = Conv2dHalf::GetParam(); break;
case Precision::Int8: params = Conv2dInt8::GetParam(); break;
case Precision::BFloat16: params = Conv2dBFloat16::GetParam(); break;
case miopenFloat: params = Conv2dFloat::GetParam(); break;
case miopenHalf: params = Conv2dHalf::GetParam(); break;
case miopenInt8: params = Conv2dInt8::GetParam(); break;
case miopenBFloat16: params = Conv2dBFloat16::GetParam(); break;
case miopenInt8x4:
case miopenInt32:
case miopenDouble:
MIOPEN_THROW(miopenStatusBadParm,
"miopenInt8x4, miopenInt32, miopenDouble data type not supported by "
"conv_embed_db test");

default: params = Conv2dFloat::GetParam();
}

Expand All @@ -88,9 +75,8 @@ void Run2dDriver(Precision prec)
GetArgs(test_value, tokens);
std::vector<const char*> ptrs;

std::transform(tokens.begin(), tokens.end(), std::back_inserter(ptrs), [](const auto& str) {
return str.data();
});
for(std::string const& str : tokens)
ptrs.push_back(str.data());

testing::internal::CaptureStderr();
test_drive<conv2d_driver>(ptrs.size(), ptrs.data());
Expand All @@ -104,13 +90,14 @@ TEST_P(Conv2dFloat, FloatTest)
#if MIOPEN_EMBED_DB

const auto& handle = get_handle();
if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--float")
if(miopen::StartsWith(handle.GetDeviceName(), "gfx908") ||
miopen::StartsWith(handle.GetDeviceName(), "gfx90a") || GetFloatArg() != "--float")
{
GTEST_SKIP();
}
else
{
Run2dDriver(Precision::Float);
Run2dDriver(miopenFloat);
}

#else
Expand All @@ -123,13 +110,14 @@ TEST_P(Conv2dHalf, HalfTest)
#if MIOPEN_EMBED_DB

const auto& handle = get_handle();
if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--half")
if(miopen::StartsWith(handle.GetDeviceName(), "gfx908") ||
miopen::StartsWith(handle.GetDeviceName(), "gfx90a") || GetFloatArg() != "--half")
{
GTEST_SKIP();
}
else
{
Run2dDriver(Precision::Half);
Run2dDriver(miopenHalf);
}

#else
Expand All @@ -142,13 +130,14 @@ TEST_P(Conv2dInt8, Int8Test)
#if MIOPEN_EMBED_DB

const auto& handle = get_handle();
if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--int8")
if(miopen::StartsWith(handle.GetDeviceName(), "gfx908") ||
miopen::StartsWith(handle.GetDeviceName(), "gfx90a") || GetFloatArg() != "--int8")
{
GTEST_SKIP();
}
else
{
Run2dDriver(Precision::Int8);
Run2dDriver(miopenInt8);
}

#else
Expand All @@ -161,13 +150,14 @@ TEST_P(Conv2dBFloat16, BFloat16Test)
#if MIOPEN_EMBED_DB

const auto& handle = get_handle();
if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--bfloat16")
if(miopen::StartsWith(handle.GetDeviceName(), "gfx908") ||
miopen::StartsWith(handle.GetDeviceName(), "gfx90a") || GetFloatArg() != "--bfloat16")
{
GTEST_SKIP();
}
else
{
Run2dDriver(Precision::BFloat16);
Run2dDriver(miopenBFloat16);
}

#else
Expand All @@ -178,59 +168,38 @@ TEST_P(Conv2dBFloat16, BFloat16Test)
std::vector<TestCase> GetTestCases(const std::string& precision)
{

std::vector<std::string> winograd = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2"};
std::vector<std::string> igemm_wrw = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_4R1"};
std::vector<std::string> igemm_fwd = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_4R1"};
std::vector<std::string> igemm_fwd_wrw = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_4R1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_4R1"};
std::vector<std::string> winograd = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2=0"};
std::vector<std::string> igemm_wrw = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2=0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_4R1=0"};
std::vector<std::string> igemm_fwd = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2=0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_4R1=0"};
std::vector<std::string> igemm_fwd_wrw = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2=0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_4R1=0",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_4R1=0"};

const std::vector<TestCase> test_cases = {
// clang-format off
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision +
" --disable-validation --verbose --input 128 1024 14 14 --weights 2048 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd), precision +
" --disable-validation --verbose --input 128 1024 14 14 --weights 256 1024 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision +
" --disable-validation --verbose --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(winograd), precision +
" --disable-validation --verbose --input 128 128 28 28 --weights 128 128 3 3 --pads_strides_dilations 1 1 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision +
" --disable-validation --verbose --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 128 28 28 --weights 512 128 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 2048 7 7 --weights 512 2048 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd), precision +
" --disable-validation --verbose --input 128 256 14 14 --weights 1024 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision +
" --disable-validation --verbose --input 128 256 14 14 --weights 256 256 3 3 --pads_strides_dilations 1 1 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 256 56 56 --weights 128 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(winograd), precision +
" --disable-validation --verbose --input 128 256 56 56 --weights 512 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 256 56 56 --weights 64 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(winograd), precision +
" --disable-validation --verbose --input 128 3 230 230 --weights 64 3 7 7 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 512 28 28 --weights 1024 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 512 28 28 --weights 128 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd), precision +
" --disable-validation --verbose --input 128 512 28 28 --weights 256 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 512 7 7 --weights 2048 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision +
" --disable-validation --verbose --input 128 512 7 7 --weights 512 512 3 3 --pads_strides_dilations 1 1 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 64 56 56 --weights 256 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 64 56 56 --weights 64 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision +
" --disable-validation --verbose --input 128 64 56 56 --weights 64 64 3 3 --pads_strides_dilations 1 1 1 1 1 1")
MAKE_TEST_CASE(std::move(winograd), precision + " --disable-validation --verbose --input 128 128 28 28 --weights 128 128 3 3 --pads_strides_dilations 1 1 1 1 1 1"),
MAKE_TEST_CASE(std::move(winograd), precision + " --disable-validation --verbose --input 128 256 56 56 --weights 512 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
MAKE_TEST_CASE(std::move(winograd), precision + " --disable-validation --verbose --input 128 3 230 230 --weights 64 3 7 7 --pads_strides_dilations 0 0 2 2 1 1"),
MAKE_TEST_CASE(std::move(igemm_wrw), precision + " --disable-validation --verbose --input 128 64 56 56 --weights 64 64 3 3 --pads_strides_dilations 1 1 1 1 1 1"),
MAKE_TEST_CASE(std::move(igemm_wrw), precision + " --disable-validation --verbose --input 128 256 14 14 --weights 256 256 3 3 --pads_strides_dilations 1 1 1 1 1 1"),
MAKE_TEST_CASE(std::move(igemm_wrw), precision + " --disable-validation --verbose --input 128 512 7 7 --weights 512 512 3 3 --pads_strides_dilations 1 1 1 1 1 1"),
MAKE_TEST_CASE(std::move(igemm_wrw), precision + " --disable-validation --verbose --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"),

This comment has been minimized.

Copy link
@xinlipn

xinlipn Jun 2, 2023

Author Contributor

@atamazov , this test is the same as below, it also shows up twice in CTest. Please confirm two instances of the same test are needed

MAKE_TEST_CASE(std::move(igemm_wrw), precision + " --disable-validation --verbose --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"),

This comment has been minimized.

Copy link
@xinlipn

xinlipn Jun 2, 2023

Author Contributor

@atamazov , Please see comments above. I keep both for now

MAKE_TEST_CASE(std::move(igemm_wrw), precision + " --disable-validation --verbose --input 128 1024 14 14 --weights 2048 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
MAKE_TEST_CASE(std::move(igemm_fwd), precision + " --disable-validation --verbose --input 128 256 14 14 --weights 1024 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
MAKE_TEST_CASE(std::move(igemm_fwd), precision + " --disable-validation --verbose --input 128 512 28 28 --weights 256 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
MAKE_TEST_CASE(std::move(igemm_fwd), precision + " --disable-validation --verbose --input 128 1024 14 14 --weights 256 1024 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 64 56 56 --weights 256 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 64 56 56 --weights 64 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 128 28 28 --weights 512 128 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 256 56 56 --weights 128 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 256 56 56 --weights 64 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 512 28 28 --weights 1024 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 512 28 28 --weights 128 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 512 7 7 --weights 2048 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 2048 7 7 --weights 512 2048 1 1 --pads_strides_dilations 0 0 1 1 1 1")
// clang-format on
};

Expand Down

0 comments on commit beb7fcb

Please sign in to comment.