diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 2ade4839b1..c2b74eabee 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -37,4 +37,5 @@ The MIOpen API library is structured as follows: * :doc:`ReduceCalculation <../doxygen/html/group__ReduceCalculation>` (experimental) * :doc:`RotaryPositionalEmbeddings <../doxygen/html/group__RotaryPositionalEmbeddings>` (experimental) * :doc:`ReLU <../doxygen/html/group___re_l_u>` (experimental) + * :doc:`Kthvalue <../doxygen/html/group__kthvalue>` (experimental) * :doc:`GLU <../doxygen/html/group__glu>` (experimental) diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 256901aa94..0e06c8fc4a 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -46,6 +46,7 @@ add_executable(MIOpenDriver dm_getitem.cpp dm_glu.cpp dm_groupnorm.cpp + dm_kthvalue.cpp dm_layernorm.cpp dm_lrn.cpp dm_pool.cpp diff --git a/driver/dm_kthvalue.cpp b/driver/dm_kthvalue.cpp new file mode 100644 index 0000000000..71adfdc245 --- /dev/null +++ b/driver/dm_kthvalue.cpp @@ -0,0 +1,41 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "registry_driver_maker.hpp" +#include "kthvalue_driver.hpp" + +static Driver* makeDriver(const std::string& base_arg) +{ + if(base_arg == "kthvalue") + return new KthvalueDriver(); + else if(base_arg == "kthvaluefp16") + return new KthvalueDriver(); + else if(base_arg == "kthvaluebfp16") + return new KthvalueDriver(); + return nullptr; +} + +REGISTER_DRIVER_MAKER(makeDriver); diff --git a/driver/driver.hpp b/driver/driver.hpp index aa0b89f10a..c190299191 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -176,7 +176,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) "t5layernorm[bfp16|fp16], adam[fp16], ampadam, reduceextreme[bfp16|fp16], " "adamw[fp16], ampadamw, transformersadamw[fp16], transformersampadamw, " "getitem[bfp16|fp16], reducecalculation[bfp16|fp16], rope[bfp16|fp16], " - "prelu[bfp16|fp16], glu[bfp16|fp16]\n"); + "prelu[bfp16|fp16], kthvalue[bfp16|fp16], glu[bfp16|fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -209,7 +209,8 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "getitemfp16" && arg != "getitembfp16" && arg != "reducecalculation" && arg != "reducecalculationfp16" && arg != "reducecalculationbfp16" && arg != "rope" && arg != "ropefp16" && arg != "ropebfp16" && arg != "prelu" && arg != "prelufp16" && - arg != "prelubfp16" && arg != "glu" && arg != "glufp16" && arg != "glubfp16" && + arg != "prelubfp16" && arg != "kthvalue" && arg != "kthvaluefp16" && + arg != "kthvaluebfp16" && arg != "glu" && arg != "glufp16" && arg != "glubfp16" && arg != "--version") { printf("FAILED: Invalid Base Input Argument\n"); diff --git a/driver/kthvalue_driver.hpp b/driver/kthvalue_driver.hpp new file mode 100644 index 0000000000..75f7e5b535 --- /dev/null +++ b/driver/kthvalue_driver.hpp @@ -0,0 +1,387 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once +#include "InputFlags.hpp" +#include "driver.hpp" +#include "tensor_driver.hpp" +#include "timer.hpp" +#include "random.hpp" + +#include <../test/tensor_holder.hpp> +#include <../test/verify.hpp> + +#include +#include +#include + +#include + +template +void mloKthvalueFwdRunHost(TIO* input, + miopenTensorDescriptor_t pInputDesc, + TIO* outputHost, + miopenTensorDescriptor_t outputDesc, + size_t* indices, + miopenTensorDescriptor_t indicesDesc, + size_t k, + int dim) +{ + auto inputDesc = miopen::deref(pInputDesc); + size_t inputSize = inputDesc.GetElementSize(); + size_t dimSize = inputDesc.GetLengths()[dim]; + size_t dimStride = inputDesc.GetStrides()[dim]; + auto inputTv = miopen::get_inner_expanded_tv<5>(miopen::deref(pInputDesc)); + auto inputTvWithoutDim = miopen::get_tv_without_dim<5>(inputTv, dim); + auto outputTv = miopen::get_inner_expanded_tv<5>(miopen::deref(outputDesc)); + auto indicesTv = miopen::get_inner_expanded_tv<5>(miopen::deref(indicesDesc)); + + size_t numSlice = inputSize / dimSize; + + std::vector elements; + std::vector ids(dimSize); + for(int i = 0; i < dimSize; ++i) + { + ids[i] = i; + } + + for(int slideID = 0; slideID < numSlice; ++slideID) + { + elements.clear(); + tensor_layout_t<4> layout(inputTvWithoutDim, slideID); + auto idx = inputTvWithoutDim.get_tensor_view_idx(layout); + + for(int j = 0; j < dimSize; ++j) + { + elements.push_back(static_cast(input[idx + j * dimStride])); + } + + std::sort(ids.begin(), ids.end(), [=](size_t x, size_t y) -> bool { + return elements[x] < elements[y]; + }); + auto output_layout = tensor_layout_t<5>(outputTv, slideID); + auto indices_layout = tensor_layout_t<5>(indicesTv, slideID); + outputHost[outputTv.get_tensor_view_idx(output_layout)] = + static_cast(elements[ids[k - 1]]); + indices[indicesTv.get_tensor_view_idx(indices_layout)] = ids[k - 1]; + } +} + +template +class KthvalueDriver : public Driver +{ +public: + KthvalueDriver() : Driver() + { + miopenCreateTensorDescriptor(&inputDesc); + miopenCreateTensorDescriptor(&indicesDesc); + miopenCreateTensorDescriptor(&outputDesc); + + data_type = miopen_type{}; + } + + std::vector ComputeStrides(std::vector input); + int AddCmdLineArgs() override; + int ParseCmdLineArgs(int argc, char* argv[]) override; + InputFlags& GetInputFlags() override { return inflags; } + + int GetandSetData() override; + + int AllocateBuffersAndCopy() override; + + int RunForwardGPU() override; + int RunForwardCPU(); + + int RunBackwardGPU() override; + int RunBackwardCPU(); + + int VerifyBackward() override; + int VerifyForward() override; + ~KthvalueDriver() override + { + miopenDestroyTensorDescriptor(inputDesc); + miopenDestroyTensorDescriptor(indicesDesc); + miopenDestroyTensorDescriptor(outputDesc); + } + +private: + InputFlags inflags; + + miopenTensorDescriptor_t inputDesc; + miopenTensorDescriptor_t indicesDesc; + miopenTensorDescriptor_t outputDesc; + + std::unique_ptr input_dev; + std::unique_ptr indices_dev; + std::unique_ptr output_dev; + + std::vector input; + std::vector indices; + std::vector indicesHost; + std::vector output; + std::vector outputHost; + + bool isContiguous; + int dim; + size_t k; + bool keepDim; +}; + +template +int KthvalueDriver::ParseCmdLineArgs(int argc, char* argv[]) +{ + inflags.Parse(argc, argv); + isContiguous = inflags.GetValueInt("is-contiguous") == 1 ? true : false; + k = inflags.GetValueInt("k"); + dim = inflags.GetValueInt("dim"); + keepDim = inflags.GetValueInt("keep-dim") == 1 ? true : false; + auto inDims = inflags.GetValueTensor("dim-lengths").lengths; + int num_dim = inDims.size(); + if(dim < -num_dim || dim >= num_dim) + { + MIOPEN_THROW(miopenStatusBadParm, "Kthvalue: dim doesn't not exist"); + } + + if(inflags.GetValueInt("time") == 1) + { + miopenEnableProfiling(GetHandle(), true); + } + return miopenStatusSuccess; +} + +template +int KthvalueDriver::GetandSetData() +{ + auto inDims = inflags.GetValueTensor("dim-lengths").lengths; + std::vector inStride = ComputeStrides(inDims); + auto outDims = inflags.GetValueTensor("dim-lengths").lengths; + + if(dim < 0) + { + dim += inDims.size(); + } + if(!keepDim) + { + outDims.erase(outDims.begin() + dim); + if(outDims.empty()) + outDims.push_back(1); + } + else + { + outDims[dim] = 1; + } + + SetTensorNd(inputDesc, inDims, inStride, data_type); + SetTensorNd(outputDesc, outDims, data_type); + SetTensorNd(indicesDesc, outDims, miopenInt64); + + return 0; +} + +// Equivalent to: tensor.tranpose(0, -1).contiguous().tranpose(0, -1) incase contiguous = False +template +std::vector KthvalueDriver::ComputeStrides(std::vector inputDim) +{ + if(!isContiguous) + std::swap(inputDim.front(), inputDim.back()); + std::vector strides(inputDim.size()); + strides.back() = 1; + for(int i = inputDim.size() - 2; i >= 0; --i) + strides[i] = strides[i + 1] * inputDim[i + 1]; + if(!isContiguous) + std::swap(strides.front(), strides.back()); + return strides; +} + +template +int KthvalueDriver::AddCmdLineArgs() +{ + inflags.AddInputFlag("forw", 'F', "1", "Run only Forward (Default=1)", "int"); + inflags.AddTensorFlag( + "dim-lengths", 'D', "256x4x2", "The dimensional lengths of the input tensor"); + inflags.AddInputFlag("k", 'k', "1", "k (Default=1)", "int"); + inflags.AddInputFlag("dim", 'd', "-1", "dim (Default=-1)", "int"); + inflags.AddInputFlag("keep-dim", + 'K', + "0", + "Whether the output tensor has dim retained or not (Default=0)", + "int"); + inflags.AddInputFlag("is-contiguous", 'c', "1", "is-contiguous (Default=1)", "int"); + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); + inflags.AddInputFlag("verify", 'V', "1", "Verify Each Layer (Default=1)", "int"); + inflags.AddInputFlag("time", 't', "0", "Time Each Layer (Default=0)", "int"); + inflags.AddInputFlag( + "wall", 'w', "0", "Wall-clock Time Each Layer, Requires time == 1 (Default=0)", "int"); + + return miopenStatusSuccess; +} + +template +int KthvalueDriver::AllocateBuffersAndCopy() +{ + size_t in_sz = miopen::deref(inputDesc).GetElementSize(); + size_t idx_sz = miopen::deref(indicesDesc).GetElementSize(); + size_t out_sz = miopen::deref(outputDesc).GetElementSize(); + + uint32_t ctx = 0; + + input_dev = std::unique_ptr(new GPUMem(ctx, in_sz, sizeof(TIO))); + indices_dev = std::unique_ptr(new GPUMem(ctx, idx_sz, sizeof(size_t))); + output_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(TIO))); + + input = std::vector(in_sz, static_cast(0)); + indices = std::vector(idx_sz, 0); + indicesHost = std::vector(idx_sz, 0); + output = std::vector(out_sz, static_cast(0)); + outputHost = std::vector(out_sz, static_cast(0)); + + for(int i = 0; i < in_sz; i++) + { + input[i] = prng::gen_A_to_B(static_cast(-10), static_cast(10)); + } + + fill(output.begin(), output.end(), static_cast(0)); + fill(indices.begin(), indices.end(), static_cast(0)); + + if(input_dev->ToGPU(GetStream(), input.data()) != 0) + std::cerr << "Error copying (in) to GPU, size: " << input_dev->GetSize() << std::endl; + + if(indices_dev->ToGPU(GetStream(), indices.data()) != 0) + std::cerr << "Error copying (idx) to GPU, size: " << indices_dev->GetSize() << std::endl; + + if(output_dev->ToGPU(GetStream(), output.data()) != 0) + std::cerr << "Error copying (out) to GPU, size: " << output_dev->GetSize() << std::endl; + + return miopenStatusSuccess; +} + +template +int KthvalueDriver::RunForwardGPU() +{ + float kernel_total_time = 0; + float kernel_first_time = 0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenKthvalueForward(GetHandle(), + inputDesc, + input_dev->GetMem(), + outputDesc, + output_dev->GetMem(), + indicesDesc, + (size_t*)indices_dev->GetMem(), + k, + dim, + keepDim); + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + std::cout << "Wall-clock Time Kthvalue Fwd Elapsed: " << t.gettime_ms() / iter << " ms" + << std::endl; + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + std::cout << "GPU Kernel Time Kthvalue Fwd Elapsed: " << kernel_average_time << " ms" + << std::endl; + } + + if(output_dev->FromGPU(GetStream(), output.data()) != 0) + std::cerr << "Error copying (out_dev) from GPU, size: " << output_dev->GetSize() + << std::endl; + + if(indices_dev->FromGPU(GetStream(), indices.data()) != 0) + std::cerr << "Error copying (indices_dev) from GPU, size: " << indices_dev->GetSize() + << std::endl; + + return miopenStatusSuccess; +} + +template +int KthvalueDriver::RunForwardCPU() +{ + mloKthvalueFwdRunHost(input.data(), + inputDesc, + outputHost.data(), + outputDesc, + indicesHost.data(), + indicesDesc, + k, + dim); + + return miopenStatusSuccess; +} + +template +int KthvalueDriver::RunBackwardGPU() +{ + return miopenStatusSuccess; +} + +template +int KthvalueDriver::RunBackwardCPU() +{ + return miopenStatusSuccess; +} + +template +int KthvalueDriver::VerifyForward() +{ + RunForwardCPU(); + + double tolerance = std::numeric_limits::epsilon() * 10; + auto errorOutput = miopen::rms_range(outputHost, output); + + if(!std::isfinite(errorOutput) || errorOutput > tolerance) + { + std::cout << "Forward Kthvalue output FAILED: " << errorOutput << " > " << tolerance + << std::endl; + return EC_VerifyFwd; + } + else + { + std::cout << "Forward Kthvalue Verifies OK on CPU reference (" << errorOutput << "< " + << tolerance << ')' << std::endl; + } + + return miopenStatusSuccess; +} + +template +int KthvalueDriver::VerifyBackward() +{ + return miopenStatusSuccess; +} diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 7ed36c72a4..d9e88d1b84 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -7735,6 +7735,40 @@ MIOPEN_EXPORT miopenStatus_t miopenRoPEBackward(miopenHandle_t handle, void* dx); /** @} */ // CLOSEOUT ROPE DOXYGEN GROUP +// kthvalue APIs +/** @addtogroup kthvalue + * + * @{ + */ + +/*! @brief Execute a Kthvalue forward layer + * + * @param handle MIOpen handle (input) + * @param inputDesc Tensor descriptor for input tensor (input) + * @param input Data tensor input (input) + * @param outputDesc Tensor descriptor for output tensor (input) + * @param output Data tensor output (output) + * @param indices Data tensor indices (output) + * @param indicesDesc Tensor descriptor for indices tensor (input) + * @param k The k-th smallest element(input) + * @param dim The dimension to find the kth value along (Default = -1)(input) + * @param keepDim Whether the output tensor has dim retained or not (Default = + * false)(input) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + const void* input, + miopenTensorDescriptor_t outputDesc, + void* output, + miopenTensorDescriptor_t indicesDesc, + size_t* indices, + size_t k, + int32_t dim = -1, + bool keepDim = false); + +/** @} */ +// CLOSEOUT kthvalue DOXYGEN GROUP #endif // MIOPEN_BETA_API #ifdef MIOPEN_BETA_API diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c4ffeede18..870afab0f5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -152,6 +152,8 @@ set( MIOpen_Source getitem/problem_description.cpp kernel_build_params.cpp kernel_warnings.cpp + kthvalue/problem_description.cpp + kthvalue_api.cpp layernorm_api.cpp layernorm/problem_description.cpp load_file.cpp @@ -298,6 +300,7 @@ set( MIOpen_Source solver/glu/forward_glu.cpp solver/groupnorm/forward_groupnorm.cpp solver/getitem/backward_getitem.cpp + solver/kthvalue/forward_kthvalue.cpp solver/layernorm/backward_t5layernorm.cpp solver/layernorm/forward_addlayernorm.cpp solver/layernorm/forward_layernorm.cpp @@ -476,6 +479,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/miopen_type_traits.hpp kernels/miopen_utility.hpp kernels/neuron.inc + kernels/radix.hpp kernels/rocm_version.inc kernels/stride_array.hpp kernels/tensor_view.hpp @@ -520,6 +524,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenGLU.cpp kernels/MIOpenGroupNorm.cpp kernels/MIOpenGetitem.cpp + kernels/MIOpenKthvalue.cpp kernels/MIOpenLayerNorm.cpp kernels/MIOpenLRNBwd.cl kernels/MIOpenLRNFwd.cl @@ -656,6 +661,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN getitem.cpp glu.cpp kernel_cache.cpp + kthvalue.cpp layernorm.cpp lrn.cpp mlo_dir_conv.cpp diff --git a/src/include/miopen/kthvalue.hpp b/src/include/miopen/kthvalue.hpp new file mode 100644 index 0000000000..32cb008e0b --- /dev/null +++ b/src/include/miopen/kthvalue.hpp @@ -0,0 +1,48 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_KTHVALUE_HPP_ +#define MIOPEN_KTHVALUE_HPP_ + +#include + +namespace miopen { + +struct Handle; +struct TensorDescriptor; + +MIOPEN_INTERNALS_EXPORT miopenStatus_t KthvalueForward(Handle& handle, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& outputDesc, + Data_t output, + const TensorDescriptor& indicesDesc, + size_t* indices, + size_t k, + int32_t dim, + bool keepDim); + +} // namespace miopen +#endif // MIOPEN_KTHVALUE_HPP_ diff --git a/src/include/miopen/kthvalue/invoke_params.hpp b/src/include/miopen/kthvalue/invoke_params.hpp new file mode 100644 index 0000000000..701538f9b9 --- /dev/null +++ b/src/include/miopen/kthvalue/invoke_params.hpp @@ -0,0 +1,55 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include + +namespace miopen { + +namespace kthvalue { + +struct FwdInvokeParams : public miopen::InvokeParams +{ + FwdInvokeParams() = default; + + const TensorDescriptor* inputDesc = nullptr; + ConstData_t input = nullptr; + const TensorDescriptor* outputDesc = nullptr; + Data_t output = nullptr; + const TensorDescriptor* indicesDesc = nullptr; + size_t* indices = nullptr; + + size_t k = 1; + int32_t dim = 0; + bool keepDim = false; + std::size_t GetWorkspaceSize() const { return 0; } + Data_t GetWorkspace() const { return nullptr; } +}; + +} // namespace kthvalue + +} // namespace miopen diff --git a/src/include/miopen/kthvalue/problem_description.hpp b/src/include/miopen/kthvalue/problem_description.hpp new file mode 100644 index 0000000000..d597ed0e45 --- /dev/null +++ b/src/include/miopen/kthvalue/problem_description.hpp @@ -0,0 +1,129 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include + +namespace miopen { + +struct NetworkConfig; + +namespace kthvalue { + +struct FwdProblemDescription : ProblemDescriptionBase +{ + FwdProblemDescription(const TensorDescriptor& inputDesc_, + const TensorDescriptor& outputDesc_, + const TensorDescriptor& indicesDesc_, + int32_t dim_, + size_t k_, + bool keepDim_) + : inputDesc(inputDesc_), + outputDesc(outputDesc_), + indicesDesc(indicesDesc_), + dim(dim_), + k(k_), + keepDim(keepDim_) + { + if(k < 1 || k > inputDesc.GetLengths()[dim]) + { + MIOPEN_THROW(miopenStatusBadParm, + "Kthvalue: selected number k out of range for dimension"); + } + if(dim < 0 || dim >= inputDesc.GetNumDims()) + { + MIOPEN_THROW(miopenStatusBadParm, "Kthvalue: dim doesn't not exist"); + } + if(inputDesc.GetType() != outputDesc.GetType()) + { + MIOPEN_THROW(miopenStatusBadParm, "Reduce: Input, output tensor types do not match."); + } + if(!IsRightLength()) + { + MIOPEN_THROW(miopenStatusBadParm, + "Reduce: Input and output tensor dimension lengths do not match."); + } + if(outputDesc.GetLengths() != indicesDesc.GetLengths()) + { + MIOPEN_THROW(miopenStatusBadParm, + "Reduce: Output and indices tensor dimension lengths do not match."); + } + } + + bool IsRightLength() const + { + if(inputDesc.GetLengths().size() == 1) + return true; + + if(keepDim && inputDesc.GetNumDims() != outputDesc.GetNumDims()) + { + return false; + } + if(!keepDim && inputDesc.GetNumDims() != outputDesc.GetNumDims() + 1) + { + return false; + } + + int32_t posOut = 0; + for(int32_t i = 0; i < inputDesc.GetLengths().size(); i++) + { + if(i == dim) + { + if(!keepDim) + continue; + if(outputDesc.GetLengths()[posOut] != 1) + return false; + } + else if(inputDesc.GetLengths()[i] != outputDesc.GetLengths()[posOut]) + { + return false; + } + + posOut++; + } + return true; + } + + const TensorDescriptor& GetInputDesc() const { return inputDesc; } + const TensorDescriptor& GetOutputDesc() const { return outputDesc; } + const TensorDescriptor& GetIndicesDesc() const { return indicesDesc; } + int32_t GetDim() const { return dim; } + size_t GetK() const { return k; } + NetworkConfig MakeNetworkConfig() const override; + +public: + TensorDescriptor inputDesc; + TensorDescriptor outputDesc; + TensorDescriptor indicesDesc; + int32_t dim; + size_t k; + bool keepDim; +}; + +} // namespace kthvalue + +} // namespace miopen diff --git a/src/include/miopen/kthvalue/solvers.hpp b/src/include/miopen/kthvalue/solvers.hpp new file mode 100644 index 0000000000..9c58795730 --- /dev/null +++ b/src/include/miopen/kthvalue/solvers.hpp @@ -0,0 +1,55 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include + +namespace miopen { + +namespace solver { + +namespace kthvalue { + +using KthvalueFwdSolverBase = + NonTunableSolverBase; + +struct KthvalueFwd final : KthvalueFwdSolverBase +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::kthvalue::FwdProblemDescription& problem) const override; + + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::kthvalue::FwdProblemDescription& problem) const override; +}; + +} // namespace kthvalue + +} // namespace solver + +} // namespace miopen diff --git a/src/include/miopen/reduce/problem_description.hpp b/src/include/miopen/reduce/problem_description.hpp index 7e327ba565..294c79777e 100644 --- a/src/include/miopen/reduce/problem_description.hpp +++ b/src/include/miopen/reduce/problem_description.hpp @@ -239,7 +239,7 @@ struct ProblemDescriptionCalculation : ProblemDescriptionBase bool IsValidDim() const { - if((dim < 0) || (dim > xDesc.GetLengths().size())) + if((dim < 0) || (dim >= xDesc.GetLengths().size())) { MIOPEN_THROW( miopenStatusBadParm, diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index ab824faa32..177d2f9c71 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -61,7 +61,8 @@ enum class Primitive Adam, Item, RoPE, - ReLU + ReLU, + Kthvalue, }; struct MIOPEN_INTERNALS_EXPORT Id diff --git a/src/include/miopen/tensor_view_utils.hpp b/src/include/miopen/tensor_view_utils.hpp index 1b095affb7..b92f3b1b2d 100644 --- a/src/include/miopen/tensor_view_utils.hpp +++ b/src/include/miopen/tensor_view_utils.hpp @@ -42,7 +42,12 @@ inline tensor_view_t get_inner_expanded_tv(const TensorDescriptor Desc) tensor_view_t tensor_view{}; for(size_t i = 0; i < N; ++i) { - if(i < dims.size()) + if(dims.empty()) + { + tensor_view.stride[i] = 0; + tensor_view.size[i] = 0; + } + else if(i < dims.size()) { tensor_view.stride[i] = strides[i]; tensor_view.size[i] = dims[i]; @@ -76,6 +81,28 @@ inline void slice_tv(tensor_view_t& tensor_view, int32_t sliceCount, const in } } +template +inline tensor_view_t get_tv_without_dim(const tensor_view_t& origin_tv, int selected_dim) +{ + tensor_view_t res{}; + for(int i = 0; i < N; ++i) + { + if(i == selected_dim) + continue; + if(i < selected_dim) + { + res.size[i] = origin_tv.size[i]; + res.stride[i] = origin_tv.stride[i]; + } + else + { + res.size[i - 1] = origin_tv.size[i]; + res.stride[i - 1] = origin_tv.stride[i]; + } + } + return res; +} + } // namespace miopen #endif // MIOPEN_TENSOR_VIEW_UTIL_HPP_ diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp new file mode 100644 index 0000000000..624308b017 --- /dev/null +++ b/src/kernels/MIOpenKthvalue.cpp @@ -0,0 +1,197 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#include +#endif + +#include "float_types.h" +#include "tensor_view.hpp" +#include "radix.hpp" + +#ifndef IN_OUT_TYPE +#define IN_OUT_TYPE float +#endif + +#ifndef LOCAL_SIZE +#define LOCAL_SIZE 256 +#endif + +template +__device__ void kthvalueFwd(const DTYPE* input, + DTYPE* output, + size_t* indices, + size_t k, + size_t dim_size, + size_t dim_stride, + size_t output_size, + tensor_view_t<4> input_tv, + tensor_view_t<5> output_tv, + tensor_view_t<5> indices_tv) +{ + /* + * Input : {N, C, D, H, W}. Select dim: 2(D) + * Output/indices : {N, C, H, W} or {N, C, 1, H, W} (if keepDim param in miopen.h = True) + * Each lws handle dim_size elements to find the kth value. + * Lws = {256 or 512, 1, 1} + * Gws = {A * B * D * E * lws.x, 1, 1}, + */ + using RADIX_TYPE = typename RadixType::type; + + const int RADIX_BITS = 2; + const int RADIX_SIZE = 1 << RADIX_BITS; + const int RADIX_MASK = RADIX_SIZE - 1; + + size_t lid = threadIdx.x; + size_t gid = blockIdx.x; + if(gid >= output_size) + { + return; + } + + __shared__ size_t lsum[LOCAL_SIZE][RADIX_SIZE]; + __shared__ DTYPE lval; + __shared__ size_t lidx; + size_t counts[RADIX_SIZE]; + RADIX_TYPE desired_mask = 0; + RADIX_TYPE desired = 0; + + tensor_layout_t<4> layout(input_tv, gid); + auto idx = input_tv.get_tensor_view_idx(layout); + + for(int pos = sizeof(RADIX_TYPE) * 8 - RADIX_BITS; pos >= 0; pos -= RADIX_BITS) + { + for(size_t& count : counts) + { + count = 0; + } + + for(size_t i = lid; i < dim_size; i += LOCAL_SIZE) + { + size_t input_idx = idx + i * dim_stride; + RADIX_TYPE val = encode(input[input_idx]); + if((val & desired_mask) == desired) + { + ++counts[GetBitFieldImpl(val, pos)]; + } + } + + for(size_t i = 0; i < RADIX_SIZE; ++i) + { + lsum[lid][i] = counts[i]; + } + __syncthreads(); + for(size_t i = LOCAL_SIZE >> 1; i > 0; i >>= 1) + { + if(lid < i) + { + for(size_t j = 0; j < RADIX_SIZE; ++j) + { + lsum[lid][j] += lsum[lid + i][j]; + } + } + __syncthreads(); + } + for(size_t i = 0; i < RADIX_SIZE; ++i) + { + counts[i] = lsum[0][i]; + } + __syncthreads(); + + bool found = false; + // Process in ascending order + for(size_t j = 0; j < RADIX_SIZE; ++j) + { + if(counts[j] < k) + { + k -= counts[j]; + continue; + } + // Answer is inside this count + if(counts[j] == 1 || pos == 0) + { + // 1. counts[j] == 1 + // We found an unique answer. + // 2. pos == 0 + // There are multiple answers so we return any of them + for(size_t i = lid; i < dim_size; i += LOCAL_SIZE) + { + size_t input_idx = idx + i * dim_stride; + DTYPE val_ori = input[input_idx]; + RADIX_TYPE val = encode(val_ori); + if((val & desired_mask) == desired && + GetBitFieldImpl(val, pos) == j) + { + // For case 2, this will be non-deterministic. + lval = val_ori; + lidx = i; + } + } + found = true; + break; + } + desired = SetBitFieldImpl(desired, j, pos); + desired_mask = SetBitFieldImpl(desired_mask, RADIX_MASK, pos); + break; + } + if(found) + break; + } + + __syncthreads(); + if(lid == 0) + { + auto output_layout = tensor_layout_t<5>(output_tv, gid); + auto indices_layout = tensor_layout_t<5>(indices_tv, gid); + output[output_tv.get_tensor_view_idx(output_layout)] = lval; + indices[indices_tv.get_tensor_view_idx(indices_layout)] = lidx; + } +} + +extern "C" __global__ void KthvalueFwd(const IN_OUT_TYPE* input, + IN_OUT_TYPE* output, + size_t* indices, + size_t k, + size_t dim_size, + size_t dim_stride, + size_t output_size, + tensor_view_t<4> input_tv, + tensor_view_t<5> output_tv, + tensor_view_t<5> indices_tv) +{ + kthvalueFwd(input, + output, + indices, + k, + dim_size, + dim_stride, + output_size, + input_tv, + output_tv, + indices_tv); +} diff --git a/src/kernels/radix.hpp b/src/kernels/radix.hpp new file mode 100644 index 0000000000..f75443e149 --- /dev/null +++ b/src/kernels/radix.hpp @@ -0,0 +1,105 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef GUARD_RADIX_H +#define GUARD_RADIX_H + +#include +#include +#include + +#define DEFINE_RADIX_TYPE(DTYPE, cpp_type) \ + template <> \ + struct RadixType \ + { \ + using type = cpp_type; \ + }; + +template +struct RadixType +{ +}; + +DEFINE_RADIX_TYPE(int32_t, uint32_t) +DEFINE_RADIX_TYPE(int64_t, uint64_t) +DEFINE_RADIX_TYPE(bool, bool) +DEFINE_RADIX_TYPE(float, uint32_t) +DEFINE_RADIX_TYPE(__half, ushort) +DEFINE_RADIX_TYPE(ushort, ushort) // bfloat16 +#undef DEFINE_RADIX_TYPE + +template ::type> +__device__ inline Radix encode(DTYPE v) +{ + // convert negative number to positive representation in Radix type. + if constexpr(std::is_same::value) + { + return v; + } + else if constexpr(std::is_same::value) + { + return static_cast(std::numeric_limits::max()) + v + 1; + } + else if constexpr(std::is_same::value) + { + return static_cast(std::numeric_limits::max()) + v + 1; + } + // bfloat16 is passed as ushort in kernel + else if constexpr(std::is_same::value) + { + Radix x = v; + Radix mask = (x & 0x8000) ? 0xffff : 0x8000; + return isnan(v) ? 0xffff : (x ^ mask); + } + else if constexpr(std::is_same<__half, DTYPE>::value) + { + Radix x = __half_as_ushort(v); + Radix mask = (x & 0x8000) ? 0xffff : 0x8000; + return __hisnan(v) ? 0xffff : (x ^ mask); + } + else if constexpr(std::is_same::value) + { + Radix x = __float_as_uint(v); + Radix mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; + return isnan(v) ? 0xffffffff : (x ^ mask); + } +} + +// returns x[pos+bits:pos] +template +__device__ inline Radix GetBitFieldImpl(Radix x, int pos) +{ + return (x >> pos) & ((1 << bits) - 1); +} + +// x[pos+bits:pos] = a +template +__device__ inline Radix SetBitFieldImpl(Radix x, Radix a, int pos) +{ + return x | (a << pos); +} + +#endif // GUARD_RADIX_H diff --git a/src/kthvalue.cpp b/src/kthvalue.cpp new file mode 100644 index 0000000000..a9f4e73067 --- /dev/null +++ b/src/kthvalue.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +namespace miopen { + +miopenStatus_t KthvalueForward(Handle& handle, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& outputDesc, + Data_t output, + const TensorDescriptor& indicesDesc, + size_t* indices, + size_t k, + int32_t dim, + bool keepDim) +{ + if(dim < 0) + { + dim += inputDesc.GetNumDims(); + } + + const auto problem = + kthvalue::FwdProblemDescription{inputDesc, outputDesc, indicesDesc, dim, k, keepDim}; + + const auto invoke_params = [&]() { + auto tmp = kthvalue::FwdInvokeParams{}; + tmp.inputDesc = &inputDesc; + tmp.outputDesc = &outputDesc; + tmp.indicesDesc = &indicesDesc; + tmp.input = input; + tmp.indices = indices; + tmp.output = output; + tmp.k = k; + tmp.dim = dim; + tmp.keepDim = keepDim; + return tmp; + }(); + + const auto algo = AlgorithmName{"KthvalueFwd"}; + const auto solvers = solver::SolverContainer{}; + + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +} // namespace miopen diff --git a/src/kthvalue/problem_description.cpp b/src/kthvalue/problem_description.cpp new file mode 100644 index 0000000000..9f3cec4f64 --- /dev/null +++ b/src/kthvalue/problem_description.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include + +namespace miopen { + +namespace kthvalue { + +NetworkConfig FwdProblemDescription::MakeNetworkConfig() const +{ + auto input_dtype = inputDesc.GetType(); + auto size = inputDesc.GetElementSize(); + auto dim_size = inputDesc.GetLengths()[dim]; + auto dim_stride = inputDesc.GetStrides()[dim]; + int dim_num = inputDesc.GetNumDims(); + auto output_size = size / dim_size; + + std::ostringstream ss; + + ss << "kthvalue_fwd"; + ss << "i_dtype" << input_dtype; + ss << "dim_size" << dim_size; + ss << "dim_num" << dim_num; + ss << "dim_stride" << dim_stride; + ss << "output_size" << output_size; + + return NetworkConfig{ss.str()}; +} + +} // namespace kthvalue + +} // namespace miopen diff --git a/src/kthvalue_api.cpp b/src/kthvalue_api.cpp new file mode 100644 index 0000000000..03405e845e --- /dev/null +++ b/src/kthvalue_api.cpp @@ -0,0 +1,101 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include + +inline std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + os << '{'; + for(int i = 0; i < v.size(); ++i) + { + if(i != 0) + os << ','; + os << v[i]; + } + os << '}'; + return os; +} + +static void LogCmdKthvalue(const miopenTensorDescriptor_t inputDesc, bool is_fwd) +{ + if(miopen::IsLoggingCmd()) + { + std::stringstream ss; + auto dtype = miopen::deref(inputDesc).GetType(); + if(dtype == miopenHalf) + { + ss << "kthvaluefp16"; + } + else if(dtype == miopenFloat) + { + ss << "kthvaluefp32"; + } + else if(dtype == miopenBFloat16) + { + ss << "kthvaluebfp16"; + } + + MIOPEN_LOG_FUNCTION(inputDesc); + ss << " -n " << miopen::deref(inputDesc).GetLengths()[0]; + ss << " -T " << miopen::deref(inputDesc).GetLengths(); + ss << " -Si " << miopen::deref(inputDesc).GetStrides(); + ss << " -F " << ((is_fwd) ? "1" : "2"); + + MIOPEN_LOG_DRIVER_CMD(ss.str()); + } +} + +extern "C" miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, + miopenTensorDescriptor_t inputDesc, + const void* input, + miopenTensorDescriptor_t outputDesc, + void* output, + miopenTensorDescriptor_t indicesDesc, + size_t* indices, + size_t k, + int32_t dim, + bool keepDim) +{ + MIOPEN_LOG_FUNCTION( + handle, inputDesc, input, outputDesc, output, indicesDesc, indices, k, dim, keepDim); + + LogCmdKthvalue(inputDesc, true); + + return miopen::try_([&] { + miopen::KthvalueForward(miopen::deref(handle), + miopen::deref(inputDesc), + DataCast(input), + miopen::deref(outputDesc), + DataCast(output), + miopen::deref(indicesDesc), + indices, + k, + dim, + keepDim); + }); +} diff --git a/src/solver.cpp b/src/solver.cpp index 1149255363..07723008fc 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -679,6 +680,7 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) Register(registry, ++id, Primitive::RoPE, rope::RoPEBackward{}.SolverDbId()); Register(registry, ++id, Primitive::ReLU, prelu::MultiWeightsBackward{}.SolverDbId()); Register(registry, ++id, Primitive::ReLU, prelu::SingleWeightBackward{}.SolverDbId()); + Register(registry, ++id, Primitive::Kthvalue, kthvalue::KthvalueFwd{}.SolverDbId()); Register(registry, ++id, Primitive::Activation, glu::GLUForward{}.SolverDbId()); Register(registry, ++id, Primitive::Activation, glu::GLUBackward{}.SolverDbId()); diff --git a/src/solver/kthvalue/forward_kthvalue.cpp b/src/solver/kthvalue/forward_kthvalue.cpp new file mode 100644 index 0000000000..2639a41b01 --- /dev/null +++ b/src/solver/kthvalue/forward_kthvalue.cpp @@ -0,0 +1,143 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace miopen { + +namespace solver { + +namespace kthvalue { + +bool IsImprovementOverROCm(const miopen::kthvalue::FwdProblemDescription& problem) +{ + TensorDescriptor inputDesc = problem.GetInputDesc(); + size_t dimSize = inputDesc.GetLengths()[problem.GetDim()]; + size_t dimStride = inputDesc.GetStrides()[problem.GetDim()]; + size_t dimNum = inputDesc.GetLengths().size(); + + return dimNum >= 2 && dimStride == 1 && dimSize >= 300; +} + +bool KthvalueFwd::IsApplicable(const ExecutionContext& /*context*/, + const miopen::kthvalue::FwdProblemDescription& problem) const +{ + if(!IsImprovementOverROCm(problem)) + return false; + if(problem.GetInputDesc().GetNumDims() > 5) + return false; + return true; +} + +ConvSolution KthvalueFwd::GetSolution(const ExecutionContext& context, + const miopen::kthvalue::FwdProblemDescription& problem) const +{ + std::ignore = context; + auto result = ConvSolution{miopenStatusSuccess}; + + auto input_desc = problem.GetInputDesc(); + auto in_dtype = miopen::GetDataType(input_desc.GetType()); + auto dtype = problem.GetOutputDesc().GetType(); + auto size = input_desc.GetElementSize(); + auto dim_size = input_desc.GetLengths()[problem.GetDim()]; + size_t output_size = size / dim_size; + + size_t xlocalsize = 256; + if(dim_size >= 8192) + { + xlocalsize = 512; + } + size_t xgridsize = output_size * xlocalsize; + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + + kernel.kernel_file = "MIOpenKthvalue.cpp"; + kernel.kernel_name = "KthvalueFwd"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"IN_OUT_TYPE", in_dtype == "bfloat16" ? "ushort" : in_dtype}, + {"LOCAL_SIZE", xlocalsize}, + }; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + + result.invoker_factory = [=](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + size_t dim_stride = params.inputDesc->GetStrides()[params.dim]; + + auto input_tv = get_inner_expanded_tv<5>(deref(params.inputDesc)); + auto input_tv_without_selected_dim = get_tv_without_dim<5>(input_tv, params.dim); + + auto output_tv = get_inner_expanded_tv<5>(deref(params.outputDesc)); + auto indices_tv = get_inner_expanded_tv<5>(deref(params.indicesDesc)); + + kernel(params.input, + params.output, + params.indices, + params.k, + dim_size, + dim_stride, + output_size, + input_tv_without_selected_dim, + output_tv, + indices_tv); + }; + }; + + return result; +} + +} // namespace kthvalue + +} // namespace solver + +} // namespace miopen diff --git a/test/cpu_kthvalue.hpp b/test/cpu_kthvalue.hpp new file mode 100644 index 0000000000..e1260a29f1 --- /dev/null +++ b/test/cpu_kthvalue.hpp @@ -0,0 +1,81 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include "tensor_holder.hpp" +#include "tensor_view.hpp" + +#include + +#include + +template +void cpu_kthvalue(tensor input, + tensor& outputHost, + std::vector& indices, + miopen::TensorDescriptor indiceDesc, + size_t k, + int dim) +{ + size_t inputSize = input.desc.GetElementSize(); + size_t dimSize = input.desc.GetLengths()[dim]; + size_t dimStride = input.desc.GetStrides()[dim]; + auto inputTv = miopen::get_inner_expanded_tv<5>(input.desc); + auto inputTvWithoutDim = miopen::get_tv_without_dim<5>(inputTv, dim); + auto outputTv = miopen::get_inner_expanded_tv<5>(outputHost.desc); + auto indicesTv = miopen::get_inner_expanded_tv<5>(indiceDesc); + + size_t numSlice = inputSize / dimSize; + + std::vector elements; + std::vector ids(dimSize); + for(int i = 0; i < dimSize; ++i) + { + ids[i] = i; + } + + for(int slideID = 0; slideID < numSlice; ++slideID) + { + elements.clear(); + tensor_layout_t<4> layout(inputTvWithoutDim, slideID); + auto idx = inputTvWithoutDim.get_tensor_view_idx(layout); + + for(int j = 0; j < dimSize; ++j) + { + elements.push_back(static_cast(input[idx + j * dimStride])); + } + + std::sort(ids.begin(), ids.end(), [=](size_t x, size_t y) -> bool { + return elements[x] < elements[y]; + }); + auto output_layout = tensor_layout_t<5>(outputTv, slideID); + auto indices_layout = tensor_layout_t<5>(indicesTv, slideID); + outputHost[outputTv.get_tensor_view_idx(output_layout)] = + static_cast(elements[ids[k - 1]]); + indices[indicesTv.get_tensor_view_idx(indices_layout)] = ids[k - 1]; + } +} diff --git a/test/gtest/kthvalue.cpp b/test/gtest/kthvalue.cpp new file mode 100644 index 0000000000..0a08a25288 --- /dev/null +++ b/test/gtest/kthvalue.cpp @@ -0,0 +1,112 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "kthvalue.hpp" +#include "tensor_holder.hpp" +#include +#include + +MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_TEST_FLOAT_ARG) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL) + +namespace kthvalue { + +std::string GetFloatArg() +{ + const auto& tmp = env::value(MIOPEN_TEST_FLOAT_ARG); + if(tmp.empty()) + { + return ""; + } + return tmp; +} + +bool CheckFloatArg(std::string arg) +{ + if(!MIOPEN_TEST_ALL || (env::enabled(MIOPEN_TEST_ALL) && GetFloatArg() == arg)) + { + return true; + } + return false; +} + +struct GPU_Kthvalue_fwd_FP32 : KthvalueFwdTest +{ +}; + +struct GPU_Kthvalue_fwd_FP16 : KthvalueFwdTest +{ +}; + +struct GPU_Kthvalue_fwd_BFP16 : KthvalueFwdTest +{ +}; +}; // namespace kthvalue + +using namespace kthvalue; + +TEST_P(GPU_Kthvalue_fwd_FP32, Test) +{ + if(CheckFloatArg("--float")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(GPU_Kthvalue_fwd_FP16, Test) +{ + if(CheckFloatArg("--half")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(GPU_Kthvalue_fwd_BFP16, Test) +{ + if(CheckFloatArg("--bfloat16")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP32, testing::ValuesIn(KthvalueTestConfigs())); +INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP16, testing::ValuesIn(KthvalueTestConfigs())); +INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_BFP16, testing::ValuesIn(KthvalueTestConfigs())); diff --git a/test/gtest/kthvalue.hpp b/test/gtest/kthvalue.hpp new file mode 100644 index 0000000000..2aa7e6fd41 --- /dev/null +++ b/test/gtest/kthvalue.hpp @@ -0,0 +1,194 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "../driver/tensor_driver.hpp" +#include "cpu_kthvalue.hpp" +#include "get_handle.hpp" + +#include "random.hpp" +#include "tensor_holder.hpp" +#include "verify.hpp" + +#include +#include +#include + +#include +struct KthvalueTestCase +{ + std::vector dims; + bool isContiguous; + int32_t dim; + size_t k; + bool keepDim; + friend std::ostream& operator<<(std::ostream& os, const KthvalueTestCase& tc) + { + os << "dims: "; + for(auto dim_size : tc.dims) + { + os << dim_size << " "; + } + return os << "is_contiguous " << tc.isContiguous << " selected_dim " << tc.dim << " k " + << tc.k << " keepDim " << tc.keepDim; + } + + std::vector GetDims() const { return dims; } + + KthvalueTestCase() {} + + KthvalueTestCase(std::vector dims_, + size_t k_, + int32_t dim_ = -1, + bool isContiguous_ = true, + bool keepDim_ = false) + : dims(dims_), isContiguous(isContiguous_), dim(dim_), k(k_), keepDim(keepDim_) + { + } + + std::vector ComputeStrides(std::vector inputDim) const + { + if(!isContiguous) + std::swap(inputDim.front(), inputDim.back()); + std::vector strides(inputDim.size()); + strides.back() = 1; + for(int i = inputDim.size() - 2; i >= 0; --i) + strides[i] = strides[i + 1] * inputDim[i + 1]; + if(!isContiguous) + std::swap(strides.front(), strides.back()); + return strides; + } +}; + +inline std::vector KthvalueTestConfigs() +{ + return { + KthvalueTestCase({100, 500}, 10, 1, true, true), // test keep dim + KthvalueTestCase({100, 500}, 10), // 2D cont + KthvalueTestCase({400, 10}, 10, 0, false), // 2D non-cont + KthvalueTestCase({10, 20, 300}, 1), // 3D cont + KthvalueTestCase({350, 10, 20}, 5, 0, false), // 3D non-cont + KthvalueTestCase({8, 3, 10, 2000}, 2000), // 4D cont + KthvalueTestCase({1000, 3, 10, 15}, 1000, 0, false), // 4D non-cont + KthvalueTestCase({2, 2, 4, 10, 3000}, 120), // 5D cont + KthvalueTestCase({3000, 8, 2, 4, 20}, 9, 0, false), // 5D non-cont + }; +} + +template +struct KthvalueFwdTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + config = GetParam(); + + auto inDims = config.GetDims(); + auto inStrides = config.ComputeStrides(inDims); + if(config.dim < 0) + { + config.dim += inDims.size(); + } + EXPECT_TRUE(config.dim >= 0 and config.dim < inDims.size()); + auto outDims = config.GetDims(); + if(!config.keepDim) + { + outDims.erase(outDims.begin() + config.dim); + if(outDims.empty()) + outDims.push_back(1); + } + else + { + outDims[config.dim] = 1; + } + + auto in_gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(0.1, 200); }; + input = tensor{inDims, inStrides}.generate(in_gen_value); + + output = tensor{outDims}; + std::fill(output.begin(), output.end(), 0); + + outputHost = tensor{outDims}; + std::fill(outputHost.begin(), outputHost.end(), 0); + + // miopenDataType_t doesn't support size_t, I use double instead (both types use 64 bits) + indicesDesc = miopen::TensorDescriptor(miopenDouble, outDims); + size_t outputSize = indicesDesc.GetElementSize(); + indices.resize(outputSize); + indicesHost.resize(outputSize); + + input_dev = handle.Write(input.data); + indices_dev = handle.Write(indices); + output_dev = handle.Write(output.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + + miopenStatus_t status; + + status = miopen::KthvalueForward(handle, + input.desc, + input_dev.get(), + output.desc, + output_dev.get(), + indicesDesc, + (size_t*)indices_dev.get(), + config.k, + config.dim, + config.keepDim); + cpu_kthvalue(input, outputHost, indicesHost, indicesDesc, config.k, config.dim); + + EXPECT_EQ(status, miopenStatusSuccess); + output.data = handle.Read(output_dev, output.data.size()); + indices = handle.Read(indices_dev, indices.size()); + } + + void Verify() + { + double threshold = std::numeric_limits::epsilon(); + + auto error = miopen::rms_range(outputHost, output); + + EXPECT_TRUE(miopen::range_distance(outputHost) == miopen::range_distance(output)); + EXPECT_TRUE(error < threshold * 10) << "Error output beyond tolerance Error: " << error + << ", Thresholdx10: " << threshold * 10; + } + KthvalueTestCase config; + + tensor input; + // tensor holder doesn't support size_t, so I use vector instead + miopen::TensorDescriptor indicesDesc; + std::vector indices; + tensor output; + + tensor outputHost; + std::vector indicesHost; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr indices_dev; + miopen::Allocator::ManageDataPtr output_dev; +};