Skip to content

Commit

Permalink
[tests] bg/ck_gfx_white_list: start using ck_utility::is_ck_whitelist…
Browse files Browse the repository at this point in the history
… to restrict tests to applicable platforms (ROCm#2458)

* bg/ck_gfx_white_list :  start using ck_utility::is_ck_whitelist function for all CK solvers

* bg/ck_gfx_white_list: fix review comments
  • Loading branch information
bghimireamd authored Oct 18, 2023
1 parent 877d94b commit 23ac5ab
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 16 deletions.
21 changes: 14 additions & 7 deletions src/include/miopen/solver/ck_utility_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ namespace miopen {
namespace solver {
namespace ck_utility {

// Disclaimer: Currently CK is only supported in MI100, MI200 and MI300.
// Please use is_ck_whitelist instead of this function.
static inline bool is_ck_supported_hardware(const Handle& handle)
{
return (StartsWith(handle.GetDeviceName(), "gfx803") && handle.GetMaxComputeUnits() == 64) ||
Expand All @@ -63,14 +65,19 @@ static inline bool is_ck_supported_hardware(const Handle& handle)
StartsWith(handle.GetDeviceName(), "gfx1102");
}

static inline bool is_conv_ck_supported_hardware(const std::string& device_name, bool is_wrw)
// MI100 : gfx908
// MI200 : gfx90a
// MI300 : gfx940, gfx941, gfx942
static inline bool is_ck_whitelist(const std::string& device_name)
{
auto res_wrw = StartsWith(device_name, "gfx908") || StartsWith(device_name, "gfx90a") ||
StartsWith(device_name, "gfx940") || StartsWith(device_name, "gfx941") ||
StartsWith(device_name, "gfx942");
return is_wrw ? res_wrw
: (res_wrw || StartsWith(device_name, "gfx900") ||
StartsWith(device_name, "gfx906"));
return (StartsWith(device_name, "gfx908") || StartsWith(device_name, "gfx90a") ||
StartsWith(device_name, "gfx940") || StartsWith(device_name, "gfx941") ||
StartsWith(device_name, "gfx942"));
}

static inline bool is_ck_whitelist(const Handle& handle)
{
return is_ck_whitelist(handle.GetDeviceName());
}

static inline bool is_support_amd_buffer_atomic_fadd(const std::string& device_name)
Expand Down
2 changes: 1 addition & 1 deletion src/solver/batchnorm/backward_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ bool BnCKBwdBackward::IsApplicable(
return false;
if(!bn_problem.IsLayoutNHWC())
return false;
if(!ck_utility::is_ck_supported_hardware(context.GetStream()))
if(!ck_utility::is_ck_whitelist(context.GetStream()))
return false;
if(bn_problem.GetXDesc().GetType() != bn_problem.GetScaleBiasDiffDesc().GetType())
return false;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/batchnorm/forward_inference_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ bool BnCKFwdInference::IsApplicable(const ExecutionContext& context,
return false;
if(!bn_problem.IsLayoutNHWC())
return false;
if(!ck_utility::is_ck_supported_hardware(context.GetStream()))
if(!ck_utility::is_ck_whitelist(context.GetStream()))
return false;

switch(bn_problem.GetXDesc().GetType())
Expand Down
2 changes: 1 addition & 1 deletion src/solver/batchnorm/forward_training_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ bool BnCKFwdTraining::IsApplicable(
return false;
if(!bn_problem.IsLayoutNHWC())
return false;
if(!ck_utility::is_ck_supported_hardware(context.GetStream()))
if(!ck_utility::is_ck_whitelist(context.GetStream()))
return false;

switch(bn_problem.GetXDesc().GetType())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ bool ConvHipImplicitGemm3DGroupBwdXdlops::IsApplicable(
return false;
if(!problem.IsLayoutNHWC())
return false;
if(!ck_utility::is_conv_ck_supported_hardware(ctx.GetStream().GetDeviceName(), false))
if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName()))
return false;
switch(problem.GetInDataType())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ bool ConvHipImplicitGemm3DGroupFwdXdlops::IsApplicable(
return false;
if(!problem.IsLayoutNHWC())
return false;
if(!ck_utility::is_conv_ck_supported_hardware(ctx.GetStream().GetDeviceName(), false))
if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName()))
return false;
switch(problem.GetInDataType())
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ bool ConvHipImplicitGemm3DGroupWrwXdlops::IsApplicable(
return false;
if(!problem.IsLayoutNHWC())
return false;
if(!ck_utility::is_conv_ck_supported_hardware(ctx.GetStream().GetDeviceName(), true))
if(!ck_utility::is_ck_whitelist(ctx.GetStream().GetDeviceName()))
return false;
switch(problem.GetInDataType())
{
Expand Down
25 changes: 22 additions & 3 deletions test/gtest/bn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

#include <miopen/miopen.h>
#include <gtest/gtest.h>
#include <miopen/solver/ck_utility_common.hpp>

#include "bn_test_data.hpp"
#include "test_operations.hpp"
Expand All @@ -41,11 +42,15 @@ struct BNInferTest : public ::testing::TestWithParam<std::tuple<BNTestCase, miop
protected:
void SetUp() override
{
test_skipped = false;
std::tie(bn_config, tensor_layout) = GetParam();
bn_infer_test_data.SetUpImpl(bn_config, tensor_layout);

auto&& handle = get_handle();
if(!miopen::solver::ck_utility::is_ck_whitelist(handle.GetStream()))
{
test_skipped = true;
GTEST_SKIP() << "Not Applicable on " << handle.GetDeviceName() << " Architecture";
}
miopenBatchNormalizationForwardInference(&handle,
bn_config.mode,
&bn_infer_test_data.alpha,
Expand All @@ -69,7 +74,9 @@ struct BNInferTest : public ::testing::TestWithParam<std::tuple<BNTestCase, miop
void TearDown() override
{
if(test_skipped)
{
return;
}
auto&& handle = get_handle();
bn_infer_test_data.output.data = handle.Read<YDataType>(
bn_infer_test_data.out_dev, bn_infer_test_data.output.data.size());
Expand All @@ -96,11 +103,15 @@ struct BNBwdTest : public ::testing::TestWithParam<std::tuple<BNTestCase, miopen
protected:
void SetUp() override
{
test_skipped = false;
std::tie(bn_config, tensor_layout) = GetParam();
bn_bwd_test_data.SetUpImpl(bn_config, tensor_layout);

auto&& handle = get_handle();
if(!miopen::solver::ck_utility::is_ck_whitelist(handle.GetStream()))
{
test_skipped = true;
GTEST_SKIP() << "Not Applicable on " << handle.GetDeviceName() << " Architecture";
}
miopenBatchNormalizationBackward(&handle,
bn_config.mode,
&bn_bwd_test_data.alphaDataDiff,
Expand Down Expand Up @@ -129,7 +140,9 @@ struct BNBwdTest : public ::testing::TestWithParam<std::tuple<BNTestCase, miopen
void TearDown() override
{
if(test_skipped)
{
return;
}
auto&& handle = get_handle();
bn_bwd_test_data.output.data =
handle.Read<DyDataType>(bn_bwd_test_data.out_dev, bn_bwd_test_data.output.data.size());
Expand Down Expand Up @@ -177,11 +190,15 @@ struct BNFwdTrainTest
protected:
void SetUp() override
{
test_skipped = false;
std::tie(bn_config, tensor_layout) = GetParam();
bn_fwd_train_test_data.SetUpImpl(bn_config, tensor_layout);

auto&& handle = get_handle();
if(!miopen::solver::ck_utility::is_ck_whitelist(handle.GetStream()))
{
test_skipped = true;
GTEST_SKIP() << "Not Applicable on " << handle.GetDeviceName() << " Architecture";
}
miopenBatchNormalizationForwardTraining(&handle,
bn_config.mode,
&bn_fwd_train_test_data.alpha,
Expand Down Expand Up @@ -214,7 +231,9 @@ struct BNFwdTrainTest
void TearDown() override
{
if(test_skipped)
{
return;
}
auto&& handle = get_handle();
bn_fwd_train_test_data.output.data = handle.Read<YDataType>(
bn_fwd_train_test_data.out_dev, bn_fwd_train_test_data.output.data.size());
Expand Down

0 comments on commit 23ac5ab

Please sign in to comment.