diff --git a/driver/driver.hpp b/driver/driver.hpp index b29517bc0e..566809503a 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -149,7 +149,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) printf("Usage: ./driver *base_arg* *other_args*\n"); printf("Supported Base Arguments: conv[fp16|int8|bfp16|fp8|bfp8], CBAInfer[fp16], " "pool[fp16], lrn[fp16], " - "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm, ctc, dropout[fp16], " + "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], " "tensorop[fp16], reduce[fp16|fp64], layernorm[bfp16|fp16], sum[bfp16|fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -169,7 +169,7 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "pool" && arg != "poolfp16" && arg != "lrn" && arg != "lrnfp16" && arg != "activ" && arg != "activfp16" && arg != "softmax" && arg != "softmaxfp16" && arg != "bnorm" && arg != "bnormfp16" && arg != "rnn" && arg != "rnnfp16" && arg != "rnn_seq" && - arg != "rnn_seqfp16" && arg != "gemm" /*&& arg != "gemmfp16"*/ && arg != "ctc" && + arg != "rnn_seqfp16" && arg != "gemm" && arg != "gemmfp16" && arg != "ctc" && arg != "dropout" && arg != "dropoutfp16" && arg != "tensorop" && arg != "tensoropfp16" && arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "layernorm" && arg != "layernormfp16" && arg != "layernormbfp16" && arg != "sum" && arg != "sumfp16" && diff --git a/driver/gemm_driver.hpp b/driver/gemm_driver.hpp index 9b5e358074..f464d89270 100644 --- a/driver/gemm_driver.hpp +++ b/driver/gemm_driver.hpp @@ -182,7 +182,19 @@ int GemmDriver::ParseCmdLineArgs(int argc, char* argv[]) template int GemmDriver::GetandSetData() { - gemm_desc.dataType = data_type; + if constexpr(std::is_same_v) + { + gemm_desc.dataType = miopenFloat; + } + else if constexpr(std::is_same_v) + { + gemm_desc.dataType = miopenHalf; + } + else + { + static_assert(!"unsupported type"); + } + gemm_desc.a_cast_type = data_type; gemm_desc.b_cast_type = data_type; diff --git a/driver/main.cpp b/driver/main.cpp index 1e749efb77..9c629dd11e 100644 --- a/driver/main.cpp +++ b/driver/main.cpp @@ -136,11 +136,10 @@ int main(int argc, char* argv[]) { drv = new GemmDriver(); } -// TODO half is not supported in gemm -// else if(base_arg == "gemmfp16") -// { -// drv = new GemmDriver(); -// } + else if(base_arg == "gemmfp16") + { + drv = new GemmDriver(); + } #endif else if(base_arg == "bnorm") { diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 52fd341f1b..aec30828eb 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -205,11 +205,13 @@ if(MIOPEN_TEST_HALF) set(MIOPENDRIVER_MODE_CONV convfp16) set(MIOPENDRIVER_MODE_BN bnormfp16) set(MIOPENDRIVER_MODE_POOL poolfp16) + set(MIOPENDRIVER_MODE_GEMM gemmfp16) elseif(MIOPEN_TEST_INT8) set(MIOPEN_TEST_FLOAT_ARG --int8) set(MIOPENDRIVER_MODE_CONV convint8) set(MIOPENDRIVER_MODE_BN NOT_SUPPORTED) set(MIOPENDRIVER_MODE_POOL NOT_SUPPORTED) + set(MIOPENDRIVER_MODE_GEMM NOT_SUPPORTED) set(MIOPEN_TEST_CONV_INT8_OUTPUT_TYPE_INT8 --output_type int8) set(MIOPEN_TEST_CONV_INT8_OUTPUT_TYPE_INT32 --output_type int32) set(MIOPEN_TEST_CONV_INT8_OUTPUT_TYPE_FLOAT --output_type float) @@ -218,12 +220,14 @@ elseif(MIOPEN_TEST_BFLOAT16) set(MIOPENDRIVER_MODE_CONV convbfp16) set(MIOPENDRIVER_MODE_BN NOT_SUPPORTED) set(MIOPENDRIVER_MODE_POOL NOT_SUPPORTED) + set(MIOPENDRIVER_MODE_GEMM NOT_SUPPORTED) else() set(MIOPEN_TEST_FLOAT_ARG --float) set(MIOPEN_TEST_FLOAT TRUE) set(MIOPENDRIVER_MODE_CONV conv) set(MIOPENDRIVER_MODE_BN bnorm) set(MIOPENDRIVER_MODE_POOL pool) + set(MIOPENDRIVER_MODE_GEMM gemm) endif() message(STATUS "MIOPEN_TEST_FLOAT ${MIOPEN_TEST_FLOAT}") @@ -750,6 +754,10 @@ if(${MIOPEN_TEST_WITH_MIOPENDRIVER}) --pad_d 0 -p 0 -q 0 --conv_stride_d 1 -u 1 -v 1 --dilation_d 1 -l 1 -j 1 --spatial_dim 3 -m conv -g 1 -F 1 -i 1 -t 1 -w 1 ) + + add_custom_test(smoke_miopendriver_gemm GFX94X_ENABLED GFX103X_ENABLED GFX110X_ENABLED HALF_ENABLED + COMMAND $ ${MIOPENDRIVER_MODE_GEMM} -m 256 -n 512 -k 1024 -i 1 -V 1 + ) endif() set(IMPLICITGEMM_MLIR_ENV_BASE MIOPEN_FIND_MODE=normal)