From e154ccaf1868dd420a26449436370ae6f9565598 Mon Sep 17 00:00:00 2001 From: BuiChiTrung Date: Tue, 20 Aug 2024 09:42:51 +0700 Subject: [PATCH 01/28] rebase skeleton code with develop --- include/miopen/miopen.h | 33 +++ src/CMakeLists.txt | 6 + src/include/miopen/kthvalue.hpp | 48 +++++ src/include/miopen/kthvalue/invoke_params.hpp | 67 ++++++ .../miopen/kthvalue/problem_description.hpp | 64 ++++++ src/include/miopen/kthvalue/solvers.hpp | 65 ++++++ src/include/miopen/solver_id.hpp | 3 +- src/include/miopen/tensor_view_utils.hpp | 29 ++- src/kernels/MIOpenKthvalue.cpp | 193 ++++++++++++++++++ src/kernels/radix.hpp | 136 ++++++++++++ src/kthvalue.cpp | 87 ++++++++ src/kthvalue/problem_description.cpp | 54 +++++ src/kthvalue_api.cpp | 112 ++++++++++ src/solver.cpp | 2 + src/solver/kthvalue/forward_kthvalue.cpp | 133 ++++++++++++ 15 files changed, 1030 insertions(+), 2 deletions(-) create mode 100644 src/include/miopen/kthvalue.hpp create mode 100644 src/include/miopen/kthvalue/invoke_params.hpp create mode 100644 src/include/miopen/kthvalue/problem_description.hpp create mode 100644 src/include/miopen/kthvalue/solvers.hpp create mode 100644 src/kernels/MIOpenKthvalue.cpp create mode 100644 src/kernels/radix.hpp create mode 100644 src/kthvalue.cpp create mode 100644 src/kthvalue/problem_description.cpp create mode 100644 src/kthvalue_api.cpp create mode 100644 src/solver/kthvalue/forward_kthvalue.cpp diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 6c66f04867..2c6819181c 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -7675,6 +7675,39 @@ MIOPEN_EXPORT miopenStatus_t miopenRoPEBackward(miopenHandle_t handle, void* dx); /** @} */ // CLOSEOUT ROPE DOXYGEN GROUP +// kthvalue APIs +/** @addtogroup kthvalue + * + * @{ + */ + +/*! @brief Execute a Kthvalue forward layer + * + * @param handle MIOpen handle (input) + * @param workspace Address of the allocated workspace data (input) + * @param workspaceSizeInBytes Size in bytes of the allocated workspace data (input) + * @param inputDesc Tensor descriptor for input tensor (input) + * @param input Data tensor input (input) + * @param outputDesc Tensor descriptor for output tensor (input) + * @param output Data tensor output (output) + * @param indices Data tensor index (output) + * @param k The k-th smallest element(input) + * @param dim The dimension to find the kth value along(input) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, + void* workspace, + size_t workspaceSizeInBytes, + miopenTensorDescriptor_t inputDesc, + const void* input, + miopenTensorDescriptor_t outputDesc, + void* output, + size_t* indices, + size_t k, + int32_t dim); + +/** @} */ +// CLOSEOUT kthvalue DOXYGEN GROUP #endif // MIOPEN_BETA_API #ifdef __cplusplus diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index d2f84f43b8..ef06c9383f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -150,6 +150,8 @@ set( MIOpen_Source getitem/problem_description.cpp kernel_build_params.cpp kernel_warnings.cpp + kthvalue/problem_description.cpp + kthvalue_api.cpp layernorm_api.cpp layernorm/problem_description.cpp load_file.cpp @@ -289,6 +291,7 @@ set( MIOpen_Source solver/conv_winoRxS_fused.cpp solver/groupnorm/forward_groupnorm.cpp solver/getitem/backward_getitem.cpp + solver/kthvalue/forward_kthvalue.cpp solver/layernorm/backward_t5layernorm.cpp solver/layernorm/forward_addlayernorm.cpp solver/layernorm/forward_layernorm.cpp @@ -462,6 +465,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/miopen_type_traits.hpp kernels/miopen_utility.hpp kernels/neuron.inc + kernels/radix.hpp kernels/rocm_version.inc kernels/stride_array.hpp kernels/tensor_view.hpp @@ -504,6 +508,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenConvDirGenFwd.cl kernels/MIOpenGroupNorm.cpp kernels/MIOpenGetitem.cpp + kernels/MIOpenKthvalue.cpp kernels/MIOpenLayerNorm.cpp kernels/MIOpenLRNBwd.cl kernels/MIOpenLRNFwd.cl @@ -660,6 +665,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN pooling.cpp t5layernorm.cpp ocl/fusionopconvocl.cpp + kthvalue.cpp ocl/fusionopbiasbnactivocl.cpp reducecalculation.cpp reduceextreme.cpp diff --git a/src/include/miopen/kthvalue.hpp b/src/include/miopen/kthvalue.hpp new file mode 100644 index 0000000000..2ebcbd0847 --- /dev/null +++ b/src/include/miopen/kthvalue.hpp @@ -0,0 +1,48 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_KTHVALUE_HPP_ +#define MIOPEN_KTHVALUE_HPP_ + +#include + +namespace miopen { + +struct Handle; +struct TensorDescriptor; + +miopenStatus_t KthvalueForward(Handle& handle, + Data_t workspace, + size_t workspaceSizeInBytes, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& outputDesc, + Data_t output, + size_t* indices, + size_t k, + int32_t dim); + +} // namespace miopen +#endif // MIOPEN_KTHVALUE_HPP_ diff --git a/src/include/miopen/kthvalue/invoke_params.hpp b/src/include/miopen/kthvalue/invoke_params.hpp new file mode 100644 index 0000000000..318f28b415 --- /dev/null +++ b/src/include/miopen/kthvalue/invoke_params.hpp @@ -0,0 +1,67 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include "miopen/common.hpp" +#include "miopen/miopen.h" +#include +#include + +#include + +namespace miopen { + +namespace kthvalue { + +struct KthvalueInvokeParams : public miopen::InvokeParams +{ + KthvalueInvokeParams() = default; + + const TensorDescriptor* inputDesc = nullptr; + + Data_t workspace = nullptr; + std::size_t workspace_size = 0; + ConstData_t input = nullptr; + size_t* indices = nullptr; + + size_t k = 1; + int32_t dim = 0; + + std::size_t GetWorkspaceSize() const { return workspace_size; } + Data_t GetWorkspace() const { return workspace; } +}; + +struct FwdInvokeParams : KthvalueInvokeParams +{ + FwdInvokeParams() = default; + + const TensorDescriptor* outputDesc = nullptr; + Data_t output = nullptr; +}; + +} // namespace kthvalue + +} // namespace miopen diff --git a/src/include/miopen/kthvalue/problem_description.hpp b/src/include/miopen/kthvalue/problem_description.hpp new file mode 100644 index 0000000000..92ff989ea6 --- /dev/null +++ b/src/include/miopen/kthvalue/problem_description.hpp @@ -0,0 +1,64 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include "miopen/errors.hpp" +#include "miopen/miopen.h" +#include +#include + +#include +#include + +namespace miopen { + +struct NetworkConfig; + +namespace kthvalue { + +struct KthvalueFwdProblemDescription : ProblemDescriptionBase +{ + KthvalueFwdProblemDescription(const TensorDescriptor& inputDesc_, + const TensorDescriptor& outputDesc_, + int32_t dim_) + : inputDesc(inputDesc_), outputDesc(outputDesc_), dim(dim_) + { + } + + const TensorDescriptor& GetInputDesc() const { return inputDesc; } + const TensorDescriptor& GetOutputDesc() const { return outputDesc; } + int32_t GetDim() const { return dim; } + NetworkConfig MakeNetworkConfig() const override; + +public: + TensorDescriptor inputDesc; + TensorDescriptor outputDesc; + int32_t dim; +}; + +} // namespace kthvalue + +} // namespace miopen diff --git a/src/include/miopen/kthvalue/solvers.hpp b/src/include/miopen/kthvalue/solvers.hpp new file mode 100644 index 0000000000..ab17daae83 --- /dev/null +++ b/src/include/miopen/kthvalue/solvers.hpp @@ -0,0 +1,65 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace miopen { + +namespace solver { + +namespace kthvalue { + +using KthvalueFwdSolverBase = + NonTunableSolverBase; + +struct KthvalueFwd final : KthvalueFwdSolverBase +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool + IsApplicable(const ExecutionContext& context, + const miopen::kthvalue::KthvalueFwdProblemDescription& problem) const override; + + ConvSolution + GetSolution(const ExecutionContext& context, + const miopen::kthvalue::KthvalueFwdProblemDescription& problem) const override; + + std::size_t + GetWorkspaceSize(const ExecutionContext& context, + const miopen::kthvalue::KthvalueFwdProblemDescription& problem) const override; + + bool MayNeedWorkspace() const override { return true; } +}; + +} // namespace kthvalue + +} // namespace solver + +} // namespace miopen diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index fb81bafb5c..b76450646e 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -60,7 +60,8 @@ enum class Primitive Softmax, Adam, Item, - RoPE + RoPE, + Kthvalue, }; struct MIOPEN_INTERNALS_EXPORT Id diff --git a/src/include/miopen/tensor_view_utils.hpp b/src/include/miopen/tensor_view_utils.hpp index 9f7430ba8a..12b57e2055 100644 --- a/src/include/miopen/tensor_view_utils.hpp +++ b/src/include/miopen/tensor_view_utils.hpp @@ -41,7 +41,12 @@ inline tensor_view_t get_inner_expanded_tv(const TensorDescriptor Desc) tensor_view_t tensor_view; for(size_t i = 0; i < N; ++i) { - if(i < dims.size()) + if(dims.empty()) + { + tensor_view.stride[i] = 0; + tensor_view.size[i] = 0; + } + else if(i < dims.size()) { tensor_view.stride[i] = strides[i]; tensor_view.size[i] = dims[i]; @@ -75,6 +80,28 @@ inline void slice_tv(tensor_view_t& tensor_view, int32_t sliceCount, const in } } +template +inline tensor_view_t get_tv_without_dim(const tensor_view_t& origin_tv, int selected_dim) +{ + tensor_view_t res; + for(int i = 0; i < N; ++i) + { + if(i == selected_dim) + continue; + if(i < selected_dim) + { + res.size[i] = origin_tv.size[i]; + res.stride[i] = origin_tv.stride[i]; + } + else + { + res.size[i - 1] = origin_tv.size[i]; + res.stride[i - 1] = origin_tv.stride[i]; + } + } + return res; +} + } // namespace miopen #endif // MIOPEN_TENSOR_REORDER_UTIL_HPP_ diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp new file mode 100644 index 0000000000..f1dec6c75d --- /dev/null +++ b/src/kernels/MIOpenKthvalue.cpp @@ -0,0 +1,193 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACDTYPEN OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECDTYPEN WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +// #ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +// #endif + +#include "float_types.h" +#include "radix.hpp" +#include "tensor_view.hpp" + +#ifndef IN_OUT_TYPE +#define IN_OUT_TYPE float +#endif + +#ifndef CVT_ACCUM2FLOAT +#define CVT_ACCUM2FLOAT(x) (float_to_bfloat16(x)) +#endif + +#ifndef CVT_FLOAT2ACCUM +#define CVT_FLOAT2ACCUM(x) (bfloat16_to_float(x)) +#endif + +#ifndef LOCAL_SIZE +#define LOCAL_SIZE 256 +#endif + +#ifndef RADIX_BITS +#define RADIX_BITS 2 +#endif + +#ifndef RADIX_SIZE +#define RADIX_SIZE (1 << RADIX_BITS) +#endif + +#ifndef RADIX_MASK +#define RADIX_MASK (RADIX_SIZE - 1) +#endif + +template +__device__ void kthvalueFwd(const DTYPE* input, + DTYPE* output, + size_t* indices, + size_t k, + size_t dim_size, + size_t dim_stride, + tensor_view_t<4> input_tv) +{ + /* + * Example) + * input : {A, B, C, D, E} + * output/indices : {A, B, 1, D, E} or {A, B, D, E} + * dim = 2 (C) + * => gws = {LOCAL_SIZE, A * B * D * E}, lws = {LOCAL_SIZE, 1} + */ + + size_t lid = threadIdx.x; + size_t gid = blockIdx.y * blockDim.y + threadIdx.y; + + __shared__ size_t lsum[LOCAL_SIZE][RADIX_SIZE]; + __shared__ DTYPE lval; + __shared__ long lidx; + size_t counts[RADIX_SIZE]; + RADIX_TYPE desired_mask = 0; + RADIX_TYPE desired = 0; + + tensor_layout_t<4> layout(input_tv, gid); + auto idx = input_tv.get_tensor_view_idx(layout); + + for(size_t pos = sizeof(RADIX_TYPE) * 8 - RADIX_BITS; pos >= 0; pos -= RADIX_BITS) + { + for(size_t i = 0; i < RADIX_SIZE; ++i) + { + counts[i] = 0; + } + + for(size_t i = lid; i < dim_size; i += LOCAL_SIZE) + { + size_t input_idx = idx + i * dim_stride; + RADIX_TYPE val = ENCODE(input[input_idx]); + if((val & desired_mask) == desired) + { + ++counts[GetBitField(val, pos, RADIX_BITS)]; + } + } + + for(size_t i = 0; i < RADIX_SIZE; ++i) + { + lsum[lid][i] = counts[i]; + } + __syncthreads(); + // warp shuffle + for(size_t i = LOCAL_SIZE >> 1; i > 0; i >>= 1) + { + if(lid < i) + { + for(size_t j = 0; j < RADIX_SIZE; ++j) + { + lsum[lid][j] += lsum[lid + i][j]; + } + } + __syncthreads(); + } + // remove use share mem + for(size_t i = 0; i < RADIX_SIZE; ++i) + { + counts[i] = lsum[0][i]; + } + __syncthreads(); + + bool found = false; + // Process in ascending order + for(size_t j = 0; j < RADIX_SIZE; ++j) + { + if(counts[j] >= k) + { + // Answer is inside this count + if(counts[j] == 1 || pos == 0) + { + // 1. counts[j] == 1 + // We found an unique answer. + // 2. pos == 0 + // There are multiple answers so we return any of them + for(size_t i = lid; i < dim_size; i += LOCAL_SIZE) + { + size_t input_idx = idx + i * dim_stride; + DTYPE val_ori = input[input_idx]; + RADIX_TYPE val = ENCODE(val_ori); + if((val & desired_mask) == desired && + GetBitField(val, pos, RADIX_BITS) == j) + { + // For case 2, this will be non-deterministic. + lval = val_ori; + lidx = i; + } + } + found = true; + break; + } + desired = SetBitField(desired, j, pos, RADIX_BITS); + desired_mask = SetBitField(desired_mask, RADIX_MASK, pos, RADIX_BITS); + break; + } + k -= counts[j]; + } + if(found) + break; + } + + __syncthreads(); + if(lid == 0) + { + output[gid] = lval; + indices[gid] = lidx; + } +} + +extern "C" __global__ void KthvalueUnreducedFwd(const IN_OUT_TYPE* input, + IN_OUT_TYPE* output, + size_t* indices, + size_t k, + size_t dim_size, + size_t dim_stride, + tensor_view_t<4> input_tv) +{ + kthvalueFwd(input, output, indices, k, dim_size, dim_stride, input_tv); +} diff --git a/src/kernels/radix.hpp b/src/kernels/radix.hpp new file mode 100644 index 0000000000..2c74a16712 --- /dev/null +++ b/src/kernels/radix.hpp @@ -0,0 +1,136 @@ +#pragma once +#include + +// #include "backend/modnn/core/op/device/hip/utils/hip_atomic.h" + +#define ENCODE encode +#define RADIX_TYPE typename RadixType::type +#define GetBitField(x, pos, bits) GetBitFieldImpl(x, pos, bits) +#define SetBitField(x, a, pos, bits) SetBitFieldImpl(x, a, pos, bits) + +#define DEFINE_RADIX_TYPE(DTYPE, cpp_type) \ + template <> \ + struct RadixType \ + { \ + using type = cpp_type; \ + }; + +template +struct RadixType +{ +}; + +DEFINE_RADIX_TYPE(uint8_t, uint32_t) +DEFINE_RADIX_TYPE(int8_t, uint32_t) +DEFINE_RADIX_TYPE(int16_t, uint32_t) +DEFINE_RADIX_TYPE(int32_t, uint32_t) +DEFINE_RADIX_TYPE(int64_t, uint64_t) +DEFINE_RADIX_TYPE(bool, bool) +DEFINE_RADIX_TYPE(_Float16, uint16_t) +DEFINE_RADIX_TYPE(float, uint32_t) +DEFINE_RADIX_TYPE(double, uint64_t) + +template ::type> +__device__ inline Radix encode(DTYPE v) +{ + if constexpr(std::is_same::value) + { + return v; + } + else if constexpr(std::is_same::value) + { + return v; + } + else if constexpr(std::is_same::value) + { + return 128u + v; + } + else if constexpr(std::is_same::value) + { + return 32768u + v; + } + else if constexpr(std::is_same::value) + { + return 2147483648u + v; + } + else if constexpr(std::is_same::value) + { + return 9223372036854775808ull + v; + } + else if constexpr(std::is_same<_Float16, DTYPE>::value) + { + Radix x = __half_as_ushort(v); + Radix mask = (x & 0x8000) ? 0xffff : 0x8000; + return (v == v) ? (x ^ mask) : 0xffff; + } + else if constexpr(std::is_same::value) + { + Radix x = __float_as_uint(v); + Radix mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; + return (v == v) ? (x ^ mask) : 0xffffffff; + } + else if constexpr(std::is_same::value) + { + Radix x = __double_as_ulonglong(v); + Radix mask = -((x >> 63)) | 0x8000000000000000; + return (v == v) ? (x ^ mask) : 0xffffffffffffffff; + } +} + +template +__device__ inline DTYPE decode(Radix v) +{ + if constexpr(std::is_same::value) + { + return v; + } + else if constexpr(std::is_same::value) + { + return v; + } + else if constexpr(std::is_same::value) + { + return v - 128u; + } + else if constexpr(std::is_same::value) + { + return v - 32768u; + } + else if constexpr(std::is_same::value) + { + return v - 2147483648u; + } + else if constexpr(std::is_same::value) + { + return v - 9223372036854775808ull; + } + else if constexpr(std::is_same<_Float16, DTYPE>::value) + { + Radix mask = (v & 0x8000) ? 0x8000 : 0xffff; + return __ushort_as_half((ushort)(v ^ mask)); + } + else if constexpr(std::is_same::value) + { + Radix mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff; + return __uint_as_float(v ^ mask); + } + else if constexpr(std::is_same::value) + { + Radix mask = ((v >> 63) - 1) | 0x8000000000000000; + return __ulonglong_as_double(v ^ mask); + } +} + +// returns x[pos+bits:pos] +template +__device__ inline Radix GetBitFieldImpl(Radix x, int pos, int bits) +{ + return (x >> pos) & ((1 << bits) - 1); +} + +// x[pos+bits:pos] = a +template +__device__ inline Radix SetBitFieldImpl(Radix x, Radix a, int pos, int bits) +{ + return x | (a << pos); +} \ No newline at end of file diff --git a/src/kthvalue.cpp b/src/kthvalue.cpp new file mode 100644 index 0000000000..98f6e0d01b --- /dev/null +++ b/src/kthvalue.cpp @@ -0,0 +1,87 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "miopen/miopen.h" +#include "miopen/kthvalue/invoke_params.hpp" +#include "miopen/kthvalue/problem_description.hpp" +#include "miopen/kthvalue/solvers.hpp" +#include +#include +#include +#include +#include + +namespace miopen { + +miopenStatus_t KthvalueForward(Handle& handle, + Data_t workspace, + size_t workspaceSizeInBytes, + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& outputDesc, + Data_t output, + size_t* indices, + size_t k, + int32_t dim) +{ + // const auto problem = + // kthvalue::KthvalueFwdProblemDescription{inputDesc, targetDesc, outputDesc, reduction}; + + // const auto invoke_params = [&]() { + // auto tmp = kthvalue::FwdInvokeParams{}; + // tmp.inputDesc = &inputDesc; + // tmp.targetDesc = &targetDesc; + // tmp.outputDesc = &outputDesc; + // tmp.input = input; + // tmp.target = target; + // tmp.output = output; + // tmp.workspace = workspace; + // tmp.workspace_size = workspaceSizeInBytes; + // tmp.alpha = alpha; + // tmp.gamma = gamma; + // tmp.reduction = reduction; + // return tmp; + // }(); + + // if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + // { + // const auto algo = AlgorithmName{"KthvalueUnreducedFwd"}; + // const auto solvers = solver::SolverContainer{}; + + // solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + // } + // else + // { + // const auto algo = AlgorithmName{"KthvalueFwd"}; + // const auto solvers = solver::SolverContainer{}; + + // solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + // } + + return miopenStatusSuccess; +} + +} // namespace miopen diff --git a/src/kthvalue/problem_description.cpp b/src/kthvalue/problem_description.cpp new file mode 100644 index 0000000000..96e7f2ed67 --- /dev/null +++ b/src/kthvalue/problem_description.cpp @@ -0,0 +1,54 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include + +#include + +namespace miopen { + +namespace kthvalue { + +NetworkConfig KthvalueFwdProblemDescription::MakeNetworkConfig() const +{ + auto input_dtype = inputDesc.GetType(); + auto size = inputDesc.GetElementSize(); + auto dim_num = inputDesc.GetSize(); + + std::ostringstream ss; + + ss << "kthvalue_fwd"; + ss << "i_dtype" << input_dtype; + ss << "dim_num" << dim_num; + ss << "size" << size; + + return NetworkConfig{ss.str()}; +} + +} // namespace kthvalue + +} // namespace miopen diff --git a/src/kthvalue_api.cpp b/src/kthvalue_api.cpp new file mode 100644 index 0000000000..20408ecea5 --- /dev/null +++ b/src/kthvalue_api.cpp @@ -0,0 +1,112 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "miopen/miopen.h" +#include +#include +#include +#include +#include + +inline std::ostream& operator<<(std::ostream& os, const std::vector& v) +{ + os << '{'; + for(int i = 0; i < v.size(); ++i) + { + if(i != 0) + os << ','; + os << v[i]; + } + os << '}'; + return os; +} + +static void LogCmdKthvalue(const miopenTensorDescriptor_t inputDesc, bool is_fwd) +{ + if(miopen::IsLoggingCmd()) + { + std::stringstream ss; + auto dtype = miopen::deref(inputDesc).GetType(); + if(dtype == miopenHalf) + { + ss << "kthvaluefp16"; + } + else if(dtype == miopenFloat) + { + ss << "kthvaluefp32"; + } + else if(dtype == miopenBFloat16) + { + ss << "kthvaluebfp16"; + } + + MIOPEN_LOG_FUNCTION(inputDesc); + ss << " -n " << miopen::deref(inputDesc).GetLengths()[0]; + ss << " -T " << miopen::deref(inputDesc).GetLengths(); + ss << " -Si " << miopen::deref(inputDesc).GetStrides(); + ss << " -F " << ((is_fwd) ? "1" : "2"); + + MIOPEN_LOG_DRIVER_CMD(ss.str()); + } +} + +extern "C" miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, + void* workspace, + size_t workspaceSizeInBytes, + miopenTensorDescriptor_t inputDesc, + const void* input, + miopenTensorDescriptor_t outputDesc, + void* output, + size_t* indices, + size_t k, + int32_t dim) +{ + MIOPEN_LOG_FUNCTION(handle, + workspace, + workspaceSizeInBytes, + inputDesc, + input, + outputDesc, + output, + indices, + k, + dim); + + LogCmdKthvalue(inputDesc, true); + + return miopen::try_([&] { + miopen::KthvalueForward(miopen::deref(handle), + DataCast(workspace), + workspaceSizeInBytes, + miopen::deref(inputDesc), + DataCast(input), + miopen::deref(outputDesc), + DataCast(output), + indices, + k, + dim); + }); +} diff --git a/src/solver.cpp b/src/solver.cpp index bbd13bd89f..85f47f9c81 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -24,6 +24,7 @@ * *******************************************************************************/ +#include "miopen/kthvalue/solvers.hpp" #include #include @@ -676,6 +677,7 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) Register(registry, ++id, Primitive::RoPE, rope::RoPEForward{}.SolverDbId()); Register(registry, ++id, Primitive::RoPE, rope::RoPEBackward{}.SolverDbId()); + Register(registry, ++id, Primitive::Kthvalue, kthvalue::KthvalueFwd{}.SolverDbId()); // IMPORTANT: New solvers should be added to the end of the function! } diff --git a/src/solver/kthvalue/forward_kthvalue.cpp b/src/solver/kthvalue/forward_kthvalue.cpp new file mode 100644 index 0000000000..07b2c2f5de --- /dev/null +++ b/src/solver/kthvalue/forward_kthvalue.cpp @@ -0,0 +1,133 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "miopen/errors.hpp" +#include "miopen/kthvalue/problem_description.hpp" +#include "miopen/miopen.h" +#include "miopen/tensor_view_utils.hpp" +#include +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 256 + +namespace miopen { + +namespace solver { + +namespace kthvalue { + +bool KthvalueFwd::IsApplicable(const ExecutionContext& /*context*/, + const miopen::kthvalue::KthvalueFwdProblemDescription& problem) const +{ + if(problem.GetInputDesc().GetSize() > 5) + return false; + return true; +} + +ConvSolution +KthvalueFwd::GetSolution(const ExecutionContext& context, + const miopen::kthvalue::KthvalueFwdProblemDescription& problem) const +{ + std::ignore = context; + auto result = ConvSolution{miopenStatusSuccess}; + + auto input_desc = problem.GetInputDesc(); + auto in_dtype = miopen::GetDataType(input_desc.GetType()); + auto dtype = problem.GetOutputDesc().GetType(); + auto size = input_desc.GetElementSize(); + auto dim_size = input_desc.GetLengths()[problem.GetDim()]; + + size_t xlocalsize = LOCAL_SIZE; + size_t xgridsize = AlignUp(size / dim_size, xlocalsize); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + + kernel.kernel_file = "MIOpenKthvalue.cpp"; + kernel.kernel_name = "KthvalueForward"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"IN_OUT_TYPE", in_dtype == "bfloat16" ? "ushort" : in_dtype}, + {"LOCAL_SIZE", LOCAL_SIZE}, + }; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + + result.invoker_factory = [=](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + size_t dimSize = params.inputDesc->GetLengths()[params.dim]; + size_t dimStride = params.inputDesc->GetStrides()[params.dim]; + + auto input_tv = get_inner_expanded_tv<5>(deref(params.inputDesc)); + auto input_tv_without_reduced_dim = get_tv_without_dim<5, 4>(input_tv, params.dim); + + kernel(params.input, + params.output, + params.indices, + params.k, + dimSize, + dimStride, + input_tv_without_reduced_dim); + }; + }; + + return result; +} + +std::size_t KthvalueFwd::GetWorkspaceSize( + const ExecutionContext& /*context*/, + const miopen::kthvalue::KthvalueFwdProblemDescription& /*problem*/) const +{ + return 0; +} + +} // namespace kthvalue + +} // namespace solver + +} // namespace miopen From 6e9898c355d8549aa6f4926792a38b30c38ab4ae Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Wed, 26 Jun 2024 16:55:49 +0700 Subject: [PATCH 02/28] add fwd kernel with driver --- driver/CMakeLists.txt | 1 + driver/dm_kthvalue.cpp | 41 ++ driver/driver.hpp | 5 +- driver/kthvalue_driver.hpp | 531 ++++++++++++++++++ include/miopen/miopen.h | 3 +- src/CMakeLists.txt | 1 + src/include/miopen/kthvalue.hpp | 1 + src/include/miopen/kthvalue/invoke_params.hpp | 7 +- .../miopen/kthvalue/problem_description.hpp | 25 +- src/kernels/MIOpenKthvalue.cpp | 56 +- src/kernels/radix.hpp | 59 +- src/kernels/warp_shuffle.hpp | 76 +++ src/kthvalue.cpp | 53 +- src/kthvalue/problem_description.cpp | 2 - src/kthvalue_api.cpp | 3 + src/solver/kthvalue/forward_kthvalue.cpp | 14 +- 16 files changed, 787 insertions(+), 91 deletions(-) create mode 100644 driver/dm_kthvalue.cpp create mode 100644 driver/kthvalue_driver.hpp create mode 100644 src/kernels/warp_shuffle.hpp diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 19abd61597..f91e3e23c8 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -45,6 +45,7 @@ add_executable(MIOpenDriver dm_gemm.cpp dm_getitem.cpp dm_groupnorm.cpp + dm_kthvalue.cpp dm_layernorm.cpp dm_lrn.cpp dm_pool.cpp diff --git a/driver/dm_kthvalue.cpp b/driver/dm_kthvalue.cpp new file mode 100644 index 0000000000..71adfdc245 --- /dev/null +++ b/driver/dm_kthvalue.cpp @@ -0,0 +1,41 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "registry_driver_maker.hpp" +#include "kthvalue_driver.hpp" + +static Driver* makeDriver(const std::string& base_arg) +{ + if(base_arg == "kthvalue") + return new KthvalueDriver(); + else if(base_arg == "kthvaluefp16") + return new KthvalueDriver(); + else if(base_arg == "kthvaluebfp16") + return new KthvalueDriver(); + return nullptr; +} + +REGISTER_DRIVER_MAKER(makeDriver); diff --git a/driver/driver.hpp b/driver/driver.hpp index 196d48c2b6..cc05841b9f 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -175,7 +175,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) "groupnorm[bfp16|fp16], cat[bfp16|fp16], addlayernorm[bfp16|fp16], " "t5layernorm[bfp16|fp16], adam[fp16], ampadam, reduceextreme[bfp16|fp16], " "adamw[fp16], ampadamw, transformersadamw[fp16], transformersampadamw, " - "getitem[bfp16|fp16], reducecalculation[bfp16|fp16], rope[bfp16|fp16]\n"); + "getitem[bfp16|fp16], reducecalculation[bfp16|fp16], rope[bfp16|fp16], kthvalue[bfp16|fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -207,7 +207,8 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "transformersadamwfp16" && arg != "transformersampadamw" && arg != "getitem" && arg != "getitemfp16" && arg != "getitembfp16" && arg != "reducecalculation" && arg != "reducecalculationfp16" && arg != "reducecalculationbfp16" && arg != "rope" && - arg != "ropefp16" && arg != "ropebfp16" && arg != "--version") + arg != "ropefp16" && arg != "ropebfp16" && arg != "kthvalue" && arg != "kthvaluefp16" && + arg != "kthvaluebfp16" && arg != "--version") { printf("FAILED: Invalid Base Input Argument\n"); Usage(); diff --git a/driver/kthvalue_driver.hpp b/driver/kthvalue_driver.hpp new file mode 100644 index 0000000000..b6e9b46574 --- /dev/null +++ b/driver/kthvalue_driver.hpp @@ -0,0 +1,531 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once +#include "InputFlags.hpp" +#include "driver.hpp" +#include "miopen/errors.hpp" +#include +#include +#include +#include "miopen/miopen.h" +#include "tensor_driver.hpp" +#include "timer.hpp" +#include "random.hpp" +#include <../test/tensor_holder.hpp> +#include <../test/verify.hpp> +#include +#include +#include +#include + +template +void mloKthvalueFwdRunHost(TIO* input, + miopenTensorDescriptor_t pInputDesc, + TIO* outputHost, + miopenTensorDescriptor_t outputDesc, + size_t* indices, + size_t k, + int dim) +{ + auto inputDesc = miopen::deref(pInputDesc); + size_t inputSize = inputDesc.GetElementSize(); + size_t dimSize = inputDesc.GetLengths()[dim]; + size_t dimStride = inputDesc.GetStrides()[dim]; + auto inputTv = miopen::get_inner_expanded_tv<5>(miopen::deref(pInputDesc)); + auto inputTvWithoutDim = miopen::get_tv_without_dim<5, 4>(inputTv, dim); + // auto outputTv = miopen::get_inner_expanded_tv<5>(miopen::deref(outputDesc)); + + size_t numSlice = inputSize / dimSize; + + std::vector elements; + std::vector ids(dimSize); + for(int i = 0; i < dimSize; ++i) + { + ids[i] = i; + } + + for(int slideID = 0; slideID < numSlice; ++slideID) + { + elements.clear(); + tensor_layout_t<4> layout(inputTvWithoutDim, slideID); + auto idx = inputTvWithoutDim.get_tensor_view_idx(layout); + + for(int j = 0; j < dimSize; ++j) + { + elements.push_back(input[idx + j * dimStride]); + } + + std::sort(ids.begin(), ids.end(), [=](size_t x, size_t y) -> bool { + return elements[x] < elements[y]; + }); + outputHost[slideID] = elements[ids[k - 1]]; + indices[slideID] = ids[k - 1]; + } +} + +template +class KthvalueDriver : public Driver +{ +public: + KthvalueDriver() : Driver() + { + miopenCreateTensorDescriptor(&inputDesc); + miopenCreateTensorDescriptor(&indicesDesc); + miopenCreateTensorDescriptor(&outputDesc); + miopenCreateTensorDescriptor(&doutputDesc); + miopenCreateTensorDescriptor(&dinputDesc); + + data_type = miopen_type{}; + } + + std::vector ComputeStrides(std::vector input); + int AddCmdLineArgs() override; + int ParseCmdLineArgs(int argc, char* argv[]) override; + InputFlags& GetInputFlags() override { return inflags; } + + int GetandSetData() override; + + int AllocateBuffersAndCopy() override; + + int RunForwardGPU() override; + int RunForwardCPU(); + + int RunBackwardGPU() override; + int RunBackwardCPU(); + + int VerifyBackward() override; + int VerifyForward() override; + ~KthvalueDriver() override + { + miopenDestroyTensorDescriptor(inputDesc); + miopenDestroyTensorDescriptor(indicesDesc); + miopenDestroyTensorDescriptor(outputDesc); + miopenDestroyTensorDescriptor(doutputDesc); + miopenDestroyTensorDescriptor(dinputDesc); + } + +private: + InputFlags inflags; + + miopenTensorDescriptor_t inputDesc; + miopenTensorDescriptor_t indicesDesc; + miopenTensorDescriptor_t outputDesc; + miopenTensorDescriptor_t doutputDesc; + miopenTensorDescriptor_t dinputDesc; + + std::unique_ptr input_dev; + std::unique_ptr indices_dev; + std::unique_ptr output_dev; + std::unique_ptr doutput_dev; + std::unique_ptr dinput_dev; + std::unique_ptr workspace_dev; + + std::vector input; + std::vector indices; + std::vector indicesHost; + std::vector output; + std::vector outputHost; + std::vector doutput; + std::vector dinput; + std::vector dinputHost; + std::vector workspace; + + bool isContiguous; + int dim; + size_t k; + + size_t workSpaceSizeInBytes; +}; + +template +int KthvalueDriver::ParseCmdLineArgs(int argc, char* argv[]) +{ + inflags.Parse(argc, argv); + isContiguous = inflags.GetValueInt("is-contiguous") == 1 ? true : false; + k = inflags.GetValueInt("k"); + dim = inflags.GetValueInt("dim"); + auto inDims = inflags.GetValueTensor("dim-lengths").lengths; + int num_dim = inDims.size(); + if(dim < -num_dim || dim >= num_dim) + { + MIOPEN_THROW(miopenStatusBadParm, "Kthvalue: dim doesn't not exist"); + } + + if(inflags.GetValueInt("time") == 1) + { + miopenEnableProfiling(GetHandle(), true); + } + return miopenStatusSuccess; +} + +template +int KthvalueDriver::GetandSetData() +{ + auto inDims = inflags.GetValueTensor("dim-lengths").lengths; + std::vector inStride = ComputeStrides(inDims); + auto outDims = inflags.GetValueTensor("dim-lengths").lengths; + + if(dim < 0) + { + dim += inDims.size(); + } + outDims.erase(outDims.begin() + dim); + + SetTensorNd(inputDesc, inDims, inStride, data_type); + SetTensorNd(doutputDesc, outDims, data_type); + SetTensorNd(dinputDesc, inDims, data_type); + SetTensorNd(outputDesc, outDims, data_type); + // miopenDataType_t doesn't support size_t tensor, I use double instead (both types use 64 bits) + SetTensorNd(indicesDesc, outDims, miopen_type{}); + + return 0; +} + +// Equivalent to: tensor.tranpose(0, -1).contiguous().tranpose(0, -1) incase contiguous = False +template +std::vector KthvalueDriver::ComputeStrides(std::vector inputDim) +{ + if(!isContiguous) + std::swap(inputDim.front(), inputDim.back()); + std::vector strides(inputDim.size()); + strides.back() = 1; + for(int i = inputDim.size() - 2; i >= 0; --i) + strides[i] = strides[i + 1] * inputDim[i + 1]; + if(!isContiguous) + std::swap(strides.front(), strides.back()); + return strides; +} + +template +int KthvalueDriver::AddCmdLineArgs() +{ + inflags.AddInputFlag("forw", 'F', "1", "Run only Forward (Default=1)", "int"); + inflags.AddTensorFlag( + "dim-lengths", 'D', "256x4x2", "The dimensional lengths of the input tensor"); + inflags.AddInputFlag("k", 'k', "1", "dim (Default=1)", "int"); + inflags.AddInputFlag("dim", 'd', "-1", "dim (Default=-1)", "int"); + inflags.AddInputFlag("is-contiguous", 'c', "1", "is-contiguous (Default=1)", "int"); + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); + inflags.AddInputFlag("verify", 'V', "1", "Verify Each Layer (Default=1)", "int"); + inflags.AddInputFlag("time", 't', "0", "Time Each Layer (Default=0)", "int"); + inflags.AddInputFlag( + "wall", 'w', "0", "Wall-clock Time Each Layer, Requires time == 1 (Default=0)", "int"); + + return miopenStatusSuccess; +} + +template +int KthvalueDriver::AllocateBuffersAndCopy() +{ + size_t in_sz = miopen::deref(inputDesc).GetElementSize(); + size_t idx_sz = miopen::deref(indicesDesc).GetElementSize(); + size_t out_sz = miopen::deref(outputDesc).GetElementSize(); + size_t dO_sz = miopen::deref(doutputDesc).GetElementSize(); + size_t dI_sz = miopen::deref(dinputDesc).GetElementSize(); + + uint32_t ctx = 0; + + input_dev = std::unique_ptr(new GPUMem(ctx, in_sz, sizeof(TIO))); + indices_dev = std::unique_ptr(new GPUMem(ctx, idx_sz, sizeof(size_t))); + output_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(TIO))); + doutput_dev = std::unique_ptr(new GPUMem(ctx, dO_sz, sizeof(TIO))); + dinput_dev = std::unique_ptr(new GPUMem(ctx, dI_sz, sizeof(TIO))); + + // miopenGetKthvalueForwardWorkspaceSize(handle, inputDesc, outputDesc, &workSpaceSizeInBytes); + workSpaceSizeInBytes = 0; + workspace_dev = + std::unique_ptr(new GPUMem(ctx, workSpaceSizeInBytes / sizeof(TIO), sizeof(TIO))); + + input = std::vector(in_sz, static_cast(0)); + indices = std::vector(idx_sz, 0); + indicesHost = std::vector(idx_sz, 0); + output = std::vector(out_sz, static_cast(0)); + outputHost = std::vector(out_sz, static_cast(0)); + doutput = std::vector(dO_sz, static_cast(0)); + dinput = std::vector(dI_sz, static_cast(0)); + dinputHost = std::vector(dI_sz, static_cast(0)); + workspace = std::vector(workSpaceSizeInBytes / sizeof(TIO), static_cast(0)); + + for(int i = 0; i < in_sz; i++) + { + input[i] = prng::gen_A_to_B(static_cast(-10), static_cast(10)); + } + for(int i = 0; i < dO_sz; ++i) + { + doutput[i] = prng::gen_A_to_B(static_cast(-2), static_cast(2)); + } + + fill(output.begin(), output.end(), static_cast(0)); + fill(indices.begin(), indices.end(), static_cast(0)); + fill(dinput.begin(), dinput.end(), static_cast(0)); + + if(input_dev->ToGPU(GetStream(), input.data()) != 0) + std::cerr << "Error copying (in) to GPU, size: " << input_dev->GetSize() << std::endl; + + if(indices_dev->ToGPU(GetStream(), indices.data()) != 0) + std::cerr << "Error copying (idx) to GPU, size: " << indices_dev->GetSize() << std::endl; + + if(output_dev->ToGPU(GetStream(), output.data()) != 0) + std::cerr << "Error copying (out) to GPU, size: " << output_dev->GetSize() << std::endl; + + if(doutput_dev->ToGPU(GetStream(), doutput.data()) != 0) + std::cerr << "Error copying (dO) to GPU, size: " << doutput_dev->GetSize() << std::endl; + + if(dinput_dev->ToGPU(GetStream(), dinput.data()) != 0) + std::cerr << "Error copying (dI) to GPU, size: " << dinput_dev->GetSize() << std::endl; + + if(workspace_dev->ToGPU(GetStream(), workspace.data()) != 0) + std::cerr << "Error copying (dI) to GPU, size: " << workspace_dev->GetSize() << std::endl; + + return miopenStatusSuccess; +} + +template +int KthvalueDriver::RunForwardGPU() +{ + float kernel_total_time = 0; + float kernel_first_time = 0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenKthvalueForward(GetHandle(), + workspace_dev->GetMem(), + workSpaceSizeInBytes, + inputDesc, + input_dev->GetMem(), + outputDesc, + output_dev->GetMem(), + indicesDesc, + (size_t*)indices_dev->GetMem(), + k, + dim); + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + std::cout << "Wall-clock Time Kthvalue Fwd Elapsed: " << t.gettime_ms() / iter << " ms" + << std::endl; + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + std::cout << "GPU Kernel Time Kthvalue Fwd Elapsed: " << kernel_average_time << " ms" + << std::endl; + } + + if(output_dev->FromGPU(GetStream(), output.data()) != 0) + std::cerr << "Error copying (out_dev) from GPU, size: " << output_dev->GetSize() + << std::endl; + + if(indices_dev->FromGPU(GetStream(), indices.data()) != 0) + std::cerr << "Error copying (indices_dev) from GPU, size: " << indices_dev->GetSize() + << std::endl; + + return miopenStatusSuccess; +} + +template +int KthvalueDriver::RunForwardCPU() +{ + mloKthvalueFwdRunHost( + input.data(), inputDesc, outputHost.data(), outputDesc, indicesHost.data(), k, dim); + + return miopenStatusSuccess; +} + +template +int KthvalueDriver::RunBackwardGPU() +{ + // float kernel_total_time = 0; + // float kernel_first_time = 0; + + // Timer t; + // START_TIME + + // for(int i = 0; i < inflags.GetValueInt("iter"); i++) + // { + // void* p_dtarget = nullptr; + // if(isTargetGradientComputed) + // { + // p_dtarget = dtarget_dev->GetMem(); + // } + + // miopenKthvalueBackward(GetHandle(), + // inputDesc, + // input_dev->GetMem(), + // targetDesc, + // target_dev->GetMem(), + // doutputDesc, + // doutput_dev->GetMem(), + // dinputDesc, + // dinput_dev->GetMem(), + // dtargetDesc, + // p_dtarget, + // alpha, + // gamma, + // reduction); + + // float time = 0.0; + // miopenGetKernelTime(GetHandle(), &time); + // kernel_total_time += time; + // if(i == 0) + // kernel_first_time = time; + // } + + // if(inflags.GetValueInt("time") == 1) + // { + // STOP_TIME + // int iter = inflags.GetValueInt("iter"); + // if(WALL_CLOCK) + // std::cout << "Wall-clock Time Sigmoid Focal Loss Bwd Elapsed: " << t.gettime_ms() / + // iter + // << " ms" << std::endl; + + // float kernel_average_time = + // iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + // std::cout << "GPU Kernel Time Sigmoid Focal Loss Bwd Elapsed: " << kernel_average_time + // << " ms" << std::endl; + // } + + // if(dinput_dev->FromGPU(GetStream(), dinput.data()) != 0) + // std::cerr << "Error copying (dI_dev) from GPU, size: " << dinput_dev->GetSize() + // << std::endl; + // if(isTargetGradientComputed && dtarget_dev->FromGPU(GetStream(), dtarget.data()) != 0) + // std::cerr << "Error copying (dT_dev) from GPU, size: " << dtarget_dev->GetSize() + // << std::endl; + + return miopenStatusSuccess; +} + +template +int KthvalueDriver::RunBackwardCPU() +{ + // TIO* p_dtarget = nullptr; + // if(isTargetGradientComputed) + // { + // p_dtarget = dtargetHost.data(); + // } + // if(reduction == MIOPEN_LOSS_REDUCTION_NONE) + // { + + // mloKthvalueUnreducedBwdRunHost(input.data(), + // inputDesc, + // target.data(), + // targetDesc, + // doutput.data(), + // doutputDesc, + // dinputHost.data(), + // dinputDesc, + // p_dtarget, + // dtargetDesc, + // alpha, + // gamma); + // } + // else + // { + // mloKthvalueBwdRunHost(input.data(), + // inputDesc, + // target.data(), + // targetDesc, + // doutput.data(), + // doutputDesc, + // dinputHost.data(), + // dinputDesc, + // p_dtarget, + // dtargetDesc, + // alpha, + // gamma, + // divisor); + // } + + return miopenStatusSuccess; +} + +template +int KthvalueDriver::VerifyForward() +{ + RunForwardCPU(); + + double tolerance = std::numeric_limits::epsilon() * 10; + auto errorOutput = miopen::rms_range(outputHost, output); + + if(!std::isfinite(errorOutput) || errorOutput > tolerance) + { + std::cout << "Forward Kthvalue output FAILED: " << errorOutput << " > " << tolerance + << std::endl; + return EC_VerifyFwd; + } + else + { + std::cout << "Forward Kthvalue Verifies OK on CPU reference (" << errorOutput << "< " + << tolerance << ')' << std::endl; + } + + return miopenStatusSuccess; +} + +template +int KthvalueDriver::VerifyBackward() +{ + // RunBackwardCPU(); + + // double tolerance = std::numeric_limits::epsilon() * 10; + // auto dinputError = miopen::rms_range(dinputHost, dinput); + // auto dtargetError = miopen::rms_range(dtargetHost, dtarget); + + // if(!std::isfinite(dinputError) || dinputError > tolerance) + // { + // std::cout << "Backward " << reduction << " Sigmoid Focal Loss FAILED: " << dinputError + // << " > " << tolerance << std::endl; + // return EC_VerifyFwd; + // } + // else if(isTargetGradientComputed && (!std::isfinite(dtargetError) || dtargetError > + // tolerance)) + // { + // std::cout << "Backward " << reduction << " Sigmoid Focal Loss FAILED: " << dtargetError + // << " > " << tolerance << std::endl; + // return EC_VerifyFwd; + // } + // else + // { + // std::cout << "Backward " << reduction + // << " Sigmoid Focal Loss Verifies OK on CPU reference (dinput: " << dinputError + // << ", dtarget: " << dtargetError << "< " << tolerance << ')' << std::endl; + // } + + return miopenStatusSuccess; +} diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 2c6819181c..ce5fa5d4d2 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -7702,9 +7702,10 @@ MIOPEN_EXPORT miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, const void* input, miopenTensorDescriptor_t outputDesc, void* output, + miopenTensorDescriptor_t indicesDesc, size_t* indices, size_t k, - int32_t dim); + int32_t dim = -1); /** @} */ // CLOSEOUT kthvalue DOXYGEN GROUP diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index ef06c9383f..a595ffac1d 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -470,6 +470,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/stride_array.hpp kernels/tensor_view.hpp kernels/utilities.inc + kernels/warp_shuffle.hpp kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1536vgprs_fp16_fp16acc_f2x3_c16_stride1.inc kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1536vgprs_fp16_fp16acc_f2x3_c32_stride1.inc kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1024vgprs_fp16_fp16acc_f2x3_c16_stride1.inc diff --git a/src/include/miopen/kthvalue.hpp b/src/include/miopen/kthvalue.hpp index 2ebcbd0847..d7dd6f1332 100644 --- a/src/include/miopen/kthvalue.hpp +++ b/src/include/miopen/kthvalue.hpp @@ -40,6 +40,7 @@ miopenStatus_t KthvalueForward(Handle& handle, ConstData_t input, const TensorDescriptor& outputDesc, Data_t output, + const TensorDescriptor& indicesDesc, size_t* indices, size_t k, int32_t dim); diff --git a/src/include/miopen/kthvalue/invoke_params.hpp b/src/include/miopen/kthvalue/invoke_params.hpp index 318f28b415..6cd6896d9b 100644 --- a/src/include/miopen/kthvalue/invoke_params.hpp +++ b/src/include/miopen/kthvalue/invoke_params.hpp @@ -45,7 +45,6 @@ struct KthvalueInvokeParams : public miopen::InvokeParams Data_t workspace = nullptr; std::size_t workspace_size = 0; ConstData_t input = nullptr; - size_t* indices = nullptr; size_t k = 1; int32_t dim = 0; @@ -58,8 +57,10 @@ struct FwdInvokeParams : KthvalueInvokeParams { FwdInvokeParams() = default; - const TensorDescriptor* outputDesc = nullptr; - Data_t output = nullptr; + const TensorDescriptor* outputDesc = nullptr; + Data_t output = nullptr; + const TensorDescriptor* indicesDesc = nullptr; + size_t* indices = nullptr; }; } // namespace kthvalue diff --git a/src/include/miopen/kthvalue/problem_description.hpp b/src/include/miopen/kthvalue/problem_description.hpp index 92ff989ea6..0f3ea7eb39 100644 --- a/src/include/miopen/kthvalue/problem_description.hpp +++ b/src/include/miopen/kthvalue/problem_description.hpp @@ -27,6 +27,7 @@ #include "miopen/errors.hpp" #include "miopen/miopen.h" +#include #include #include @@ -43,20 +44,40 @@ struct KthvalueFwdProblemDescription : ProblemDescriptionBase { KthvalueFwdProblemDescription(const TensorDescriptor& inputDesc_, const TensorDescriptor& outputDesc_, - int32_t dim_) - : inputDesc(inputDesc_), outputDesc(outputDesc_), dim(dim_) + const TensorDescriptor& indicesDesc_, + int32_t dim_, + size_t k_) + : inputDesc(inputDesc_), + outputDesc(outputDesc_), + indicesDesc(indicesDesc_), + dim(dim_), + k(k_) { + if(k > inputDesc.GetLengths()[dim]) + { + MIOPEN_THROW(miopenStatusBadParm, + "Kthvalue: k must be less than the size of the dimension"); + } + int num_dim = inputDesc.GetSize(); + if(dim < -num_dim || dim >= num_dim) + { + MIOPEN_THROW(miopenStatusBadParm, "Kthvalue: dim doesn't not exist"); + } } const TensorDescriptor& GetInputDesc() const { return inputDesc; } const TensorDescriptor& GetOutputDesc() const { return outputDesc; } + const TensorDescriptor& GetIndicesDesc() const { return indicesDesc; } int32_t GetDim() const { return dim; } + size_t GetK() const { return k; } NetworkConfig MakeNetworkConfig() const override; public: TensorDescriptor inputDesc; TensorDescriptor outputDesc; + TensorDescriptor indicesDesc; int32_t dim; + size_t k; }; } // namespace kthvalue diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index f1dec6c75d..1366c714f4 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -24,16 +24,17 @@ * *******************************************************************************/ -#include -#include -// #ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +// #include +#include +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS #include #include -// #endif +#endif #include "float_types.h" #include "radix.hpp" #include "tensor_view.hpp" +#include "warp_shuffle.hpp" #ifndef IN_OUT_TYPE #define IN_OUT_TYPE float @@ -70,7 +71,9 @@ __device__ void kthvalueFwd(const DTYPE* input, size_t k, size_t dim_size, size_t dim_stride, - tensor_view_t<4> input_tv) + tensor_view_t<4> input_tv, + tensor_view_t<4> output_tv, + tensor_view_t<4> indices_tv) { /* * Example) @@ -81,9 +84,10 @@ __device__ void kthvalueFwd(const DTYPE* input, */ size_t lid = threadIdx.x; - size_t gid = blockIdx.y * blockDim.y + threadIdx.y; + size_t gid = blockIdx.x; __shared__ size_t lsum[LOCAL_SIZE][RADIX_SIZE]; + __shared__ size_t smem_count[RADIX_SIZE]; __shared__ DTYPE lval; __shared__ long lidx; size_t counts[RADIX_SIZE]; @@ -115,7 +119,6 @@ __device__ void kthvalueFwd(const DTYPE* input, lsum[lid][i] = counts[i]; } __syncthreads(); - // warp shuffle for(size_t i = LOCAL_SIZE >> 1; i > 0; i >>= 1) { if(lid < i) @@ -127,13 +130,25 @@ __device__ void kthvalueFwd(const DTYPE* input, } __syncthreads(); } - // remove use share mem for(size_t i = 0; i < RADIX_SIZE; ++i) { counts[i] = lsum[0][i]; } __syncthreads(); + // __syncthreads(); + // #pragma unroll + // for(size_t i = 0; i < RADIX_SIZE; ++i) + // { + // counts[i] = block_reduce_sum(counts[i]); + // if(lid == 0) + // { + // smem_count[i] = counts[i]; + // } + // __syncthreads(); + // counts[i] = smem_count[i]; + // } + bool found = false; // Process in ascending order for(size_t j = 0; j < RADIX_SIZE; ++j) @@ -176,18 +191,23 @@ __device__ void kthvalueFwd(const DTYPE* input, __syncthreads(); if(lid == 0) { - output[gid] = lval; - indices[gid] = lidx; + auto output_layout = tensor_layout_t<4>(output_tv, gid); + auto indices_layout = tensor_layout_t<4>(indices_tv, gid); + output[output_tv.get_tensor_view_idx(output_layout)] = lval; + indices[indices_tv.get_tensor_view_idx(indices_layout)] = lidx; } } -extern "C" __global__ void KthvalueUnreducedFwd(const IN_OUT_TYPE* input, - IN_OUT_TYPE* output, - size_t* indices, - size_t k, - size_t dim_size, - size_t dim_stride, - tensor_view_t<4> input_tv) +extern "C" __global__ void KthvalueFwd(const IN_OUT_TYPE* input, + IN_OUT_TYPE* output, + size_t* indices, + size_t k, + size_t dim_size, + size_t dim_stride, + tensor_view_t<4> input_tv, + tensor_view_t<4> output_tv, + tensor_view_t<4> indices_tv) { - kthvalueFwd(input, output, indices, k, dim_size, dim_stride, input_tv); + kthvalueFwd( + input, output, indices, k, dim_size, dim_stride, input_tv, output_tv, indices_tv); } diff --git a/src/kernels/radix.hpp b/src/kernels/radix.hpp index 2c74a16712..a8a6f10966 100644 --- a/src/kernels/radix.hpp +++ b/src/kernels/radix.hpp @@ -1,5 +1,6 @@ #pragma once #include +#include // #include "backend/modnn/core/op/device/hip/utils/hip_atomic.h" @@ -20,13 +21,13 @@ struct RadixType { }; -DEFINE_RADIX_TYPE(uint8_t, uint32_t) -DEFINE_RADIX_TYPE(int8_t, uint32_t) -DEFINE_RADIX_TYPE(int16_t, uint32_t) +// DEFINE_RADIX_TYPE(uint8_t, uint32_t) +// DEFINE_RADIX_TYPE(int8_t, uint32_t) +// DEFINE_RADIX_TYPE(int16_t, uint32_t) DEFINE_RADIX_TYPE(int32_t, uint32_t) DEFINE_RADIX_TYPE(int64_t, uint64_t) DEFINE_RADIX_TYPE(bool, bool) -DEFINE_RADIX_TYPE(_Float16, uint16_t) +// DEFINE_RADIX_TYPE(_Float16, uint16_t) DEFINE_RADIX_TYPE(float, uint32_t) DEFINE_RADIX_TYPE(double, uint64_t) @@ -37,18 +38,18 @@ __device__ inline Radix encode(DTYPE v) { return v; } - else if constexpr(std::is_same::value) - { - return v; - } - else if constexpr(std::is_same::value) - { - return 128u + v; - } - else if constexpr(std::is_same::value) - { - return 32768u + v; - } + // else if constexpr(std::is_same::value) + // { + // return v; + // } + // else if constexpr(std::is_same::value) + // { + // return 128u + v; + // } + // else if constexpr(std::is_same::value) + // { + // return 32768u + v; + // } else if constexpr(std::is_same::value) { return 2147483648u + v; @@ -84,18 +85,18 @@ __device__ inline DTYPE decode(Radix v) { return v; } - else if constexpr(std::is_same::value) - { - return v; - } - else if constexpr(std::is_same::value) - { - return v - 128u; - } - else if constexpr(std::is_same::value) - { - return v - 32768u; - } + // else if constexpr(std::is_same::value) + // { + // return v; + // } + // else if constexpr(std::is_same::value) + // { + // return v - 128u; + // } + // else if constexpr(std::is_same::value) + // { + // return v - 32768u; + // } else if constexpr(std::is_same::value) { return v - 2147483648u; @@ -133,4 +134,4 @@ template __device__ inline Radix SetBitFieldImpl(Radix x, Radix a, int pos, int bits) { return x | (a << pos); -} \ No newline at end of file +} diff --git a/src/kernels/warp_shuffle.hpp b/src/kernels/warp_shuffle.hpp new file mode 100644 index 0000000000..addd236f69 --- /dev/null +++ b/src/kernels/warp_shuffle.hpp @@ -0,0 +1,76 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#endif + +#include "float_types.h" + +#ifndef REDUCE_SIZE +#define REDUCE_SIZE 256 +#endif + +template +__device__ __forceinline__ DTYPE warp_reduce_sum(DTYPE val) +{ + if(warpSize >= 64) + val += __shfl_down(val, 32); + if(warpSize >= 32) + val += __shfl_down(val, 16); + if(warpSize >= 16) + val += __shfl_down(val, 8); + if(warpSize >= 8) + val += __shfl_down(val, 4); + if(warpSize >= 4) + val += __shfl_down(val, 2); + if(warpSize >= 2) + val += __shfl_down(val, 1); + return val; +} + +template +__device__ __forceinline__ DTYPE block_reduce_sum(DTYPE val) +{ + static __shared__ DTYPE shared[REDUCE_SIZE / warpSize]; + auto lane = threadIdx.x % warpSize; + auto wid = threadIdx.x / warpSize; + + val = warp_reduce_sum(val); + + if(lane == 0) + shared[wid] = val; + __syncthreads(); + + val = threadIdx.x < REDUCE_SIZE / warpSize ? shared[lane] : 0; + if(wid == 0) + // val = (threadIdx.x % warpSize) < REDUCE_SIZE / warpSize ? shared[lane] : 0; + // if(lane == 0) + val = warp_reduce_sum(val); + + return val; +} diff --git a/src/kthvalue.cpp b/src/kthvalue.cpp index 98f6e0d01b..093841ca32 100644 --- a/src/kthvalue.cpp +++ b/src/kthvalue.cpp @@ -43,43 +43,38 @@ miopenStatus_t KthvalueForward(Handle& handle, ConstData_t input, const TensorDescriptor& outputDesc, Data_t output, + const TensorDescriptor& indicesDesc, size_t* indices, size_t k, int32_t dim) { - // const auto problem = - // kthvalue::KthvalueFwdProblemDescription{inputDesc, targetDesc, outputDesc, reduction}; + const auto problem = + kthvalue::KthvalueFwdProblemDescription{inputDesc, outputDesc, indicesDesc, dim, k}; - // const auto invoke_params = [&]() { - // auto tmp = kthvalue::FwdInvokeParams{}; - // tmp.inputDesc = &inputDesc; - // tmp.targetDesc = &targetDesc; - // tmp.outputDesc = &outputDesc; - // tmp.input = input; - // tmp.target = target; - // tmp.output = output; - // tmp.workspace = workspace; - // tmp.workspace_size = workspaceSizeInBytes; - // tmp.alpha = alpha; - // tmp.gamma = gamma; - // tmp.reduction = reduction; - // return tmp; - // }(); + if(dim < 0) + { + dim += indicesDesc.GetSize(); + } - // if(reduction == MIOPEN_LOSS_REDUCTION_NONE) - // { - // const auto algo = AlgorithmName{"KthvalueUnreducedFwd"}; - // const auto solvers = solver::SolverContainer{}; + const auto invoke_params = [&]() { + auto tmp = kthvalue::FwdInvokeParams{}; + tmp.inputDesc = &inputDesc; + tmp.outputDesc = &outputDesc; + tmp.indicesDesc = &indicesDesc; + tmp.input = input; + tmp.indices = indices; + tmp.output = output; + tmp.workspace = workspace; + tmp.workspace_size = workspaceSizeInBytes; + tmp.k = k; + tmp.dim = dim; + return tmp; + }(); - // solvers.ExecutePrimitive(handle, problem, algo, invoke_params); - // } - // else - // { - // const auto algo = AlgorithmName{"KthvalueFwd"}; - // const auto solvers = solver::SolverContainer{}; + const auto algo = AlgorithmName{"KthvalueFwd"}; + const auto solvers = solver::SolverContainer{}; - // solvers.ExecutePrimitive(handle, problem, algo, invoke_params); - // } + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); return miopenStatusSuccess; } diff --git a/src/kthvalue/problem_description.cpp b/src/kthvalue/problem_description.cpp index 96e7f2ed67..a569044dbc 100644 --- a/src/kthvalue/problem_description.cpp +++ b/src/kthvalue/problem_description.cpp @@ -37,13 +37,11 @@ NetworkConfig KthvalueFwdProblemDescription::MakeNetworkConfig() const { auto input_dtype = inputDesc.GetType(); auto size = inputDesc.GetElementSize(); - auto dim_num = inputDesc.GetSize(); std::ostringstream ss; ss << "kthvalue_fwd"; ss << "i_dtype" << input_dtype; - ss << "dim_num" << dim_num; ss << "size" << size; return NetworkConfig{ss.str()}; diff --git a/src/kthvalue_api.cpp b/src/kthvalue_api.cpp index 20408ecea5..7aae2d4d2a 100644 --- a/src/kthvalue_api.cpp +++ b/src/kthvalue_api.cpp @@ -80,6 +80,7 @@ extern "C" miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, const void* input, miopenTensorDescriptor_t outputDesc, void* output, + miopenTensorDescriptor_t indicesDesc, size_t* indices, size_t k, int32_t dim) @@ -91,6 +92,7 @@ extern "C" miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, input, outputDesc, output, + indicesDesc, indices, k, dim); @@ -105,6 +107,7 @@ extern "C" miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, DataCast(input), miopen::deref(outputDesc), DataCast(output), + miopen::deref(indicesDesc), indices, k, dim); diff --git a/src/solver/kthvalue/forward_kthvalue.cpp b/src/solver/kthvalue/forward_kthvalue.cpp index 07b2c2f5de..8eb494cf02 100644 --- a/src/solver/kthvalue/forward_kthvalue.cpp +++ b/src/solver/kthvalue/forward_kthvalue.cpp @@ -65,7 +65,7 @@ KthvalueFwd::GetSolution(const ExecutionContext& context, auto dim_size = input_desc.GetLengths()[problem.GetDim()]; size_t xlocalsize = LOCAL_SIZE; - size_t xgridsize = AlignUp(size / dim_size, xlocalsize); + size_t xgridsize = size / dim_size * xlocalsize; size_t ylocalsize = 1; size_t ygridsize = 1; size_t zlocalsize = 1; @@ -74,7 +74,7 @@ KthvalueFwd::GetSolution(const ExecutionContext& context, auto kernel = KernelInfo{}; kernel.kernel_file = "MIOpenKthvalue.cpp"; - kernel.kernel_name = "KthvalueForward"; + kernel.kernel_name = "KthvalueFwd"; const auto build_params = KernelBuildParameters{ {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, @@ -103,8 +103,10 @@ KthvalueFwd::GetSolution(const ExecutionContext& context, size_t dimSize = params.inputDesc->GetLengths()[params.dim]; size_t dimStride = params.inputDesc->GetStrides()[params.dim]; - auto input_tv = get_inner_expanded_tv<5>(deref(params.inputDesc)); - auto input_tv_without_reduced_dim = get_tv_without_dim<5, 4>(input_tv, params.dim); + auto input_tv = get_inner_expanded_tv<5>(deref(params.inputDesc)); + auto output_tv = get_inner_expanded_tv<4>(deref(params.outputDesc)); + auto indices_tv = get_inner_expanded_tv<4>(deref(params.indicesDesc)); + auto input_tv_without_selected_dim = get_tv_without_dim<5, 4>(input_tv, params.dim); kernel(params.input, params.output, @@ -112,7 +114,9 @@ KthvalueFwd::GetSolution(const ExecutionContext& context, params.k, dimSize, dimStride, - input_tv_without_reduced_dim); + input_tv_without_selected_dim, + output_tv, + indices_tv); }; }; From b27cb594570752f282fe24c730a1501ca0cb08a0 Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Thu, 27 Jun 2024 14:37:26 +0700 Subject: [PATCH 03/28] check improvement over rocm --- driver/kthvalue_driver.hpp | 20 +++++++--- .../miopen/kthvalue/problem_description.hpp | 12 +++--- src/include/miopen/kthvalue/solvers.hpp | 18 +++------ src/kernels/MIOpenKthvalue.cpp | 38 +++++++++---------- src/kernels/float_types.h | 2 +- src/kernels/radix.hpp | 38 ++----------------- src/kthvalue.cpp | 2 +- src/kthvalue/problem_description.cpp | 2 +- src/solver/kthvalue/forward_kthvalue.cpp | 25 ++++++------ 9 files changed, 66 insertions(+), 91 deletions(-) diff --git a/driver/kthvalue_driver.hpp b/driver/kthvalue_driver.hpp index b6e9b46574..f2922f031b 100644 --- a/driver/kthvalue_driver.hpp +++ b/driver/kthvalue_driver.hpp @@ -48,6 +48,7 @@ void mloKthvalueFwdRunHost(TIO* input, TIO* outputHost, miopenTensorDescriptor_t outputDesc, size_t* indices, + miopenTensorDescriptor_t indicesDesc, size_t k, int dim) { @@ -57,7 +58,8 @@ void mloKthvalueFwdRunHost(TIO* input, size_t dimStride = inputDesc.GetStrides()[dim]; auto inputTv = miopen::get_inner_expanded_tv<5>(miopen::deref(pInputDesc)); auto inputTvWithoutDim = miopen::get_tv_without_dim<5, 4>(inputTv, dim); - // auto outputTv = miopen::get_inner_expanded_tv<5>(miopen::deref(outputDesc)); + auto outputTv = miopen::get_inner_expanded_tv<4>(miopen::deref(outputDesc)); + auto indicesTv = miopen::get_inner_expanded_tv<4>(miopen::deref(indicesDesc)); size_t numSlice = inputSize / dimSize; @@ -82,8 +84,10 @@ void mloKthvalueFwdRunHost(TIO* input, std::sort(ids.begin(), ids.end(), [=](size_t x, size_t y) -> bool { return elements[x] < elements[y]; }); - outputHost[slideID] = elements[ids[k - 1]]; - indices[slideID] = ids[k - 1]; + auto output_layout = tensor_layout_t<4>(outputTv, slideID); + auto indices_layout = tensor_layout_t<4>(indicesTv, slideID); + outputHost[outputTv.get_tensor_view_idx(output_layout)] = elements[ids[k - 1]]; + indices[indicesTv.get_tensor_view_idx(indices_layout)] = ids[k - 1]; } } @@ -361,8 +365,14 @@ int KthvalueDriver::RunForwardGPU() template int KthvalueDriver::RunForwardCPU() { - mloKthvalueFwdRunHost( - input.data(), inputDesc, outputHost.data(), outputDesc, indicesHost.data(), k, dim); + mloKthvalueFwdRunHost(input.data(), + inputDesc, + outputHost.data(), + outputDesc, + indicesHost.data(), + indicesDesc, + k, + dim); return miopenStatusSuccess; } diff --git a/src/include/miopen/kthvalue/problem_description.hpp b/src/include/miopen/kthvalue/problem_description.hpp index 0f3ea7eb39..866cac1be7 100644 --- a/src/include/miopen/kthvalue/problem_description.hpp +++ b/src/include/miopen/kthvalue/problem_description.hpp @@ -40,13 +40,13 @@ struct NetworkConfig; namespace kthvalue { -struct KthvalueFwdProblemDescription : ProblemDescriptionBase +struct FwdProblemDescription : ProblemDescriptionBase { - KthvalueFwdProblemDescription(const TensorDescriptor& inputDesc_, - const TensorDescriptor& outputDesc_, - const TensorDescriptor& indicesDesc_, - int32_t dim_, - size_t k_) + FwdProblemDescription(const TensorDescriptor& inputDesc_, + const TensorDescriptor& outputDesc_, + const TensorDescriptor& indicesDesc_, + int32_t dim_, + size_t k_) : inputDesc(inputDesc_), outputDesc(outputDesc_), indicesDesc(indicesDesc_), diff --git a/src/include/miopen/kthvalue/solvers.hpp b/src/include/miopen/kthvalue/solvers.hpp index ab17daae83..7e90192c04 100644 --- a/src/include/miopen/kthvalue/solvers.hpp +++ b/src/include/miopen/kthvalue/solvers.hpp @@ -37,25 +37,17 @@ namespace solver { namespace kthvalue { using KthvalueFwdSolverBase = - NonTunableSolverBase; + NonTunableSolverBase; struct KthvalueFwd final : KthvalueFwdSolverBase { const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool - IsApplicable(const ExecutionContext& context, - const miopen::kthvalue::KthvalueFwdProblemDescription& problem) const override; + bool IsApplicable(const ExecutionContext& context, + const miopen::kthvalue::FwdProblemDescription& problem) const override; - ConvSolution - GetSolution(const ExecutionContext& context, - const miopen::kthvalue::KthvalueFwdProblemDescription& problem) const override; - - std::size_t - GetWorkspaceSize(const ExecutionContext& context, - const miopen::kthvalue::KthvalueFwdProblemDescription& problem) const override; - - bool MayNeedWorkspace() const override { return true; } + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::kthvalue::FwdProblemDescription& problem) const override; }; } // namespace kthvalue diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index 1366c714f4..2e698f0c53 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -18,23 +18,21 @@ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACDTYPEN OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECDTYPEN WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * *******************************************************************************/ -// #include -#include #ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS #include #include #endif #include "float_types.h" -#include "radix.hpp" #include "tensor_view.hpp" #include "warp_shuffle.hpp" +#include "radix.hpp" #ifndef IN_OUT_TYPE #define IN_OUT_TYPE float @@ -64,9 +62,11 @@ #define RADIX_MASK (RADIX_SIZE - 1) #endif -template -__device__ void kthvalueFwd(const DTYPE* input, - DTYPE* output, +#define RADIX_TYPE RadixType::type + +template +__device__ void kthvalueFwd(const TIO* input, + TIO* output, size_t* indices, size_t k, size_t dim_size, @@ -87,8 +87,7 @@ __device__ void kthvalueFwd(const DTYPE* input, size_t gid = blockIdx.x; __shared__ size_t lsum[LOCAL_SIZE][RADIX_SIZE]; - __shared__ size_t smem_count[RADIX_SIZE]; - __shared__ DTYPE lval; + __shared__ FLOAT_ACCUM lval; __shared__ long lidx; size_t counts[RADIX_SIZE]; RADIX_TYPE desired_mask = 0; @@ -107,10 +106,10 @@ __device__ void kthvalueFwd(const DTYPE* input, for(size_t i = lid; i < dim_size; i += LOCAL_SIZE) { size_t input_idx = idx + i * dim_stride; - RADIX_TYPE val = ENCODE(input[input_idx]); + RADIX_TYPE val = encode(CVT_FLOAT2ACCUM(input[input_idx])); if((val & desired_mask) == desired) { - ++counts[GetBitField(val, pos, RADIX_BITS)]; + ++counts[GetBitFieldImpl(val, pos, RADIX_BITS)]; } } @@ -164,11 +163,11 @@ __device__ void kthvalueFwd(const DTYPE* input, // There are multiple answers so we return any of them for(size_t i = lid; i < dim_size; i += LOCAL_SIZE) { - size_t input_idx = idx + i * dim_stride; - DTYPE val_ori = input[input_idx]; - RADIX_TYPE val = ENCODE(val_ori); + size_t input_idx = idx + i * dim_stride; + FLOAT_ACCUM val_ori = CVT_FLOAT2ACCUM(input[input_idx]); + RADIX_TYPE val = encode(val_ori); if((val & desired_mask) == desired && - GetBitField(val, pos, RADIX_BITS) == j) + GetBitFieldImpl(val, pos, RADIX_BITS) == j) { // For case 2, this will be non-deterministic. lval = val_ori; @@ -178,8 +177,9 @@ __device__ void kthvalueFwd(const DTYPE* input, found = true; break; } - desired = SetBitField(desired, j, pos, RADIX_BITS); - desired_mask = SetBitField(desired_mask, RADIX_MASK, pos, RADIX_BITS); + desired = SetBitFieldImpl(desired, j, pos, RADIX_BITS); + desired_mask = + SetBitFieldImpl(desired_mask, RADIX_MASK, pos, RADIX_BITS); break; } k -= counts[j]; @@ -193,7 +193,7 @@ __device__ void kthvalueFwd(const DTYPE* input, { auto output_layout = tensor_layout_t<4>(output_tv, gid); auto indices_layout = tensor_layout_t<4>(indices_tv, gid); - output[output_tv.get_tensor_view_idx(output_layout)] = lval; + output[output_tv.get_tensor_view_idx(output_layout)] = CVT_ACCUM2FLOAT(lval); indices[indices_tv.get_tensor_view_idx(indices_layout)] = lidx; } } diff --git a/src/kernels/float_types.h b/src/kernels/float_types.h index dc29a66a41..7a88112474 100644 --- a/src/kernels/float_types.h +++ b/src/kernels/float_types.h @@ -106,7 +106,7 @@ #define _FLOAT_ACCUM double #endif // __HIP_PLATFORM_AMD__ #define MAX_VAL_ACCUM DBL_MAX -#else // MIOPEN_USE_DOUBLE_ACCUM +#else // MIOPEN_USE_DOUBLE_ACCUM #ifdef __HIP_PLATFORM_AMD__ #define FLOAT_ACCUM float #else diff --git a/src/kernels/radix.hpp b/src/kernels/radix.hpp index a8a6f10966..429ce9b6c8 100644 --- a/src/kernels/radix.hpp +++ b/src/kernels/radix.hpp @@ -2,12 +2,10 @@ #include #include -// #include "backend/modnn/core/op/device/hip/utils/hip_atomic.h" - -#define ENCODE encode -#define RADIX_TYPE typename RadixType::type -#define GetBitField(x, pos, bits) GetBitFieldImpl(x, pos, bits) -#define SetBitField(x, a, pos, bits) SetBitFieldImpl(x, a, pos, bits) +// #define ENCODE encode +// #define RADIX_TYPE typename RadixType::type +// #define GetBitField(x, pos, bits) GetBitFieldImpl(x, pos, bits) +// #define SetBitField(x, a, pos, bits) SetBitFieldImpl(x, a, pos, bits) #define DEFINE_RADIX_TYPE(DTYPE, cpp_type) \ template <> \ @@ -21,13 +19,9 @@ struct RadixType { }; -// DEFINE_RADIX_TYPE(uint8_t, uint32_t) -// DEFINE_RADIX_TYPE(int8_t, uint32_t) -// DEFINE_RADIX_TYPE(int16_t, uint32_t) DEFINE_RADIX_TYPE(int32_t, uint32_t) DEFINE_RADIX_TYPE(int64_t, uint64_t) DEFINE_RADIX_TYPE(bool, bool) -// DEFINE_RADIX_TYPE(_Float16, uint16_t) DEFINE_RADIX_TYPE(float, uint32_t) DEFINE_RADIX_TYPE(double, uint64_t) @@ -38,18 +32,6 @@ __device__ inline Radix encode(DTYPE v) { return v; } - // else if constexpr(std::is_same::value) - // { - // return v; - // } - // else if constexpr(std::is_same::value) - // { - // return 128u + v; - // } - // else if constexpr(std::is_same::value) - // { - // return 32768u + v; - // } else if constexpr(std::is_same::value) { return 2147483648u + v; @@ -85,18 +67,6 @@ __device__ inline DTYPE decode(Radix v) { return v; } - // else if constexpr(std::is_same::value) - // { - // return v; - // } - // else if constexpr(std::is_same::value) - // { - // return v - 128u; - // } - // else if constexpr(std::is_same::value) - // { - // return v - 32768u; - // } else if constexpr(std::is_same::value) { return v - 2147483648u; diff --git a/src/kthvalue.cpp b/src/kthvalue.cpp index 093841ca32..03d5cb3fb4 100644 --- a/src/kthvalue.cpp +++ b/src/kthvalue.cpp @@ -49,7 +49,7 @@ miopenStatus_t KthvalueForward(Handle& handle, int32_t dim) { const auto problem = - kthvalue::KthvalueFwdProblemDescription{inputDesc, outputDesc, indicesDesc, dim, k}; + kthvalue::FwdProblemDescription{inputDesc, outputDesc, indicesDesc, dim, k}; if(dim < 0) { diff --git a/src/kthvalue/problem_description.cpp b/src/kthvalue/problem_description.cpp index a569044dbc..dc812109c0 100644 --- a/src/kthvalue/problem_description.cpp +++ b/src/kthvalue/problem_description.cpp @@ -33,7 +33,7 @@ namespace miopen { namespace kthvalue { -NetworkConfig KthvalueFwdProblemDescription::MakeNetworkConfig() const +NetworkConfig FwdProblemDescription::MakeNetworkConfig() const { auto input_dtype = inputDesc.GetType(); auto size = inputDesc.GetElementSize(); diff --git a/src/solver/kthvalue/forward_kthvalue.cpp b/src/solver/kthvalue/forward_kthvalue.cpp index 8eb494cf02..b311155eee 100644 --- a/src/solver/kthvalue/forward_kthvalue.cpp +++ b/src/solver/kthvalue/forward_kthvalue.cpp @@ -27,6 +27,7 @@ #include "miopen/errors.hpp" #include "miopen/kthvalue/problem_description.hpp" #include "miopen/miopen.h" +#include "miopen/tensor.hpp" #include "miopen/tensor_view_utils.hpp" #include #include @@ -43,17 +44,26 @@ namespace solver { namespace kthvalue { +bool IsImprovementOverROCm(const miopen::kthvalue::FwdProblemDescription& problem) +{ + TensorDescriptor inputDesc = problem.GetInputDesc(); + size_t dimSize = inputDesc.GetLengths()[problem.GetDim()]; + + return dimSize >= 300; +} + bool KthvalueFwd::IsApplicable(const ExecutionContext& /*context*/, - const miopen::kthvalue::KthvalueFwdProblemDescription& problem) const + const miopen::kthvalue::FwdProblemDescription& problem) const { if(problem.GetInputDesc().GetSize() > 5) return false; + if(!IsImprovementOverROCm(problem)) + return false; return true; } -ConvSolution -KthvalueFwd::GetSolution(const ExecutionContext& context, - const miopen::kthvalue::KthvalueFwdProblemDescription& problem) const +ConvSolution KthvalueFwd::GetSolution(const ExecutionContext& context, + const miopen::kthvalue::FwdProblemDescription& problem) const { std::ignore = context; auto result = ConvSolution{miopenStatusSuccess}; @@ -123,13 +133,6 @@ KthvalueFwd::GetSolution(const ExecutionContext& context, return result; } -std::size_t KthvalueFwd::GetWorkspaceSize( - const ExecutionContext& /*context*/, - const miopen::kthvalue::KthvalueFwdProblemDescription& /*problem*/) const -{ - return 0; -} - } // namespace kthvalue } // namespace solver From 0f0e59f5ae20e1c3a6aa81bfb959b061c85e19df Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Fri, 28 Jun 2024 14:00:17 +0700 Subject: [PATCH 04/28] validate tensor length in prob description --- driver/kthvalue_driver.hpp | 52 ++--- include/miopen/miopen.h | 10 +- src/include/miopen/kthvalue.hpp | 5 +- src/include/miopen/kthvalue/invoke_params.hpp | 14 +- .../miopen/kthvalue/problem_description.hpp | 50 ++++- .../miopen/reduce/problem_description.hpp | 2 +- src/kernels/MIOpenKthvalue.cpp | 12 +- src/kernels/float_types.h | 2 +- src/kthvalue.cpp | 32 ++- src/kthvalue/problem_description.cpp | 4 +- src/kthvalue_api.cpp | 23 +-- src/solver/kthvalue/forward_kthvalue.cpp | 7 +- test/cpu_kthvalue.hpp | 54 +++++ test/gtest/kthvalue.cpp | 114 ++++++++++ test/gtest/kthvalue.hpp | 194 ++++++++++++++++++ 15 files changed, 485 insertions(+), 90 deletions(-) create mode 100644 test/cpu_kthvalue.hpp create mode 100644 test/gtest/kthvalue.cpp create mode 100644 test/gtest/kthvalue.hpp diff --git a/driver/kthvalue_driver.hpp b/driver/kthvalue_driver.hpp index f2922f031b..d35e6e5208 100644 --- a/driver/kthvalue_driver.hpp +++ b/driver/kthvalue_driver.hpp @@ -58,12 +58,12 @@ void mloKthvalueFwdRunHost(TIO* input, size_t dimStride = inputDesc.GetStrides()[dim]; auto inputTv = miopen::get_inner_expanded_tv<5>(miopen::deref(pInputDesc)); auto inputTvWithoutDim = miopen::get_tv_without_dim<5, 4>(inputTv, dim); - auto outputTv = miopen::get_inner_expanded_tv<4>(miopen::deref(outputDesc)); - auto indicesTv = miopen::get_inner_expanded_tv<4>(miopen::deref(indicesDesc)); + auto outputTv = miopen::get_inner_expanded_tv<5>(miopen::deref(outputDesc)); + auto indicesTv = miopen::get_inner_expanded_tv<5>(miopen::deref(indicesDesc)); size_t numSlice = inputSize / dimSize; - std::vector elements; + std::vector elements; std::vector ids(dimSize); for(int i = 0; i < dimSize; ++i) { @@ -78,16 +78,17 @@ void mloKthvalueFwdRunHost(TIO* input, for(int j = 0; j < dimSize; ++j) { - elements.push_back(input[idx + j * dimStride]); + elements.push_back(static_cast(input[idx + j * dimStride])); } std::sort(ids.begin(), ids.end(), [=](size_t x, size_t y) -> bool { return elements[x] < elements[y]; }); - auto output_layout = tensor_layout_t<4>(outputTv, slideID); - auto indices_layout = tensor_layout_t<4>(indicesTv, slideID); - outputHost[outputTv.get_tensor_view_idx(output_layout)] = elements[ids[k - 1]]; - indices[indicesTv.get_tensor_view_idx(indices_layout)] = ids[k - 1]; + auto output_layout = tensor_layout_t<5>(outputTv, slideID); + auto indices_layout = tensor_layout_t<5>(indicesTv, slideID); + outputHost[outputTv.get_tensor_view_idx(output_layout)] = + static_cast(elements[ids[k - 1]]); + indices[indicesTv.get_tensor_view_idx(indices_layout)] = ids[k - 1]; } } @@ -146,7 +147,6 @@ class KthvalueDriver : public Driver std::unique_ptr output_dev; std::unique_ptr doutput_dev; std::unique_ptr dinput_dev; - std::unique_ptr workspace_dev; std::vector input; std::vector indices; @@ -156,13 +156,11 @@ class KthvalueDriver : public Driver std::vector doutput; std::vector dinput; std::vector dinputHost; - std::vector workspace; bool isContiguous; int dim; size_t k; - - size_t workSpaceSizeInBytes; + bool keepDim; }; template @@ -172,6 +170,7 @@ int KthvalueDriver::ParseCmdLineArgs(int argc, char* argv[]) isContiguous = inflags.GetValueInt("is-contiguous") == 1 ? true : false; k = inflags.GetValueInt("k"); dim = inflags.GetValueInt("dim"); + keepDim = inflags.GetValueInt("keep-dim") == 1 ? true : false; auto inDims = inflags.GetValueTensor("dim-lengths").lengths; int num_dim = inDims.size(); if(dim < -num_dim || dim >= num_dim) @@ -197,13 +196,22 @@ int KthvalueDriver::GetandSetData() { dim += inDims.size(); } - outDims.erase(outDims.begin() + dim); + if(!keepDim) + { + outDims.erase(outDims.begin() + dim); + if(outDims.empty()) + outDims.push_back(1); + } + else + { + outDims[dim] = 1; + } SetTensorNd(inputDesc, inDims, inStride, data_type); SetTensorNd(doutputDesc, outDims, data_type); SetTensorNd(dinputDesc, inDims, data_type); SetTensorNd(outputDesc, outDims, data_type); - // miopenDataType_t doesn't support size_t tensor, I use double instead (both types use 64 bits) + // miopenDataType_t doesn't support size_t, I use double instead (both types use 64 bits) SetTensorNd(indicesDesc, outDims, miopen_type{}); return 0; @@ -232,6 +240,11 @@ int KthvalueDriver::AddCmdLineArgs() "dim-lengths", 'D', "256x4x2", "The dimensional lengths of the input tensor"); inflags.AddInputFlag("k", 'k', "1", "dim (Default=1)", "int"); inflags.AddInputFlag("dim", 'd', "-1", "dim (Default=-1)", "int"); + inflags.AddInputFlag("keep-dim", + 'K', + "0", + "Whether the output tensor has dim retained or not (Default=0)", + "int"); inflags.AddInputFlag("is-contiguous", 'c', "1", "is-contiguous (Default=1)", "int"); inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); inflags.AddInputFlag("verify", 'V', "1", "Verify Each Layer (Default=1)", "int"); @@ -259,11 +272,6 @@ int KthvalueDriver::AllocateBuffersAndCopy() doutput_dev = std::unique_ptr(new GPUMem(ctx, dO_sz, sizeof(TIO))); dinput_dev = std::unique_ptr(new GPUMem(ctx, dI_sz, sizeof(TIO))); - // miopenGetKthvalueForwardWorkspaceSize(handle, inputDesc, outputDesc, &workSpaceSizeInBytes); - workSpaceSizeInBytes = 0; - workspace_dev = - std::unique_ptr(new GPUMem(ctx, workSpaceSizeInBytes / sizeof(TIO), sizeof(TIO))); - input = std::vector(in_sz, static_cast(0)); indices = std::vector(idx_sz, 0); indicesHost = std::vector(idx_sz, 0); @@ -272,7 +280,6 @@ int KthvalueDriver::AllocateBuffersAndCopy() doutput = std::vector(dO_sz, static_cast(0)); dinput = std::vector(dI_sz, static_cast(0)); dinputHost = std::vector(dI_sz, static_cast(0)); - workspace = std::vector(workSpaceSizeInBytes / sizeof(TIO), static_cast(0)); for(int i = 0; i < in_sz; i++) { @@ -302,9 +309,6 @@ int KthvalueDriver::AllocateBuffersAndCopy() if(dinput_dev->ToGPU(GetStream(), dinput.data()) != 0) std::cerr << "Error copying (dI) to GPU, size: " << dinput_dev->GetSize() << std::endl; - if(workspace_dev->ToGPU(GetStream(), workspace.data()) != 0) - std::cerr << "Error copying (dI) to GPU, size: " << workspace_dev->GetSize() << std::endl; - return miopenStatusSuccess; } @@ -320,8 +324,6 @@ int KthvalueDriver::RunForwardGPU() for(int i = 0; i < inflags.GetValueInt("iter"); i++) { miopenKthvalueForward(GetHandle(), - workspace_dev->GetMem(), - workSpaceSizeInBytes, inputDesc, input_dev->GetMem(), outputDesc, diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index ce5fa5d4d2..3cdde80edb 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -524,7 +524,7 @@ typedef enum miopenActivationABS = 5, /*!< Absolute value \f$abs(x)\f$ */ miopenActivationPOWER = 6, /*!< Scaled and shifted power \f$(\alpha + \beta * x)^{gamma}\f$ */ miopenActivationCLIPPEDRELU = - 7, /*!< Clipped Rectified Linear Unit \f$ min(\alpha, max(0,x)) \f$ */ + 7, /*!< Clipped Rectified Linear Unit \f$ min(\alpha, max(0,x)) \f$ */ miopenActivationLEAKYRELU = 8, /*!< Leaky Rectified Linear Unit \f$ \alpha * x | x <= 0; x | x > 0 \f$ */ miopenActivationELU = @@ -7684,8 +7684,6 @@ MIOPEN_EXPORT miopenStatus_t miopenRoPEBackward(miopenHandle_t handle, /*! @brief Execute a Kthvalue forward layer * * @param handle MIOpen handle (input) - * @param workspace Address of the allocated workspace data (input) - * @param workspaceSizeInBytes Size in bytes of the allocated workspace data (input) * @param inputDesc Tensor descriptor for input tensor (input) * @param input Data tensor input (input) * @param outputDesc Tensor descriptor for output tensor (input) @@ -7693,11 +7691,10 @@ MIOPEN_EXPORT miopenStatus_t miopenRoPEBackward(miopenHandle_t handle, * @param indices Data tensor index (output) * @param k The k-th smallest element(input) * @param dim The dimension to find the kth value along(input) + * @param keepDim Whether the output tensor has dim retained or not(input) * @return miopenStatus_t */ MIOPEN_EXPORT miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, - void* workspace, - size_t workspaceSizeInBytes, miopenTensorDescriptor_t inputDesc, const void* input, miopenTensorDescriptor_t outputDesc, @@ -7705,7 +7702,8 @@ MIOPEN_EXPORT miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, miopenTensorDescriptor_t indicesDesc, size_t* indices, size_t k, - int32_t dim = -1); + int32_t dim = -1, + bool keepDim = false); /** @} */ // CLOSEOUT kthvalue DOXYGEN GROUP diff --git a/src/include/miopen/kthvalue.hpp b/src/include/miopen/kthvalue.hpp index d7dd6f1332..990ec77ddb 100644 --- a/src/include/miopen/kthvalue.hpp +++ b/src/include/miopen/kthvalue.hpp @@ -34,8 +34,6 @@ struct Handle; struct TensorDescriptor; miopenStatus_t KthvalueForward(Handle& handle, - Data_t workspace, - size_t workspaceSizeInBytes, const TensorDescriptor& inputDesc, ConstData_t input, const TensorDescriptor& outputDesc, @@ -43,7 +41,8 @@ miopenStatus_t KthvalueForward(Handle& handle, const TensorDescriptor& indicesDesc, size_t* indices, size_t k, - int32_t dim); + int32_t dim, + bool keepDim); } // namespace miopen #endif // MIOPEN_KTHVALUE_HPP_ diff --git a/src/include/miopen/kthvalue/invoke_params.hpp b/src/include/miopen/kthvalue/invoke_params.hpp index 6cd6896d9b..08d80c510f 100644 --- a/src/include/miopen/kthvalue/invoke_params.hpp +++ b/src/include/miopen/kthvalue/invoke_params.hpp @@ -42,15 +42,13 @@ struct KthvalueInvokeParams : public miopen::InvokeParams const TensorDescriptor* inputDesc = nullptr; - Data_t workspace = nullptr; - std::size_t workspace_size = 0; - ConstData_t input = nullptr; + ConstData_t input = nullptr; - size_t k = 1; - int32_t dim = 0; - - std::size_t GetWorkspaceSize() const { return workspace_size; } - Data_t GetWorkspace() const { return workspace; } + size_t k = 1; + int32_t dim = 0; + bool keepDim = false; + std::size_t GetWorkspaceSize() const { return 0; } + Data_t GetWorkspace() const { return nullptr; } }; struct FwdInvokeParams : KthvalueInvokeParams diff --git a/src/include/miopen/kthvalue/problem_description.hpp b/src/include/miopen/kthvalue/problem_description.hpp index 866cac1be7..c6f0e4d87c 100644 --- a/src/include/miopen/kthvalue/problem_description.hpp +++ b/src/include/miopen/kthvalue/problem_description.hpp @@ -58,11 +58,57 @@ struct FwdProblemDescription : ProblemDescriptionBase MIOPEN_THROW(miopenStatusBadParm, "Kthvalue: k must be less than the size of the dimension"); } - int num_dim = inputDesc.GetSize(); - if(dim < -num_dim || dim >= num_dim) + if(dim < 0 || dim >= inputDesc.GetSize()) { MIOPEN_THROW(miopenStatusBadParm, "Kthvalue: dim doesn't not exist"); } + if(inputDesc.GetType() != outputDesc.GetType()) + { + MIOPEN_THROW(miopenStatusBadParm, "Reduce: Input, output tensor types do not match."); + } + if(!IsRightLength()) + { + MIOPEN_THROW(miopenStatusBadParm, + "Reduce: Input and output tensor dimension lengths do not match."); + } + if(!IsSameLength()) + { + MIOPEN_THROW(miopenStatusBadParm, + "Reduce: Output and indices tensor dimension lengths do not match."); + } + } + + bool IsRightLength() const + { + if(inputDesc.GetLengths().size() == 1) + return true; + + int32_t posOut = 0; + for(int32_t i = 0; i < inputDesc.GetLengths().size(); i++) + { + if(i == dim) + continue; + + if(inputDesc.GetLengths()[i] != outputDesc.GetLengths()[posOut]) + { + return false; + } + + posOut++; + } + return true; + } + + bool IsSameLength() const + { + for(int32_t i = 0; i < outputDesc.GetLengths().size(); i++) + { + if(outputDesc.GetLengths()[i] != indicesDesc.GetLengths()[i]) + { + return false; + } + } + return true; } const TensorDescriptor& GetInputDesc() const { return inputDesc; } diff --git a/src/include/miopen/reduce/problem_description.hpp b/src/include/miopen/reduce/problem_description.hpp index 7e327ba565..294c79777e 100644 --- a/src/include/miopen/reduce/problem_description.hpp +++ b/src/include/miopen/reduce/problem_description.hpp @@ -239,7 +239,7 @@ struct ProblemDescriptionCalculation : ProblemDescriptionBase bool IsValidDim() const { - if((dim < 0) || (dim > xDesc.GetLengths().size())) + if((dim < 0) || (dim >= xDesc.GetLengths().size())) { MIOPEN_THROW( miopenStatusBadParm, diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index 2e698f0c53..dbf5be350d 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -72,8 +72,8 @@ __device__ void kthvalueFwd(const TIO* input, size_t dim_size, size_t dim_stride, tensor_view_t<4> input_tv, - tensor_view_t<4> output_tv, - tensor_view_t<4> indices_tv) + tensor_view_t<5> output_tv, + tensor_view_t<5> indices_tv) { /* * Example) @@ -191,8 +191,8 @@ __device__ void kthvalueFwd(const TIO* input, __syncthreads(); if(lid == 0) { - auto output_layout = tensor_layout_t<4>(output_tv, gid); - auto indices_layout = tensor_layout_t<4>(indices_tv, gid); + auto output_layout = tensor_layout_t<5>(output_tv, gid); + auto indices_layout = tensor_layout_t<5>(indices_tv, gid); output[output_tv.get_tensor_view_idx(output_layout)] = CVT_ACCUM2FLOAT(lval); indices[indices_tv.get_tensor_view_idx(indices_layout)] = lidx; } @@ -205,8 +205,8 @@ extern "C" __global__ void KthvalueFwd(const IN_OUT_TYPE* input, size_t dim_size, size_t dim_stride, tensor_view_t<4> input_tv, - tensor_view_t<4> output_tv, - tensor_view_t<4> indices_tv) + tensor_view_t<5> output_tv, + tensor_view_t<5> indices_tv) { kthvalueFwd( input, output, indices, k, dim_size, dim_stride, input_tv, output_tv, indices_tv); diff --git a/src/kernels/float_types.h b/src/kernels/float_types.h index 7a88112474..dc29a66a41 100644 --- a/src/kernels/float_types.h +++ b/src/kernels/float_types.h @@ -106,7 +106,7 @@ #define _FLOAT_ACCUM double #endif // __HIP_PLATFORM_AMD__ #define MAX_VAL_ACCUM DBL_MAX -#else // MIOPEN_USE_DOUBLE_ACCUM +#else // MIOPEN_USE_DOUBLE_ACCUM #ifdef __HIP_PLATFORM_AMD__ #define FLOAT_ACCUM float #else diff --git a/src/kthvalue.cpp b/src/kthvalue.cpp index 03d5cb3fb4..c098b3ee5b 100644 --- a/src/kthvalue.cpp +++ b/src/kthvalue.cpp @@ -37,8 +37,6 @@ namespace miopen { miopenStatus_t KthvalueForward(Handle& handle, - Data_t workspace, - size_t workspaceSizeInBytes, const TensorDescriptor& inputDesc, ConstData_t input, const TensorDescriptor& outputDesc, @@ -46,28 +44,28 @@ miopenStatus_t KthvalueForward(Handle& handle, const TensorDescriptor& indicesDesc, size_t* indices, size_t k, - int32_t dim) + int32_t dim, + bool keepDim) { - const auto problem = - kthvalue::FwdProblemDescription{inputDesc, outputDesc, indicesDesc, dim, k}; - if(dim < 0) { dim += indicesDesc.GetSize(); } + const auto problem = + kthvalue::FwdProblemDescription{inputDesc, outputDesc, indicesDesc, dim, k}; + const auto invoke_params = [&]() { - auto tmp = kthvalue::FwdInvokeParams{}; - tmp.inputDesc = &inputDesc; - tmp.outputDesc = &outputDesc; - tmp.indicesDesc = &indicesDesc; - tmp.input = input; - tmp.indices = indices; - tmp.output = output; - tmp.workspace = workspace; - tmp.workspace_size = workspaceSizeInBytes; - tmp.k = k; - tmp.dim = dim; + auto tmp = kthvalue::FwdInvokeParams{}; + tmp.inputDesc = &inputDesc; + tmp.outputDesc = &outputDesc; + tmp.indicesDesc = &indicesDesc; + tmp.input = input; + tmp.indices = indices; + tmp.output = output; + tmp.k = k; + tmp.dim = dim; + tmp.keepDim = keepDim; return tmp; }(); diff --git a/src/kthvalue/problem_description.cpp b/src/kthvalue/problem_description.cpp index dc812109c0..0d0086af6f 100644 --- a/src/kthvalue/problem_description.cpp +++ b/src/kthvalue/problem_description.cpp @@ -37,12 +37,14 @@ NetworkConfig FwdProblemDescription::MakeNetworkConfig() const { auto input_dtype = inputDesc.GetType(); auto size = inputDesc.GetElementSize(); + auto dim_size = inputDesc.GetLengths()[dim]; + auto output_size = size / dim_size; std::ostringstream ss; ss << "kthvalue_fwd"; ss << "i_dtype" << input_dtype; - ss << "size" << size; + ss << "output_size" << output_size; return NetworkConfig{ss.str()}; } diff --git a/src/kthvalue_api.cpp b/src/kthvalue_api.cpp index 7aae2d4d2a..db1833d369 100644 --- a/src/kthvalue_api.cpp +++ b/src/kthvalue_api.cpp @@ -74,8 +74,6 @@ static void LogCmdKthvalue(const miopenTensorDescriptor_t inputDesc, bool is_fwd } extern "C" miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, - void* workspace, - size_t workspaceSizeInBytes, miopenTensorDescriptor_t inputDesc, const void* input, miopenTensorDescriptor_t outputDesc, @@ -83,26 +81,16 @@ extern "C" miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, miopenTensorDescriptor_t indicesDesc, size_t* indices, size_t k, - int32_t dim) + int32_t dim, + bool keepDim) { - MIOPEN_LOG_FUNCTION(handle, - workspace, - workspaceSizeInBytes, - inputDesc, - input, - outputDesc, - output, - indicesDesc, - indices, - k, - dim); + MIOPEN_LOG_FUNCTION( + handle, inputDesc, input, outputDesc, output, indicesDesc, indices, k, dim, keepDim); LogCmdKthvalue(inputDesc, true); return miopen::try_([&] { miopen::KthvalueForward(miopen::deref(handle), - DataCast(workspace), - workspaceSizeInBytes, miopen::deref(inputDesc), DataCast(input), miopen::deref(outputDesc), @@ -110,6 +98,7 @@ extern "C" miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, miopen::deref(indicesDesc), indices, k, - dim); + dim, + keepDim); }); } diff --git a/src/solver/kthvalue/forward_kthvalue.cpp b/src/solver/kthvalue/forward_kthvalue.cpp index b311155eee..9998933a88 100644 --- a/src/solver/kthvalue/forward_kthvalue.cpp +++ b/src/solver/kthvalue/forward_kthvalue.cpp @@ -113,11 +113,12 @@ ConvSolution KthvalueFwd::GetSolution(const ExecutionContext& context, size_t dimSize = params.inputDesc->GetLengths()[params.dim]; size_t dimStride = params.inputDesc->GetStrides()[params.dim]; - auto input_tv = get_inner_expanded_tv<5>(deref(params.inputDesc)); - auto output_tv = get_inner_expanded_tv<4>(deref(params.outputDesc)); - auto indices_tv = get_inner_expanded_tv<4>(deref(params.indicesDesc)); + auto input_tv = get_inner_expanded_tv<5>(deref(params.inputDesc)); auto input_tv_without_selected_dim = get_tv_without_dim<5, 4>(input_tv, params.dim); + auto output_tv = get_inner_expanded_tv<5>(deref(params.outputDesc)); + auto indices_tv = get_inner_expanded_tv<5>(deref(params.indicesDesc)); + kernel(params.input, params.output, params.indices, diff --git a/test/cpu_kthvalue.hpp b/test/cpu_kthvalue.hpp new file mode 100644 index 0000000000..269f22e3c4 --- /dev/null +++ b/test/cpu_kthvalue.hpp @@ -0,0 +1,54 @@ +#pragma once + +#include "miopen/tensor.hpp" +#include "tensor_holder.hpp" +#include "tensor_view.hpp" +#include "miopen/tensor_view_utils.hpp" +#include + +template +void cpu_kthvalue(tensor input, + tensor& outputHost, + std::vector& indices, + miopen::TensorDescriptor indiceDesc, + size_t k, + int dim) +{ + size_t inputSize = input.desc.GetElementSize(); + size_t dimSize = input.desc.GetLengths()[dim]; + size_t dimStride = input.desc.GetStrides()[dim]; + auto inputTv = miopen::get_inner_expanded_tv<5>(input.desc); + auto inputTvWithoutDim = miopen::get_tv_without_dim<5, 4>(inputTv, dim); + auto outputTv = miopen::get_inner_expanded_tv<5>(outputHost.desc); + auto indicesTv = miopen::get_inner_expanded_tv<5>(indiceDesc); + + size_t numSlice = inputSize / dimSize; + + std::vector elements; + std::vector ids(dimSize); + for(int i = 0; i < dimSize; ++i) + { + ids[i] = i; + } + + for(int slideID = 0; slideID < numSlice; ++slideID) + { + elements.clear(); + tensor_layout_t<4> layout(inputTvWithoutDim, slideID); + auto idx = inputTvWithoutDim.get_tensor_view_idx(layout); + + for(int j = 0; j < dimSize; ++j) + { + elements.push_back(static_cast(input[idx + j * dimStride])); + } + + std::sort(ids.begin(), ids.end(), [=](size_t x, size_t y) -> bool { + return elements[x] < elements[y]; + }); + auto output_layout = tensor_layout_t<5>(outputTv, slideID); + auto indices_layout = tensor_layout_t<5>(indicesTv, slideID); + outputHost[outputTv.get_tensor_view_idx(output_layout)] = + static_cast(elements[ids[k - 1]]); + indices[indicesTv.get_tensor_view_idx(indices_layout)] = ids[k - 1]; + } +} diff --git a/test/gtest/kthvalue.cpp b/test/gtest/kthvalue.cpp new file mode 100644 index 0000000000..8fbbd10478 --- /dev/null +++ b/test/gtest/kthvalue.cpp @@ -0,0 +1,114 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "kthvalue.hpp" +#include "miopen/bfloat16.hpp" +#include "tensor_holder.hpp" +#include + +MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_TEST_FLOAT_ARG) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL) + +namespace kthvalue { + +std::string GetFloatArg() +{ + const auto& tmp = miopen::GetStringEnv(ENV(MIOPEN_TEST_FLOAT_ARG)); + if(tmp.empty()) + { + return ""; + } + return tmp; +} + +struct KthvalueForwardTestFloat32 : KthvalueFwdTest +{ +}; + +struct KthvalueForwardTestFloat16 : KthvalueFwdTest +{ +}; + +struct KthvalueForwardTestBFloat16 : KthvalueFwdTest +{ +}; + +using namespace kthvalue; + +TEST_P(KthvalueForwardTestFloat32, KthvalueForwardTest) +{ + if(miopen::IsUnset(ENV(MIOPEN_TEST_ALL)) || + (miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--float"))) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +// TEST_P(KthvalueForwardTestFloat16, KthvalueForwardTest) +// { +// if(miopen::IsUnset(ENV(MIOPEN_TEST_ALL)) || +// (miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--half"))) +// { +// RunTest(); +// Verify(); +// } +// else +// { +// GTEST_SKIP(); +// } +// }; + +// TEST_P(KthvalueForwardTestBFloat16, KthvalueForwardTest) +// { +// if(miopen::IsUnset(ENV(MIOPEN_TEST_ALL)) || +// (miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--bfloat16"))) +// { +// RunTest(); +// Verify(); +// } +// else +// { +// GTEST_SKIP(); +// } +// }; + +INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, + KthvalueForwardTestFloat32, + testing::ValuesIn(KthvalueTestConfigs())); + +// INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, +// KthvalueForwardTestFloat16, +// testing::ValuesIn(KthvalueTestConfigs())); + +// INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, +// KthvalueForwardTestBFloat16, +// testing::ValuesIn(KthvalueTestConfigs())); +} // namespace kthvalue diff --git a/test/gtest/kthvalue.hpp b/test/gtest/kthvalue.hpp new file mode 100644 index 0000000000..8d6b15610a --- /dev/null +++ b/test/gtest/kthvalue.hpp @@ -0,0 +1,194 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "../driver/tensor_driver.hpp" +#include "cpu_kthvalue.hpp" +#include "get_handle.hpp" +#include "miopen/allocator.hpp" +#include "miopen/tensor.hpp" +#include "random.hpp" +#include "tensor_holder.hpp" +#include "verify.hpp" +#include +#include +#include +#include + +struct KthvalueTestCase +{ + std::vector dims; + bool isContiguous; + int32_t dim; + size_t k; + bool keepDim; + friend std::ostream& operator<<(std::ostream& os, const KthvalueTestCase& tc) + { + os << "dims: "; + for(auto dim_size : tc.dims) + { + os << dim_size << " "; + } + return os << "is_contiguous " << tc.isContiguous << " selected_dim " << tc.dim << " k " + << tc.k << " keepDim " << tc.keepDim; + } + + std::vector GetDims() const { return dims; } + + KthvalueTestCase() {} + + KthvalueTestCase(std::vector dims_, + size_t k_, + bool isContiguous_ = true, + int32_t dim_ = -1, + bool keepDim_ = false) + : dims(dims_), isContiguous(isContiguous_), dim(dim_), k(k_), keepDim(keepDim_) + { + } + + std::vector ComputeStrides(std::vector inputDim) const + { + if(!isContiguous) + std::swap(inputDim.front(), inputDim.back()); + std::vector strides(inputDim.size()); + strides.back() = 1; + for(int i = inputDim.size() - 2; i >= 0; --i) + strides[i] = strides[i + 1] * inputDim[i + 1]; + if(!isContiguous) + std::swap(strides.front(), strides.back()); + return strides; + } +}; + +inline std::vector KthvalueTestConfigs() +{ + return { + KthvalueTestCase({4000}, 3), // 1D cont + KthvalueTestCase({100, 500}, 10), // 2D cont + KthvalueTestCase({100, 500}, 10, false), // 2D non-cont + KthvalueTestCase({10, 20, 300}, 100, false), // 3D cont + KthvalueTestCase({10, 20, 300}, 1, false), // 3D non-cont + KthvalueTestCase({8, 3, 2000, 10}, 2000, true, 2), // 4D cont + KthvalueTestCase({8, 3, 2000, 10}, 2000, false, 2), // 4D non-cont + KthvalueTestCase({2, 2, 3000, 4, 10}, 1, true, 2), // 5D cont + KthvalueTestCase({2, 2, 3000, 4, 10}, 1, false, 2), // 5D cont + }; +} + +template +struct KthvalueFwdTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + config = GetParam(); + + auto inDims = config.GetDims(); + auto inStrides = config.ComputeStrides(inDims); + if(config.dim < 0) + { + config.dim += inDims.size(); + } + EXPECT_TRUE(config.dim >= 0 and config.dim < inDims.size()); + auto outDims = config.GetDims(); + if(!config.keepDim) + { + outDims.erase(outDims.begin() + config.dim); + if(outDims.empty()) + outDims.push_back(1); + } + else + { + outDims[config.dim] = 1; + } + + auto in_gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(0.1, 200); }; + input = tensor{inDims, inStrides}.generate(in_gen_value); + + output = tensor{outDims}; + std::fill(output.begin(), output.end(), 0); + + outputHost = tensor{outDims}; + std::fill(outputHost.begin(), outputHost.end(), 0); + + // miopenDataType_t doesn't support size_t, I use double instead (both types use 64 bits) + indicesDesc = miopen::TensorDescriptor(miopenDouble, outDims); + size_t outputSize = indicesDesc.GetElementSize(); + indices.resize(outputSize); + indicesHost.resize(outputSize); + + input_dev = handle.Write(input.data); + indices_dev = handle.Write(indices); + output_dev = handle.Write(output.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + + miopenStatus_t status; + + status = miopen::KthvalueForward(handle, + input.desc, + input_dev.get(), + output.desc, + output_dev.get(), + indicesDesc, + (size_t*)indices_dev.get(), + config.k, + config.dim, + config.keepDim); + cpu_kthvalue(input, outputHost, indicesHost, indicesDesc, config.k, config.dim); + + EXPECT_EQ(status, miopenStatusSuccess); + output.data = handle.Read(output_dev, output.data.size()); + indices = handle.Read(indices_dev, indices.size()); + } + + void Verify() + { + double threshold = std::numeric_limits::epsilon(); + + auto error = miopen::rms_range(outputHost, output); + + EXPECT_TRUE(miopen::range_distance(outputHost) == miopen::range_distance(output)); + EXPECT_TRUE(error < threshold * 10) << "Error output beyond tolerance Error: " << error + << ", Thresholdx10: " << threshold * 10; + } + KthvalueTestCase config; + + tensor input; + // tensor holder doesn't support size_t, so I use vector instead + miopen::TensorDescriptor indicesDesc; + std::vector indices; + tensor output; + + tensor outputHost; + std::vector indicesHost; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr indices_dev; + miopen::Allocator::ManageDataPtr output_dev; +}; From 7853ddbcc2dc42b31c3fa15fb79173430d23499f Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Mon, 1 Jul 2024 10:11:22 +0700 Subject: [PATCH 05/28] add keepDim param --- driver/kthvalue_driver.hpp | 3 +- include/miopen/miopen.h | 2 +- .../miopen/kthvalue/problem_description.hpp | 31 ++++++++++++++++--- src/kernels/MIOpenKthvalue.cpp | 10 +++--- src/kernels/radix.hpp | 1 + src/kthvalue.cpp | 2 +- src/kthvalue/problem_description.cpp | 10 +++--- src/solver/kthvalue/forward_kthvalue.cpp | 2 ++ test/gtest/kthvalue.hpp | 22 ++++++------- 9 files changed, 56 insertions(+), 27 deletions(-) diff --git a/driver/kthvalue_driver.hpp b/driver/kthvalue_driver.hpp index d35e6e5208..bdf0509fe9 100644 --- a/driver/kthvalue_driver.hpp +++ b/driver/kthvalue_driver.hpp @@ -331,7 +331,8 @@ int KthvalueDriver::RunForwardGPU() indicesDesc, (size_t*)indices_dev->GetMem(), k, - dim); + dim, + keepDim); float time = 0.0; miopenGetKernelTime(GetHandle(), &time); kernel_total_time += time; diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 3cdde80edb..f8c8ad7774 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -524,7 +524,7 @@ typedef enum miopenActivationABS = 5, /*!< Absolute value \f$abs(x)\f$ */ miopenActivationPOWER = 6, /*!< Scaled and shifted power \f$(\alpha + \beta * x)^{gamma}\f$ */ miopenActivationCLIPPEDRELU = - 7, /*!< Clipped Rectified Linear Unit \f$ min(\alpha, max(0,x)) \f$ */ + 7, /*!< Clipped Rectified Linear Unit \f$ min(\alpha, max(0,x)) \f$ */ miopenActivationLEAKYRELU = 8, /*!< Leaky Rectified Linear Unit \f$ \alpha * x | x <= 0; x | x > 0 \f$ */ miopenActivationELU = diff --git a/src/include/miopen/kthvalue/problem_description.hpp b/src/include/miopen/kthvalue/problem_description.hpp index c6f0e4d87c..14416762ad 100644 --- a/src/include/miopen/kthvalue/problem_description.hpp +++ b/src/include/miopen/kthvalue/problem_description.hpp @@ -46,12 +46,14 @@ struct FwdProblemDescription : ProblemDescriptionBase const TensorDescriptor& outputDesc_, const TensorDescriptor& indicesDesc_, int32_t dim_, - size_t k_) + size_t k_, + bool keepDim_) : inputDesc(inputDesc_), outputDesc(outputDesc_), indicesDesc(indicesDesc_), dim(dim_), - k(k_) + k(k_), + keepDim(keepDim_) { if(k > inputDesc.GetLengths()[dim]) { @@ -76,6 +78,7 @@ struct FwdProblemDescription : ProblemDescriptionBase MIOPEN_THROW(miopenStatusBadParm, "Reduce: Output and indices tensor dimension lengths do not match."); } + isInputContiguous = checkContiguous(inputDesc); } bool IsRightLength() const @@ -87,9 +90,13 @@ struct FwdProblemDescription : ProblemDescriptionBase for(int32_t i = 0; i < inputDesc.GetLengths().size(); i++) { if(i == dim) - continue; - - if(inputDesc.GetLengths()[i] != outputDesc.GetLengths()[posOut]) + { + if(!keepDim) + continue; + if(outputDesc.GetLengths()[posOut] != 1) + return false; + } + else if(inputDesc.GetLengths()[i] != outputDesc.GetLengths()[posOut]) { return false; } @@ -111,6 +118,18 @@ struct FwdProblemDescription : ProblemDescriptionBase return true; } + bool checkContiguous(const TensorDescriptor& tensorDesc) + { + size_t stride = 1; + for(int i = tensorDesc.GetSize() - 1; i >= 0; --i) + { + if(stride != tensorDesc.GetStrides()[i]) + return false; + stride *= tensorDesc.GetLengths()[i]; + } + return true; + } + const TensorDescriptor& GetInputDesc() const { return inputDesc; } const TensorDescriptor& GetOutputDesc() const { return outputDesc; } const TensorDescriptor& GetIndicesDesc() const { return indicesDesc; } @@ -124,6 +143,8 @@ struct FwdProblemDescription : ProblemDescriptionBase TensorDescriptor indicesDesc; int32_t dim; size_t k; + bool isInputContiguous; + bool keepDim; }; } // namespace kthvalue diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index dbf5be350d..9eb9595477 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -191,10 +191,12 @@ __device__ void kthvalueFwd(const TIO* input, __syncthreads(); if(lid == 0) { - auto output_layout = tensor_layout_t<5>(output_tv, gid); - auto indices_layout = tensor_layout_t<5>(indices_tv, gid); - output[output_tv.get_tensor_view_idx(output_layout)] = CVT_ACCUM2FLOAT(lval); - indices[indices_tv.get_tensor_view_idx(indices_layout)] = lidx; + // auto output_layout = tensor_layout_t<5>(output_tv, + // gid); auto indices_layout = + // tensor_layout_t<5>(indices_tv, gid); output[output_tv.get_tensor_view_idx(output_layout)] + // = CVT_ACCUM2FLOAT(lval); indices[indices_tv.get_tensor_view_idx(indices_layout)] = lidx; + output[gid] = CVT_ACCUM2FLOAT(lval); + indices[gid] = lidx; } } diff --git a/src/kernels/radix.hpp b/src/kernels/radix.hpp index 429ce9b6c8..a4aa161d4d 100644 --- a/src/kernels/radix.hpp +++ b/src/kernels/radix.hpp @@ -24,6 +24,7 @@ DEFINE_RADIX_TYPE(int64_t, uint64_t) DEFINE_RADIX_TYPE(bool, bool) DEFINE_RADIX_TYPE(float, uint32_t) DEFINE_RADIX_TYPE(double, uint64_t) +DEFINE_RADIX_TYPE(_Float16, uint16_t) template ::type> __device__ inline Radix encode(DTYPE v) diff --git a/src/kthvalue.cpp b/src/kthvalue.cpp index c098b3ee5b..6f93d76573 100644 --- a/src/kthvalue.cpp +++ b/src/kthvalue.cpp @@ -53,7 +53,7 @@ miopenStatus_t KthvalueForward(Handle& handle, } const auto problem = - kthvalue::FwdProblemDescription{inputDesc, outputDesc, indicesDesc, dim, k}; + kthvalue::FwdProblemDescription{inputDesc, outputDesc, indicesDesc, dim, k, keepDim}; const auto invoke_params = [&]() { auto tmp = kthvalue::FwdInvokeParams{}; diff --git a/src/kthvalue/problem_description.cpp b/src/kthvalue/problem_description.cpp index 0d0086af6f..7dfd804b84 100644 --- a/src/kthvalue/problem_description.cpp +++ b/src/kthvalue/problem_description.cpp @@ -35,16 +35,18 @@ namespace kthvalue { NetworkConfig FwdProblemDescription::MakeNetworkConfig() const { - auto input_dtype = inputDesc.GetType(); - auto size = inputDesc.GetElementSize(); - auto dim_size = inputDesc.GetLengths()[dim]; - auto output_size = size / dim_size; + auto input_dtype = inputDesc.GetType(); + auto size = inputDesc.GetElementSize(); + auto dim_size = inputDesc.GetLengths()[dim]; + auto output_size = size / dim_size; + auto input_contiguous = isInputContiguous; std::ostringstream ss; ss << "kthvalue_fwd"; ss << "i_dtype" << input_dtype; ss << "output_size" << output_size; + ss << "input_contigious" << input_contiguous; return NetworkConfig{ss.str()}; } diff --git a/src/solver/kthvalue/forward_kthvalue.cpp b/src/solver/kthvalue/forward_kthvalue.cpp index 9998933a88..a3e91f9943 100644 --- a/src/solver/kthvalue/forward_kthvalue.cpp +++ b/src/solver/kthvalue/forward_kthvalue.cpp @@ -59,6 +59,8 @@ bool KthvalueFwd::IsApplicable(const ExecutionContext& /*context*/, return false; if(!IsImprovementOverROCm(problem)) return false; + if(!problem.isInputContiguous) + return false; return true; } diff --git a/test/gtest/kthvalue.hpp b/test/gtest/kthvalue.hpp index 8d6b15610a..fbb7b7d0d9 100644 --- a/test/gtest/kthvalue.hpp +++ b/test/gtest/kthvalue.hpp @@ -60,9 +60,9 @@ struct KthvalueTestCase KthvalueTestCase(std::vector dims_, size_t k_, - bool isContiguous_ = true, int32_t dim_ = -1, - bool keepDim_ = false) + bool keepDim_ = false, + bool isContiguous_ = true) : dims(dims_), isContiguous(isContiguous_), dim(dim_), k(k_), keepDim(keepDim_) { } @@ -84,15 +84,15 @@ struct KthvalueTestCase inline std::vector KthvalueTestConfigs() { return { - KthvalueTestCase({4000}, 3), // 1D cont - KthvalueTestCase({100, 500}, 10), // 2D cont - KthvalueTestCase({100, 500}, 10, false), // 2D non-cont - KthvalueTestCase({10, 20, 300}, 100, false), // 3D cont - KthvalueTestCase({10, 20, 300}, 1, false), // 3D non-cont - KthvalueTestCase({8, 3, 2000, 10}, 2000, true, 2), // 4D cont - KthvalueTestCase({8, 3, 2000, 10}, 2000, false, 2), // 4D non-cont - KthvalueTestCase({2, 2, 3000, 4, 10}, 1, true, 2), // 5D cont - KthvalueTestCase({2, 2, 3000, 4, 10}, 1, false, 2), // 5D cont + KthvalueTestCase({4000}, 3), // 1D cont + KthvalueTestCase({100, 500}, 10), // 2D cont + KthvalueTestCase({100, 500}, 10), // 2D non-cont + KthvalueTestCase({10, 20, 300}, 100), // 3D cont + KthvalueTestCase({10, 20, 300}, 1), // 3D non-cont + KthvalueTestCase({8, 3, 2000, 10}, 2000, 2), // 4D cont + KthvalueTestCase({8, 3, 2000, 10}, 2000, 2), // 4D non-cont + KthvalueTestCase({2, 2, 3000, 4, 10}, 1, 2), // 5D cont + KthvalueTestCase({2, 2, 3000, 4, 10}, 1, 2), // 5D cont }; } From c290a2e507469812778d6824d1cd798563cf01b5 Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Tue, 2 Jul 2024 16:38:35 +0700 Subject: [PATCH 06/28] backup code --- driver/kthvalue_driver.hpp | 122 ------------------ .../miopen/kthvalue/problem_description.hpp | 5 +- src/kernels/MIOpenKthvalue.cpp | 63 ++++----- src/kernels/radix.hpp | 8 +- src/kernels/warp_shuffle.hpp | 2 - src/kthvalue/problem_description.cpp | 2 +- src/solver/kthvalue/forward_kthvalue.cpp | 33 ++--- 7 files changed, 58 insertions(+), 177 deletions(-) diff --git a/driver/kthvalue_driver.hpp b/driver/kthvalue_driver.hpp index bdf0509fe9..72df4dd9e1 100644 --- a/driver/kthvalue_driver.hpp +++ b/driver/kthvalue_driver.hpp @@ -383,108 +383,12 @@ int KthvalueDriver::RunForwardCPU() template int KthvalueDriver::RunBackwardGPU() { - // float kernel_total_time = 0; - // float kernel_first_time = 0; - - // Timer t; - // START_TIME - - // for(int i = 0; i < inflags.GetValueInt("iter"); i++) - // { - // void* p_dtarget = nullptr; - // if(isTargetGradientComputed) - // { - // p_dtarget = dtarget_dev->GetMem(); - // } - - // miopenKthvalueBackward(GetHandle(), - // inputDesc, - // input_dev->GetMem(), - // targetDesc, - // target_dev->GetMem(), - // doutputDesc, - // doutput_dev->GetMem(), - // dinputDesc, - // dinput_dev->GetMem(), - // dtargetDesc, - // p_dtarget, - // alpha, - // gamma, - // reduction); - - // float time = 0.0; - // miopenGetKernelTime(GetHandle(), &time); - // kernel_total_time += time; - // if(i == 0) - // kernel_first_time = time; - // } - - // if(inflags.GetValueInt("time") == 1) - // { - // STOP_TIME - // int iter = inflags.GetValueInt("iter"); - // if(WALL_CLOCK) - // std::cout << "Wall-clock Time Sigmoid Focal Loss Bwd Elapsed: " << t.gettime_ms() / - // iter - // << " ms" << std::endl; - - // float kernel_average_time = - // iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; - // std::cout << "GPU Kernel Time Sigmoid Focal Loss Bwd Elapsed: " << kernel_average_time - // << " ms" << std::endl; - // } - - // if(dinput_dev->FromGPU(GetStream(), dinput.data()) != 0) - // std::cerr << "Error copying (dI_dev) from GPU, size: " << dinput_dev->GetSize() - // << std::endl; - // if(isTargetGradientComputed && dtarget_dev->FromGPU(GetStream(), dtarget.data()) != 0) - // std::cerr << "Error copying (dT_dev) from GPU, size: " << dtarget_dev->GetSize() - // << std::endl; - return miopenStatusSuccess; } template int KthvalueDriver::RunBackwardCPU() { - // TIO* p_dtarget = nullptr; - // if(isTargetGradientComputed) - // { - // p_dtarget = dtargetHost.data(); - // } - // if(reduction == MIOPEN_LOSS_REDUCTION_NONE) - // { - - // mloKthvalueUnreducedBwdRunHost(input.data(), - // inputDesc, - // target.data(), - // targetDesc, - // doutput.data(), - // doutputDesc, - // dinputHost.data(), - // dinputDesc, - // p_dtarget, - // dtargetDesc, - // alpha, - // gamma); - // } - // else - // { - // mloKthvalueBwdRunHost(input.data(), - // inputDesc, - // target.data(), - // targetDesc, - // doutput.data(), - // doutputDesc, - // dinputHost.data(), - // dinputDesc, - // p_dtarget, - // dtargetDesc, - // alpha, - // gamma, - // divisor); - // } - return miopenStatusSuccess; } @@ -514,31 +418,5 @@ int KthvalueDriver::VerifyForward() template int KthvalueDriver::VerifyBackward() { - // RunBackwardCPU(); - - // double tolerance = std::numeric_limits::epsilon() * 10; - // auto dinputError = miopen::rms_range(dinputHost, dinput); - // auto dtargetError = miopen::rms_range(dtargetHost, dtarget); - - // if(!std::isfinite(dinputError) || dinputError > tolerance) - // { - // std::cout << "Backward " << reduction << " Sigmoid Focal Loss FAILED: " << dinputError - // << " > " << tolerance << std::endl; - // return EC_VerifyFwd; - // } - // else if(isTargetGradientComputed && (!std::isfinite(dtargetError) || dtargetError > - // tolerance)) - // { - // std::cout << "Backward " << reduction << " Sigmoid Focal Loss FAILED: " << dtargetError - // << " > " << tolerance << std::endl; - // return EC_VerifyFwd; - // } - // else - // { - // std::cout << "Backward " << reduction - // << " Sigmoid Focal Loss Verifies OK on CPU reference (dinput: " << dinputError - // << ", dtarget: " << dtargetError << "< " << tolerance << ')' << std::endl; - // } - return miopenStatusSuccess; } diff --git a/src/include/miopen/kthvalue/problem_description.hpp b/src/include/miopen/kthvalue/problem_description.hpp index 14416762ad..6b581eab8b 100644 --- a/src/include/miopen/kthvalue/problem_description.hpp +++ b/src/include/miopen/kthvalue/problem_description.hpp @@ -78,7 +78,8 @@ struct FwdProblemDescription : ProblemDescriptionBase MIOPEN_THROW(miopenStatusBadParm, "Reduce: Output and indices tensor dimension lengths do not match."); } - isInputContiguous = checkContiguous(inputDesc); + isContiguous = checkContiguous(inputDesc) && checkContiguous(outputDesc) && + checkContiguous(indicesDesc); } bool IsRightLength() const @@ -143,7 +144,7 @@ struct FwdProblemDescription : ProblemDescriptionBase TensorDescriptor indicesDesc; int32_t dim; size_t k; - bool isInputContiguous; + bool isContiguous; bool keepDim; }; diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index 9eb9595477..0ae5faf168 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -62,15 +62,14 @@ #define RADIX_MASK (RADIX_SIZE - 1) #endif -#define RADIX_TYPE RadixType::type - -template -__device__ void kthvalueFwd(const TIO* input, - TIO* output, +template +__device__ void kthvalueFwd(const DTYPE* input, + DTYPE* output, size_t* indices, size_t k, size_t dim_size, size_t dim_stride, + size_t output_size, tensor_view_t<4> input_tv, tensor_view_t<5> output_tv, tensor_view_t<5> indices_tv) @@ -85,28 +84,34 @@ __device__ void kthvalueFwd(const TIO* input, size_t lid = threadIdx.x; size_t gid = blockIdx.x; + if(gid >= output_size) + { + return; + } __shared__ size_t lsum[LOCAL_SIZE][RADIX_SIZE]; - __shared__ FLOAT_ACCUM lval; + __shared__ DTYPE lval; __shared__ long lidx; size_t counts[RADIX_SIZE]; RADIX_TYPE desired_mask = 0; RADIX_TYPE desired = 0; - tensor_layout_t<4> layout(input_tv, gid); - auto idx = input_tv.get_tensor_view_idx(layout); + // tensor_layout_t<4> layout(input_tv, gid); + // auto idx = input_tv.get_tensor_view_idx(layout); + size_t idx = gid * dim_size; + dim_stride = 1; for(size_t pos = sizeof(RADIX_TYPE) * 8 - RADIX_BITS; pos >= 0; pos -= RADIX_BITS) { - for(size_t i = 0; i < RADIX_SIZE; ++i) + for(unsigned long& count : counts) { - counts[i] = 0; + count = 0; } for(size_t i = lid; i < dim_size; i += LOCAL_SIZE) { size_t input_idx = idx + i * dim_stride; - RADIX_TYPE val = encode(CVT_FLOAT2ACCUM(input[input_idx])); + RADIX_TYPE val = encode(input[input_idx]); if((val & desired_mask) == desired) { ++counts[GetBitFieldImpl(val, pos, RADIX_BITS)]; @@ -135,19 +140,6 @@ __device__ void kthvalueFwd(const TIO* input, } __syncthreads(); - // __syncthreads(); - // #pragma unroll - // for(size_t i = 0; i < RADIX_SIZE; ++i) - // { - // counts[i] = block_reduce_sum(counts[i]); - // if(lid == 0) - // { - // smem_count[i] = counts[i]; - // } - // __syncthreads(); - // counts[i] = smem_count[i]; - // } - bool found = false; // Process in ascending order for(size_t j = 0; j < RADIX_SIZE; ++j) @@ -163,9 +155,9 @@ __device__ void kthvalueFwd(const TIO* input, // There are multiple answers so we return any of them for(size_t i = lid; i < dim_size; i += LOCAL_SIZE) { - size_t input_idx = idx + i * dim_stride; - FLOAT_ACCUM val_ori = CVT_FLOAT2ACCUM(input[input_idx]); - RADIX_TYPE val = encode(val_ori); + size_t input_idx = idx + i * dim_stride; + DTYPE val_ori = input[input_idx]; + RADIX_TYPE val = encode(val_ori); if((val & desired_mask) == desired && GetBitFieldImpl(val, pos, RADIX_BITS) == j) { @@ -194,8 +186,8 @@ __device__ void kthvalueFwd(const TIO* input, // auto output_layout = tensor_layout_t<5>(output_tv, // gid); auto indices_layout = // tensor_layout_t<5>(indices_tv, gid); output[output_tv.get_tensor_view_idx(output_layout)] - // = CVT_ACCUM2FLOAT(lval); indices[indices_tv.get_tensor_view_idx(indices_layout)] = lidx; - output[gid] = CVT_ACCUM2FLOAT(lval); + // = lval; indices[indices_tv.get_tensor_view_idx(indices_layout)] = lidx; + output[gid] = lval; indices[gid] = lidx; } } @@ -206,10 +198,19 @@ extern "C" __global__ void KthvalueFwd(const IN_OUT_TYPE* input, size_t k, size_t dim_size, size_t dim_stride, + size_t output_size, tensor_view_t<4> input_tv, tensor_view_t<5> output_tv, tensor_view_t<5> indices_tv) { - kthvalueFwd( - input, output, indices, k, dim_size, dim_stride, input_tv, output_tv, indices_tv); + kthvalueFwd(input, + output, + indices, + k, + dim_size, + dim_stride, + output_size, + input_tv, + output_tv, + indices_tv); } diff --git a/src/kernels/radix.hpp b/src/kernels/radix.hpp index a4aa161d4d..b75506d382 100644 --- a/src/kernels/radix.hpp +++ b/src/kernels/radix.hpp @@ -3,7 +3,7 @@ #include // #define ENCODE encode -// #define RADIX_TYPE typename RadixType::type +#define RADIX_TYPE typename RadixType::type // #define GetBitField(x, pos, bits) GetBitFieldImpl(x, pos, bits) // #define SetBitField(x, a, pos, bits) SetBitFieldImpl(x, a, pos, bits) @@ -24,7 +24,7 @@ DEFINE_RADIX_TYPE(int64_t, uint64_t) DEFINE_RADIX_TYPE(bool, bool) DEFINE_RADIX_TYPE(float, uint32_t) DEFINE_RADIX_TYPE(double, uint64_t) -DEFINE_RADIX_TYPE(_Float16, uint16_t) +DEFINE_RADIX_TYPE(__half, ushort) template ::type> __device__ inline Radix encode(DTYPE v) @@ -41,7 +41,7 @@ __device__ inline Radix encode(DTYPE v) { return 9223372036854775808ull + v; } - else if constexpr(std::is_same<_Float16, DTYPE>::value) + else if constexpr(std::is_same<__half, DTYPE>::value) { Radix x = __half_as_ushort(v); Radix mask = (x & 0x8000) ? 0xffff : 0x8000; @@ -76,7 +76,7 @@ __device__ inline DTYPE decode(Radix v) { return v - 9223372036854775808ull; } - else if constexpr(std::is_same<_Float16, DTYPE>::value) + else if constexpr(std::is_same<__half, DTYPE>::value) { Radix mask = (v & 0x8000) ? 0x8000 : 0xffff; return __ushort_as_half((ushort)(v ^ mask)); diff --git a/src/kernels/warp_shuffle.hpp b/src/kernels/warp_shuffle.hpp index addd236f69..6f74dee16c 100644 --- a/src/kernels/warp_shuffle.hpp +++ b/src/kernels/warp_shuffle.hpp @@ -68,8 +68,6 @@ __device__ __forceinline__ DTYPE block_reduce_sum(DTYPE val) val = threadIdx.x < REDUCE_SIZE / warpSize ? shared[lane] : 0; if(wid == 0) - // val = (threadIdx.x % warpSize) < REDUCE_SIZE / warpSize ? shared[lane] : 0; - // if(lane == 0) val = warp_reduce_sum(val); return val; diff --git a/src/kthvalue/problem_description.cpp b/src/kthvalue/problem_description.cpp index 7dfd804b84..6a2e520b0e 100644 --- a/src/kthvalue/problem_description.cpp +++ b/src/kthvalue/problem_description.cpp @@ -39,7 +39,7 @@ NetworkConfig FwdProblemDescription::MakeNetworkConfig() const auto size = inputDesc.GetElementSize(); auto dim_size = inputDesc.GetLengths()[dim]; auto output_size = size / dim_size; - auto input_contiguous = isInputContiguous; + auto input_contiguous = isContiguous; std::ostringstream ss; diff --git a/src/solver/kthvalue/forward_kthvalue.cpp b/src/solver/kthvalue/forward_kthvalue.cpp index a3e91f9943..60e240a455 100644 --- a/src/solver/kthvalue/forward_kthvalue.cpp +++ b/src/solver/kthvalue/forward_kthvalue.cpp @@ -47,9 +47,10 @@ namespace kthvalue { bool IsImprovementOverROCm(const miopen::kthvalue::FwdProblemDescription& problem) { TensorDescriptor inputDesc = problem.GetInputDesc(); + int dimNum = inputDesc.GetSize(); size_t dimSize = inputDesc.GetLengths()[problem.GetDim()]; - return dimSize >= 300; + return dimNum >= 2 && problem.GetDim() == dimNum - 1 && dimSize >= 256; } bool KthvalueFwd::IsApplicable(const ExecutionContext& /*context*/, @@ -57,9 +58,9 @@ bool KthvalueFwd::IsApplicable(const ExecutionContext& /*context*/, { if(problem.GetInputDesc().GetSize() > 5) return false; - if(!IsImprovementOverROCm(problem)) + if(!problem.isContiguous) return false; - if(!problem.isInputContiguous) + if(!IsImprovementOverROCm(problem)) return false; return true; } @@ -70,14 +71,16 @@ ConvSolution KthvalueFwd::GetSolution(const ExecutionContext& context, std::ignore = context; auto result = ConvSolution{miopenStatusSuccess}; - auto input_desc = problem.GetInputDesc(); - auto in_dtype = miopen::GetDataType(input_desc.GetType()); - auto dtype = problem.GetOutputDesc().GetType(); - auto size = input_desc.GetElementSize(); - auto dim_size = input_desc.GetLengths()[problem.GetDim()]; + auto input_desc = problem.GetInputDesc(); + auto in_dtype = miopen::GetDataType(input_desc.GetType()); + auto dtype = problem.GetOutputDesc().GetType(); + auto size = input_desc.GetElementSize(); + auto dim_size = input_desc.GetLengths()[problem.GetDim()]; + size_t output_size = size / dim_size; - size_t xlocalsize = LOCAL_SIZE; - size_t xgridsize = size / dim_size * xlocalsize; + size_t xlocalsize; + xlocalsize = LOCAL_SIZE; + size_t xgridsize = output_size * xlocalsize; size_t ylocalsize = 1; size_t ygridsize = 1; size_t zlocalsize = 1; @@ -93,7 +96,7 @@ ConvSolution KthvalueFwd::GetSolution(const ExecutionContext& context, {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, {"IN_OUT_TYPE", in_dtype == "bfloat16" ? "ushort" : in_dtype}, - {"LOCAL_SIZE", LOCAL_SIZE}, + {"LOCAL_SIZE", xlocalsize}, }; kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); @@ -112,8 +115,7 @@ ConvSolution KthvalueFwd::GetSolution(const ExecutionContext& context, return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { decltype(auto) kernel = handle_.Run(kernels.front()); decltype(auto) params = raw_params.CastTo(); - size_t dimSize = params.inputDesc->GetLengths()[params.dim]; - size_t dimStride = params.inputDesc->GetStrides()[params.dim]; + size_t dim_stride = params.inputDesc->GetStrides()[params.dim]; auto input_tv = get_inner_expanded_tv<5>(deref(params.inputDesc)); auto input_tv_without_selected_dim = get_tv_without_dim<5, 4>(input_tv, params.dim); @@ -125,8 +127,9 @@ ConvSolution KthvalueFwd::GetSolution(const ExecutionContext& context, params.output, params.indices, params.k, - dimSize, - dimStride, + dim_size, + dim_stride, + output_size, input_tv_without_selected_dim, output_tv, indices_tv); From e4ffcececcc0f2620a3d6d84d99319990e743d16 Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Tue, 2 Jul 2024 18:03:22 +0700 Subject: [PATCH 07/28] add constraint to accept only last dim --- src/kernels/MIOpenKthvalue.cpp | 19 +++--- src/kernels/warp_shuffle.hpp | 74 ------------------------ src/kthvalue/problem_description.cpp | 12 ++-- src/solver/kthvalue/forward_kthvalue.cpp | 15 ++--- 4 files changed, 22 insertions(+), 98 deletions(-) delete mode 100644 src/kernels/warp_shuffle.hpp diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index 0ae5faf168..e96730a998 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -31,7 +31,6 @@ #include "float_types.h" #include "tensor_view.hpp" -#include "warp_shuffle.hpp" #include "radix.hpp" #ifndef IN_OUT_TYPE @@ -96,10 +95,8 @@ __device__ void kthvalueFwd(const DTYPE* input, RADIX_TYPE desired_mask = 0; RADIX_TYPE desired = 0; - // tensor_layout_t<4> layout(input_tv, gid); - // auto idx = input_tv.get_tensor_view_idx(layout); - size_t idx = gid * dim_size; - dim_stride = 1; + tensor_layout_t<4> layout(input_tv, gid); + auto idx = input_tv.get_tensor_view_idx(layout); for(size_t pos = sizeof(RADIX_TYPE) * 8 - RADIX_BITS; pos >= 0; pos -= RADIX_BITS) { @@ -183,12 +180,12 @@ __device__ void kthvalueFwd(const DTYPE* input, __syncthreads(); if(lid == 0) { - // auto output_layout = tensor_layout_t<5>(output_tv, - // gid); auto indices_layout = - // tensor_layout_t<5>(indices_tv, gid); output[output_tv.get_tensor_view_idx(output_layout)] - // = lval; indices[indices_tv.get_tensor_view_idx(indices_layout)] = lidx; - output[gid] = lval; - indices[gid] = lidx; + auto output_layout = tensor_layout_t<5>(output_tv, gid); + auto indices_layout = tensor_layout_t<5>(indices_tv, gid); + output[output_tv.get_tensor_view_idx(output_layout)] = lval; + indices[indices_tv.get_tensor_view_idx(indices_layout)] = lidx; + // output[gid] = lval; + // indices[gid] = lidx; } } diff --git a/src/kernels/warp_shuffle.hpp b/src/kernels/warp_shuffle.hpp deleted file mode 100644 index 6f74dee16c..0000000000 --- a/src/kernels/warp_shuffle.hpp +++ /dev/null @@ -1,74 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (c) 2024 Advanced Micro Devices, Inc. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ - -#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS -#include -#include -#endif - -#include "float_types.h" - -#ifndef REDUCE_SIZE -#define REDUCE_SIZE 256 -#endif - -template -__device__ __forceinline__ DTYPE warp_reduce_sum(DTYPE val) -{ - if(warpSize >= 64) - val += __shfl_down(val, 32); - if(warpSize >= 32) - val += __shfl_down(val, 16); - if(warpSize >= 16) - val += __shfl_down(val, 8); - if(warpSize >= 8) - val += __shfl_down(val, 4); - if(warpSize >= 4) - val += __shfl_down(val, 2); - if(warpSize >= 2) - val += __shfl_down(val, 1); - return val; -} - -template -__device__ __forceinline__ DTYPE block_reduce_sum(DTYPE val) -{ - static __shared__ DTYPE shared[REDUCE_SIZE / warpSize]; - auto lane = threadIdx.x % warpSize; - auto wid = threadIdx.x / warpSize; - - val = warp_reduce_sum(val); - - if(lane == 0) - shared[wid] = val; - __syncthreads(); - - val = threadIdx.x < REDUCE_SIZE / warpSize ? shared[lane] : 0; - if(wid == 0) - val = warp_reduce_sum(val); - - return val; -} diff --git a/src/kthvalue/problem_description.cpp b/src/kthvalue/problem_description.cpp index 6a2e520b0e..9f28413e21 100644 --- a/src/kthvalue/problem_description.cpp +++ b/src/kthvalue/problem_description.cpp @@ -35,18 +35,18 @@ namespace kthvalue { NetworkConfig FwdProblemDescription::MakeNetworkConfig() const { - auto input_dtype = inputDesc.GetType(); - auto size = inputDesc.GetElementSize(); - auto dim_size = inputDesc.GetLengths()[dim]; - auto output_size = size / dim_size; - auto input_contiguous = isContiguous; + auto input_dtype = inputDesc.GetType(); + auto size = inputDesc.GetElementSize(); + auto dim_size = inputDesc.GetLengths()[dim]; + auto output_size = size / dim_size; std::ostringstream ss; ss << "kthvalue_fwd"; ss << "i_dtype" << input_dtype; + ss << "dim_size" << dim_size; ss << "output_size" << output_size; - ss << "input_contigious" << input_contiguous; + ss << "contiguous" << isContiguous; return NetworkConfig{ss.str()}; } diff --git a/src/solver/kthvalue/forward_kthvalue.cpp b/src/solver/kthvalue/forward_kthvalue.cpp index 60e240a455..c9b51725bc 100644 --- a/src/solver/kthvalue/forward_kthvalue.cpp +++ b/src/solver/kthvalue/forward_kthvalue.cpp @@ -36,8 +36,6 @@ #include #include -#define LOCAL_SIZE 256 - namespace miopen { namespace solver { @@ -50,18 +48,18 @@ bool IsImprovementOverROCm(const miopen::kthvalue::FwdProblemDescription& proble int dimNum = inputDesc.GetSize(); size_t dimSize = inputDesc.GetLengths()[problem.GetDim()]; - return dimNum >= 2 && problem.GetDim() == dimNum - 1 && dimSize >= 256; + return dimNum >= 2 && problem.GetDim() == dimNum - 1 && dimSize >= 300; } bool KthvalueFwd::IsApplicable(const ExecutionContext& /*context*/, const miopen::kthvalue::FwdProblemDescription& problem) const { + if(!IsImprovementOverROCm(problem)) + return false; if(problem.GetInputDesc().GetSize() > 5) return false; if(!problem.isContiguous) return false; - if(!IsImprovementOverROCm(problem)) - return false; return true; } @@ -78,8 +76,11 @@ ConvSolution KthvalueFwd::GetSolution(const ExecutionContext& context, auto dim_size = input_desc.GetLengths()[problem.GetDim()]; size_t output_size = size / dim_size; - size_t xlocalsize; - xlocalsize = LOCAL_SIZE; + size_t xlocalsize = 256; + if(dim_size >= 8192) + { + xlocalsize = 512; + } size_t xgridsize = output_size * xlocalsize; size_t ylocalsize = 1; size_t ygridsize = 1; From f30ddddde6429f9972e4bc155e81fbf5ba5047a7 Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Wed, 3 Jul 2024 11:10:09 +0700 Subject: [PATCH 08/28] add utils function to tensor.cpp --- docs/reference/index.rst | 1 + driver/kthvalue_driver.hpp | 31 ----------- src/include/miopen/kthvalue/invoke_params.hpp | 23 +++----- .../miopen/kthvalue/problem_description.hpp | 31 ++--------- src/include/miopen/tensor.hpp | 1 + src/kernels/MIOpenKthvalue.cpp | 4 +- src/kthvalue/problem_description.cpp | 6 ++- src/solver.cpp | 2 +- src/solver/kthvalue/forward_kthvalue.cpp | 2 +- test/gtest/kthvalue.cpp | 54 ++++++------------- test/gtest/kthvalue.hpp | 18 +++---- 11 files changed, 44 insertions(+), 129 deletions(-) diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 2574dbbf5e..f1b7cb9649 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -36,3 +36,4 @@ The MIOpen API library is structured as follows: * :doc:`Getitem <../doxygen/html/group__getitem>` (experimental) * :doc:`ReduceCalculation <../doxygen/html/group__ReduceCalculation>` (experimental) * :doc:`RotaryPositionalEmbeddings <../doxygen/html/group__RotaryPositionalEmbeddings>` (experimental) + * :doc:`Kthvalue <../doxygen/html/group__kthvalue>` (experimental) diff --git a/driver/kthvalue_driver.hpp b/driver/kthvalue_driver.hpp index 72df4dd9e1..190fb9519c 100644 --- a/driver/kthvalue_driver.hpp +++ b/driver/kthvalue_driver.hpp @@ -101,8 +101,6 @@ class KthvalueDriver : public Driver miopenCreateTensorDescriptor(&inputDesc); miopenCreateTensorDescriptor(&indicesDesc); miopenCreateTensorDescriptor(&outputDesc); - miopenCreateTensorDescriptor(&doutputDesc); - miopenCreateTensorDescriptor(&dinputDesc); data_type = miopen_type{}; } @@ -129,8 +127,6 @@ class KthvalueDriver : public Driver miopenDestroyTensorDescriptor(inputDesc); miopenDestroyTensorDescriptor(indicesDesc); miopenDestroyTensorDescriptor(outputDesc); - miopenDestroyTensorDescriptor(doutputDesc); - miopenDestroyTensorDescriptor(dinputDesc); } private: @@ -139,23 +135,16 @@ class KthvalueDriver : public Driver miopenTensorDescriptor_t inputDesc; miopenTensorDescriptor_t indicesDesc; miopenTensorDescriptor_t outputDesc; - miopenTensorDescriptor_t doutputDesc; - miopenTensorDescriptor_t dinputDesc; std::unique_ptr input_dev; std::unique_ptr indices_dev; std::unique_ptr output_dev; - std::unique_ptr doutput_dev; - std::unique_ptr dinput_dev; std::vector input; std::vector indices; std::vector indicesHost; std::vector output; std::vector outputHost; - std::vector doutput; - std::vector dinput; - std::vector dinputHost; bool isContiguous; int dim; @@ -208,8 +197,6 @@ int KthvalueDriver::GetandSetData() } SetTensorNd(inputDesc, inDims, inStride, data_type); - SetTensorNd(doutputDesc, outDims, data_type); - SetTensorNd(dinputDesc, inDims, data_type); SetTensorNd(outputDesc, outDims, data_type); // miopenDataType_t doesn't support size_t, I use double instead (both types use 64 bits) SetTensorNd(indicesDesc, outDims, miopen_type{}); @@ -261,38 +248,26 @@ int KthvalueDriver::AllocateBuffersAndCopy() size_t in_sz = miopen::deref(inputDesc).GetElementSize(); size_t idx_sz = miopen::deref(indicesDesc).GetElementSize(); size_t out_sz = miopen::deref(outputDesc).GetElementSize(); - size_t dO_sz = miopen::deref(doutputDesc).GetElementSize(); - size_t dI_sz = miopen::deref(dinputDesc).GetElementSize(); uint32_t ctx = 0; input_dev = std::unique_ptr(new GPUMem(ctx, in_sz, sizeof(TIO))); indices_dev = std::unique_ptr(new GPUMem(ctx, idx_sz, sizeof(size_t))); output_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(TIO))); - doutput_dev = std::unique_ptr(new GPUMem(ctx, dO_sz, sizeof(TIO))); - dinput_dev = std::unique_ptr(new GPUMem(ctx, dI_sz, sizeof(TIO))); input = std::vector(in_sz, static_cast(0)); indices = std::vector(idx_sz, 0); indicesHost = std::vector(idx_sz, 0); output = std::vector(out_sz, static_cast(0)); outputHost = std::vector(out_sz, static_cast(0)); - doutput = std::vector(dO_sz, static_cast(0)); - dinput = std::vector(dI_sz, static_cast(0)); - dinputHost = std::vector(dI_sz, static_cast(0)); for(int i = 0; i < in_sz; i++) { input[i] = prng::gen_A_to_B(static_cast(-10), static_cast(10)); } - for(int i = 0; i < dO_sz; ++i) - { - doutput[i] = prng::gen_A_to_B(static_cast(-2), static_cast(2)); - } fill(output.begin(), output.end(), static_cast(0)); fill(indices.begin(), indices.end(), static_cast(0)); - fill(dinput.begin(), dinput.end(), static_cast(0)); if(input_dev->ToGPU(GetStream(), input.data()) != 0) std::cerr << "Error copying (in) to GPU, size: " << input_dev->GetSize() << std::endl; @@ -303,12 +278,6 @@ int KthvalueDriver::AllocateBuffersAndCopy() if(output_dev->ToGPU(GetStream(), output.data()) != 0) std::cerr << "Error copying (out) to GPU, size: " << output_dev->GetSize() << std::endl; - if(doutput_dev->ToGPU(GetStream(), doutput.data()) != 0) - std::cerr << "Error copying (dO) to GPU, size: " << doutput_dev->GetSize() << std::endl; - - if(dinput_dev->ToGPU(GetStream(), dinput.data()) != 0) - std::cerr << "Error copying (dI) to GPU, size: " << dinput_dev->GetSize() << std::endl; - return miopenStatusSuccess; } diff --git a/src/include/miopen/kthvalue/invoke_params.hpp b/src/include/miopen/kthvalue/invoke_params.hpp index 08d80c510f..36b9f0c03d 100644 --- a/src/include/miopen/kthvalue/invoke_params.hpp +++ b/src/include/miopen/kthvalue/invoke_params.hpp @@ -36,13 +36,16 @@ namespace miopen { namespace kthvalue { -struct KthvalueInvokeParams : public miopen::InvokeParams +struct FwdInvokeParams : public miopen::InvokeParams { - KthvalueInvokeParams() = default; - - const TensorDescriptor* inputDesc = nullptr; + FwdInvokeParams() = default; - ConstData_t input = nullptr; + const TensorDescriptor* inputDesc = nullptr; + ConstData_t input = nullptr; + const TensorDescriptor* outputDesc = nullptr; + Data_t output = nullptr; + const TensorDescriptor* indicesDesc = nullptr; + size_t* indices = nullptr; size_t k = 1; int32_t dim = 0; @@ -51,16 +54,6 @@ struct KthvalueInvokeParams : public miopen::InvokeParams Data_t GetWorkspace() const { return nullptr; } }; -struct FwdInvokeParams : KthvalueInvokeParams -{ - FwdInvokeParams() = default; - - const TensorDescriptor* outputDesc = nullptr; - Data_t output = nullptr; - const TensorDescriptor* indicesDesc = nullptr; - size_t* indices = nullptr; -}; - } // namespace kthvalue } // namespace miopen diff --git a/src/include/miopen/kthvalue/problem_description.hpp b/src/include/miopen/kthvalue/problem_description.hpp index 6b581eab8b..a801bc841d 100644 --- a/src/include/miopen/kthvalue/problem_description.hpp +++ b/src/include/miopen/kthvalue/problem_description.hpp @@ -73,13 +73,12 @@ struct FwdProblemDescription : ProblemDescriptionBase MIOPEN_THROW(miopenStatusBadParm, "Reduce: Input and output tensor dimension lengths do not match."); } - if(!IsSameLength()) + if(!outputDesc.IsSameLength(indicesDesc)) { MIOPEN_THROW(miopenStatusBadParm, "Reduce: Output and indices tensor dimension lengths do not match."); } - isContiguous = checkContiguous(inputDesc) && checkContiguous(outputDesc) && - checkContiguous(indicesDesc); + isInputContiguous = inputDesc.IsContiguous(); } bool IsRightLength() const @@ -107,30 +106,6 @@ struct FwdProblemDescription : ProblemDescriptionBase return true; } - bool IsSameLength() const - { - for(int32_t i = 0; i < outputDesc.GetLengths().size(); i++) - { - if(outputDesc.GetLengths()[i] != indicesDesc.GetLengths()[i]) - { - return false; - } - } - return true; - } - - bool checkContiguous(const TensorDescriptor& tensorDesc) - { - size_t stride = 1; - for(int i = tensorDesc.GetSize() - 1; i >= 0; --i) - { - if(stride != tensorDesc.GetStrides()[i]) - return false; - stride *= tensorDesc.GetLengths()[i]; - } - return true; - } - const TensorDescriptor& GetInputDesc() const { return inputDesc; } const TensorDescriptor& GetOutputDesc() const { return outputDesc; } const TensorDescriptor& GetIndicesDesc() const { return indicesDesc; } @@ -144,7 +119,7 @@ struct FwdProblemDescription : ProblemDescriptionBase TensorDescriptor indicesDesc; int32_t dim; size_t k; - bool isContiguous; + bool isInputContiguous; bool keepDim; }; diff --git a/src/include/miopen/tensor.hpp b/src/include/miopen/tensor.hpp index f4d2b2dca7..b962f3d033 100644 --- a/src/include/miopen/tensor.hpp +++ b/src/include/miopen/tensor.hpp @@ -233,6 +233,7 @@ struct MIOPEN_INTERNALS_EXPORT TensorDescriptor : miopenTensorDescriptor /// Checks only lengths. bool AllLengthsFitIntoInt() const; + bool IsSameLength(const TensorDescriptor& otherDesc) const; bool operator==(const TensorDescriptor& rhs) const; bool operator!=(const TensorDescriptor& rhs) const; bool operator<(const TensorDescriptor& rhs) const; diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index e96730a998..72b974af4c 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -98,7 +98,7 @@ __device__ void kthvalueFwd(const DTYPE* input, tensor_layout_t<4> layout(input_tv, gid); auto idx = input_tv.get_tensor_view_idx(layout); - for(size_t pos = sizeof(RADIX_TYPE) * 8 - RADIX_BITS; pos >= 0; pos -= RADIX_BITS) + for(int pos = sizeof(RADIX_TYPE) * 8 - RADIX_BITS; pos >= 0; pos -= RADIX_BITS) { for(unsigned long& count : counts) { @@ -184,8 +184,6 @@ __device__ void kthvalueFwd(const DTYPE* input, auto indices_layout = tensor_layout_t<5>(indices_tv, gid); output[output_tv.get_tensor_view_idx(output_layout)] = lval; indices[indices_tv.get_tensor_view_idx(indices_layout)] = lidx; - // output[gid] = lval; - // indices[gid] = lidx; } } diff --git a/src/kthvalue/problem_description.cpp b/src/kthvalue/problem_description.cpp index 9f28413e21..00a0a59b20 100644 --- a/src/kthvalue/problem_description.cpp +++ b/src/kthvalue/problem_description.cpp @@ -38,6 +38,8 @@ NetworkConfig FwdProblemDescription::MakeNetworkConfig() const auto input_dtype = inputDesc.GetType(); auto size = inputDesc.GetElementSize(); auto dim_size = inputDesc.GetLengths()[dim]; + int dim_num = inputDesc.GetSize(); + bool is_last_dim = (dim == dim_num - 1); auto output_size = size / dim_size; std::ostringstream ss; @@ -45,8 +47,10 @@ NetworkConfig FwdProblemDescription::MakeNetworkConfig() const ss << "kthvalue_fwd"; ss << "i_dtype" << input_dtype; ss << "dim_size" << dim_size; + ss << "dim_num" << dim_num; + ss << "is_last_dim" << is_last_dim; ss << "output_size" << output_size; - ss << "contiguous" << isContiguous; + ss << "input_contiguous" << isInputContiguous; return NetworkConfig{ss.str()}; } diff --git a/src/solver.cpp b/src/solver.cpp index 85f47f9c81..52b5301cd7 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -24,7 +24,6 @@ * *******************************************************************************/ -#include "miopen/kthvalue/solvers.hpp" #include #include @@ -34,6 +33,7 @@ #include #include #include +#include #include #include #include diff --git a/src/solver/kthvalue/forward_kthvalue.cpp b/src/solver/kthvalue/forward_kthvalue.cpp index c9b51725bc..49b2616658 100644 --- a/src/solver/kthvalue/forward_kthvalue.cpp +++ b/src/solver/kthvalue/forward_kthvalue.cpp @@ -58,7 +58,7 @@ bool KthvalueFwd::IsApplicable(const ExecutionContext& /*context*/, return false; if(problem.GetInputDesc().GetSize() > 5) return false; - if(!problem.isContiguous) + if(!problem.isInputContiguous) return false; return true; } diff --git a/test/gtest/kthvalue.cpp b/test/gtest/kthvalue.cpp index 8fbbd10478..ff32840fc1 100644 --- a/test/gtest/kthvalue.cpp +++ b/test/gtest/kthvalue.cpp @@ -52,10 +52,6 @@ struct KthvalueForwardTestFloat16 : KthvalueFwdTest { }; -struct KthvalueForwardTestBFloat16 : KthvalueFwdTest -{ -}; - using namespace kthvalue; TEST_P(KthvalueForwardTestFloat32, KthvalueForwardTest) @@ -72,43 +68,25 @@ TEST_P(KthvalueForwardTestFloat32, KthvalueForwardTest) } }; -// TEST_P(KthvalueForwardTestFloat16, KthvalueForwardTest) -// { -// if(miopen::IsUnset(ENV(MIOPEN_TEST_ALL)) || -// (miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--half"))) -// { -// RunTest(); -// Verify(); -// } -// else -// { -// GTEST_SKIP(); -// } -// }; - -// TEST_P(KthvalueForwardTestBFloat16, KthvalueForwardTest) -// { -// if(miopen::IsUnset(ENV(MIOPEN_TEST_ALL)) || -// (miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--bfloat16"))) -// { -// RunTest(); -// Verify(); -// } -// else -// { -// GTEST_SKIP(); -// } -// }; +TEST_P(KthvalueForwardTestFloat16, KthvalueForwardTest) +{ + if(miopen::IsUnset(ENV(MIOPEN_TEST_ALL)) || + (miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--half"))) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, KthvalueForwardTestFloat32, testing::ValuesIn(KthvalueTestConfigs())); -// INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, -// KthvalueForwardTestFloat16, -// testing::ValuesIn(KthvalueTestConfigs())); - -// INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, -// KthvalueForwardTestBFloat16, -// testing::ValuesIn(KthvalueTestConfigs())); +INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, + KthvalueForwardTestFloat16, + testing::ValuesIn(KthvalueTestConfigs())); } // namespace kthvalue diff --git a/test/gtest/kthvalue.hpp b/test/gtest/kthvalue.hpp index fbb7b7d0d9..fe6272f984 100644 --- a/test/gtest/kthvalue.hpp +++ b/test/gtest/kthvalue.hpp @@ -61,8 +61,8 @@ struct KthvalueTestCase KthvalueTestCase(std::vector dims_, size_t k_, int32_t dim_ = -1, - bool keepDim_ = false, - bool isContiguous_ = true) + bool isContiguous_ = true, + bool keepDim_ = false) : dims(dims_), isContiguous(isContiguous_), dim(dim_), k(k_), keepDim(keepDim_) { } @@ -84,15 +84,11 @@ struct KthvalueTestCase inline std::vector KthvalueTestConfigs() { return { - KthvalueTestCase({4000}, 3), // 1D cont - KthvalueTestCase({100, 500}, 10), // 2D cont - KthvalueTestCase({100, 500}, 10), // 2D non-cont - KthvalueTestCase({10, 20, 300}, 100), // 3D cont - KthvalueTestCase({10, 20, 300}, 1), // 3D non-cont - KthvalueTestCase({8, 3, 2000, 10}, 2000, 2), // 4D cont - KthvalueTestCase({8, 3, 2000, 10}, 2000, 2), // 4D non-cont - KthvalueTestCase({2, 2, 3000, 4, 10}, 1, 2), // 5D cont - KthvalueTestCase({2, 2, 3000, 4, 10}, 1, 2), // 5D cont + KthvalueTestCase({100, 500}, 10), + KthvalueTestCase({10, 20, 300}, 1), + KthvalueTestCase({8, 3, 10, 2000}, 2000), + KthvalueTestCase({2, 2, 4, 10, 3000}, 1), + KthvalueTestCase({2, 2, 4, 10, 3000}, 1, -1, false), }; } From 3b7ac602107ba57cb1b90af5cbfa0d6b618f5f49 Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Wed, 3 Jul 2024 14:56:50 +0700 Subject: [PATCH 09/28] minor changes --- src/kernels/MIOpenKthvalue.cpp | 50 +++++++++++++++++----------------- src/kernels/radix.hpp | 3 -- test/gtest/kthvalue.hpp | 2 +- 3 files changed, 26 insertions(+), 29 deletions(-) diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index 72b974af4c..666b16e519 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -78,7 +78,7 @@ __device__ void kthvalueFwd(const DTYPE* input, * input : {A, B, C, D, E} * output/indices : {A, B, 1, D, E} or {A, B, D, E} * dim = 2 (C) - * => gws = {LOCAL_SIZE, A * B * D * E}, lws = {LOCAL_SIZE, 1} + * => grid = {A * B * D * E, 1}, block = {LOCAL_SIZE, 1} */ size_t lid = threadIdx.x; @@ -141,37 +141,37 @@ __device__ void kthvalueFwd(const DTYPE* input, // Process in ascending order for(size_t j = 0; j < RADIX_SIZE; ++j) { - if(counts[j] >= k) + if(counts[j] < k) { - // Answer is inside this count - if(counts[j] == 1 || pos == 0) + k -= counts[j]; + continue; + } + // Answer is inside this count + if(counts[j] == 1 || pos == 0) + { + // 1. counts[j] == 1 + // We found an unique answer. + // 2. pos == 0 + // There are multiple answers so we return any of them + for(size_t i = lid; i < dim_size; i += LOCAL_SIZE) { - // 1. counts[j] == 1 - // We found an unique answer. - // 2. pos == 0 - // There are multiple answers so we return any of them - for(size_t i = lid; i < dim_size; i += LOCAL_SIZE) + size_t input_idx = idx + i * dim_stride; + DTYPE val_ori = input[input_idx]; + RADIX_TYPE val = encode(val_ori); + if((val & desired_mask) == desired && + GetBitFieldImpl(val, pos, RADIX_BITS) == j) { - size_t input_idx = idx + i * dim_stride; - DTYPE val_ori = input[input_idx]; - RADIX_TYPE val = encode(val_ori); - if((val & desired_mask) == desired && - GetBitFieldImpl(val, pos, RADIX_BITS) == j) - { - // For case 2, this will be non-deterministic. - lval = val_ori; - lidx = i; - } + // For case 2, this will be non-deterministic. + lval = val_ori; + lidx = i; } - found = true; - break; } - desired = SetBitFieldImpl(desired, j, pos, RADIX_BITS); - desired_mask = - SetBitFieldImpl(desired_mask, RADIX_MASK, pos, RADIX_BITS); + found = true; break; } - k -= counts[j]; + desired = SetBitFieldImpl(desired, j, pos, RADIX_BITS); + desired_mask = SetBitFieldImpl(desired_mask, RADIX_MASK, pos, RADIX_BITS); + break; } if(found) break; diff --git a/src/kernels/radix.hpp b/src/kernels/radix.hpp index b75506d382..455519735f 100644 --- a/src/kernels/radix.hpp +++ b/src/kernels/radix.hpp @@ -2,10 +2,7 @@ #include #include -// #define ENCODE encode #define RADIX_TYPE typename RadixType::type -// #define GetBitField(x, pos, bits) GetBitFieldImpl(x, pos, bits) -// #define SetBitField(x, a, pos, bits) SetBitFieldImpl(x, a, pos, bits) #define DEFINE_RADIX_TYPE(DTYPE, cpp_type) \ template <> \ diff --git a/test/gtest/kthvalue.hpp b/test/gtest/kthvalue.hpp index fe6272f984..f4d7c438bd 100644 --- a/test/gtest/kthvalue.hpp +++ b/test/gtest/kthvalue.hpp @@ -88,7 +88,7 @@ inline std::vector KthvalueTestConfigs() KthvalueTestCase({10, 20, 300}, 1), KthvalueTestCase({8, 3, 10, 2000}, 2000), KthvalueTestCase({2, 2, 4, 10, 3000}, 1), - KthvalueTestCase({2, 2, 4, 10, 3000}, 1, -1, false), + KthvalueTestCase({2, 2, 4, 10, 3000}, 1, -1, true, true), }; } From e6707ddd303c2e2caa8715ea30d47309d1fe8b48 Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Thu, 4 Jul 2024 11:48:35 +0700 Subject: [PATCH 10/28] apply the kernel to non-cont tensor --- .../miopen/kthvalue/problem_description.hpp | 2 -- src/kthvalue/problem_description.cpp | 5 ++--- src/solver/kthvalue/forward_kthvalue.cpp | 7 +++---- test/gtest/kthvalue.hpp | 14 +++++++++----- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/include/miopen/kthvalue/problem_description.hpp b/src/include/miopen/kthvalue/problem_description.hpp index a801bc841d..ac8755a5e7 100644 --- a/src/include/miopen/kthvalue/problem_description.hpp +++ b/src/include/miopen/kthvalue/problem_description.hpp @@ -78,7 +78,6 @@ struct FwdProblemDescription : ProblemDescriptionBase MIOPEN_THROW(miopenStatusBadParm, "Reduce: Output and indices tensor dimension lengths do not match."); } - isInputContiguous = inputDesc.IsContiguous(); } bool IsRightLength() const @@ -119,7 +118,6 @@ struct FwdProblemDescription : ProblemDescriptionBase TensorDescriptor indicesDesc; int32_t dim; size_t k; - bool isInputContiguous; bool keepDim; }; diff --git a/src/kthvalue/problem_description.cpp b/src/kthvalue/problem_description.cpp index 00a0a59b20..dec26f3b86 100644 --- a/src/kthvalue/problem_description.cpp +++ b/src/kthvalue/problem_description.cpp @@ -38,8 +38,8 @@ NetworkConfig FwdProblemDescription::MakeNetworkConfig() const auto input_dtype = inputDesc.GetType(); auto size = inputDesc.GetElementSize(); auto dim_size = inputDesc.GetLengths()[dim]; + auto dim_stride = inputDesc.GetStrides()[dim]; int dim_num = inputDesc.GetSize(); - bool is_last_dim = (dim == dim_num - 1); auto output_size = size / dim_size; std::ostringstream ss; @@ -48,9 +48,8 @@ NetworkConfig FwdProblemDescription::MakeNetworkConfig() const ss << "i_dtype" << input_dtype; ss << "dim_size" << dim_size; ss << "dim_num" << dim_num; - ss << "is_last_dim" << is_last_dim; + ss << "dim_stride" << dim_stride; ss << "output_size" << output_size; - ss << "input_contiguous" << isInputContiguous; return NetworkConfig{ss.str()}; } diff --git a/src/solver/kthvalue/forward_kthvalue.cpp b/src/solver/kthvalue/forward_kthvalue.cpp index 49b2616658..81ff449977 100644 --- a/src/solver/kthvalue/forward_kthvalue.cpp +++ b/src/solver/kthvalue/forward_kthvalue.cpp @@ -45,10 +45,11 @@ namespace kthvalue { bool IsImprovementOverROCm(const miopen::kthvalue::FwdProblemDescription& problem) { TensorDescriptor inputDesc = problem.GetInputDesc(); - int dimNum = inputDesc.GetSize(); size_t dimSize = inputDesc.GetLengths()[problem.GetDim()]; + size_t dimStride = inputDesc.GetStrides()[problem.GetDim()]; + size_t dimNum = inputDesc.GetLengths().size(); - return dimNum >= 2 && problem.GetDim() == dimNum - 1 && dimSize >= 300; + return dimNum >= 2 && dimStride == 1 && dimSize >= 300; } bool KthvalueFwd::IsApplicable(const ExecutionContext& /*context*/, @@ -58,8 +59,6 @@ bool KthvalueFwd::IsApplicable(const ExecutionContext& /*context*/, return false; if(problem.GetInputDesc().GetSize() > 5) return false; - if(!problem.isInputContiguous) - return false; return true; } diff --git a/test/gtest/kthvalue.hpp b/test/gtest/kthvalue.hpp index f4d7c438bd..41e3c8192b 100644 --- a/test/gtest/kthvalue.hpp +++ b/test/gtest/kthvalue.hpp @@ -84,11 +84,15 @@ struct KthvalueTestCase inline std::vector KthvalueTestConfigs() { return { - KthvalueTestCase({100, 500}, 10), - KthvalueTestCase({10, 20, 300}, 1), - KthvalueTestCase({8, 3, 10, 2000}, 2000), - KthvalueTestCase({2, 2, 4, 10, 3000}, 1), - KthvalueTestCase({2, 2, 4, 10, 3000}, 1, -1, true, true), + KthvalueTestCase({100, 500}, 10, 1, true, true), // test keep dim + KthvalueTestCase({100, 500}, 10), // 2D cont + KthvalueTestCase({400, 10}, 10, 0, false), // 2D non-cont + KthvalueTestCase({10, 20, 300}, 1), // 3D cont + KthvalueTestCase({350, 10, 20}, 5, 0, false), // 3D non-cont + KthvalueTestCase({8, 3, 10, 2000}, 2000), // 4D cont + KthvalueTestCase({1000, 3, 10, 15}, 1000, 0, false), // 4D non-cont + KthvalueTestCase({2, 2, 4, 10, 3000}, 120), // 5D cont + KthvalueTestCase({3000, 8, 2, 4, 20}, 9, 0, false), // 5D non-cont }; } From 986595186e4250768c15da74dab0c76a0ef6e35c Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Fri, 5 Jul 2024 14:28:59 +0700 Subject: [PATCH 11/28] remove redundant include preprocessors --- driver/kthvalue_driver.hpp | 15 ++-- src/include/miopen/kthvalue/invoke_params.hpp | 4 - .../miopen/kthvalue/problem_description.hpp | 6 -- src/include/miopen/kthvalue/solvers.hpp | 2 - src/kernels/MIOpenKthvalue.cpp | 8 -- src/kernels/radix.hpp | 79 +++++++++---------- src/kthvalue.cpp | 7 +- src/kthvalue/problem_description.cpp | 1 - src/kthvalue_api.cpp | 3 - src/solver/kthvalue/forward_kthvalue.cpp | 9 +-- test/cpu_kthvalue.hpp | 31 +++++++- test/gtest/kthvalue.cpp | 22 ++++++ test/gtest/kthvalue.hpp | 6 +- 13 files changed, 105 insertions(+), 88 deletions(-) diff --git a/driver/kthvalue_driver.hpp b/driver/kthvalue_driver.hpp index 190fb9519c..a367e386ef 100644 --- a/driver/kthvalue_driver.hpp +++ b/driver/kthvalue_driver.hpp @@ -23,23 +23,20 @@ * SOFTWARE. * *******************************************************************************/ - #pragma once #include "InputFlags.hpp" #include "driver.hpp" -#include "miopen/errors.hpp" -#include -#include -#include -#include "miopen/miopen.h" #include "tensor_driver.hpp" #include "timer.hpp" #include "random.hpp" + #include <../test/tensor_holder.hpp> #include <../test/verify.hpp> -#include -#include -#include + +#include +#include +#include + #include template diff --git a/src/include/miopen/kthvalue/invoke_params.hpp b/src/include/miopen/kthvalue/invoke_params.hpp index 36b9f0c03d..701538f9b9 100644 --- a/src/include/miopen/kthvalue/invoke_params.hpp +++ b/src/include/miopen/kthvalue/invoke_params.hpp @@ -25,13 +25,9 @@ *******************************************************************************/ #pragma once -#include "miopen/common.hpp" -#include "miopen/miopen.h" #include #include -#include - namespace miopen { namespace kthvalue { diff --git a/src/include/miopen/kthvalue/problem_description.hpp b/src/include/miopen/kthvalue/problem_description.hpp index ac8755a5e7..c25f89aec0 100644 --- a/src/include/miopen/kthvalue/problem_description.hpp +++ b/src/include/miopen/kthvalue/problem_description.hpp @@ -25,15 +25,9 @@ *******************************************************************************/ #pragma once -#include "miopen/errors.hpp" -#include "miopen/miopen.h" -#include #include #include -#include -#include - namespace miopen { struct NetworkConfig; diff --git a/src/include/miopen/kthvalue/solvers.hpp b/src/include/miopen/kthvalue/solvers.hpp index 7e90192c04..9c58795730 100644 --- a/src/include/miopen/kthvalue/solvers.hpp +++ b/src/include/miopen/kthvalue/solvers.hpp @@ -28,8 +28,6 @@ #include #include -#include - namespace miopen { namespace solver { diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index 666b16e519..8b82460453 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -37,14 +37,6 @@ #define IN_OUT_TYPE float #endif -#ifndef CVT_ACCUM2FLOAT -#define CVT_ACCUM2FLOAT(x) (float_to_bfloat16(x)) -#endif - -#ifndef CVT_FLOAT2ACCUM -#define CVT_FLOAT2ACCUM(x) (bfloat16_to_float(x)) -#endif - #ifndef LOCAL_SIZE #define LOCAL_SIZE 256 #endif diff --git a/src/kernels/radix.hpp b/src/kernels/radix.hpp index 455519735f..c32d32727a 100644 --- a/src/kernels/radix.hpp +++ b/src/kernels/radix.hpp @@ -1,5 +1,34 @@ -#pragma once +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef GUARD_RADIX_H +#define GUARD_RADIX_H + #include +// #include #include #define RADIX_TYPE typename RadixType::type @@ -20,8 +49,8 @@ DEFINE_RADIX_TYPE(int32_t, uint32_t) DEFINE_RADIX_TYPE(int64_t, uint64_t) DEFINE_RADIX_TYPE(bool, bool) DEFINE_RADIX_TYPE(float, uint32_t) -DEFINE_RADIX_TYPE(double, uint64_t) DEFINE_RADIX_TYPE(__half, ushort) +// DEFINE_RADIX_TYPE(__hip_bfloat16, ushort) template ::type> __device__ inline Radix encode(DTYPE v) @@ -38,6 +67,12 @@ __device__ inline Radix encode(DTYPE v) { return 9223372036854775808ull + v; } + // else if constexpr(std::is_same<__hip_bfloat16, DTYPE>::value) + // { + // Radix x = __bfloat16_as_ushort(v); + // Radix mask = (x & 0x8000) ? 0xffff : 0x8000; + // return (v == v) ? (x ^ mask) : 0xffff; + // } else if constexpr(std::is_same<__half, DTYPE>::value) { Radix x = __half_as_ushort(v); @@ -50,44 +85,6 @@ __device__ inline Radix encode(DTYPE v) Radix mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; return (v == v) ? (x ^ mask) : 0xffffffff; } - else if constexpr(std::is_same::value) - { - Radix x = __double_as_ulonglong(v); - Radix mask = -((x >> 63)) | 0x8000000000000000; - return (v == v) ? (x ^ mask) : 0xffffffffffffffff; - } -} - -template -__device__ inline DTYPE decode(Radix v) -{ - if constexpr(std::is_same::value) - { - return v; - } - else if constexpr(std::is_same::value) - { - return v - 2147483648u; - } - else if constexpr(std::is_same::value) - { - return v - 9223372036854775808ull; - } - else if constexpr(std::is_same<__half, DTYPE>::value) - { - Radix mask = (v & 0x8000) ? 0x8000 : 0xffff; - return __ushort_as_half((ushort)(v ^ mask)); - } - else if constexpr(std::is_same::value) - { - Radix mask = (v & 0x80000000) ? 0x80000000 : 0xffffffff; - return __uint_as_float(v ^ mask); - } - else if constexpr(std::is_same::value) - { - Radix mask = ((v >> 63) - 1) | 0x8000000000000000; - return __ulonglong_as_double(v ^ mask); - } } // returns x[pos+bits:pos] @@ -103,3 +100,5 @@ __device__ inline Radix SetBitFieldImpl(Radix x, Radix a, int pos, int bits) { return x | (a << pos); } + +#endif // GUARD_RADIX_H diff --git a/src/kthvalue.cpp b/src/kthvalue.cpp index 6f93d76573..bb3a4331d0 100644 --- a/src/kthvalue.cpp +++ b/src/kthvalue.cpp @@ -24,15 +24,14 @@ * *******************************************************************************/ -#include "miopen/miopen.h" #include "miopen/kthvalue/invoke_params.hpp" #include "miopen/kthvalue/problem_description.hpp" #include "miopen/kthvalue/solvers.hpp" -#include + +#include #include -#include -#include #include +#include namespace miopen { diff --git a/src/kthvalue/problem_description.cpp b/src/kthvalue/problem_description.cpp index dec26f3b86..ef1f4c34d3 100644 --- a/src/kthvalue/problem_description.cpp +++ b/src/kthvalue/problem_description.cpp @@ -25,7 +25,6 @@ *******************************************************************************/ #include -#include #include diff --git a/src/kthvalue_api.cpp b/src/kthvalue_api.cpp index db1833d369..03405e845e 100644 --- a/src/kthvalue_api.cpp +++ b/src/kthvalue_api.cpp @@ -24,11 +24,8 @@ * *******************************************************************************/ -#include "miopen/miopen.h" #include -#include #include -#include #include inline std::ostream& operator<<(std::ostream& os, const std::vector& v) diff --git a/src/solver/kthvalue/forward_kthvalue.cpp b/src/solver/kthvalue/forward_kthvalue.cpp index 81ff449977..fb0d51dede 100644 --- a/src/solver/kthvalue/forward_kthvalue.cpp +++ b/src/solver/kthvalue/forward_kthvalue.cpp @@ -24,17 +24,14 @@ * *******************************************************************************/ -#include "miopen/errors.hpp" -#include "miopen/kthvalue/problem_description.hpp" -#include "miopen/miopen.h" -#include "miopen/tensor.hpp" -#include "miopen/tensor_view_utils.hpp" +#include +#include +#include #include #include #include #include #include -#include namespace miopen { diff --git a/test/cpu_kthvalue.hpp b/test/cpu_kthvalue.hpp index 269f22e3c4..8d3439a53c 100644 --- a/test/cpu_kthvalue.hpp +++ b/test/cpu_kthvalue.hpp @@ -1,9 +1,36 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + #pragma once -#include "miopen/tensor.hpp" #include "tensor_holder.hpp" #include "tensor_view.hpp" -#include "miopen/tensor_view_utils.hpp" + +#include + #include template diff --git a/test/gtest/kthvalue.cpp b/test/gtest/kthvalue.cpp index ff32840fc1..9a2274c0d6 100644 --- a/test/gtest/kthvalue.cpp +++ b/test/gtest/kthvalue.cpp @@ -52,6 +52,10 @@ struct KthvalueForwardTestFloat16 : KthvalueFwdTest { }; +// struct KthvalueForwardTestBFloat16 : KthvalueFwdTest +// { +// }; + using namespace kthvalue; TEST_P(KthvalueForwardTestFloat32, KthvalueForwardTest) @@ -82,6 +86,20 @@ TEST_P(KthvalueForwardTestFloat16, KthvalueForwardTest) } }; +// TEST_P(KthvalueForwardTestBFloat16, KthvalueForwardTest) +// { +// if(miopen::IsUnset(ENV(MIOPEN_TEST_ALL)) || +// (miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--bfloat16"))) +// { +// RunTest(); +// Verify(); +// } +// else +// { +// GTEST_SKIP(); +// } +// }; + INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, KthvalueForwardTestFloat32, testing::ValuesIn(KthvalueTestConfigs())); @@ -89,4 +107,8 @@ INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, KthvalueForwardTestFloat16, testing::ValuesIn(KthvalueTestConfigs())); + +// INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, +// KthvalueForwardTestBFloat16, +// testing::ValuesIn(KthvalueTestConfigs())); } // namespace kthvalue diff --git a/test/gtest/kthvalue.hpp b/test/gtest/kthvalue.hpp index 41e3c8192b..2aa7e6fd41 100644 --- a/test/gtest/kthvalue.hpp +++ b/test/gtest/kthvalue.hpp @@ -26,16 +26,16 @@ #include "../driver/tensor_driver.hpp" #include "cpu_kthvalue.hpp" #include "get_handle.hpp" -#include "miopen/allocator.hpp" -#include "miopen/tensor.hpp" + #include "random.hpp" #include "tensor_holder.hpp" #include "verify.hpp" -#include + #include #include #include +#include struct KthvalueTestCase { std::vector dims; From 58693f33497cc2ba0759ed57b9101632001bc360 Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Fri, 5 Jul 2024 18:02:02 +0700 Subject: [PATCH 12/28] backup --- driver/kthvalue_driver.hpp | 2 +- include/miopen/miopen.h | 9 +++++---- src/CMakeLists.txt | 2 +- .../miopen/kthvalue/problem_description.hpp | 15 ++++++++++++--- src/include/miopen/tensor.hpp | 1 - src/include/miopen/tensor_view_utils.hpp | 6 +++--- src/kernels/MIOpenKthvalue.cpp | 6 ++++-- src/kernels/radix.hpp | 19 ++++++++++--------- 8 files changed, 36 insertions(+), 24 deletions(-) diff --git a/driver/kthvalue_driver.hpp b/driver/kthvalue_driver.hpp index a367e386ef..f2d1f434fc 100644 --- a/driver/kthvalue_driver.hpp +++ b/driver/kthvalue_driver.hpp @@ -222,7 +222,7 @@ int KthvalueDriver::AddCmdLineArgs() inflags.AddInputFlag("forw", 'F', "1", "Run only Forward (Default=1)", "int"); inflags.AddTensorFlag( "dim-lengths", 'D', "256x4x2", "The dimensional lengths of the input tensor"); - inflags.AddInputFlag("k", 'k', "1", "dim (Default=1)", "int"); + inflags.AddInputFlag("k", 'k', "1", "k (Default=1)", "int"); inflags.AddInputFlag("dim", 'd', "-1", "dim (Default=-1)", "int"); inflags.AddInputFlag("keep-dim", 'K', diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index f8c8ad7774..522cc08375 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -524,7 +524,7 @@ typedef enum miopenActivationABS = 5, /*!< Absolute value \f$abs(x)\f$ */ miopenActivationPOWER = 6, /*!< Scaled and shifted power \f$(\alpha + \beta * x)^{gamma}\f$ */ miopenActivationCLIPPEDRELU = - 7, /*!< Clipped Rectified Linear Unit \f$ min(\alpha, max(0,x)) \f$ */ + 7, /*!< Clipped Rectified Linear Unit \f$ min(\alpha, max(0,x)) \f$ */ miopenActivationLEAKYRELU = 8, /*!< Leaky Rectified Linear Unit \f$ \alpha * x | x <= 0; x | x > 0 \f$ */ miopenActivationELU = @@ -7688,10 +7688,11 @@ MIOPEN_EXPORT miopenStatus_t miopenRoPEBackward(miopenHandle_t handle, * @param input Data tensor input (input) * @param outputDesc Tensor descriptor for output tensor (input) * @param output Data tensor output (output) - * @param indices Data tensor index (output) + * @param indices Data tensor indices (output) * @param k The k-th smallest element(input) - * @param dim The dimension to find the kth value along(input) - * @param keepDim Whether the output tensor has dim retained or not(input) + * @param dim The dimension to find the kth value along (Default = -1)(input) + * @param keepDim Whether the output tensor has dim retained or not (Default = + * false)(input) * @return miopenStatus_t */ MIOPEN_EXPORT miopenStatus_t miopenKthvalueForward(miopenHandle_t handle, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index a595ffac1d..b8009420a9 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -643,6 +643,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN groupnorm.cpp getitem.cpp kernel_cache.cpp + kthvalue.cpp layernorm.cpp lrn.cpp mlo_dir_conv.cpp @@ -666,7 +667,6 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN pooling.cpp t5layernorm.cpp ocl/fusionopconvocl.cpp - kthvalue.cpp ocl/fusionopbiasbnactivocl.cpp reducecalculation.cpp reduceextreme.cpp diff --git a/src/include/miopen/kthvalue/problem_description.hpp b/src/include/miopen/kthvalue/problem_description.hpp index c25f89aec0..171bb81153 100644 --- a/src/include/miopen/kthvalue/problem_description.hpp +++ b/src/include/miopen/kthvalue/problem_description.hpp @@ -49,10 +49,10 @@ struct FwdProblemDescription : ProblemDescriptionBase k(k_), keepDim(keepDim_) { - if(k > inputDesc.GetLengths()[dim]) + if(k < 1 || k > inputDesc.GetLengths()[dim]) { MIOPEN_THROW(miopenStatusBadParm, - "Kthvalue: k must be less than the size of the dimension"); + "Kthvalue: selected number k out of range for dimension"); } if(dim < 0 || dim >= inputDesc.GetSize()) { @@ -67,7 +67,7 @@ struct FwdProblemDescription : ProblemDescriptionBase MIOPEN_THROW(miopenStatusBadParm, "Reduce: Input and output tensor dimension lengths do not match."); } - if(!outputDesc.IsSameLength(indicesDesc)) + if(outputDesc.GetLengths() != indicesDesc.GetLengths()) { MIOPEN_THROW(miopenStatusBadParm, "Reduce: Output and indices tensor dimension lengths do not match."); @@ -79,6 +79,15 @@ struct FwdProblemDescription : ProblemDescriptionBase if(inputDesc.GetLengths().size() == 1) return true; + if(keepDim && inputDesc.GetSize() != outputDesc.GetSize()) + { + return false; + } + if(!keepDim && inputDesc.GetSize() != outputDesc.GetSize() + 1) + { + return false; + } + int32_t posOut = 0; for(int32_t i = 0; i < inputDesc.GetLengths().size(); i++) { diff --git a/src/include/miopen/tensor.hpp b/src/include/miopen/tensor.hpp index b962f3d033..f4d2b2dca7 100644 --- a/src/include/miopen/tensor.hpp +++ b/src/include/miopen/tensor.hpp @@ -233,7 +233,6 @@ struct MIOPEN_INTERNALS_EXPORT TensorDescriptor : miopenTensorDescriptor /// Checks only lengths. bool AllLengthsFitIntoInt() const; - bool IsSameLength(const TensorDescriptor& otherDesc) const; bool operator==(const TensorDescriptor& rhs) const; bool operator!=(const TensorDescriptor& rhs) const; bool operator<(const TensorDescriptor& rhs) const; diff --git a/src/include/miopen/tensor_view_utils.hpp b/src/include/miopen/tensor_view_utils.hpp index 12b57e2055..55382605c2 100644 --- a/src/include/miopen/tensor_view_utils.hpp +++ b/src/include/miopen/tensor_view_utils.hpp @@ -80,10 +80,10 @@ inline void slice_tv(tensor_view_t& tensor_view, int32_t sliceCount, const in } } -template -inline tensor_view_t get_tv_without_dim(const tensor_view_t& origin_tv, int selected_dim) +template +inline tensor_view_t get_tv_without_dim(const tensor_view_t& origin_tv, int selected_dim) { - tensor_view_t res; + tensor_view_t res; for(int i = 0; i < N; ++i) { if(i == selected_dim) diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index 8b82460453..1d7a339fb7 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -26,6 +26,7 @@ #ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS #include +#include #include #endif @@ -101,6 +102,7 @@ __device__ void kthvalueFwd(const DTYPE* input, { size_t input_idx = idx + i * dim_stride; RADIX_TYPE val = encode(input[input_idx]); + // printf("%u\n", val); if((val & desired_mask) == desired) { ++counts[GetBitFieldImpl(val, pos, RADIX_BITS)]; @@ -161,8 +163,8 @@ __device__ void kthvalueFwd(const DTYPE* input, found = true; break; } - desired = SetBitFieldImpl(desired, j, pos, RADIX_BITS); - desired_mask = SetBitFieldImpl(desired_mask, RADIX_MASK, pos, RADIX_BITS); + desired = SetBitFieldImpl(desired, j, pos); + desired_mask = SetBitFieldImpl(desired_mask, RADIX_MASK, pos); break; } if(found) diff --git a/src/kernels/radix.hpp b/src/kernels/radix.hpp index c32d32727a..bd28391e26 100644 --- a/src/kernels/radix.hpp +++ b/src/kernels/radix.hpp @@ -28,7 +28,6 @@ #define GUARD_RADIX_H #include -// #include #include #define RADIX_TYPE typename RadixType::type @@ -50,11 +49,12 @@ DEFINE_RADIX_TYPE(int64_t, uint64_t) DEFINE_RADIX_TYPE(bool, bool) DEFINE_RADIX_TYPE(float, uint32_t) DEFINE_RADIX_TYPE(__half, ushort) -// DEFINE_RADIX_TYPE(__hip_bfloat16, ushort) +DEFINE_RADIX_TYPE(ushort, ushort) // bfloat16 template ::type> __device__ inline Radix encode(DTYPE v) { + // convert negative number to positive representation in Radix type. if constexpr(std::is_same::value) { return v; @@ -67,12 +67,13 @@ __device__ inline Radix encode(DTYPE v) { return 9223372036854775808ull + v; } - // else if constexpr(std::is_same<__hip_bfloat16, DTYPE>::value) - // { - // Radix x = __bfloat16_as_ushort(v); - // Radix mask = (x & 0x8000) ? 0xffff : 0x8000; - // return (v == v) ? (x ^ mask) : 0xffff; - // } + // bfloat16 is passed as ushort in kernel + else if constexpr(std::is_same::value) + { + Radix x = v; + Radix mask = (x & 0x8000) ? 0xffff : 0x8000; + return (v == v) ? (x ^ mask) : 0xffff; + } else if constexpr(std::is_same<__half, DTYPE>::value) { Radix x = __half_as_ushort(v); @@ -96,7 +97,7 @@ __device__ inline Radix GetBitFieldImpl(Radix x, int pos, int bits) // x[pos+bits:pos] = a template -__device__ inline Radix SetBitFieldImpl(Radix x, Radix a, int pos, int bits) +__device__ inline Radix SetBitFieldImpl(Radix x, Radix a, int pos) { return x | (a << pos); } From c3d229a4e668b2e3f72ee9d88f3e329d6236d72b Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Fri, 5 Jul 2024 18:05:49 +0700 Subject: [PATCH 13/28] resolve comments --- driver/kthvalue_driver.hpp | 2 +- include/miopen/miopen.h | 2 +- src/include/miopen/tensor_view_utils.hpp | 2 +- src/solver/kthvalue/forward_kthvalue.cpp | 2 +- test/cpu_kthvalue.hpp | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/driver/kthvalue_driver.hpp b/driver/kthvalue_driver.hpp index f2d1f434fc..4ac646830b 100644 --- a/driver/kthvalue_driver.hpp +++ b/driver/kthvalue_driver.hpp @@ -54,7 +54,7 @@ void mloKthvalueFwdRunHost(TIO* input, size_t dimSize = inputDesc.GetLengths()[dim]; size_t dimStride = inputDesc.GetStrides()[dim]; auto inputTv = miopen::get_inner_expanded_tv<5>(miopen::deref(pInputDesc)); - auto inputTvWithoutDim = miopen::get_tv_without_dim<5, 4>(inputTv, dim); + auto inputTvWithoutDim = miopen::get_tv_without_dim<5>(inputTv, dim); auto outputTv = miopen::get_inner_expanded_tv<5>(miopen::deref(outputDesc)); auto indicesTv = miopen::get_inner_expanded_tv<5>(miopen::deref(indicesDesc)); diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 522cc08375..4f65bfb28d 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -524,7 +524,7 @@ typedef enum miopenActivationABS = 5, /*!< Absolute value \f$abs(x)\f$ */ miopenActivationPOWER = 6, /*!< Scaled and shifted power \f$(\alpha + \beta * x)^{gamma}\f$ */ miopenActivationCLIPPEDRELU = - 7, /*!< Clipped Rectified Linear Unit \f$ min(\alpha, max(0,x)) \f$ */ + 7, /*!< Clipped Rectified Linear Unit \f$ min(\alpha, max(0,x)) \f$ */ miopenActivationLEAKYRELU = 8, /*!< Leaky Rectified Linear Unit \f$ \alpha * x | x <= 0; x | x > 0 \f$ */ miopenActivationELU = diff --git a/src/include/miopen/tensor_view_utils.hpp b/src/include/miopen/tensor_view_utils.hpp index 55382605c2..d77993fdd0 100644 --- a/src/include/miopen/tensor_view_utils.hpp +++ b/src/include/miopen/tensor_view_utils.hpp @@ -81,7 +81,7 @@ inline void slice_tv(tensor_view_t& tensor_view, int32_t sliceCount, const in } template -inline tensor_view_t get_tv_without_dim(const tensor_view_t& origin_tv, int selected_dim) +inline tensor_view_t get_tv_without_dim(const tensor_view_t& origin_tv, int selected_dim) { tensor_view_t res; for(int i = 0; i < N; ++i) diff --git a/src/solver/kthvalue/forward_kthvalue.cpp b/src/solver/kthvalue/forward_kthvalue.cpp index fb0d51dede..b550fbd889 100644 --- a/src/solver/kthvalue/forward_kthvalue.cpp +++ b/src/solver/kthvalue/forward_kthvalue.cpp @@ -115,7 +115,7 @@ ConvSolution KthvalueFwd::GetSolution(const ExecutionContext& context, size_t dim_stride = params.inputDesc->GetStrides()[params.dim]; auto input_tv = get_inner_expanded_tv<5>(deref(params.inputDesc)); - auto input_tv_without_selected_dim = get_tv_without_dim<5, 4>(input_tv, params.dim); + auto input_tv_without_selected_dim = get_tv_without_dim<5>(input_tv, params.dim); auto output_tv = get_inner_expanded_tv<5>(deref(params.outputDesc)); auto indices_tv = get_inner_expanded_tv<5>(deref(params.indicesDesc)); diff --git a/test/cpu_kthvalue.hpp b/test/cpu_kthvalue.hpp index 8d3439a53c..e1260a29f1 100644 --- a/test/cpu_kthvalue.hpp +++ b/test/cpu_kthvalue.hpp @@ -45,7 +45,7 @@ void cpu_kthvalue(tensor input, size_t dimSize = input.desc.GetLengths()[dim]; size_t dimStride = input.desc.GetStrides()[dim]; auto inputTv = miopen::get_inner_expanded_tv<5>(input.desc); - auto inputTvWithoutDim = miopen::get_tv_without_dim<5, 4>(inputTv, dim); + auto inputTvWithoutDim = miopen::get_tv_without_dim<5>(inputTv, dim); auto outputTv = miopen::get_inner_expanded_tv<5>(outputHost.desc); auto indicesTv = miopen::get_inner_expanded_tv<5>(indiceDesc); From a552ba289c9ec9505bc7699b342b8fa8c4834521 Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Fri, 5 Jul 2024 18:11:26 +0700 Subject: [PATCH 14/28] resolve cmts --- include/miopen/miopen.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 4f65bfb28d..b73498d919 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -524,7 +524,7 @@ typedef enum miopenActivationABS = 5, /*!< Absolute value \f$abs(x)\f$ */ miopenActivationPOWER = 6, /*!< Scaled and shifted power \f$(\alpha + \beta * x)^{gamma}\f$ */ miopenActivationCLIPPEDRELU = - 7, /*!< Clipped Rectified Linear Unit \f$ min(\alpha, max(0,x)) \f$ */ + 7, /*!< Clipped Rectified Linear Unit \f$ min(\alpha, max(0,x)) \f$ */ miopenActivationLEAKYRELU = 8, /*!< Leaky Rectified Linear Unit \f$ \alpha * x | x <= 0; x | x > 0 \f$ */ miopenActivationELU = @@ -7689,6 +7689,7 @@ MIOPEN_EXPORT miopenStatus_t miopenRoPEBackward(miopenHandle_t handle, * @param outputDesc Tensor descriptor for output tensor (input) * @param output Data tensor output (output) * @param indices Data tensor indices (output) + * @param indicesDesc Tensor descriptor for indices tensor (input) * @param k The k-th smallest element(input) * @param dim The dimension to find the kth value along (Default = -1)(input) * @param keepDim Whether the output tensor has dim retained or not (Default = From 8ddbb85fb070fc0c8e42cfdfc92f6bec2af3a40a Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Fri, 5 Jul 2024 18:12:31 +0700 Subject: [PATCH 15/28] fix cmts --- include/miopen/miopen.h | 2 +- src/kthvalue.cpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index b73498d919..5f85a93333 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -524,7 +524,7 @@ typedef enum miopenActivationABS = 5, /*!< Absolute value \f$abs(x)\f$ */ miopenActivationPOWER = 6, /*!< Scaled and shifted power \f$(\alpha + \beta * x)^{gamma}\f$ */ miopenActivationCLIPPEDRELU = - 7, /*!< Clipped Rectified Linear Unit \f$ min(\alpha, max(0,x)) \f$ */ + 7, /*!< Clipped Rectified Linear Unit \f$ min(\alpha, max(0,x)) \f$ */ miopenActivationLEAKYRELU = 8, /*!< Leaky Rectified Linear Unit \f$ \alpha * x | x <= 0; x | x > 0 \f$ */ miopenActivationELU = diff --git a/src/kthvalue.cpp b/src/kthvalue.cpp index bb3a4331d0..506bd39861 100644 --- a/src/kthvalue.cpp +++ b/src/kthvalue.cpp @@ -48,7 +48,7 @@ miopenStatus_t KthvalueForward(Handle& handle, { if(dim < 0) { - dim += indicesDesc.GetSize(); + dim += inputDesc.GetSize(); } const auto problem = From ab4ee7341c8eb0c8b30d68ca424940236002914f Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Mon, 8 Jul 2024 13:49:40 +0700 Subject: [PATCH 16/28] resolve cmts --- src/include/miopen/tensor_view_utils.hpp | 2 +- src/kernels/MIOpenKthvalue.cpp | 1 - test/gtest/kthvalue.cpp | 38 ++++++++++++------------ 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/src/include/miopen/tensor_view_utils.hpp b/src/include/miopen/tensor_view_utils.hpp index d77993fdd0..67820f87c8 100644 --- a/src/include/miopen/tensor_view_utils.hpp +++ b/src/include/miopen/tensor_view_utils.hpp @@ -104,4 +104,4 @@ inline tensor_view_t get_tv_without_dim(const tensor_view_t& origin_tv } // namespace miopen -#endif // MIOPEN_TENSOR_REORDER_UTIL_HPP_ +#endif // MIOPEN_TENSOR_VIEW_UTIL_HPP_ diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index 1d7a339fb7..4aeb9d2f1b 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -102,7 +102,6 @@ __device__ void kthvalueFwd(const DTYPE* input, { size_t input_idx = idx + i * dim_stride; RADIX_TYPE val = encode(input[input_idx]); - // printf("%u\n", val); if((val & desired_mask) == desired) { ++counts[GetBitFieldImpl(val, pos, RADIX_BITS)]; diff --git a/test/gtest/kthvalue.cpp b/test/gtest/kthvalue.cpp index 9a2274c0d6..32324b46ac 100644 --- a/test/gtest/kthvalue.cpp +++ b/test/gtest/kthvalue.cpp @@ -52,9 +52,9 @@ struct KthvalueForwardTestFloat16 : KthvalueFwdTest { }; -// struct KthvalueForwardTestBFloat16 : KthvalueFwdTest -// { -// }; +struct KthvalueForwardTestBFloat16 : KthvalueFwdTest +{ +}; using namespace kthvalue; @@ -86,19 +86,19 @@ TEST_P(KthvalueForwardTestFloat16, KthvalueForwardTest) } }; -// TEST_P(KthvalueForwardTestBFloat16, KthvalueForwardTest) -// { -// if(miopen::IsUnset(ENV(MIOPEN_TEST_ALL)) || -// (miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--bfloat16"))) -// { -// RunTest(); -// Verify(); -// } -// else -// { -// GTEST_SKIP(); -// } -// }; +TEST_P(KthvalueForwardTestBFloat16, KthvalueForwardTest) +{ + if(miopen::IsUnset(ENV(MIOPEN_TEST_ALL)) || + (miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--bfloat16"))) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, KthvalueForwardTestFloat32, @@ -108,7 +108,7 @@ INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, KthvalueForwardTestFloat16, testing::ValuesIn(KthvalueTestConfigs())); -// INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, -// KthvalueForwardTestBFloat16, -// testing::ValuesIn(KthvalueTestConfigs())); +INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, + KthvalueForwardTestBFloat16, + testing::ValuesIn(KthvalueTestConfigs())); } // namespace kthvalue From 81af7f3396aed6666b829831325b91957ba8a011 Mon Sep 17 00:00:00 2001 From: Bui Chi Trung Date: Tue, 9 Jul 2024 10:15:03 +0700 Subject: [PATCH 17/28] update include header --- src/kthvalue.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/kthvalue.cpp b/src/kthvalue.cpp index 506bd39861..e92aa8b454 100644 --- a/src/kthvalue.cpp +++ b/src/kthvalue.cpp @@ -24,10 +24,9 @@ * *******************************************************************************/ -#include "miopen/kthvalue/invoke_params.hpp" -#include "miopen/kthvalue/problem_description.hpp" -#include "miopen/kthvalue/solvers.hpp" - +#include +#include +#include #include #include #include From aec61d2b018a7eba1427a9c887285452e6db2771 Mon Sep 17 00:00:00 2001 From: BuiChiTrung Date: Tue, 20 Aug 2024 10:47:07 +0700 Subject: [PATCH 18/28] remove outdate functions --- src/CMakeLists.txt | 1 - .../miopen/kthvalue/problem_description.hpp | 6 +++--- src/kthvalue.cpp | 2 +- src/kthvalue/problem_description.cpp | 2 +- src/solver/kthvalue/forward_kthvalue.cpp | 2 +- test/gtest/kthvalue.cpp | 20 ++++++++++++------- 6 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index b8009420a9..f086b3eac4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -470,7 +470,6 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/stride_array.hpp kernels/tensor_view.hpp kernels/utilities.inc - kernels/warp_shuffle.hpp kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1536vgprs_fp16_fp16acc_f2x3_c16_stride1.inc kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1536vgprs_fp16_fp16acc_f2x3_c32_stride1.inc kernels/winograd/Conv_Winograd_Fury_v2_4_1_gfx11_1024vgprs_fp16_fp16acc_f2x3_c16_stride1.inc diff --git a/src/include/miopen/kthvalue/problem_description.hpp b/src/include/miopen/kthvalue/problem_description.hpp index 171bb81153..d597ed0e45 100644 --- a/src/include/miopen/kthvalue/problem_description.hpp +++ b/src/include/miopen/kthvalue/problem_description.hpp @@ -54,7 +54,7 @@ struct FwdProblemDescription : ProblemDescriptionBase MIOPEN_THROW(miopenStatusBadParm, "Kthvalue: selected number k out of range for dimension"); } - if(dim < 0 || dim >= inputDesc.GetSize()) + if(dim < 0 || dim >= inputDesc.GetNumDims()) { MIOPEN_THROW(miopenStatusBadParm, "Kthvalue: dim doesn't not exist"); } @@ -79,11 +79,11 @@ struct FwdProblemDescription : ProblemDescriptionBase if(inputDesc.GetLengths().size() == 1) return true; - if(keepDim && inputDesc.GetSize() != outputDesc.GetSize()) + if(keepDim && inputDesc.GetNumDims() != outputDesc.GetNumDims()) { return false; } - if(!keepDim && inputDesc.GetSize() != outputDesc.GetSize() + 1) + if(!keepDim && inputDesc.GetNumDims() != outputDesc.GetNumDims() + 1) { return false; } diff --git a/src/kthvalue.cpp b/src/kthvalue.cpp index e92aa8b454..a9f4e73067 100644 --- a/src/kthvalue.cpp +++ b/src/kthvalue.cpp @@ -47,7 +47,7 @@ miopenStatus_t KthvalueForward(Handle& handle, { if(dim < 0) { - dim += inputDesc.GetSize(); + dim += inputDesc.GetNumDims(); } const auto problem = diff --git a/src/kthvalue/problem_description.cpp b/src/kthvalue/problem_description.cpp index ef1f4c34d3..9f3cec4f64 100644 --- a/src/kthvalue/problem_description.cpp +++ b/src/kthvalue/problem_description.cpp @@ -38,7 +38,7 @@ NetworkConfig FwdProblemDescription::MakeNetworkConfig() const auto size = inputDesc.GetElementSize(); auto dim_size = inputDesc.GetLengths()[dim]; auto dim_stride = inputDesc.GetStrides()[dim]; - int dim_num = inputDesc.GetSize(); + int dim_num = inputDesc.GetNumDims(); auto output_size = size / dim_size; std::ostringstream ss; diff --git a/src/solver/kthvalue/forward_kthvalue.cpp b/src/solver/kthvalue/forward_kthvalue.cpp index b550fbd889..2639a41b01 100644 --- a/src/solver/kthvalue/forward_kthvalue.cpp +++ b/src/solver/kthvalue/forward_kthvalue.cpp @@ -54,7 +54,7 @@ bool KthvalueFwd::IsApplicable(const ExecutionContext& /*context*/, { if(!IsImprovementOverROCm(problem)) return false; - if(problem.GetInputDesc().GetSize() > 5) + if(problem.GetInputDesc().GetNumDims() > 5) return false; return true; } diff --git a/test/gtest/kthvalue.cpp b/test/gtest/kthvalue.cpp index 32324b46ac..24f470f625 100644 --- a/test/gtest/kthvalue.cpp +++ b/test/gtest/kthvalue.cpp @@ -36,7 +36,7 @@ namespace kthvalue { std::string GetFloatArg() { - const auto& tmp = miopen::GetStringEnv(ENV(MIOPEN_TEST_FLOAT_ARG)); + const auto& tmp = env::value(MIOPEN_TEST_FLOAT_ARG); if(tmp.empty()) { return ""; @@ -44,6 +44,15 @@ std::string GetFloatArg() return tmp; } +bool CheckFloatArg(std::string arg) +{ + if(!MIOPEN_TEST_ALL || (env::enabled(MIOPEN_TEST_ALL) && GetFloatArg() == arg)) + { + return true; + } + return false; +} + struct KthvalueForwardTestFloat32 : KthvalueFwdTest { }; @@ -60,8 +69,7 @@ using namespace kthvalue; TEST_P(KthvalueForwardTestFloat32, KthvalueForwardTest) { - if(miopen::IsUnset(ENV(MIOPEN_TEST_ALL)) || - (miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--float"))) + if(CheckFloatArg("--float")) { RunTest(); Verify(); @@ -74,8 +82,7 @@ TEST_P(KthvalueForwardTestFloat32, KthvalueForwardTest) TEST_P(KthvalueForwardTestFloat16, KthvalueForwardTest) { - if(miopen::IsUnset(ENV(MIOPEN_TEST_ALL)) || - (miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--half"))) + if(CheckFloatArg("--half")) { RunTest(); Verify(); @@ -88,8 +95,7 @@ TEST_P(KthvalueForwardTestFloat16, KthvalueForwardTest) TEST_P(KthvalueForwardTestBFloat16, KthvalueForwardTest) { - if(miopen::IsUnset(ENV(MIOPEN_TEST_ALL)) || - (miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--bfloat16"))) + if(CheckFloatArg("--bfloat16")) { RunTest(); Verify(); From c485922762baf03de48e9bd3c2a4f311efee242b Mon Sep 17 00:00:00 2001 From: BuiChiTrung Date: Wed, 21 Aug 2024 11:40:01 +0700 Subject: [PATCH 19/28] resolve comments and fix window build fail --- driver/kthvalue_driver.hpp | 3 +-- include/miopen/miopen.h | 2 +- src/include/miopen/kthvalue.hpp | 2 +- src/kernels/MIOpenKthvalue.cpp | 23 ++++++++--------------- src/kernels/radix.hpp | 18 +++++++++--------- 5 files changed, 20 insertions(+), 28 deletions(-) diff --git a/driver/kthvalue_driver.hpp b/driver/kthvalue_driver.hpp index 4ac646830b..75f7e5b535 100644 --- a/driver/kthvalue_driver.hpp +++ b/driver/kthvalue_driver.hpp @@ -195,8 +195,7 @@ int KthvalueDriver::GetandSetData() SetTensorNd(inputDesc, inDims, inStride, data_type); SetTensorNd(outputDesc, outDims, data_type); - // miopenDataType_t doesn't support size_t, I use double instead (both types use 64 bits) - SetTensorNd(indicesDesc, outDims, miopen_type{}); + SetTensorNd(indicesDesc, outDims, miopenInt64); return 0; } diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 5f85a93333..48a20cf57c 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -5015,7 +5015,7 @@ MIOPEN_EXPORT miopenStatus_t miopenCTCLoss(miopenHandle_t handle, */ /*! @enum miopenRNGType_t - * random number generator type + * random number generator typnte */ typedef enum { diff --git a/src/include/miopen/kthvalue.hpp b/src/include/miopen/kthvalue.hpp index 990ec77ddb..3cb9527920 100644 --- a/src/include/miopen/kthvalue.hpp +++ b/src/include/miopen/kthvalue.hpp @@ -33,7 +33,7 @@ namespace miopen { struct Handle; struct TensorDescriptor; -miopenStatus_t KthvalueForward(Handle& handle, +MIOPEN_INTERNALS_EXPORT miopenStatus_t KthvalueForward(Handle& handle, const TensorDescriptor& inputDesc, ConstData_t input, const TensorDescriptor& outputDesc, diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index 4aeb9d2f1b..3bfb4dd45a 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -42,17 +42,7 @@ #define LOCAL_SIZE 256 #endif -#ifndef RADIX_BITS -#define RADIX_BITS 2 -#endif - -#ifndef RADIX_SIZE -#define RADIX_SIZE (1 << RADIX_BITS) -#endif - -#ifndef RADIX_MASK -#define RADIX_MASK (RADIX_SIZE - 1) -#endif +using RADIX_TYPE = typename RadixType::type; template __device__ void kthvalueFwd(const DTYPE* input, @@ -73,6 +63,9 @@ __device__ void kthvalueFwd(const DTYPE* input, * dim = 2 (C) * => grid = {A * B * D * E, 1}, block = {LOCAL_SIZE, 1} */ + const int RADIX_BITS = 2; + const int RADIX_SIZE = 1 << RADIX_BITS; + const int RADIX_MASK = RADIX_SIZE - 1; size_t lid = threadIdx.x; size_t gid = blockIdx.x; @@ -83,7 +76,7 @@ __device__ void kthvalueFwd(const DTYPE* input, __shared__ size_t lsum[LOCAL_SIZE][RADIX_SIZE]; __shared__ DTYPE lval; - __shared__ long lidx; + __shared__ size_t lidx; size_t counts[RADIX_SIZE]; RADIX_TYPE desired_mask = 0; RADIX_TYPE desired = 0; @@ -93,7 +86,7 @@ __device__ void kthvalueFwd(const DTYPE* input, for(int pos = sizeof(RADIX_TYPE) * 8 - RADIX_BITS; pos >= 0; pos -= RADIX_BITS) { - for(unsigned long& count : counts) + for(size_t& count : counts) { count = 0; } @@ -104,7 +97,7 @@ __device__ void kthvalueFwd(const DTYPE* input, RADIX_TYPE val = encode(input[input_idx]); if((val & desired_mask) == desired) { - ++counts[GetBitFieldImpl(val, pos, RADIX_BITS)]; + ++counts[GetBitFieldImpl(val, pos)]; } } @@ -152,7 +145,7 @@ __device__ void kthvalueFwd(const DTYPE* input, DTYPE val_ori = input[input_idx]; RADIX_TYPE val = encode(val_ori); if((val & desired_mask) == desired && - GetBitFieldImpl(val, pos, RADIX_BITS) == j) + GetBitFieldImpl(val, pos) == j) { // For case 2, this will be non-deterministic. lval = val_ori; diff --git a/src/kernels/radix.hpp b/src/kernels/radix.hpp index bd28391e26..354ff68775 100644 --- a/src/kernels/radix.hpp +++ b/src/kernels/radix.hpp @@ -27,11 +27,10 @@ #ifndef GUARD_RADIX_H #define GUARD_RADIX_H +#include #include #include -#define RADIX_TYPE typename RadixType::type - #define DEFINE_RADIX_TYPE(DTYPE, cpp_type) \ template <> \ struct RadixType \ @@ -50,6 +49,7 @@ DEFINE_RADIX_TYPE(bool, bool) DEFINE_RADIX_TYPE(float, uint32_t) DEFINE_RADIX_TYPE(__half, ushort) DEFINE_RADIX_TYPE(ushort, ushort) // bfloat16 +#undef DEFINE_RADIX_TYPE template ::type> __device__ inline Radix encode(DTYPE v) @@ -61,36 +61,36 @@ __device__ inline Radix encode(DTYPE v) } else if constexpr(std::is_same::value) { - return 2147483648u + v; + return static_cast(std::numeric_limits::max()) + v + 1; } else if constexpr(std::is_same::value) { - return 9223372036854775808ull + v; + return static_cast(std::numeric_limits::max()) + v + 1; } // bfloat16 is passed as ushort in kernel else if constexpr(std::is_same::value) { Radix x = v; Radix mask = (x & 0x8000) ? 0xffff : 0x8000; - return (v == v) ? (x ^ mask) : 0xffff; + return isnan(v) ? 0xffff : (x ^ mask); } else if constexpr(std::is_same<__half, DTYPE>::value) { Radix x = __half_as_ushort(v); Radix mask = (x & 0x8000) ? 0xffff : 0x8000; - return (v == v) ? (x ^ mask) : 0xffff; + return isnan(v) ? 0xffff : (x ^ mask); } else if constexpr(std::is_same::value) { Radix x = __float_as_uint(v); Radix mask = (x & 0x80000000) ? 0xffffffff : 0x80000000; - return (v == v) ? (x ^ mask) : 0xffffffff; + return isnan(v) ? 0xffffffff : (x ^ mask); } } // returns x[pos+bits:pos] -template -__device__ inline Radix GetBitFieldImpl(Radix x, int pos, int bits) +template +__device__ inline Radix GetBitFieldImpl(Radix x, int pos) { return (x >> pos) & ((1 << bits) - 1); } From bc382a2414c252b577109703a3f4dd01a55fff32 Mon Sep 17 00:00:00 2001 From: BuiChiTrung Date: Wed, 21 Aug 2024 11:53:26 +0700 Subject: [PATCH 20/28] clang-format --- src/include/miopen/kthvalue.hpp | 18 +++++++++--------- src/kernels/radix.hpp | 2 +- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/include/miopen/kthvalue.hpp b/src/include/miopen/kthvalue.hpp index 3cb9527920..32cb008e0b 100644 --- a/src/include/miopen/kthvalue.hpp +++ b/src/include/miopen/kthvalue.hpp @@ -34,15 +34,15 @@ struct Handle; struct TensorDescriptor; MIOPEN_INTERNALS_EXPORT miopenStatus_t KthvalueForward(Handle& handle, - const TensorDescriptor& inputDesc, - ConstData_t input, - const TensorDescriptor& outputDesc, - Data_t output, - const TensorDescriptor& indicesDesc, - size_t* indices, - size_t k, - int32_t dim, - bool keepDim); + const TensorDescriptor& inputDesc, + ConstData_t input, + const TensorDescriptor& outputDesc, + Data_t output, + const TensorDescriptor& indicesDesc, + size_t* indices, + size_t k, + int32_t dim, + bool keepDim); } // namespace miopen #endif // MIOPEN_KTHVALUE_HPP_ diff --git a/src/kernels/radix.hpp b/src/kernels/radix.hpp index 354ff68775..476f0b4d68 100644 --- a/src/kernels/radix.hpp +++ b/src/kernels/radix.hpp @@ -78,7 +78,7 @@ __device__ inline Radix encode(DTYPE v) { Radix x = __half_as_ushort(v); Radix mask = (x & 0x8000) ? 0xffff : 0x8000; - return isnan(v) ? 0xffff : (x ^ mask); + return isnan(v) ? 0xffff : (x ^ mask); } else if constexpr(std::is_same::value) { From df0fbc9b7c5bb9e16b32f4eed46dd9c449d9dd3b Mon Sep 17 00:00:00 2001 From: BuiChiTrung Date: Thu, 22 Aug 2024 11:12:35 +0700 Subject: [PATCH 21/28] fix case isnan compile err in fp16 tests --- include/miopen/miopen.h | 2 +- src/kernels/MIOpenKthvalue.cpp | 4 ++-- src/kernels/radix.hpp | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 48a20cf57c..5f85a93333 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -5015,7 +5015,7 @@ MIOPEN_EXPORT miopenStatus_t miopenCTCLoss(miopenHandle_t handle, */ /*! @enum miopenRNGType_t - * random number generator typnte + * random number generator type */ typedef enum { diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index 3bfb4dd45a..41f441f596 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -42,8 +42,6 @@ #define LOCAL_SIZE 256 #endif -using RADIX_TYPE = typename RadixType::type; - template __device__ void kthvalueFwd(const DTYPE* input, DTYPE* output, @@ -63,6 +61,8 @@ __device__ void kthvalueFwd(const DTYPE* input, * dim = 2 (C) * => grid = {A * B * D * E, 1}, block = {LOCAL_SIZE, 1} */ + using RADIX_TYPE = typename RadixType::type; + const int RADIX_BITS = 2; const int RADIX_SIZE = 1 << RADIX_BITS; const int RADIX_MASK = RADIX_SIZE - 1; diff --git a/src/kernels/radix.hpp b/src/kernels/radix.hpp index 476f0b4d68..f75443e149 100644 --- a/src/kernels/radix.hpp +++ b/src/kernels/radix.hpp @@ -78,7 +78,7 @@ __device__ inline Radix encode(DTYPE v) { Radix x = __half_as_ushort(v); Radix mask = (x & 0x8000) ? 0xffff : 0x8000; - return isnan(v) ? 0xffff : (x ^ mask); + return __hisnan(v) ? 0xffff : (x ^ mask); } else if constexpr(std::is_same::value) { From 732cead7a77edb9a77bebdf6b32e8b7ea4c16a48 Mon Sep 17 00:00:00 2001 From: BuiChiTrung Date: Fri, 23 Aug 2024 11:21:06 +0700 Subject: [PATCH 22/28] update test name config --- test/gtest/kthvalue.cpp | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/test/gtest/kthvalue.cpp b/test/gtest/kthvalue.cpp index 24f470f625..c4f742893d 100644 --- a/test/gtest/kthvalue.cpp +++ b/test/gtest/kthvalue.cpp @@ -53,21 +53,21 @@ bool CheckFloatArg(std::string arg) return false; } -struct KthvalueForwardTestFloat32 : KthvalueFwdTest +struct GPU_Kthvalue_fwd_FP32 : KthvalueFwdTest { }; -struct KthvalueForwardTestFloat16 : KthvalueFwdTest +struct GPU_Kthvalue_fwd_FP16 : KthvalueFwdTest { }; -struct KthvalueForwardTestBFloat16 : KthvalueFwdTest +struct GPU_Kthvalue_fwd_BFP16 : KthvalueFwdTest { }; using namespace kthvalue; -TEST_P(KthvalueForwardTestFloat32, KthvalueForwardTest) +TEST_P(GPU_Kthvalue_fwd_FP32, Test) { if(CheckFloatArg("--float")) { @@ -80,7 +80,7 @@ TEST_P(KthvalueForwardTestFloat32, KthvalueForwardTest) } }; -TEST_P(KthvalueForwardTestFloat16, KthvalueForwardTest) +TEST_P(GPU_Kthvalue_fwd_FP16, Test) { if(CheckFloatArg("--half")) { @@ -93,7 +93,7 @@ TEST_P(KthvalueForwardTestFloat16, KthvalueForwardTest) } }; -TEST_P(KthvalueForwardTestBFloat16, KthvalueForwardTest) +TEST_P(GPU_Kthvalue_fwd_BFP16, Test) { if(CheckFloatArg("--bfloat16")) { @@ -106,15 +106,13 @@ TEST_P(KthvalueForwardTestBFloat16, KthvalueForwardTest) } }; -INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, - KthvalueForwardTestFloat32, +INSTANTIATE_TEST_SUITE_P(Smoke, + GPU_Kthvalue_fwd_FP32, testing::ValuesIn(KthvalueTestConfigs())); - -INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, - KthvalueForwardTestFloat16, +INSTANTIATE_TEST_SUITE_P(Smoke, + GPU_Kthvalue_fwd_FP16, testing::ValuesIn(KthvalueTestConfigs())); - -INSTANTIATE_TEST_SUITE_P(KthvalueForwardTestSet, - KthvalueForwardTestBFloat16, +INSTANTIATE_TEST_SUITE_P(Smoke, + GPU_Kthvalue_fwd_BFP16, testing::ValuesIn(KthvalueTestConfigs())); } // namespace kthvalue From be322dd3840b701506b840862c0c2789d5624bc2 Mon Sep 17 00:00:00 2001 From: BuiChiTrung Date: Fri, 23 Aug 2024 14:43:49 +0700 Subject: [PATCH 23/28] update tensor view util to pass cppcheck --- src/include/miopen/tensor_view_utils.hpp | 4 ++-- test/gtest/kthvalue.cpp | 12 +++--------- 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/include/miopen/tensor_view_utils.hpp b/src/include/miopen/tensor_view_utils.hpp index a7016c67e2..b92f3b1b2d 100644 --- a/src/include/miopen/tensor_view_utils.hpp +++ b/src/include/miopen/tensor_view_utils.hpp @@ -39,7 +39,7 @@ inline tensor_view_t get_inner_expanded_tv(const TensorDescriptor Desc) auto dims = Desc.GetLengths(); auto strides = Desc.GetStrides(); - tensor_view_t tensor_view; + tensor_view_t tensor_view{}; for(size_t i = 0; i < N; ++i) { if(dims.empty()) @@ -84,7 +84,7 @@ inline void slice_tv(tensor_view_t& tensor_view, int32_t sliceCount, const in template inline tensor_view_t get_tv_without_dim(const tensor_view_t& origin_tv, int selected_dim) { - tensor_view_t res; + tensor_view_t res{}; for(int i = 0; i < N; ++i) { if(i == selected_dim) diff --git a/test/gtest/kthvalue.cpp b/test/gtest/kthvalue.cpp index c4f742893d..1efdac934d 100644 --- a/test/gtest/kthvalue.cpp +++ b/test/gtest/kthvalue.cpp @@ -106,13 +106,7 @@ TEST_P(GPU_Kthvalue_fwd_BFP16, Test) } }; -INSTANTIATE_TEST_SUITE_P(Smoke, - GPU_Kthvalue_fwd_FP32, - testing::ValuesIn(KthvalueTestConfigs())); -INSTANTIATE_TEST_SUITE_P(Smoke, - GPU_Kthvalue_fwd_FP16, - testing::ValuesIn(KthvalueTestConfigs())); -INSTANTIATE_TEST_SUITE_P(Smoke, - GPU_Kthvalue_fwd_BFP16, - testing::ValuesIn(KthvalueTestConfigs())); +INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP32, testing::ValuesIn(KthvalueTestConfigs())); +INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP16, testing::ValuesIn(KthvalueTestConfigs())); +INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_BFP16, testing::ValuesIn(KthvalueTestConfigs())); } // namespace kthvalue From dcc4d76c0278f729f06bb20d7445488d18f19426 Mon Sep 17 00:00:00 2001 From: BuiChiTrung Date: Fri, 23 Aug 2024 22:26:25 +0700 Subject: [PATCH 24/28] try comment out unit-test to see if the pipeline pass --- test/gtest/kthvalue.cpp | 148 ++++++++++++++++++++-------------------- 1 file changed, 74 insertions(+), 74 deletions(-) diff --git a/test/gtest/kthvalue.cpp b/test/gtest/kthvalue.cpp index 1efdac934d..f0c494fe09 100644 --- a/test/gtest/kthvalue.cpp +++ b/test/gtest/kthvalue.cpp @@ -24,89 +24,89 @@ * *******************************************************************************/ -#include "kthvalue.hpp" -#include "miopen/bfloat16.hpp" -#include "tensor_holder.hpp" -#include +// #include "kthvalue.hpp" +// #include "miopen/bfloat16.hpp" +// #include "tensor_holder.hpp" +// #include -MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_TEST_FLOAT_ARG) -MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL) +// MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_TEST_FLOAT_ARG) +// MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL) -namespace kthvalue { +// namespace kthvalue { -std::string GetFloatArg() -{ - const auto& tmp = env::value(MIOPEN_TEST_FLOAT_ARG); - if(tmp.empty()) - { - return ""; - } - return tmp; -} +// std::string GetFloatArg() +// { +// const auto& tmp = env::value(MIOPEN_TEST_FLOAT_ARG); +// if(tmp.empty()) +// { +// return ""; +// } +// return tmp; +// } -bool CheckFloatArg(std::string arg) -{ - if(!MIOPEN_TEST_ALL || (env::enabled(MIOPEN_TEST_ALL) && GetFloatArg() == arg)) - { - return true; - } - return false; -} +// bool CheckFloatArg(std::string arg) +// { +// if(!MIOPEN_TEST_ALL || (env::enabled(MIOPEN_TEST_ALL) && GetFloatArg() == arg)) +// { +// return true; +// } +// return false; +// } -struct GPU_Kthvalue_fwd_FP32 : KthvalueFwdTest -{ -}; +// struct GPU_Kthvalue_fwd_FP32 : KthvalueFwdTest +// { +// }; -struct GPU_Kthvalue_fwd_FP16 : KthvalueFwdTest -{ -}; +// struct GPU_Kthvalue_fwd_FP16 : KthvalueFwdTest +// { +// }; -struct GPU_Kthvalue_fwd_BFP16 : KthvalueFwdTest -{ -}; +// struct GPU_Kthvalue_fwd_BFP16 : KthvalueFwdTest +// { +// }; -using namespace kthvalue; +// using namespace kthvalue; -TEST_P(GPU_Kthvalue_fwd_FP32, Test) -{ - if(CheckFloatArg("--float")) - { - RunTest(); - Verify(); - } - else - { - GTEST_SKIP(); - } -}; +// TEST_P(GPU_Kthvalue_fwd_FP32, Test) +// { +// if(CheckFloatArg("--float")) +// { +// RunTest(); +// Verify(); +// } +// else +// { +// GTEST_SKIP(); +// } +// }; -TEST_P(GPU_Kthvalue_fwd_FP16, Test) -{ - if(CheckFloatArg("--half")) - { - RunTest(); - Verify(); - } - else - { - GTEST_SKIP(); - } -}; +// TEST_P(GPU_Kthvalue_fwd_FP16, Test) +// { +// if(CheckFloatArg("--half")) +// { +// RunTest(); +// Verify(); +// } +// else +// { +// GTEST_SKIP(); +// } +// }; -TEST_P(GPU_Kthvalue_fwd_BFP16, Test) -{ - if(CheckFloatArg("--bfloat16")) - { - RunTest(); - Verify(); - } - else - { - GTEST_SKIP(); - } -}; +// TEST_P(GPU_Kthvalue_fwd_BFP16, Test) +// { +// if(CheckFloatArg("--bfloat16")) +// { +// RunTest(); +// Verify(); +// } +// else +// { +// GTEST_SKIP(); +// } +// }; -INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP32, testing::ValuesIn(KthvalueTestConfigs())); -INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP16, testing::ValuesIn(KthvalueTestConfigs())); -INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_BFP16, testing::ValuesIn(KthvalueTestConfigs())); -} // namespace kthvalue +// INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP32, testing::ValuesIn(KthvalueTestConfigs())); +// INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP16, testing::ValuesIn(KthvalueTestConfigs())); +// INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_BFP16, testing::ValuesIn(KthvalueTestConfigs())); +// } // namespace kthvalue From 452e1c198e9bea9e1a99cf4ec82a8542d40d1132 Mon Sep 17 00:00:00 2001 From: BuiChiTrung Date: Fri, 23 Aug 2024 23:37:52 +0700 Subject: [PATCH 25/28] fix gtest format --- test/gtest/kthvalue.cpp | 148 ++++++++++++++++++++-------------------- 1 file changed, 74 insertions(+), 74 deletions(-) diff --git a/test/gtest/kthvalue.cpp b/test/gtest/kthvalue.cpp index f0c494fe09..38c1170a2b 100644 --- a/test/gtest/kthvalue.cpp +++ b/test/gtest/kthvalue.cpp @@ -24,89 +24,89 @@ * *******************************************************************************/ -// #include "kthvalue.hpp" -// #include "miopen/bfloat16.hpp" -// #include "tensor_holder.hpp" -// #include +#include "kthvalue.hpp" +#include "tensor_holder.hpp" +#include +#include -// MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_TEST_FLOAT_ARG) -// MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL) +MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_TEST_FLOAT_ARG) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL) -// namespace kthvalue { +namespace kthvalue { -// std::string GetFloatArg() -// { -// const auto& tmp = env::value(MIOPEN_TEST_FLOAT_ARG); -// if(tmp.empty()) -// { -// return ""; -// } -// return tmp; -// } +std::string GetFloatArg() +{ + const auto& tmp = env::value(MIOPEN_TEST_FLOAT_ARG); + if(tmp.empty()) + { + return ""; + } + return tmp; +} -// bool CheckFloatArg(std::string arg) -// { -// if(!MIOPEN_TEST_ALL || (env::enabled(MIOPEN_TEST_ALL) && GetFloatArg() == arg)) -// { -// return true; -// } -// return false; -// } +bool CheckFloatArg(std::string arg) +{ + if(!MIOPEN_TEST_ALL || (env::enabled(MIOPEN_TEST_ALL) && GetFloatArg() == arg)) + { + return true; + } + return false; +} -// struct GPU_Kthvalue_fwd_FP32 : KthvalueFwdTest -// { -// }; +struct GPU_Kthvalue_fwd_FP32 : KthvalueFwdTest +{ +}; -// struct GPU_Kthvalue_fwd_FP16 : KthvalueFwdTest -// { -// }; +struct GPU_Kthvalue_fwd_FP16 : KthvalueFwdTest +{ +}; -// struct GPU_Kthvalue_fwd_BFP16 : KthvalueFwdTest -// { -// }; +struct GPU_Kthvalue_fwd_BFP16 : KthvalueFwdTest +{ +}; +}; // namespace kthvalue -// using namespace kthvalue; +using namespace kthvalue; -// TEST_P(GPU_Kthvalue_fwd_FP32, Test) -// { -// if(CheckFloatArg("--float")) -// { -// RunTest(); -// Verify(); -// } -// else -// { -// GTEST_SKIP(); -// } -// }; +TEST_P(GPU_Kthvalue_fwd_FP32, Test) +{ + if(CheckFloatArg("--float")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; -// TEST_P(GPU_Kthvalue_fwd_FP16, Test) -// { -// if(CheckFloatArg("--half")) -// { -// RunTest(); -// Verify(); -// } -// else -// { -// GTEST_SKIP(); -// } -// }; +TEST_P(GPU_Kthvalue_fwd_FP16, Test) +{ + if(CheckFloatArg("--half")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; -// TEST_P(GPU_Kthvalue_fwd_BFP16, Test) -// { -// if(CheckFloatArg("--bfloat16")) -// { -// RunTest(); -// Verify(); -// } -// else -// { -// GTEST_SKIP(); -// } -// }; +TEST_P(GPU_Kthvalue_fwd_BFP16, Test) +{ + if(CheckFloatArg("--bfloat16")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; -// INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP32, testing::ValuesIn(KthvalueTestConfigs())); -// INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP16, testing::ValuesIn(KthvalueTestConfigs())); -// INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_BFP16, testing::ValuesIn(KthvalueTestConfigs())); -// } // namespace kthvalue +INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP32, testing::ValuesIn(KthvalueTestConfigs())); +INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP16, testing::ValuesIn(KthvalueTestConfigs())); +INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_BFP16,testing::ValuesIn(KthvalueTestConfigs())); From 664afa4d258f9846c0af1ddd50369beb08b35e95 Mon Sep 17 00:00:00 2001 From: BuiChiTrung Date: Mon, 26 Aug 2024 09:40:47 +0700 Subject: [PATCH 26/28] check hip tidy --- .../miopen/solver/implicitgemm_ck_util.hpp | 28 +++++++++---------- test/gtest/kthvalue.cpp | 2 +- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/include/miopen/solver/implicitgemm_ck_util.hpp b/src/include/miopen/solver/implicitgemm_ck_util.hpp index e6cceaef0f..abdd171227 100644 --- a/src/include/miopen/solver/implicitgemm_ck_util.hpp +++ b/src/include/miopen/solver/implicitgemm_ck_util.hpp @@ -378,7 +378,7 @@ class TransposeInstance void ZeroOutBuffer() { - [[maybe_unused]] auto status = hipMemsetAsync(buf_handle.get(), 0, tensor_sz); + [[maybe_unused]] auto status = hipMemset(buf_handle.get(), 0, tensor_sz); assert(status == hipSuccess); } @@ -680,7 +680,7 @@ ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx, internal::MakeTaggedTransposeInstances( result, ctx, problem, ck_args, input1_op, input2_op, output_op, _ck_buff_des); - result.invoker_factory = [split_k = split_k, + result.invoker_factory = [split_k, ck_args = std::move(ck_args), sh_conv_ptr = std::shared_ptr{std::move(*ptr_iter)}, input1_tr_inst = std::move(_input1_tr_inst), @@ -689,7 +689,7 @@ ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx, output_init_tr_inst = std::move(_output_init_tr_inst), ck_buff_des = _ck_buff_des](const std::vector& kernels) mutable { - return [split_k = split_k, + return [split_k, kernels, ck_args = std::move(ck_args), sh_conv_ptr = std::move(sh_conv_ptr), @@ -697,8 +697,8 @@ ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx, input2_tr_inst = std::move(input2_tr_inst), output_tr_inst = std::move(output_tr_inst), output_init_tr_inst = std::move(output_init_tr_inst), - ck_buff_des = ck_buff_des](const Handle& handle, - const AnyInvokeParams& primitive_parameters) mutable { + ck_buff_des](const Handle& handle, + const AnyInvokeParams& primitive_parameters) mutable { handle.ResetKernelTime(); const auto& data_ctx = primitive_parameters.CastTo(); @@ -826,17 +826,17 @@ ConvSolution InitInvokerFactoryNHWC(const ExecutionContext&, [[maybe_unused]] bool should_allocated_wrw_buffer = ShouldAllocateWorkSpaceBufferForWRW(problem); - result.invoker_factory = [split_k = split_k, - ck_args = CKArgsType{problem}, - alpha_beta_case = alpha_beta_case, - should_allocated_wrw_buffer = should_allocated_wrw_buffer, + result.invoker_factory = [split_k, + ck_args = CKArgsType{problem}, + alpha_beta_case, + should_allocated_wrw_buffer, sh_conv_ptr = std::shared_ptr{std::move(*ptr_iter)}]( const std::vector&) mutable { - return [split_k = split_k, - ck_args = std::move(ck_args), - alpha_beta_case = alpha_beta_case, - should_allocated_wrw_buffer = should_allocated_wrw_buffer, - sh_conv_ptr = std::move(sh_conv_ptr)]( + return [split_k, + ck_args = std::move(ck_args), + alpha_beta_case, + should_allocated_wrw_buffer, + sh_conv_ptr = std::move(sh_conv_ptr)]( const Handle& handle, const AnyInvokeParams& primitive_parameters) { const auto& data_ctx = primitive_parameters.CastTo(); std::unique_ptr argument_ptr; diff --git a/test/gtest/kthvalue.cpp b/test/gtest/kthvalue.cpp index 38c1170a2b..0a08a25288 100644 --- a/test/gtest/kthvalue.cpp +++ b/test/gtest/kthvalue.cpp @@ -109,4 +109,4 @@ TEST_P(GPU_Kthvalue_fwd_BFP16, Test) INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP32, testing::ValuesIn(KthvalueTestConfigs())); INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_FP16, testing::ValuesIn(KthvalueTestConfigs())); -INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_BFP16,testing::ValuesIn(KthvalueTestConfigs())); +INSTANTIATE_TEST_SUITE_P(Smoke, GPU_Kthvalue_fwd_BFP16, testing::ValuesIn(KthvalueTestConfigs())); From eedd52b9373aa3dea5b7a4a4462e330456bfccf0 Mon Sep 17 00:00:00 2001 From: BuiChiTrung Date: Sat, 31 Aug 2024 11:53:33 +0700 Subject: [PATCH 27/28] update kernel comment --- src/kernels/MIOpenKthvalue.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/kernels/MIOpenKthvalue.cpp b/src/kernels/MIOpenKthvalue.cpp index 41f441f596..624308b017 100644 --- a/src/kernels/MIOpenKthvalue.cpp +++ b/src/kernels/MIOpenKthvalue.cpp @@ -55,11 +55,11 @@ __device__ void kthvalueFwd(const DTYPE* input, tensor_view_t<5> indices_tv) { /* - * Example) - * input : {A, B, C, D, E} - * output/indices : {A, B, 1, D, E} or {A, B, D, E} - * dim = 2 (C) - * => grid = {A * B * D * E, 1}, block = {LOCAL_SIZE, 1} + * Input : {N, C, D, H, W}. Select dim: 2(D) + * Output/indices : {N, C, H, W} or {N, C, 1, H, W} (if keepDim param in miopen.h = True) + * Each lws handle dim_size elements to find the kth value. + * Lws = {256 or 512, 1, 1} + * Gws = {A * B * D * E * lws.x, 1, 1}, */ using RADIX_TYPE = typename RadixType::type; From 8baccb8dea99e84ceb9ffc65fe0efe063762c99e Mon Sep 17 00:00:00 2001 From: BuiChiTrung Date: Wed, 25 Sep 2024 10:44:33 +0700 Subject: [PATCH 28/28] rollback src/include/miopen/solver/implicitgemm_ck_util.hpp --- .../miopen/solver/implicitgemm_ck_util.hpp | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/include/miopen/solver/implicitgemm_ck_util.hpp b/src/include/miopen/solver/implicitgemm_ck_util.hpp index abdd171227..0cb381b407 100644 --- a/src/include/miopen/solver/implicitgemm_ck_util.hpp +++ b/src/include/miopen/solver/implicitgemm_ck_util.hpp @@ -376,9 +376,10 @@ class TransposeInstance Run(handle, kernels, out_ptr, buf_handle.get()); } - void ZeroOutBuffer() + void ZeroOutBuffer(const Handle& handle) { - [[maybe_unused]] auto status = hipMemset(buf_handle.get(), 0, tensor_sz); + [[maybe_unused]] auto status = + hipMemsetAsync(buf_handle.get(), 0, tensor_sz, handle.GetStream()); assert(status == hipSuccess); } @@ -680,7 +681,7 @@ ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx, internal::MakeTaggedTransposeInstances( result, ctx, problem, ck_args, input1_op, input2_op, output_op, _ck_buff_des); - result.invoker_factory = [split_k, + result.invoker_factory = [split_k = split_k, ck_args = std::move(ck_args), sh_conv_ptr = std::shared_ptr{std::move(*ptr_iter)}, input1_tr_inst = std::move(_input1_tr_inst), @@ -689,7 +690,7 @@ ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx, output_init_tr_inst = std::move(_output_init_tr_inst), ck_buff_des = _ck_buff_des](const std::vector& kernels) mutable { - return [split_k, + return [split_k = split_k, kernels, ck_args = std::move(ck_args), sh_conv_ptr = std::move(sh_conv_ptr), @@ -697,8 +698,8 @@ ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx, input2_tr_inst = std::move(input2_tr_inst), output_tr_inst = std::move(output_tr_inst), output_init_tr_inst = std::move(output_init_tr_inst), - ck_buff_des](const Handle& handle, - const AnyInvokeParams& primitive_parameters) mutable { + ck_buff_des = ck_buff_des](const Handle& handle, + const AnyInvokeParams& primitive_parameters) mutable { handle.ResetKernelTime(); const auto& data_ctx = primitive_parameters.CastTo(); @@ -734,7 +735,7 @@ ConvSolution InitInvokerFactoryNCHW(const ExecutionContext& ctx, /// \todo: Will need SetTensor() to properly zero out non-packed tensors if(output_tr_inst.GetConvOperandTag() == internal::ConvOperandTag::Weights) { - output_tr_inst.ZeroOutBuffer(); + output_tr_inst.ZeroOutBuffer(handle); } std::array tr_ptrs = { @@ -826,17 +827,17 @@ ConvSolution InitInvokerFactoryNHWC(const ExecutionContext&, [[maybe_unused]] bool should_allocated_wrw_buffer = ShouldAllocateWorkSpaceBufferForWRW(problem); - result.invoker_factory = [split_k, - ck_args = CKArgsType{problem}, - alpha_beta_case, - should_allocated_wrw_buffer, + result.invoker_factory = [split_k = split_k, + ck_args = CKArgsType{problem}, + alpha_beta_case = alpha_beta_case, + should_allocated_wrw_buffer = should_allocated_wrw_buffer, sh_conv_ptr = std::shared_ptr{std::move(*ptr_iter)}]( const std::vector&) mutable { - return [split_k, - ck_args = std::move(ck_args), - alpha_beta_case, - should_allocated_wrw_buffer, - sh_conv_ptr = std::move(sh_conv_ptr)]( + return [split_k = split_k, + ck_args = std::move(ck_args), + alpha_beta_case = alpha_beta_case, + should_allocated_wrw_buffer = should_allocated_wrw_buffer, + sh_conv_ptr = std::move(sh_conv_ptr)]( const Handle& handle, const AnyInvokeParams& primitive_parameters) { const auto& data_ctx = primitive_parameters.CastTo(); std::unique_ptr argument_ptr;