Skip to content

Commit

Permalink
[MIOpenDriver] Enabled gemmfp16. [tests] Added smoke test for fp16 an…
Browse files Browse the repository at this point in the history
…d fp32 gemm. (#2592)

* fix-gemmfp16(01) [MIOpenDriver] Enable gemmfp16 in the driver

* fix-gemmfp16(02) [tests] Add smoke test for fp16 gemm
  • Loading branch information
atamazov authored Dec 13, 2023
1 parent 91f87d4 commit 53c39c3
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
4 changes: 2 additions & 2 deletions driver/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
printf("Usage: ./driver *base_arg* *other_args*\n");
printf("Supported Base Arguments: conv[fp16|int8|bfp16|fp8|bfp8], CBAInfer[fp16], "
"pool[fp16], lrn[fp16], "
"activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm, ctc, dropout[fp16], "
"activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], "
"tensorop[fp16], reduce[fp16|fp64], layernorm[bfp16|fp16], sum[bfp16|fp16]\n");
exit(0); // NOLINT (concurrency-mt-unsafe)
}
Expand All @@ -169,7 +169,7 @@ inline std::string ParseBaseArg(int argc, char* argv[])
arg != "pool" && arg != "poolfp16" && arg != "lrn" && arg != "lrnfp16" && arg != "activ" &&
arg != "activfp16" && arg != "softmax" && arg != "softmaxfp16" && arg != "bnorm" &&
arg != "bnormfp16" && arg != "rnn" && arg != "rnnfp16" && arg != "rnn_seq" &&
arg != "rnn_seqfp16" && arg != "gemm" /*&& arg != "gemmfp16"*/ && arg != "ctc" &&
arg != "rnn_seqfp16" && arg != "gemm" && arg != "gemmfp16" && arg != "ctc" &&
arg != "dropout" && arg != "dropoutfp16" && arg != "tensorop" && arg != "tensoropfp16" &&
arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "layernorm" &&
arg != "layernormfp16" && arg != "layernormbfp16" && arg != "sum" && arg != "sumfp16" &&
Expand Down
14 changes: 13 additions & 1 deletion driver/gemm_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,19 @@ int GemmDriver<T>::ParseCmdLineArgs(int argc, char* argv[])
template <typename T>
int GemmDriver<T>::GetandSetData()
{
gemm_desc.dataType = data_type;
if constexpr(std::is_same_v<T, float>)
{
gemm_desc.dataType = miopenFloat;
}
else if constexpr(std::is_same_v<T, float16>)
{
gemm_desc.dataType = miopenHalf;
}
else
{
static_assert(!"unsupported type");
}

gemm_desc.a_cast_type = data_type;
gemm_desc.b_cast_type = data_type;

Expand Down
9 changes: 4 additions & 5 deletions driver/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,10 @@ int main(int argc, char* argv[])
{
drv = new GemmDriver<float>();
}
// TODO half is not supported in gemm
// else if(base_arg == "gemmfp16")
// {
// drv = new GemmDriver<float16>();
// }
else if(base_arg == "gemmfp16")
{
drv = new GemmDriver<float16>();
}
#endif
else if(base_arg == "bnorm")
{
Expand Down
8 changes: 8 additions & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,13 @@ if(MIOPEN_TEST_HALF)
set(MIOPENDRIVER_MODE_CONV convfp16)
set(MIOPENDRIVER_MODE_BN bnormfp16)
set(MIOPENDRIVER_MODE_POOL poolfp16)
set(MIOPENDRIVER_MODE_GEMM gemmfp16)
elseif(MIOPEN_TEST_INT8)
set(MIOPEN_TEST_FLOAT_ARG --int8)
set(MIOPENDRIVER_MODE_CONV convint8)
set(MIOPENDRIVER_MODE_BN NOT_SUPPORTED)
set(MIOPENDRIVER_MODE_POOL NOT_SUPPORTED)
set(MIOPENDRIVER_MODE_GEMM NOT_SUPPORTED)
set(MIOPEN_TEST_CONV_INT8_OUTPUT_TYPE_INT8 --output_type int8)
set(MIOPEN_TEST_CONV_INT8_OUTPUT_TYPE_INT32 --output_type int32)
set(MIOPEN_TEST_CONV_INT8_OUTPUT_TYPE_FLOAT --output_type float)
Expand All @@ -218,12 +220,14 @@ elseif(MIOPEN_TEST_BFLOAT16)
set(MIOPENDRIVER_MODE_CONV convbfp16)
set(MIOPENDRIVER_MODE_BN NOT_SUPPORTED)
set(MIOPENDRIVER_MODE_POOL NOT_SUPPORTED)
set(MIOPENDRIVER_MODE_GEMM NOT_SUPPORTED)
else()
set(MIOPEN_TEST_FLOAT_ARG --float)
set(MIOPEN_TEST_FLOAT TRUE)
set(MIOPENDRIVER_MODE_CONV conv)
set(MIOPENDRIVER_MODE_BN bnorm)
set(MIOPENDRIVER_MODE_POOL pool)
set(MIOPENDRIVER_MODE_GEMM gemm)
endif()

message(STATUS "MIOPEN_TEST_FLOAT ${MIOPEN_TEST_FLOAT}")
Expand Down Expand Up @@ -750,6 +754,10 @@ if(${MIOPEN_TEST_WITH_MIOPENDRIVER})
--pad_d 0 -p 0 -q 0 --conv_stride_d 1 -u 1 -v 1 --dilation_d 1 -l 1 -j 1
--spatial_dim 3 -m conv -g 1 -F 1 -i 1 -t 1 -w 1
)

add_custom_test(smoke_miopendriver_gemm GFX94X_ENABLED GFX103X_ENABLED GFX110X_ENABLED HALF_ENABLED
COMMAND $<TARGET_FILE:MIOpenDriver> ${MIOPENDRIVER_MODE_GEMM} -m 256 -n 512 -k 1024 -i 1 -V 1
)
endif()

set(IMPLICITGEMM_MLIR_ENV_BASE MIOPEN_FIND_MODE=normal)
Expand Down

0 comments on commit 53c39c3

Please sign in to comment.