Skip to content

Commit

Permalink
Impl Kthvalue operation (#3204)
Browse files Browse the repository at this point in the history
  • Loading branch information
BuiChiTrung authored Oct 1, 2024
1 parent 2d69aeb commit f0384a5
Show file tree
Hide file tree
Showing 24 changed files with 1,862 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ The MIOpen API library is structured as follows:
* :doc:`ReduceCalculation <../doxygen/html/group__ReduceCalculation>` (experimental)
* :doc:`RotaryPositionalEmbeddings <../doxygen/html/group__RotaryPositionalEmbeddings>` (experimental)
* :doc:`ReLU <../doxygen/html/group___re_l_u>` (experimental)
* :doc:`Kthvalue <../doxygen/html/group__kthvalue>` (experimental)
* :doc:`GLU <../doxygen/html/group__glu>` (experimental)
1 change: 1 addition & 0 deletions driver/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ add_executable(MIOpenDriver
dm_getitem.cpp
dm_glu.cpp
dm_groupnorm.cpp
dm_kthvalue.cpp
dm_layernorm.cpp
dm_lrn.cpp
dm_pool.cpp
Expand Down
41 changes: 41 additions & 0 deletions driver/dm_kthvalue.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*******************************************************************************
*
* MIT License
*
* Copyright (c) 2024 Advanced Micro Devices, Inc.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in all
* copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*
*******************************************************************************/

#include "registry_driver_maker.hpp"
#include "kthvalue_driver.hpp"

static Driver* makeDriver(const std::string& base_arg)
{
if(base_arg == "kthvalue")
return new KthvalueDriver<float>();
else if(base_arg == "kthvaluefp16")
return new KthvalueDriver<float16>();
else if(base_arg == "kthvaluebfp16")
return new KthvalueDriver<bfloat16>();
return nullptr;
}

REGISTER_DRIVER_MAKER(makeDriver);
5 changes: 3 additions & 2 deletions driver/driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz)
"t5layernorm[bfp16|fp16], adam[fp16], ampadam, reduceextreme[bfp16|fp16], "
"adamw[fp16], ampadamw, transformersadamw[fp16], transformersampadamw, "
"getitem[bfp16|fp16], reducecalculation[bfp16|fp16], rope[bfp16|fp16], "
"prelu[bfp16|fp16], glu[bfp16|fp16]\n");
"prelu[bfp16|fp16], kthvalue[bfp16|fp16], glu[bfp16|fp16]\n");
exit(0); // NOLINT (concurrency-mt-unsafe)
}

Expand Down Expand Up @@ -209,7 +209,8 @@ inline std::string ParseBaseArg(int argc, char* argv[])
arg != "getitemfp16" && arg != "getitembfp16" && arg != "reducecalculation" &&
arg != "reducecalculationfp16" && arg != "reducecalculationbfp16" && arg != "rope" &&
arg != "ropefp16" && arg != "ropebfp16" && arg != "prelu" && arg != "prelufp16" &&
arg != "prelubfp16" && arg != "glu" && arg != "glufp16" && arg != "glubfp16" &&
arg != "prelubfp16" && arg != "kthvalue" && arg != "kthvaluefp16" &&
arg != "kthvaluebfp16" && arg != "glu" && arg != "glufp16" && arg != "glubfp16" &&
arg != "--version")
{
printf("FAILED: Invalid Base Input Argument\n");
Expand Down
Loading

0 comments on commit f0384a5

Please sign in to comment.