-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
test_conv_embed_db (ctest -> gtest) #2168
Changes from 13 commits
89c8fe2
dc3b8b9
6f1d5b0
ba9cd2a
9f3fe10
251bb89
61e49ec
17aedb8
1b5e8c1
f5d3cf1
c5cbb1a
8a8176d
ae29c5d
beb7fcb
d86eb8d
4760a5a
29afc6d
2b656e5
77dc83d
70d1d9c
d087a00
6abced7
da3594e
79eb207
20ad346
bae1221
0979c73
c987d64
20532ed
45e89ca
a6e6f1a
444ee2b
ee3adb8
878aa00
35f72a8
e4b5296
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/******************************************************************************* | ||
* | ||
* MIT License | ||
* | ||
* Copyright (c) 2023 Advanced Micro Devices, Inc. | ||
* | ||
* Permission is hereby granted, free of charge, to any person obtaining a copy | ||
* of this software and associated documentation files (the "Software"), to deal | ||
* in the Software without restriction, including without limitation the rights | ||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
* copies of the Software, and to permit persons to whom the Software is | ||
* furnished to do so, subject to the following conditions: | ||
* | ||
* The above copyright notice and this permission notice shall be included in all | ||
* copies or substantial portions of the Software. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
* SOFTWARE. | ||
* | ||
*******************************************************************************/ | ||
#include "conv_common.hpp" | ||
|
||
template <class T> | ||
struct conv2d_driver : conv_driver<T> | ||
{ | ||
conv2d_driver() : conv_driver<T>() | ||
{ | ||
this->add(this->input_dims, "input"); | ||
this->add(this->weight_tensor_dims, "weights"); | ||
this->add(this->batch_size, | ||
"batch_size", | ||
this->generate_data_limited(this->get_batch_sizes(), 1)); | ||
this->add(this->input_channels, | ||
"input_channels", | ||
this->generate_data_limited(this->get_input_channels(), 1, {32})); | ||
this->add(this->output_channels, | ||
"output_channels", | ||
this->generate_data_limited(this->get_output_channels(), 1, {64})); | ||
this->add(this->spatial_dim_elements, | ||
"spatial_dim_elements", | ||
this->generate_data_limited(this->get_2d_spatial_dims(), 1, {28, 28})); | ||
this->add(this->filter_dims, | ||
"filter_dims", | ||
this->generate_data_limited(this->get_2d_filter_dims(), 2, {3, 3})); | ||
this->add(this->pads_strides_dilations, | ||
"pads_strides_dilations", | ||
this->generate_data_limited(this->get_2d_pads_strides_dilations(), 2)); | ||
this->add(this->trans_output_pads, | ||
"trans_output_pads", | ||
this->generate_data(this->get_2d_trans_output_pads())); | ||
this->add(this->in_layout, "in_layout", this->generate_data({"NCHW"})); | ||
this->add(this->fil_layout, "fil_layout", this->generate_data({"NCHW"})); | ||
this->add(this->out_layout, "out_layout", this->generate_data({"NCHW"})); | ||
this->add(this->deterministic, "deterministic", this->generate_data({false})); | ||
this->add(this->tensor_vect, "tensor_vect", this->generate_data({0})); | ||
this->add(this->vector_length, "vector_length", this->generate_data({1})); | ||
// Only valid for int8 input and weights | ||
this->add(this->output_type, "output_type", this->generate_data({"int32"})); | ||
this->add(this->int8_vectorize, "int8_vectorize", this->generate_data({false})); | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is partial copy of test/conv2d.cpp. Please do not make copies of existing code, refactor instead. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [notice] Copying the code is the worst thing that we can do from the maintenance point of view. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Refactored by moving common code to conv_common.hpp There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is 2D-specific and should not be moved there. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Reverted and re-refactored conv2d_driver related changes. This may look out of the place since the header is in test/gtest and this file is in test/. All previous conv2d CTest cases may be refactored by using this header as we continue moving toward gTest, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @xinlipn Yeah, Please rename this header to test/conv2d.hpp.
Moving code from test/ to test/gtest/ is a separate task. Let's do that later, separately. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Notice] Not resolved yet. Let's either fix or discuss. |
||
}; |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,243 @@ | ||
#include <tuple> | ||
JehandadKhan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
#include <miopen/miopen.h> | ||
#include <gtest/gtest.h> | ||
#include "conv_2d.hpp" | ||
#include "get_handle.hpp" | ||
|
||
using TestCase = std::tuple<std::vector<std::string>, std::string>; | ||
|
||
enum class Precision | ||
xinlipn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
Float, | ||
Half, | ||
Int8, | ||
BFloat16 | ||
}; | ||
|
||
std::string GetFloatArg() | ||
{ | ||
static const auto tmp = std::getenv("MIOPEN_TEST_FLOAT_ARG"); | ||
xinlipn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if(tmp == nullptr) | ||
{ | ||
return ""; | ||
} | ||
return tmp; | ||
}; | ||
xinlipn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
std::vector<std::string> GetEnvVars(const std::vector<std::string>& check_vars) | ||
{ | ||
std::vector<std::string> vars = {}; | ||
for(const auto& cvar : check_vars) | ||
{ | ||
static const auto tmp = std::getenv(cvar.c_str()); | ||
if(tmp != nullptr) | ||
{ | ||
vars.push_back(cvar + "=0"); | ||
} | ||
} | ||
return vars; | ||
}; | ||
|
||
void GetArgs(const TestCase& param, std::vector<std::string>& tokens) | ||
{ | ||
auto env_vars = std::get<0>(param); | ||
for(auto& elem : env_vars) | ||
{ | ||
putenv(elem.data()); | ||
} | ||
|
||
auto cmd = std::get<1>(param); | ||
|
||
std::stringstream ss(cmd); | ||
std::istream_iterator<std::string> begin(ss); | ||
std::istream_iterator<std::string> end; | ||
while(begin != end) | ||
tokens.push_back(*begin++); | ||
} | ||
|
||
class Conv2dHalf : public testing::TestWithParam<std::vector<TestCase>> | ||
{ | ||
}; | ||
class Conv2dInt8 : public testing::TestWithParam<std::vector<TestCase>> | ||
{ | ||
}; | ||
class Conv2dBFloat16 : public testing::TestWithParam<std::vector<TestCase>> | ||
{ | ||
}; | ||
class Conv2dFloat : public testing::TestWithParam<std::vector<TestCase>> | ||
{ | ||
}; | ||
|
||
void Run2dDriver(Precision prec) | ||
{ | ||
|
||
std::vector<TestCase> params; | ||
switch(prec) | ||
{ | ||
case Precision::Float: params = Conv2dFloat::GetParam(); break; | ||
case Precision::Half: params = Conv2dHalf::GetParam(); break; | ||
case Precision::Int8: params = Conv2dInt8::GetParam(); break; | ||
case Precision::BFloat16: params = Conv2dBFloat16::GetParam(); break; | ||
default: params = Conv2dFloat::GetParam(); | ||
} | ||
|
||
for(const auto& test_value : params) | ||
{ | ||
std::vector<std::string> tokens; | ||
GetArgs(test_value, tokens); | ||
std::vector<const char*> ptrs; | ||
|
||
std::transform(tokens.begin(), tokens.end(), std::back_inserter(ptrs), [](const auto& str) { | ||
return str.data(); | ||
}); | ||
|
||
testing::internal::CaptureStderr(); | ||
test_drive<conv2d_driver>(ptrs.size(), ptrs.data()); | ||
auto capture = testing::internal::GetCapturedStderr(); | ||
EXPECT_FALSE(capture.find("Perf Db: record not found") != std::string::npos); | ||
} | ||
}; | ||
|
||
TEST_P(Conv2dFloat, FloatTest) | ||
{ | ||
#if MIOPEN_EMBED_DB | ||
xinlipn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
const auto& handle = get_handle(); | ||
if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--float") | ||
xinlipn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
{ | ||
GTEST_SKIP(); | ||
} | ||
else | ||
{ | ||
Run2dDriver(Precision::Float); | ||
} | ||
|
||
#else | ||
GTEST_SKIP(); | ||
#endif | ||
}; | ||
|
||
TEST_P(Conv2dHalf, HalfTest) | ||
{ | ||
#if MIOPEN_EMBED_DB | ||
|
||
const auto& handle = get_handle(); | ||
if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--half") | ||
{ | ||
GTEST_SKIP(); | ||
} | ||
else | ||
{ | ||
Run2dDriver(Precision::Half); | ||
} | ||
|
||
#else | ||
GTEST_SKIP(); | ||
#endif | ||
}; | ||
|
||
TEST_P(Conv2dInt8, Int8Test) | ||
{ | ||
#if MIOPEN_EMBED_DB | ||
|
||
const auto& handle = get_handle(); | ||
if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--int8") | ||
{ | ||
GTEST_SKIP(); | ||
} | ||
else | ||
{ | ||
Run2dDriver(Precision::Int8); | ||
} | ||
|
||
#else | ||
GTEST_SKIP(); | ||
#endif | ||
}; | ||
|
||
TEST_P(Conv2dBFloat16, BFloat16Test) | ||
{ | ||
#if MIOPEN_EMBED_DB | ||
|
||
const auto& handle = get_handle(); | ||
if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--bfloat16") | ||
{ | ||
GTEST_SKIP(); | ||
} | ||
else | ||
{ | ||
Run2dDriver(Precision::BFloat16); | ||
} | ||
|
||
#else | ||
GTEST_SKIP(); | ||
#endif | ||
}; | ||
|
||
std::vector<TestCase> GetTestCases(const std::string& precision) | ||
{ | ||
|
||
std::vector<std::string> winograd = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2"}; | ||
std::vector<std::string> igemm_wrw = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2", | ||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_4R1"}; | ||
std::vector<std::string> igemm_fwd = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2", | ||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_4R1"}; | ||
std::vector<std::string> igemm_fwd_wrw = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2", | ||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_4R1", | ||
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_4R1"}; | ||
xinlipn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
const std::vector<TestCase> test_cases = { | ||
// clang-format off | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision + | ||
" --disable-validation --verbose --input 128 1024 14 14 --weights 2048 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd), precision + | ||
" --disable-validation --verbose --input 128 1024 14 14 --weights 256 1024 1 1 --pads_strides_dilations 0 0 1 1 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision + | ||
" --disable-validation --verbose --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(winograd), precision + | ||
" --disable-validation --verbose --input 128 128 28 28 --weights 128 128 3 3 --pads_strides_dilations 1 1 1 1 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision + | ||
" --disable-validation --verbose --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision + | ||
" --disable-validation --verbose --input 128 128 28 28 --weights 512 128 1 1 --pads_strides_dilations 0 0 1 1 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision + | ||
" --disable-validation --verbose --input 128 2048 7 7 --weights 512 2048 1 1 --pads_strides_dilations 0 0 1 1 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd), precision + | ||
" --disable-validation --verbose --input 128 256 14 14 --weights 1024 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision + | ||
" --disable-validation --verbose --input 128 256 14 14 --weights 256 256 3 3 --pads_strides_dilations 1 1 1 1 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision + | ||
" --disable-validation --verbose --input 128 256 56 56 --weights 128 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(winograd), precision + | ||
" --disable-validation --verbose --input 128 256 56 56 --weights 512 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision + | ||
" --disable-validation --verbose --input 128 256 56 56 --weights 64 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(winograd), precision + | ||
" --disable-validation --verbose --input 128 3 230 230 --weights 64 3 7 7 --pads_strides_dilations 0 0 2 2 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision + | ||
" --disable-validation --verbose --input 128 512 28 28 --weights 1024 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision + | ||
" --disable-validation --verbose --input 128 512 28 28 --weights 128 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd), precision + | ||
" --disable-validation --verbose --input 128 512 28 28 --weights 256 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision + | ||
" --disable-validation --verbose --input 128 512 7 7 --weights 2048 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision + | ||
" --disable-validation --verbose --input 128 512 7 7 --weights 512 512 3 3 --pads_strides_dilations 1 1 1 1 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision + | ||
" --disable-validation --verbose --input 128 64 56 56 --weights 256 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision + | ||
" --disable-validation --verbose --input 128 64 56 56 --weights 64 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"), | ||
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision + | ||
" --disable-validation --verbose --input 128 64 56 56 --weights 64 64 3 3 --pads_strides_dilations 1 1 1 1 1 1") | ||
// clang-format on | ||
}; | ||
|
||
return test_cases; | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P(Conv2dGroup, Conv2dFloat, testing::Values(GetTestCases("--float"))); | ||
INSTANTIATE_TEST_SUITE_P(Conv2dGroup, Conv2dHalf, testing::Values(GetTestCases("--half"))); | ||
INSTANTIATE_TEST_SUITE_P(Conv2dGroup, Conv2dInt8, testing::Values(GetTestCases("--int8"))); | ||
INSTANTIATE_TEST_SUITE_P(Conv2dGroup, Conv2dBFloat16, testing::Values(GetTestCases("--bfloat16"))); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test names should reflect that the test is intended for a library configuration with embedded databases. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@atamazov Test names have been updated. Could you resolve conversation if this has been resolved? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Resolved] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JehandadKhan [Notice] just in case - it seems that this code is working, but actually it should be
if(NOT MIOPEN_EMBED_DB STREQUAL "")
here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@atamazov , I have been thinking if this line would be necessary in this case
https://github.com/ROCmSoftwarePlatform/MIOpen/blob/29afc6d0d4d98a2e821a420ac3613d74d51b9b9b/test/gtest/conv_embed_db.cpp#L105