diff --git a/test/gtest/conv_embed_db.cpp b/test/gtest/conv_embed_db.cpp index df2243c704..e294baa7a6 100644 --- a/test/gtest/conv_embed_db.cpp +++ b/test/gtest/conv_embed_db.cpp @@ -2,18 +2,12 @@ #include #include +#include #include "conv_2d.hpp" #include "get_handle.hpp" using TestCase = std::tuple, std::string>; - -enum class Precision -{ - Float, - Half, - Int8, - BFloat16 -}; +#define MAKE_TEST_CASE std::make_tuple, std::string> std::string GetFloatArg() { @@ -25,20 +19,6 @@ std::string GetFloatArg() return tmp; }; -std::vector GetEnvVars(const std::vector& check_vars) -{ - std::vector vars = {}; - for(const auto& cvar : check_vars) - { - static const auto tmp = std::getenv(cvar.c_str()); - if(tmp != nullptr) - { - vars.push_back(cvar + "=0"); - } - } - return vars; -}; - void GetArgs(const TestCase& param, std::vector& tokens) { auto env_vars = std::get<0>(param); @@ -69,16 +49,23 @@ class Conv2dFloat : public testing::TestWithParam> { }; -void Run2dDriver(Precision prec) +void Run2dDriver(miopenDataType_t prec) { std::vector params; switch(prec) { - case Precision::Float: params = Conv2dFloat::GetParam(); break; - case Precision::Half: params = Conv2dHalf::GetParam(); break; - case Precision::Int8: params = Conv2dInt8::GetParam(); break; - case Precision::BFloat16: params = Conv2dBFloat16::GetParam(); break; + case miopenFloat: params = Conv2dFloat::GetParam(); break; + case miopenHalf: params = Conv2dHalf::GetParam(); break; + case miopenInt8: params = Conv2dInt8::GetParam(); break; + case miopenBFloat16: params = Conv2dBFloat16::GetParam(); break; + case miopenInt8x4: + case miopenInt32: + case miopenDouble: + MIOPEN_THROW(miopenStatusBadParm, + "miopenInt8x4, miopenInt32, miopenDouble data type not supported by " + "conv_embed_db test"); + default: params = Conv2dFloat::GetParam(); } @@ -88,9 +75,8 @@ void Run2dDriver(Precision prec) GetArgs(test_value, tokens); std::vector ptrs; - std::transform(tokens.begin(), tokens.end(), std::back_inserter(ptrs), [](const auto& str) { - return str.data(); - }); + for(std::string const& str : tokens) + ptrs.push_back(str.data()); testing::internal::CaptureStderr(); test_drive(ptrs.size(), ptrs.data()); @@ -104,13 +90,14 @@ TEST_P(Conv2dFloat, FloatTest) #if MIOPEN_EMBED_DB const auto& handle = get_handle(); - if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--float") + if(miopen::StartsWith(handle.GetDeviceName(), "gfx908") || + miopen::StartsWith(handle.GetDeviceName(), "gfx90a") || GetFloatArg() != "--float") { GTEST_SKIP(); } else { - Run2dDriver(Precision::Float); + Run2dDriver(miopenFloat); } #else @@ -123,13 +110,14 @@ TEST_P(Conv2dHalf, HalfTest) #if MIOPEN_EMBED_DB const auto& handle = get_handle(); - if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--half") + if(miopen::StartsWith(handle.GetDeviceName(), "gfx908") || + miopen::StartsWith(handle.GetDeviceName(), "gfx90a") || GetFloatArg() != "--half") { GTEST_SKIP(); } else { - Run2dDriver(Precision::Half); + Run2dDriver(miopenHalf); } #else @@ -142,13 +130,14 @@ TEST_P(Conv2dInt8, Int8Test) #if MIOPEN_EMBED_DB const auto& handle = get_handle(); - if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--int8") + if(miopen::StartsWith(handle.GetDeviceName(), "gfx908") || + miopen::StartsWith(handle.GetDeviceName(), "gfx90a") || GetFloatArg() != "--int8") { GTEST_SKIP(); } else { - Run2dDriver(Precision::Int8); + Run2dDriver(miopenInt8); } #else @@ -161,13 +150,14 @@ TEST_P(Conv2dBFloat16, BFloat16Test) #if MIOPEN_EMBED_DB const auto& handle = get_handle(); - if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--bfloat16") + if(miopen::StartsWith(handle.GetDeviceName(), "gfx908") || + miopen::StartsWith(handle.GetDeviceName(), "gfx90a") || GetFloatArg() != "--bfloat16") { GTEST_SKIP(); } else { - Run2dDriver(Precision::BFloat16); + Run2dDriver(miopenBFloat16); } #else @@ -178,59 +168,38 @@ TEST_P(Conv2dBFloat16, BFloat16Test) std::vector GetTestCases(const std::string& precision) { - std::vector winograd = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2"}; - std::vector igemm_wrw = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2", - "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_4R1"}; - std::vector igemm_fwd = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2", - "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_4R1"}; - std::vector igemm_fwd_wrw = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2", - "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_4R1", - "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_4R1"}; + std::vector winograd = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2=0"}; + std::vector igemm_wrw = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2=0", + "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_4R1=0"}; + std::vector igemm_fwd = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2=0", + "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_4R1=0"}; + std::vector igemm_fwd_wrw = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2=0", + "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_4R1=0", + "MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_4R1=0"}; const std::vector test_cases = { // clang-format off - std::make_tuple, std::string>(GetEnvVars(igemm_wrw), precision + - " --disable-validation --verbose --input 128 1024 14 14 --weights 2048 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_fwd), precision + - " --disable-validation --verbose --input 128 1024 14 14 --weights 256 1024 1 1 --pads_strides_dilations 0 0 1 1 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_wrw), precision + - " --disable-validation --verbose --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"), - std::make_tuple, std::string>(GetEnvVars(winograd), precision + - " --disable-validation --verbose --input 128 128 28 28 --weights 128 128 3 3 --pads_strides_dilations 1 1 1 1 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_wrw), precision + - " --disable-validation --verbose --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_fwd_wrw), precision + - " --disable-validation --verbose --input 128 128 28 28 --weights 512 128 1 1 --pads_strides_dilations 0 0 1 1 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_fwd_wrw), precision + - " --disable-validation --verbose --input 128 2048 7 7 --weights 512 2048 1 1 --pads_strides_dilations 0 0 1 1 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_fwd), precision + - " --disable-validation --verbose --input 128 256 14 14 --weights 1024 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_wrw), precision + - " --disable-validation --verbose --input 128 256 14 14 --weights 256 256 3 3 --pads_strides_dilations 1 1 1 1 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_fwd_wrw), precision + - " --disable-validation --verbose --input 128 256 56 56 --weights 128 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"), - std::make_tuple, std::string>(GetEnvVars(winograd), precision + - " --disable-validation --verbose --input 128 256 56 56 --weights 512 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_fwd_wrw), precision + - " --disable-validation --verbose --input 128 256 56 56 --weights 64 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"), - std::make_tuple, std::string>(GetEnvVars(winograd), precision + - " --disable-validation --verbose --input 128 3 230 230 --weights 64 3 7 7 --pads_strides_dilations 0 0 2 2 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_fwd_wrw), precision + - " --disable-validation --verbose --input 128 512 28 28 --weights 1024 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_fwd_wrw), precision + - " --disable-validation --verbose --input 128 512 28 28 --weights 128 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_fwd), precision + - " --disable-validation --verbose --input 128 512 28 28 --weights 256 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_fwd_wrw), precision + - " --disable-validation --verbose --input 128 512 7 7 --weights 2048 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_wrw), precision + - " --disable-validation --verbose --input 128 512 7 7 --weights 512 512 3 3 --pads_strides_dilations 1 1 1 1 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_fwd_wrw), precision + - " --disable-validation --verbose --input 128 64 56 56 --weights 256 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_fwd_wrw), precision + - " --disable-validation --verbose --input 128 64 56 56 --weights 64 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"), - std::make_tuple, std::string>(GetEnvVars(igemm_wrw), precision + - " --disable-validation --verbose --input 128 64 56 56 --weights 64 64 3 3 --pads_strides_dilations 1 1 1 1 1 1") + MAKE_TEST_CASE(std::move(winograd), precision + " --disable-validation --verbose --input 128 128 28 28 --weights 128 128 3 3 --pads_strides_dilations 1 1 1 1 1 1"), + MAKE_TEST_CASE(std::move(winograd), precision + " --disable-validation --verbose --input 128 256 56 56 --weights 512 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"), + MAKE_TEST_CASE(std::move(winograd), precision + " --disable-validation --verbose --input 128 3 230 230 --weights 64 3 7 7 --pads_strides_dilations 0 0 2 2 1 1"), + MAKE_TEST_CASE(std::move(igemm_wrw), precision + " --disable-validation --verbose --input 128 64 56 56 --weights 64 64 3 3 --pads_strides_dilations 1 1 1 1 1 1"), + MAKE_TEST_CASE(std::move(igemm_wrw), precision + " --disable-validation --verbose --input 128 256 14 14 --weights 256 256 3 3 --pads_strides_dilations 1 1 1 1 1 1"), + MAKE_TEST_CASE(std::move(igemm_wrw), precision + " --disable-validation --verbose --input 128 512 7 7 --weights 512 512 3 3 --pads_strides_dilations 1 1 1 1 1 1"), + MAKE_TEST_CASE(std::move(igemm_wrw), precision + " --disable-validation --verbose --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"), + MAKE_TEST_CASE(std::move(igemm_wrw), precision + " --disable-validation --verbose --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"), + MAKE_TEST_CASE(std::move(igemm_wrw), precision + " --disable-validation --verbose --input 128 1024 14 14 --weights 2048 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"), + MAKE_TEST_CASE(std::move(igemm_fwd), precision + " --disable-validation --verbose --input 128 256 14 14 --weights 1024 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"), + MAKE_TEST_CASE(std::move(igemm_fwd), precision + " --disable-validation --verbose --input 128 512 28 28 --weights 256 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"), + MAKE_TEST_CASE(std::move(igemm_fwd), precision + " --disable-validation --verbose --input 128 1024 14 14 --weights 256 1024 1 1 --pads_strides_dilations 0 0 1 1 1 1"), + MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 64 56 56 --weights 256 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"), + MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 64 56 56 --weights 64 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"), + MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 128 28 28 --weights 512 128 1 1 --pads_strides_dilations 0 0 1 1 1 1"), + MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 256 56 56 --weights 128 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"), + MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 256 56 56 --weights 64 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"), + MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 512 28 28 --weights 1024 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"), + MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 512 28 28 --weights 128 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"), + MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 512 7 7 --weights 2048 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"), + MAKE_TEST_CASE(std::move(igemm_fwd_wrw), precision + " --disable-validation --verbose --input 128 2048 7 7 --weights 512 2048 1 1 --pads_strides_dilations 0 0 1 1 1 1") // clang-format on };