Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test_conv_embed_db (ctest -> gtest) #2168

Merged
merged 36 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
89c8fe2
conv_2d_wrapper gtest
alexandraBara May 13, 2023
dc3b8b9
pulling in env vars
alexandraBara May 16, 2023
6f1d5b0
sending tuple values
alexandraBara May 16, 2023
ba9cd2a
updates to env vars
alexandraBara May 17, 2023
9f3fe10
stdout test capture to fail based in stderr stream
alexandraBara May 19, 2023
251bb89
clang tidy
alexandraBara May 20, 2023
61e49ec
MIOPEN_FLOAT_TEST_ARG assigments to all gtests
alexandraBara May 22, 2023
17aedb8
added all precision based tests
alexandraBara May 23, 2023
1b5e8c1
some rework
alexandraBara May 24, 2023
f5d3cf1
renamed file
alexandraBara May 25, 2023
c5cbb1a
code cleanup
alexandraBara May 25, 2023
8a8176d
Merge branch 'develop' into alex_gtest
alexandraBara May 25, 2023
ae29c5d
addressed reviews
alexandraBara May 30, 2023
beb7fcb
Cleaned up test case creation, skip tests for gfx908 and gfx90a, addr…
xinlipn Jun 2, 2023
d86eb8d
Bring back GetEnvVars(), removed duplicated wrw test
xinlipn Jun 5, 2023
4760a5a
Replace trivial for loop with std::transform to fix hip tidy warning
xinlipn Jun 5, 2023
29afc6d
Cleaned up code
xinlipn Jun 6, 2023
2b656e5
Remove WA per PR2179, preserved test name
xinlipn Jun 8, 2023
77dc83d
Refactor duplicated code
xinlipn Jun 9, 2023
70d1d9c
Remove conditional compilation
xinlipn Jun 9, 2023
d087a00
Revert and refactor conv2d_driver changes
xinlipn Jun 9, 2023
6abced7
Refactor code, specify relative path to make it pass static check
xinlipn Jun 9, 2023
da3594e
Resolve test CMakeLists.txt conflicts
xinlipn Jun 15, 2023
79eb207
Merge branch 'develop' into alex_gtest
xinlipn Jun 15, 2023
20ad346
Merge branch 'develop' into alex_gtest
xinlipn Jun 15, 2023
bae1221
Merge branch 'alex_gtest' of https://github.com/ROCmSoftwarePlatform/…
xinlipn Jun 20, 2023
0979c73
Refactored code eg reusing existing API
xinlipn Jun 23, 2023
c987d64
Replace GetEnv() witn GetStringEnv(), MIOPEN_THROW() with FAIL()
xinlipn Jun 27, 2023
20532ed
Refactor code by creating IsTestRunWith()
xinlipn Jun 28, 2023
45e89ca
Moved and renamed conv_2d.hpp in gtest folder one level up
xinlipn Jul 5, 2023
a6e6f1a
Relocate conv2d.hpp
xinlipn Jul 5, 2023
444ee2b
Fix static check error
xinlipn Jul 6, 2023
ee3adb8
Fix conflicts and merge branch 'develop' into alex_gtest
xinlipn Jul 7, 2023
878aa00
Fix conflicts and merge branch 'develop' into alex_gtest
xinlipn Jul 11, 2023
35f72a8
Fix Read the Docs build failed error in CI
xinlipn Jul 13, 2023
e4b5296
Fix conflic and merge branch 'develop' into alex_gtest
xinlipn Jul 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -760,41 +760,6 @@ if(${MIOPEN_TEST_WITH_MIOPENDRIVER})
)
endif()

# ./bin/MIOpenDriver conv -n 128 -c 1024 -H 14 -W 14 -k 2048 -y 1 -x 1 -p 0 -q 0 -u 2 -v 2 -l 1 -j 1 -m conv -g 1 -F 1 -t 1
# MIOPEN_DEBUG_CONV_IMMED_FALLBACK=0
if(MIOPEN_EMBED_DB)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JehandadKhan [Notice] just in case - it seems that this code is working, but actually it should be if(NOT MIOPEN_EMBED_DB STREQUAL "") here.

Copy link
Contributor

@xinlipn xinlipn Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JehandadKhan [Notice] just in case - it seems that this code is working, but actually it should be if(NOT MIOPEN_EMBED_DB STREQUAL "") here.

@atamazov , I have been thinking if this line would be necessary in this case

https://github.com/ROCmSoftwarePlatform/MIOpen/blob/29afc6d0d4d98a2e821a420ac3613d74d51b9b9b/test/gtest/conv_embed_db.cpp#L105

set(MIOPEN_EMBED_TEST_ARG ${MIOPEN_TEST_FLOAT_ARG} --disable-validation --verbose)
# WORKAROUND for issue #874
set(MIOPEN_WA_ISSUE_874_F MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1=0)
set(MIOPEN_WA_ISSUE_874_W MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R1=0)
set(MIOPEN_WA_ISSUE_874_FW MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_V4R1=0 MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_V4R1=0)
# WORKAROUND for issue #1008
set(MIOPEN_WA_ISSUE_1008 MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2=0)
xinlipn marked this conversation as resolved.
Show resolved Hide resolved
add_custom_test(test_conv_embed_db TEST_PERF_DB_RECORD_NOT_FOUND GFX908_DISABLED GFX90A_DISABLED
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_W} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 1024 14 14 --weights 2048 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_F} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 1024 14 14 --weights 256 1024 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_W} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 128 28 28 --weights 128 128 3 3 --pads_strides_dilations 1 1 1 1 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_W} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_FW} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 128 28 28 --weights 512 128 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_FW} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 2048 7 7 --weights 512 2048 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_F} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 256 14 14 --weights 1024 256 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_W} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 256 14 14 --weights 256 256 3 3 --pads_strides_dilations 1 1 1 1 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_FW} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 256 56 56 --weights 128 256 1 1 --pads_strides_dilations 0 0 2 2 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 256 56 56 --weights 512 256 1 1 --pads_strides_dilations 0 0 2 2 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_FW} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 256 56 56 --weights 64 256 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 3 230 230 --weights 64 3 7 7 --pads_strides_dilations 0 0 2 2 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_FW} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 512 28 28 --weights 1024 512 1 1 --pads_strides_dilations 0 0 2 2 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_FW} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 512 28 28 --weights 128 512 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_FW} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 512 28 28 --weights 256 512 1 1 --pads_strides_dilations 0 0 2 2 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_FW} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 512 7 7 --weights 2048 512 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_W} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 512 7 7 --weights 512 512 3 3 --pads_strides_dilations 1 1 1 1 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_FW} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 64 56 56 --weights 256 64 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_FW} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 64 56 56 --weights 64 64 1 1 --pads_strides_dilations 0 0 1 1 1 1
COMMAND ${MIOPEN_WA_ISSUE_1008} ${MIOPEN_WA_ISSUE_874_W} $<TARGET_FILE:test_conv2d> ${MIOPEN_EMBED_TEST_ARG} --input 128 64 56 56 --weights 64 64 3 3 --pads_strides_dilations 1 1 1 1 1 1
)
endif()

set(IMPLICITGEMM_MLIR_ENV_BASE MIOPEN_FIND_MODE=normal)
set(IMPLICITGEMM_MLIR_ENV_F ${IMPLICITGEMM_MLIR_ENV_BASE} MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvMlirIgemmFwd)
set(IMPLICITGEMM_MLIR_ENV_B ${IMPLICITGEMM_MLIR_ENV_BASE} MIOPEN_DEBUG_FIND_ONLY_SOLVER=ConvMlirIgemmBwd)
Expand Down
2 changes: 1 addition & 1 deletion test/gtest/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function(add_gtest TEST_NAME)
target_link_libraries(test_${TEST_NAME} gtest_main MIOpen ${Boost_LIBRARIES} hip::host $<BUILD_INTERFACE:roc::rocblas>)
endif()
# Enable CMake to discover the test binary
gtest_discover_tests(test_${TEST_NAME} PROPERTIES ENVIRONMENT "MIOPEN_USER_DB_PATH=${CMAKE_CURRENT_BINARY_DIR}")
gtest_discover_tests(test_${TEST_NAME} PROPERTIES ENVIRONMENT "MIOPEN_USER_DB_PATH=${CMAKE_CURRENT_BINARY_DIR};MIOPEN_TEST_FLOAT_ARG=${MIOPEN_TEST_FLOAT_ARG}")

endif()
endfunction()
Expand Down
66 changes: 66 additions & 0 deletions test/gtest/conv_2d.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2023 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/
#include "conv_common.hpp"

template <class T>
struct conv2d_driver : conv_driver<T>
{
conv2d_driver() : conv_driver<T>()
{
this->add(this->input_dims, "input");
this->add(this->weight_tensor_dims, "weights");
this->add(this->batch_size,
"batch_size",
this->generate_data_limited(this->get_batch_sizes(), 1));
this->add(this->input_channels,
"input_channels",
this->generate_data_limited(this->get_input_channels(), 1, {32}));
this->add(this->output_channels,
"output_channels",
this->generate_data_limited(this->get_output_channels(), 1, {64}));
this->add(this->spatial_dim_elements,
"spatial_dim_elements",
this->generate_data_limited(this->get_2d_spatial_dims(), 1, {28, 28}));
this->add(this->filter_dims,
"filter_dims",
this->generate_data_limited(this->get_2d_filter_dims(), 2, {3, 3}));
this->add(this->pads_strides_dilations,
"pads_strides_dilations",
this->generate_data_limited(this->get_2d_pads_strides_dilations(), 2));
this->add(this->trans_output_pads,
"trans_output_pads",
this->generate_data(this->get_2d_trans_output_pads()));
this->add(this->in_layout, "in_layout", this->generate_data({"NCHW"}));
this->add(this->fil_layout, "fil_layout", this->generate_data({"NCHW"}));
this->add(this->out_layout, "out_layout", this->generate_data({"NCHW"}));
this->add(this->deterministic, "deterministic", this->generate_data({false}));
this->add(this->tensor_vect, "tensor_vect", this->generate_data({0}));
this->add(this->vector_length, "vector_length", this->generate_data({1}));
// Only valid for int8 input and weights
this->add(this->output_type, "output_type", this->generate_data({"int32"}));
this->add(this->int8_vectorize, "int8_vectorize", this->generate_data({false}));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is partial copy of test/conv2d.cpp. Please do not make copies of existing code, refactor instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[notice] Copying the code is the worst thing that we can do from the maintenance point of view.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[notice] Copying the code is the worst thing that we can do from the maintenance point of view.

Refactored by moving common code to conv_common.hpp

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is 2D-specific and should not be moved there.

Copy link
Contributor

@xinlipn xinlipn Jun 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is 2D-specific and should not be moved there.

Reverted and re-refactored conv2d_driver related changes. This may look out of the place since the header is in test/gtest and this file is in test/. All previous conv2d CTest cases may be refactored by using this header as we continue moving toward gTest,

https://github.com/ROCmSoftwarePlatform/MIOpen/blob/d087a00396879dededb1bc0b13409d85060e7fc7/test/conv2d.cpp#L26

Copy link
Contributor

@atamazov atamazov Jun 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xinlipn Yeah, test/ should not depend on test/gtest/ (Right now we are reusing code from test/.)

Please rename this header to test/conv2d.hpp.

All previous conv2d CTest cases may be refactored by using this header as we continue moving toward gTest,

Moving code from test/ to test/gtest/ is a separate task. Let's do that later, separately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Notice] Not resolved yet. Let's either fix or discuss.

};
243 changes: 243 additions & 0 deletions test/gtest/conv_embed_db.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
#include <tuple>
JehandadKhan marked this conversation as resolved.
Show resolved Hide resolved

#include <miopen/miopen.h>
#include <gtest/gtest.h>
#include "conv_2d.hpp"
#include "get_handle.hpp"

using TestCase = std::tuple<std::vector<std::string>, std::string>;

enum class Precision
xinlipn marked this conversation as resolved.
Show resolved Hide resolved
{
Float,
Half,
Int8,
BFloat16
};

std::string GetFloatArg()
{
static const auto tmp = std::getenv("MIOPEN_TEST_FLOAT_ARG");
xinlipn marked this conversation as resolved.
Show resolved Hide resolved
if(tmp == nullptr)
{
return "";
}
return tmp;
};
xinlipn marked this conversation as resolved.
Show resolved Hide resolved

std::vector<std::string> GetEnvVars(const std::vector<std::string>& check_vars)
{
std::vector<std::string> vars = {};
for(const auto& cvar : check_vars)
{
static const auto tmp = std::getenv(cvar.c_str());
if(tmp != nullptr)
{
vars.push_back(cvar + "=0");
}
}
return vars;
};

void GetArgs(const TestCase& param, std::vector<std::string>& tokens)
{
auto env_vars = std::get<0>(param);
for(auto& elem : env_vars)
{
putenv(elem.data());
}

auto cmd = std::get<1>(param);

std::stringstream ss(cmd);
std::istream_iterator<std::string> begin(ss);
std::istream_iterator<std::string> end;
while(begin != end)
tokens.push_back(*begin++);
}

class Conv2dHalf : public testing::TestWithParam<std::vector<TestCase>>
{
};
class Conv2dInt8 : public testing::TestWithParam<std::vector<TestCase>>
{
};
class Conv2dBFloat16 : public testing::TestWithParam<std::vector<TestCase>>
{
};
class Conv2dFloat : public testing::TestWithParam<std::vector<TestCase>>
{
};

void Run2dDriver(Precision prec)
{

std::vector<TestCase> params;
switch(prec)
{
case Precision::Float: params = Conv2dFloat::GetParam(); break;
case Precision::Half: params = Conv2dHalf::GetParam(); break;
case Precision::Int8: params = Conv2dInt8::GetParam(); break;
case Precision::BFloat16: params = Conv2dBFloat16::GetParam(); break;
default: params = Conv2dFloat::GetParam();
}

for(const auto& test_value : params)
{
std::vector<std::string> tokens;
GetArgs(test_value, tokens);
std::vector<const char*> ptrs;

std::transform(tokens.begin(), tokens.end(), std::back_inserter(ptrs), [](const auto& str) {
return str.data();
});

testing::internal::CaptureStderr();
test_drive<conv2d_driver>(ptrs.size(), ptrs.data());
auto capture = testing::internal::GetCapturedStderr();
EXPECT_FALSE(capture.find("Perf Db: record not found") != std::string::npos);
}
};

TEST_P(Conv2dFloat, FloatTest)
{
#if MIOPEN_EMBED_DB
xinlipn marked this conversation as resolved.
Show resolved Hide resolved

const auto& handle = get_handle();
if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--float")
xinlipn marked this conversation as resolved.
Show resolved Hide resolved
{
GTEST_SKIP();
}
else
{
Run2dDriver(Precision::Float);
}

#else
GTEST_SKIP();
#endif
};

TEST_P(Conv2dHalf, HalfTest)
{
#if MIOPEN_EMBED_DB

const auto& handle = get_handle();
if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--half")
{
GTEST_SKIP();
}
else
{
Run2dDriver(Precision::Half);
}

#else
GTEST_SKIP();
#endif
};

TEST_P(Conv2dInt8, Int8Test)
{
#if MIOPEN_EMBED_DB

const auto& handle = get_handle();
if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--int8")
{
GTEST_SKIP();
}
else
{
Run2dDriver(Precision::Int8);
}

#else
GTEST_SKIP();
#endif
};

TEST_P(Conv2dBFloat16, BFloat16Test)
{
#if MIOPEN_EMBED_DB

const auto& handle = get_handle();
if(!miopen::StartsWith(handle.GetDeviceName(), "gfx906") || GetFloatArg() != "--bfloat16")
{
GTEST_SKIP();
}
else
{
Run2dDriver(Precision::BFloat16);
}

#else
GTEST_SKIP();
#endif
};

std::vector<TestCase> GetTestCases(const std::string& precision)
{

std::vector<std::string> winograd = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2"};
std::vector<std::string> igemm_wrw = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_4R1"};
std::vector<std::string> igemm_fwd = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_4R1"};
std::vector<std::string> igemm_fwd_wrw = {"MIOPEN_DEBUG_AMD_WINOGRAD_RXS_F3X2",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_FWD_4R1",
"MIOPEN_DEBUG_CONV_IMPLICIT_GEMM_HIP_WRW_4R1"};
xinlipn marked this conversation as resolved.
Show resolved Hide resolved

const std::vector<TestCase> test_cases = {
// clang-format off
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision +
" --disable-validation --verbose --input 128 1024 14 14 --weights 2048 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd), precision +
" --disable-validation --verbose --input 128 1024 14 14 --weights 256 1024 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision +
" --disable-validation --verbose --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(winograd), precision +
" --disable-validation --verbose --input 128 128 28 28 --weights 128 128 3 3 --pads_strides_dilations 1 1 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision +
" --disable-validation --verbose --input 128 1024 14 14 --weights 512 1024 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 128 28 28 --weights 512 128 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 2048 7 7 --weights 512 2048 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd), precision +
" --disable-validation --verbose --input 128 256 14 14 --weights 1024 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision +
" --disable-validation --verbose --input 128 256 14 14 --weights 256 256 3 3 --pads_strides_dilations 1 1 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 256 56 56 --weights 128 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(winograd), precision +
" --disable-validation --verbose --input 128 256 56 56 --weights 512 256 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 256 56 56 --weights 64 256 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(winograd), precision +
" --disable-validation --verbose --input 128 3 230 230 --weights 64 3 7 7 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 512 28 28 --weights 1024 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 512 28 28 --weights 128 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd), precision +
" --disable-validation --verbose --input 128 512 28 28 --weights 256 512 1 1 --pads_strides_dilations 0 0 2 2 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 512 7 7 --weights 2048 512 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision +
" --disable-validation --verbose --input 128 512 7 7 --weights 512 512 3 3 --pads_strides_dilations 1 1 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 64 56 56 --weights 256 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_fwd_wrw), precision +
" --disable-validation --verbose --input 128 64 56 56 --weights 64 64 1 1 --pads_strides_dilations 0 0 1 1 1 1"),
std::make_tuple<std::vector<std::string>, std::string>(GetEnvVars(igemm_wrw), precision +
" --disable-validation --verbose --input 128 64 56 56 --weights 64 64 3 3 --pads_strides_dilations 1 1 1 1 1 1")
// clang-format on
};

return test_cases;
}

INSTANTIATE_TEST_SUITE_P(Conv2dGroup, Conv2dFloat, testing::Values(GetTestCases("--float")));
INSTANTIATE_TEST_SUITE_P(Conv2dGroup, Conv2dHalf, testing::Values(GetTestCases("--half")));
INSTANTIATE_TEST_SUITE_P(Conv2dGroup, Conv2dInt8, testing::Values(GetTestCases("--int8")));
INSTANTIATE_TEST_SUITE_P(Conv2dGroup, Conv2dBFloat16, testing::Values(GetTestCases("--bfloat16")));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test names should reflect that the test is intended for a library configuration with embedded databases.

Copy link
Contributor

@xinlipn xinlipn Jun 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test names should reflect that the test is intended for a library configuration with embedded databases.

@atamazov Test names have been updated. Could you resolve conversation if this has been resolved?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Resolved]