-
Notifications
You must be signed in to change notification settings - Fork 224
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
test_conv_embed_db (ctest -> gtest) #2168
Merged
Merged
Changes from all commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
89c8fe2
conv_2d_wrapper gtest
alexandraBara dc3b8b9
pulling in env vars
alexandraBara 6f1d5b0
sending tuple values
alexandraBara ba9cd2a
updates to env vars
alexandraBara 9f3fe10
stdout test capture to fail based in stderr stream
alexandraBara 251bb89
clang tidy
alexandraBara 61e49ec
MIOPEN_FLOAT_TEST_ARG assigments to all gtests
alexandraBara 17aedb8
added all precision based tests
alexandraBara 1b5e8c1
some rework
alexandraBara f5d3cf1
renamed file
alexandraBara c5cbb1a
code cleanup
alexandraBara 8a8176d
Merge branch 'develop' into alex_gtest
alexandraBara ae29c5d
addressed reviews
alexandraBara beb7fcb
Cleaned up test case creation, skip tests for gfx908 and gfx90a, addr…
xinlipn d86eb8d
Bring back GetEnvVars(), removed duplicated wrw test
xinlipn 4760a5a
Replace trivial for loop with std::transform to fix hip tidy warning
xinlipn 29afc6d
Cleaned up code
xinlipn 2b656e5
Remove WA per PR2179, preserved test name
xinlipn 77dc83d
Refactor duplicated code
xinlipn 70d1d9c
Remove conditional compilation
xinlipn d087a00
Revert and refactor conv2d_driver changes
xinlipn 6abced7
Refactor code, specify relative path to make it pass static check
xinlipn da3594e
Resolve test CMakeLists.txt conflicts
xinlipn 79eb207
Merge branch 'develop' into alex_gtest
xinlipn 20ad346
Merge branch 'develop' into alex_gtest
xinlipn bae1221
Merge branch 'alex_gtest' of https://github.com/ROCmSoftwarePlatform/…
xinlipn 0979c73
Refactored code eg reusing existing API
xinlipn c987d64
Replace GetEnv() witn GetStringEnv(), MIOPEN_THROW() with FAIL()
xinlipn 20532ed
Refactor code by creating IsTestRunWith()
xinlipn 45e89ca
Moved and renamed conv_2d.hpp in gtest folder one level up
xinlipn a6e6f1a
Relocate conv2d.hpp
xinlipn 444ee2b
Fix static check error
xinlipn ee3adb8
Fix conflicts and merge branch 'develop' into alex_gtest
xinlipn 878aa00
Fix conflicts and merge branch 'develop' into alex_gtest
xinlipn 35f72a8
Fix Read the Docs build failed error in CI
xinlipn e4b5296
Fix conflic and merge branch 'develop' into alex_gtest
xinlipn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
/******************************************************************************* | ||
* | ||
* MIT License | ||
* | ||
* Copyright (c) 2023 Advanced Micro Devices, Inc. | ||
* | ||
* Permission is hereby granted, free of charge, to any person obtaining a copy | ||
* of this software and associated documentation files (the "Software"), to deal | ||
* in the Software without restriction, including without limitation the rights | ||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
* copies of the Software, and to permit persons to whom the Software is | ||
* furnished to do so, subject to the following conditions: | ||
* | ||
* The above copyright notice and this permission notice shall be included in all | ||
* copies or substantial portions of the Software. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
* SOFTWARE. | ||
* | ||
*******************************************************************************/ | ||
#pragma once | ||
|
||
#include "conv_common.hpp" | ||
|
||
template <class T> | ||
struct conv2d_driver : conv_driver<T> | ||
{ | ||
conv2d_driver() : conv_driver<T>() | ||
{ | ||
this->add(this->input_dims, "input"); | ||
this->add(this->weight_tensor_dims, "weights"); | ||
this->add(this->batch_size, | ||
"batch_size", | ||
this->generate_data_limited(this->get_batch_sizes(), 1)); | ||
this->add(this->input_channels, | ||
"input_channels", | ||
this->generate_data_limited(this->get_input_channels(), 1, {32})); | ||
this->add(this->output_channels, | ||
"output_channels", | ||
this->generate_data_limited(this->get_output_channels(), 1, {64})); | ||
this->add(this->spatial_dim_elements, | ||
"spatial_dim_elements", | ||
this->generate_data_limited(this->get_2d_spatial_dims(), 1, {28, 28})); | ||
this->add(this->filter_dims, | ||
"filter_dims", | ||
this->generate_data_limited(this->get_2d_filter_dims(), 2, {3, 3})); | ||
this->add(this->pads_strides_dilations, | ||
"pads_strides_dilations", | ||
this->generate_data_limited(this->get_2d_pads_strides_dilations(), 2)); | ||
this->add(this->trans_output_pads, | ||
"trans_output_pads", | ||
this->generate_data(this->get_2d_trans_output_pads())); | ||
this->add(this->in_layout, "in_layout", this->generate_data({"NCHW"})); | ||
this->add(this->fil_layout, "fil_layout", this->generate_data({"NCHW"})); | ||
this->add(this->out_layout, "out_layout", this->generate_data({"NCHW"})); | ||
this->add(this->deterministic, "deterministic", this->generate_data({false})); | ||
this->add(this->tensor_vect, "tensor_vect", this->generate_data({0})); | ||
this->add(this->vector_length, "vector_length", this->generate_data({1})); | ||
// Only valid for int8 input and weights | ||
this->add(this->output_type, "output_type", this->generate_data({"int32"})); | ||
this->add(this->int8_vectorize, "int8_vectorize", this->generate_data({false})); | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,229 @@ | ||
/******************************************************************************* | ||
* | ||
* MIT License | ||
* | ||
* Copyright (c) 2023 Advanced Micro Devices, Inc. | ||
* | ||
* Permission is hereby granted, free of charge, to any person obtaining a copy | ||
* of this software and associated documentation files (the "Software"), to deal | ||
* in the Software without restriction, including without limitation the rights | ||
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
* copies of the Software, and to permit persons to whom the Software is | ||
* furnished to do so, subject to the following conditions: | ||
* | ||
* The above copyright notice and this permission notice shall be included in all | ||
* copies or substantial portions of the Software. | ||
* | ||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
* SOFTWARE. | ||
* | ||
*******************************************************************************/ | ||
#include <tuple> | ||
JehandadKhan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
#include <miopen/miopen.h> | ||
#include <gtest/gtest.h> | ||
#include <miopen/miopen.h> | ||
#include <miopen/env.hpp> | ||
#include "../conv2d.hpp" | ||
#include "get_handle.hpp" | ||
|
||
MIOPEN_DECLARE_ENV_VAR(MIOPEN_TEST_FLOAT_ARG) | ||
|
||
static bool IsTestRunWith(const char* float_arg) | ||
{ | ||
assert(float_arg != nullptr); | ||
const char* const p_envVar = miopen::GetStringEnv(MIOPEN_TEST_FLOAT_ARG{}); | ||
return (p_envVar != nullptr && std::strcmp(p_envVar, float_arg) == 0); | ||
} | ||
|
||
void GetArgs(const std::string& param, std::vector<std::string>& tokens) | ||
{ | ||
std::stringstream ss(param); | ||
std::istream_iterator<std::string> begin(ss); | ||
std::istream_iterator<std::string> end; | ||
while(begin != end) | ||
tokens.push_back(*begin++); | ||
} | ||
|
||
class ConfigWithHalf : public testing::TestWithParam<std::vector<std::string>> | ||
{ | ||
}; | ||
class ConfigWithInt8 : public testing::TestWithParam<std::vector<std::string>> | ||
{ | ||
}; | ||
class ConfigWithBFloat16 : public testing::TestWithParam<std::vector<std::string>> | ||
{ | ||
}; | ||
class ConfigWithFloat : public testing::TestWithParam<std::vector<std::string>> | ||
{ | ||
}; | ||
JehandadKhan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
void Run2dDriver(miopenDataType_t prec) | ||
{ | ||
|
||
std::vector<std::string> params; | ||
switch(prec) | ||
{ | ||
case miopenFloat: params = ConfigWithFloat::GetParam(); break; | ||
case miopenHalf: params = ConfigWithHalf::GetParam(); break; | ||
case miopenInt8: params = ConfigWithInt8::GetParam(); break; | ||
case miopenBFloat16: params = ConfigWithBFloat16::GetParam(); break; | ||
case miopenInt8x4: | ||
case miopenInt32: | ||
case miopenDouble: | ||
FAIL() << "miopenInt8x4, miopenInt32, miopenDouble data type not supported by " | ||
"conv_embed_db test"; | ||
|
||
default: params = ConfigWithFloat::GetParam(); | ||
} | ||
|
||
for(const auto& test_value : params) | ||
{ | ||
std::vector<std::string> tokens; | ||
GetArgs(test_value, tokens); | ||
std::vector<const char*> ptrs; | ||
|
||
std::transform(tokens.begin(), tokens.end(), std::back_inserter(ptrs), [](const auto& str) { | ||
return str.data(); | ||
}); | ||
|
||
testing::internal::CaptureStderr(); | ||
test_drive<conv2d_driver>(ptrs.size(), ptrs.data()); | ||
auto capture = testing::internal::GetCapturedStderr(); | ||
EXPECT_FALSE(capture.find("Perf Db: record not found") != std::string::npos); | ||
} | ||
}; | ||
|
||
bool IsTestSupportedForDevice(const miopen::Handle& handle) | ||
{ | ||
std::string devName = handle.GetDeviceName(); | ||
if(devName == "gfx900" || devName == "gfx906") | ||
return true; | ||
else | ||
return false; | ||
} | ||
|
||
TEST_P(ConfigWithFloat, FloatTest) | ||
{ | ||
#if MIOPEN_EMBED_DB | ||
xinlipn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
const auto& handle = get_handle(); | ||
if(IsTestSupportedForDevice(handle) && IsTestRunWith("--float")) | ||
{ | ||
Run2dDriver(miopenFloat); | ||
} | ||
else | ||
{ | ||
GTEST_SKIP(); | ||
} | ||
|
||
#else | ||
GTEST_SKIP(); | ||
#endif | ||
}; | ||
|
||
TEST_P(ConfigWithHalf, HalfTest) | ||
{ | ||
#if MIOPEN_EMBED_DB | ||
|
||
const auto& handle = get_handle(); | ||
if(IsTestSupportedForDevice(handle) && IsTestRunWith("--half")) | ||
{ | ||
Run2dDriver(miopenHalf); | ||
} | ||
else | ||
{ | ||
GTEST_SKIP(); | ||
} | ||
|
||
#else | ||
GTEST_SKIP(); | ||
#endif | ||
}; | ||
|
||
TEST_P(ConfigWithInt8, Int8Test) | ||
{ | ||
#if MIOPEN_EMBED_DB | ||
|
||
const auto& handle = get_handle(); | ||
if(IsTestSupportedForDevice(handle) && IsTestRunWith("--int8")) | ||
{ | ||
Run2dDriver(miopenInt8); | ||
} | ||
else | ||
{ | ||
GTEST_SKIP(); | ||
} | ||
|
||
#else | ||
GTEST_SKIP(); | ||
#endif | ||
}; | ||
|
||
TEST_P(ConfigWithBFloat16, BFloat16Test) | ||
{ | ||
#if MIOPEN_EMBED_DB | ||
|
||
const auto& handle = get_handle(); | ||
if(IsTestSupportedForDevice(handle) && IsTestRunWith("--bfloat16")) | ||
{ | ||
Run2dDriver(miopenBFloat16); | ||
} | ||
else | ||
{ | ||
GTEST_SKIP(); | ||
} | ||
|
||
#else | ||
GTEST_SKIP(); | ||
#endif | ||
}; | ||
|
||
std::vector<std::string> GetTestCases(const std::string& precision) | ||
{ | ||
std::string flags = " --disable-validation --verbose "; | ||
|
||
// If precision env var is not set | ||
if(!(IsTestRunWith("--float") || IsTestRunWith("--half") || IsTestRunWith("--int8") || | ||
IsTestRunWith("--bfloat16"))) | ||
flags.insert(0, precision); | ||
|
||
const std::vector<std::string> test_cases = { | ||
// clang-format off | ||
{flags + "--input 128 128 28 28 --weights 128 128 3 3 --pads_strides_dilations 1 1 1 1 1 1"}, | ||
{flags + "--input 128 256 56 56 --weights 512 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"}, | ||
{flags + "--input 128 3 230 230 --weights 64 3 7 7 --pads_strides_dilations 0 0 2 2 1 1"}, | ||
{flags + "--input 128 64 56 56 --weights 64 64 3 3 --pads_strides_dilations 1 1 1 1 1 1"}, | ||
{flags + "--input 128 256 14 14 --weights 256 256 3 3 --pads_strides_dilations 1 1 1 1 1 1"}, | ||
{flags + "--input 128 512 7 7 --weights 512 512 3 3 --pads_strides_dilations 1 1 1 1 1 1"}, | ||
{flags + "--input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"}, | ||
{flags + "--input 128 1024 14 14 --weights 2048 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"}, | ||
{flags + "--input 128 256 14 14 --weights 1024 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"}, | ||
{flags + "--input 128 512 28 28 --weights 256 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"}, | ||
{flags + "--input 128 1024 14 14 --weights 256 1024 1 1 --pads_strides_dilations 0 0 1 1 1 1"}, | ||
{flags + "--input 128 64 56 56 --weights 256 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"}, | ||
{flags + "--input 128 64 56 56 --weights 64 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"}, | ||
{flags + "--input 128 128 28 28 --weights 512 128 1 1 --pads_strides_dilations 0 0 1 1 1 1"}, | ||
{flags + "--input 128 256 56 56 --weights 128 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"}, | ||
{flags + "--input 128 256 56 56 --weights 64 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"}, | ||
{flags + "--input 128 512 28 28 --weights 1024 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"}, | ||
{flags + "--input 128 512 28 28 --weights 128 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"}, | ||
{flags + "--input 128 512 7 7 --weights 2048 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"}, | ||
{flags + "--input 128 2048 7 7 --weights 512 2048 1 1 --pads_strides_dilations 0 0 1 1 1 1"} | ||
// clang-format on | ||
}; | ||
|
||
return test_cases; | ||
} | ||
|
||
INSTANTIATE_TEST_SUITE_P(ConvEmbedDB, ConfigWithFloat, testing::Values(GetTestCases("--float"))); | ||
INSTANTIATE_TEST_SUITE_P(ConvEmbedDB, ConfigWithHalf, testing::Values(GetTestCases("--half"))); | ||
INSTANTIATE_TEST_SUITE_P(ConvEmbedDB, ConfigWithInt8, testing::Values(GetTestCases("--int8"))); | ||
INSTANTIATE_TEST_SUITE_P(ConvEmbedDB, | ||
ConfigWithBFloat16, | ||
testing::Values(GetTestCases("--bfloat16"))); |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JehandadKhan [Notice] just in case - it seems that this code is working, but actually it should be
if(NOT MIOPEN_EMBED_DB STREQUAL "")
here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@atamazov , I have been thinking if this line would be necessary in this case
https://github.com/ROCmSoftwarePlatform/MIOpen/blob/29afc6d0d4d98a2e821a420ac3613d74d51b9b9b/test/gtest/conv_embed_db.cpp#L105