Skip to content

Commit

Permalink
Merge pull request ROCm#396 from cgmillette/fix-test-params
Browse files Browse the repository at this point in the history
Fix rocWMMA preprocessor symbol usage
  • Loading branch information
cgmillette authored May 8, 2024
2 parents f358adf + 1b4b5ad commit 89b6a31
Show file tree
Hide file tree
Showing 23 changed files with 173 additions and 219 deletions.
4 changes: 2 additions & 2 deletions test/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
}
#endif

#ifdef ROCWMMA_BENCHMARK_TESTS
#if ROCWMMA_BENCHMARK_TESTS
#ifndef CHECK_RSMI_ERROR
#define CHECK_RSMI_ERROR(expression, smiErrorFlag) \
if(auto status = (expression); status != RSMI_STATUS_SUCCESS) \
Expand All @@ -68,7 +68,7 @@
smiErrorFlag = true; \
}
#endif
#endif
#endif // ROCWMMA_BENCHMARK_TESTS

namespace rocwmma
{
Expand Down
30 changes: 15 additions & 15 deletions test/dlrm/dlrm_kernel_base_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@

// Library includes

#ifdef ROCWMMA_VALIDATION_TESTS
#if ROCWMMA_VALIDATION_TESTS
#include "reference.hpp" // Vanilla CPU kernel
#endif // ROCWMMA_VALIDATION_TESTS

Expand Down Expand Up @@ -156,11 +156,11 @@ namespace rocwmma
mM = mK = mB = 0;
mMPadded = mKPadded = 0;
mRepeats =
#ifdef ROCWMMA_VALIDATION_TESTS
#if ROCWMMA_VALIDATION_TESTS
1;
#else
5;
#endif
#endif // ROCWMMA_VALIDATION_TESTS

mRunFlag = true;

Expand All @@ -187,10 +187,10 @@ namespace rocwmma
<< "DataT, "
<< "Direction, "
<< "MatM, MatK, MatB, "
#if defined(ROCWMMA_VALIDATION_TESTS)
#if ROCWMMA_VALIDATION_TESTS
<< "maxRelativeDiff, "
<< "tolerance, "
#endif
#endif // ROCWMMA_VALIDATION_TESTS
<< "elapsedMs, "
<< "Problem Size(GFlops), "
<< "TFlops/s, "
Expand All @@ -206,9 +206,9 @@ namespace rocwmma
<< (passDirection == DlrmDirection_t::Forward ? "Forwards" : "Backwards")
<< ", " << mM << ", " << mK << ", " << mB << ", "

#if defined(ROCWMMA_VALIDATION_TESTS)
#if ROCWMMA_VALIDATION_TESTS
<< "n/a, "
#endif
#endif // ROCWMMA_VALIDATION_TESTS
<< "n/a, n/a, n/a, n/a, SKIPPED" << std::endl;
}
else
Expand All @@ -217,12 +217,12 @@ namespace rocwmma
<< (passDirection == DlrmDirection_t::Forward ? "Forwards" : "Backwards")
<< ", " << mM << ", " << mK << ", " << mB << ", "

#if defined(ROCWMMA_VALIDATION_TESTS)
#if ROCWMMA_VALIDATION_TESTS
<< mMaxRelativeError << ", "
#endif
#endif // ROCWMMA_VALIDATION_TESTS
<< mElapsedTimeMs << ", " << mTotalGFlops << ", " << mMeasuredTFlopsPerSec
<< ", " << mEfficiency << ", "
#if defined(ROCWMMA_VALIDATION_TESTS)
#if ROCWMMA_VALIDATION_TESTS
<< (mValidationResult ? "PASSED" : "FAILED")
#else
<< "BENCH"
Expand Down Expand Up @@ -275,7 +275,7 @@ namespace rocwmma
{
MatrixUtil<row_major>::fillLaunchKernel(
dataInstance->deviceInput().get(), mM, mK, mB);
#if defined(ROCWMMA_VALIDATION_TESTS)
#if ROCWMMA_VALIDATION_TESTS
dataInstance->copyDeviceToHostFwdInput();
#endif // ROCWMMA_VALIDATION_TESTS
}
Expand All @@ -286,7 +286,7 @@ namespace rocwmma
dataInstance->deviceInput().get(), mM, mK, mB);
MatrixUtil<row_major>::fillLaunchKernel(
dataInstance->deviceUpstreamGrad().get(), 1, gradSize, mB);
#if defined(ROCWMMA_VALIDATION_TESTS)
#if ROCWMMA_VALIDATION_TESTS
dataInstance->copyDeviceToHostBwdInput();
#endif // ROCWMMA_VALIDATION_TESTS
}
Expand Down Expand Up @@ -418,7 +418,7 @@ namespace rocwmma
CHECK_HIP_ERROR(hipEventDestroy(startEvent));
CHECK_HIP_ERROR(hipEventDestroy(stopEvent));

#if defined(ROCWMMA_VALIDATION_TESTS)
#if ROCWMMA_VALIDATION_TESTS

// Run reference CPU kernel
std::function<void()> cpuKernel;
Expand Down Expand Up @@ -447,14 +447,14 @@ namespace rocwmma
};
}
cpuKernel();
#endif
#endif // ROCWMMA_VALIDATION_TESTS
}
}

template <uint32_t TileSize, typename DataT>
void DlrmKernelBase<TileSize, DataT>::validateResults()
{
#ifdef ROCWMMA_VALIDATION_TESTS
#if ROCWMMA_VALIDATION_TESTS
if(mRunFlag)
{
auto& dataInstance = DataStorage::instance();
Expand Down
41 changes: 20 additions & 21 deletions test/gemm/gemm_common_test_params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace rocwmma

// Native int8
using TestTypesI8 = std::tuple<
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
std::tuple<int8_t, int8_t, int32_t>,
#endif // ROCWMMA_EXTENDED_TESTS
std::tuple<int8_t, int32_t, int32_t>>;
Expand All @@ -62,15 +62,15 @@ namespace rocwmma

// Non-native bfloat16_t
using TestTypesBF16 = std::tuple<
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
std::tuple<bfloat16_t, bfloat16_t, bfloat16_t>,
std::tuple<bfloat16_t, bfloat16_t, float32_t>,
#endif // ROCWMMA_EXTENDED_TESTS
std::tuple<bfloat16_t, float32_t, float32_t>>;

// Native f16
using TestTypesF16 = std::tuple<
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
std::tuple<float16_t, float16_t, float16_t>,
std::tuple<float16_t, float16_t, float32_t>,
#endif // ROCWMMA_EXTENDED_TESTS
Expand All @@ -79,7 +79,7 @@ namespace rocwmma
#if !ROCWMMA_TESTS_NO_HALF
// Non-native hfloat16_t (i.e. __half)
using TestTypesH16 = std::tuple<
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
std::tuple<hfloat16_t, hfloat16_t, hfloat16_t>,
std::tuple<hfloat16_t, hfloat16_t, float32_t>,
#endif // ROCWMMA_EXTENDED_TESTS
Expand Down Expand Up @@ -140,28 +140,28 @@ namespace rocwmma
///

using TestLayoutsNN =
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
typename CombineOne<std::tuple<col_major, col_major>, TestDataLayouts>::Result;
#else
std::tuple<std::tuple<col_major, col_major, col_major>>;
#endif // ROCWMMA_EXTENDED_TESTS

using TestLayoutsNT =
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
typename CombineOne<std::tuple<col_major, row_major>, TestDataLayouts>::Result;
#else
std::tuple<std::tuple<col_major, row_major, col_major>>;
#endif // ROCWMMA_EXTENDED_TESTS

using TestLayoutsTN =
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
typename CombineOne<std::tuple<row_major, col_major>, TestDataLayouts>::Result;
#else
std::tuple<std::tuple<row_major, col_major, col_major>>;
#endif // ROCWMMA_EXTENDED_TESTS

using TestLayoutsTT =
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
typename CombineOne<std::tuple<row_major, row_major>, TestDataLayouts>::Result;
#else
std::tuple<std::tuple<row_major, row_major, col_major>>;
Expand All @@ -177,7 +177,7 @@ namespace rocwmma

// Aggregate combinations BlockK <= 32
using TestBlockSizes16x16SmallBlockK = std::tuple<std::tuple<I<16>, I<16>, I<16>>
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
,
std::tuple<I<16>, I<16>, I<32>>
#endif // ROCWMMA_EXTENDED_TESTS
Expand All @@ -186,7 +186,7 @@ namespace rocwmma
// Aggregate combinations BlockK <= 64
using TestBlockSizes16x16MediumBlockK = std::tuple<std::tuple<I<16>, I<16>, I<16>>,
std::tuple<I<16>, I<16>, I<32>>
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
,
std::tuple<I<16>, I<16>, I<64>>
#endif // ROCWMMA_EXTENDED_TESTS
Expand All @@ -196,7 +196,7 @@ namespace rocwmma
using TestBlockSizes16x16LargeBlockK = std::tuple<std::tuple<I<16>, I<16>, I<16>>,
std::tuple<I<16>, I<16>, I<32>>,
std::tuple<I<16>, I<16>, I<64>>
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
,
std::tuple<I<16>, I<16>, I<128>>
#endif // ROCWMMA_EXTENDED_TESTS
Expand All @@ -206,7 +206,7 @@ namespace rocwmma
using TestBlockSizes16x16HugeBlockK = std::tuple<std::tuple<I<16>, I<16>, I<16>>,
std::tuple<I<16>, I<16>, I<32>>,
std::tuple<I<16>, I<16>, I<64>>
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
,
std::tuple<I<16>, I<16>, I<128>>,
std::tuple<I<16>, I<16>, I<256>>
Expand All @@ -217,7 +217,7 @@ namespace rocwmma

// Aggregate combinations BlockK <= 16
using TestBlockSizes32x32SmallBlockK = std::tuple<std::tuple<I<32>, I<32>, I<8>>
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
,
std::tuple<I<32>, I<32>, I<16>>
#endif // ROCWMMA_EXTENDED_TESTS
Expand All @@ -226,7 +226,7 @@ namespace rocwmma
// Aggregate combinations BlockK <= 32
using TestBlockSizes32x32MediumBlockK = std::tuple<std::tuple<I<32>, I<32>, I<8>>,
std::tuple<I<32>, I<32>, I<16>>
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
,
std::tuple<I<32>, I<32>, I<32>>
#endif // ROCWMMA_EXTENDED_TESTS
Expand All @@ -236,7 +236,7 @@ namespace rocwmma
using TestBlockSizes32x32LargeBlockK = std::tuple<std::tuple<I<32>, I<32>, I<8>>,
std::tuple<I<32>, I<32>, I<16>>,
std::tuple<I<32>, I<32>, I<32>>
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
,
std::tuple<I<32>, I<32>, I<64>>
#endif // ROCWMMA_EXTENDED_TESTS
Expand All @@ -246,7 +246,7 @@ namespace rocwmma
using TestBlockSizes32x32HugeBlockK = std::tuple<std::tuple<I<32>, I<32>, I<8>>,
std::tuple<I<32>, I<32>, I<16>>,
std::tuple<I<32>, I<32>, I<32>>
#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
,
std::tuple<I<32>, I<32>, I<64>>,
std::tuple<I<32>, I<32>, I<128>>
Expand Down Expand Up @@ -298,10 +298,10 @@ namespace rocwmma
{
// clang-format off
// Don't benchmark wg less than 4 waves by default
#if defined(ROCWMMA_VALIDATION_TESTS) || defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_VALIDATION_TESTS || ROCWMMA_EXTENDED_TESTS
{warpSize, 1}, // 1 wave
{warpSize, 2}, {warpSize * 2, 1}, // 2 wave
#endif // ROCWMMA_VALIDATION_TESTS
#endif // ROCWMMA_VALIDATION_TESTS || ROCWMMA_EXTENDED_TESTS
{warpSize, 4}, {warpSize * 2, 2}, // 4 wave
{warpSize * 4, 1} // 4 wave
// clang-format on
Expand All @@ -310,7 +310,6 @@ namespace rocwmma

static inline std::vector<ProblemSizeT> problemSizes()
{

return
{
// clang-format off
Expand All @@ -323,15 +322,15 @@ namespace rocwmma
{512, 512, 512},
// Skip validation on larger sizes
// due to very slow.
#if !defined(ROCWMMA_VALIDATION_TESTS)
#if !ROCWMMA_VALIDATION_TESTS
{1024, 1024, 1024},
{2048, 2048, 2048},
{2560, 2560, 2560},
{3072, 3072, 3072},
{3584, 3584, 3584},
{4096, 4096, 4096},
{5120, 5120, 5120},
#ifdef ROCWMMA_EXTENDED_TESTS
#if ROCWMMA_EXTENDED_TESTS
{6144, 6144, 6144},
{7168, 7168, 7168},
{8192, 8192, 8192},
Expand Down
2 changes: 1 addition & 1 deletion test/gemm/gemm_kernel_base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ namespace rocwmma
ROCWMMA_INSTANTIATE_GEMM_KERNEL_BASE(xfloat32_t, float32_t, float32_t);
ROCWMMA_INSTANTIATE_GEMM_KERNEL_BASE(float64_t, float64_t, float64_t);

#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
ROCWMMA_INSTANTIATE_GEMM_KERNEL_BASE(int8_t, int8_t, int32_t);
ROCWMMA_INSTANTIATE_GEMM_KERNEL_BASE(bfloat16_t, bfloat16_t, bfloat16_t);
ROCWMMA_INSTANTIATE_GEMM_KERNEL_BASE(bfloat16_t, bfloat16_t, float32_t);
Expand Down
4 changes: 2 additions & 2 deletions test/gemm/gemm_kernel_base_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

#if ROCWMMA_ROCBLAS_INTEGRATION
#include "rocblas_reference.hpp" // rocBLAS GPU kernel
#endif // ROCWMMA_VALIDATE_WITH_ROCBLAS || ROCWMMA_BENCHMARK_WITH_ROCBLAS
#endif // ROCWMMA_ROCBLAS_INTEGRATION

namespace rocwmma
{
Expand Down Expand Up @@ -496,7 +496,7 @@ namespace rocwmma

if(!mRunFlag)
{
stream << "n/a, "
stream << "n/a"
<< ", "
<< "n/a"
<< ", "
Expand Down
2 changes: 1 addition & 1 deletion test/gemm/gemm_resource.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace rocwmma
template struct GemmResource<xfloat32_t, float32_t>;
template struct GemmResource<float64_t, float64_t>;

#if defined(ROCWMMA_EXTENDED_TESTS)
#if ROCWMMA_EXTENDED_TESTS
template struct GemmResource<int8_t, int8_t>;
template struct GemmResource<bfloat16_t, bfloat16_t>;
template struct GemmResource<float16_t, float16_t>;
Expand Down
8 changes: 4 additions & 4 deletions test/hip_device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ namespace rocwmma
mMaxFreqMhz = static_cast<int>(static_cast<double>(mProps.clockRate) / 1000.0);
mCurFreqMhz = mMaxFreqMhz;

#ifdef ROCWMMA_BENCHMARK_TESTS
#if ROCWMMA_BENCHMARK_TESTS
bool smiErrorFlag = false;
CHECK_RSMI_ERROR(rsmi_init(0), smiErrorFlag);
if(!smiErrorFlag)
Expand Down Expand Up @@ -134,7 +134,7 @@ namespace rocwmma
}
}
}
#endif
#endif // ROCWMMA_BENCHMARK_TESTS
}

hipDevice_t HipDevice::getDeviceHandle() const
Expand Down Expand Up @@ -184,10 +184,10 @@ namespace rocwmma

HipDevice::~HipDevice()
{
#ifdef ROCWMMA_BENCHMARK_TESTS
#if ROCWMMA_BENCHMARK_TESTS
bool smiErrorFlag = false;
CHECK_RSMI_ERROR(rsmi_shut_down(), smiErrorFlag);
#endif
#endif // ROCWMMA_BENCHMARK_TESTS
}

// Need to check the host device target support statically before hip modules attempt
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,14 @@ namespace rocwmma

static inline std::vector<Base::Param2T> param2s()
{
return {
0.0,
1.0,
2.0,
3.0 // 1 - 4 waves
#ifdef ROCWMMA_EXTENDED_TESTS
,
4.0,
5.0,
6.0,
7.0 // 8 waves
return
{
0.0, 1.0, 2.0,
3.0 // 1 - 4 waves
#if ROCWMMA_EXTENDED_TESTS
,
4.0, 5.0, 6.0,
7.0 // 8 waves
#endif // ROCWMMA_EXTENDED_TESTS
};
}
Expand Down
Loading

0 comments on commit 89b6a31

Please sign in to comment.