diff --git a/docs/apireference.rst b/docs/apireference.rst index ac56c2ee04..4f69fcdaf0 100644 --- a/docs/apireference.rst +++ b/docs/apireference.rst @@ -22,4 +22,5 @@ API Reference dropout reduction layernorm + sum diff --git a/docs/layernorm.rst b/docs/layernorm.rst index 1d480df7d6..89f1a3cc2d 100644 --- a/docs/layernorm.rst +++ b/docs/layernorm.rst @@ -1,8 +1,9 @@ -Layernorm Layer -=================== +Layernorm Layer(experimental) +============================= The layernorm types and functions. +To enable this, define MIOPEN_BETA_API before including miopen.h. miopenLayerNormMode_t diff --git a/docs/sum.rst b/docs/sum.rst new file mode 100644 index 0000000000..6ec8d3ee4c --- /dev/null +++ b/docs/sum.rst @@ -0,0 +1,23 @@ + +Sum Layer(experimental) +======================== + +The sum types and functions. +To enable this, define MIOPEN_BETA_API before including miopen.h. + + +miopenSumNanPropagation_t +---------------------------------- + +.. doxygenenum:: miopenSumNanPropagation_t + +miopenGetSumWorkspaceSize +---------------------------------- + +.. doxygenfunction:: miopenGetSumWorkspaceSize + +miopenSumForward +---------------------------------- + +.. doxygenfunction:: miopenSumForward + diff --git a/driver/driver.hpp b/driver/driver.hpp index d8e5352255..b29517bc0e 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -150,7 +150,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) printf("Supported Base Arguments: conv[fp16|int8|bfp16|fp8|bfp8], CBAInfer[fp16], " "pool[fp16], lrn[fp16], " "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm, ctc, dropout[fp16], " - "tensorop[fp16], reduce[fp16,fp64], layernorm[bfp16, fp16]\n"); + "tensorop[fp16], reduce[fp16|fp64], layernorm[bfp16|fp16], sum[bfp16|fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -172,7 +172,8 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "rnn_seqfp16" && arg != "gemm" /*&& arg != "gemmfp16"*/ && arg != "ctc" && arg != "dropout" && arg != "dropoutfp16" && arg != "tensorop" && arg != "tensoropfp16" && arg != "reduce" && arg != "reducefp16" && arg != "reducefp64" && arg != "layernorm" && - arg != "layernormfp16" && arg != "layernormbfp16" && arg != "--version") + arg != "layernormfp16" && arg != "layernormbfp16" && arg != "sum" && arg != "sumfp16" && + arg != "sumbfp16" && arg != "--version") { printf("FAILED: Invalid Base Input Argument\n"); Usage(); diff --git a/driver/layernorm_driver.hpp b/driver/layernorm_driver.hpp index dbdda790b5..ac53c70601 100644 --- a/driver/layernorm_driver.hpp +++ b/driver/layernorm_driver.hpp @@ -139,7 +139,7 @@ int LayerNormDriver::GetandSetData() { std::vector in_len = GetInputTensorLengthsFromCmdLine(); - dim = static_cast(inflags.GetValueDouble("nomalized_dim")); + dim = inflags.GetValueInt("normalized_dim"); std::vector inner_len; if(dim == in_len.size()) @@ -379,10 +379,6 @@ Tref LayerNormDriver::GetTolerance() { return 5e-5; } - else if(data_type == miopenDouble) - { - return 1e-10; - } else if(data_type == miopenBFloat16) { return 5e-3; diff --git a/driver/main.cpp b/driver/main.cpp index c4aa25c7e8..1e749efb77 100644 --- a/driver/main.cpp +++ b/driver/main.cpp @@ -41,9 +41,10 @@ #include "dropout_driver.hpp" #include "tensorop_driver.hpp" #include "reduce_driver.hpp" +#include "layernorm_driver.hpp" +#include "sum_driver.hpp" #include #include -#include "layernorm_driver.hpp" int main(int argc, char* argv[]) { @@ -209,6 +210,18 @@ int main(int argc, char* argv[]) { drv = new LayerNormDriver(); } + else if(base_arg == "sum") + { + drv = new SumDriver(); + } + else if(base_arg == "sumfp16") + { + drv = new SumDriver(); + } + else if(base_arg == "sumbfp16") + { + drv = new SumDriver(); + } else { printf("Incorrect BaseArg\n"); diff --git a/driver/sum_driver.hpp b/driver/sum_driver.hpp new file mode 100644 index 0000000000..059c254ead --- /dev/null +++ b/driver/sum_driver.hpp @@ -0,0 +1,380 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_MIOPEN_SUM_DRIVER_HPP +#define GUARD_MIOPEN_SUM_DRIVER_HPP + +#include "InputFlags.hpp" +#include "driver.hpp" +#include "tensor_driver.hpp" +#include "timer.hpp" +#include "random.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include <../test/tensor_holder.hpp> +#include <../test/verify.hpp> + +#ifndef MLO_SUMMHOST_H_ +#define MLO_SUMMHOST_H_ + +template +int32_t mloSumForwardRunHost(miopenTensorDescriptor_t inputDesc, + miopenTensorDescriptor_t outputDesc, + Tgpu* input, + Tcheck* outputhost, + int32_t dim, + miopenSumNanPropagation_t nanPropagation) +{ + auto input_dims = miopen::deref(inputDesc).GetLengths(); + auto output_dims = miopen::deref(outputDesc).GetLengths(); + + auto reduce_size = input_dims[dim]; + auto output_numel = + std::accumulate(output_dims.begin(), output_dims.end(), 1L, std::multiplies()); + + auto inner_size = 1ULL; + for(int32_t i = dim + 1; i < input_dims.size(); i++) + { + inner_size *= input_dims[i]; + } + + int32_t ret = 0; + + for(size_t o = 0; o < output_numel; o++) + { + size_t input_idx = (o / inner_size) * inner_size * reduce_size + o % inner_size; + + Tcheck sum = 0.0f; + for(size_t i = 0; i < reduce_size; i++) + { + Tcheck val = static_cast(input[input_idx]); + if(nanPropagation && isnan(val)) + { + val = 0.0f; + } + sum += val; + input_idx += inner_size; + } + outputhost[o] = sum; + } + return ret; +} +#endif + +template +class SumDriver : public Driver +{ +public: + SumDriver() : Driver() + { + miopenCreateTensorDescriptor(&inputDesc); + miopenCreateTensorDescriptor(&outputDesc); + + data_type = miopen_type{}; + } + + int AddCmdLineArgs() override; + int ParseCmdLineArgs(int argc, char* argv[]) override; + InputFlags& GetInputFlags() override { return inflags; } + + int GetandSetData() override; + std::vector GetInputTensorLengthsFromCmdLine(); + + int AllocateBuffersAndCopy() override; + + int RunForwardGPU() override; + int RunForwardCPU(); + + int RunBackwardGPU() override; + + Tref GetTolerance(); + int VerifyBackward() override; + int VerifyForward() override; + ~SumDriver() override + { + miopenDestroyTensorDescriptor(inputDesc); + miopenDestroyTensorDescriptor(outputDesc); + } + +private: + InputFlags inflags; + + int forw; + + miopenTensorDescriptor_t inputDesc; + miopenTensorDescriptor_t outputDesc; + + std::unique_ptr in_dev; + std::unique_ptr out_dev; + std::unique_ptr workspace_dev; + + std::vector in; + std::vector out; + std::vector outhost; + + size_t ws_sizeInBytes; + + int dim; + miopenSumNanPropagation_t nanPropagation; +}; + +template +int SumDriver::ParseCmdLineArgs(int argc, char* argv[]) +{ + inflags.Parse(argc, argv); + + if(inflags.GetValueInt("time") == 1) + { + miopenEnableProfiling(GetHandle(), true); + } + return miopenStatusSuccess; +} + +template +int SumDriver::GetandSetData() +{ + std::vector in_len = GetInputTensorLengthsFromCmdLine(); + dim = inflags.GetValueInt("DimToReduce"); + + SetTensorNd(inputDesc, in_len, data_type); + + std::vector out_len; + + for(int i = 0; i < in_len.size(); i++) + { + if(i != dim) + { + out_len.push_back(in_len[i]); + } + } + + SetTensorNd(outputDesc, out_len, data_type); + + nanPropagation = static_cast(inflags.GetValueInt("NanPropagation")); + + return 0; +} + +template +int SumDriver::AddCmdLineArgs() +{ + inflags.AddInputFlag("forw", 'F', "1", "Run only Forward Sum (Default=1)", "int"); + inflags.AddInputFlag("batchsize", 'n', "256", "Mini-batch size (Default=100)", "int"); + inflags.AddInputFlag("in_channels", 'c', "4", "Number of Input Channels (Default=3)", "int"); + inflags.AddInputFlag("in_d", 'D', "0", "Input Depth (Default=0)", "int"); + inflags.AddInputFlag("in_h", 'H', "0", "Input Height (Default=32)", "int"); + inflags.AddInputFlag("in_w", 'W', "8732", "Input Width (Default=32)", "int"); + + inflags.AddInputFlag( + "DimToReduce", 'R', "1", "The indice of the dimensions to be reduced(Default=1)", "int"); + inflags.AddInputFlag("NanPropagation", + 'N', + "0", + "Nan number propagation mode (check the miopenSumNanPropagation_t in " + "miopen.h) (Default=0 to indicate no Nan propagation)", + "int"); + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); + inflags.AddInputFlag("verify", 'V', "1", "Verify Each Layer (Default=1)", "int"); + inflags.AddInputFlag("time", 't', "0", "Time Each Layer (Default=0)", "int"); + inflags.AddInputFlag( + "wall", 'w', "0", "Wall-clock Time Each Layer, Requires time == 1 (Default=0)", "int"); + + return miopenStatusSuccess; +} + +template +std::vector SumDriver::GetInputTensorLengthsFromCmdLine() +{ + int in_n = inflags.GetValueInt("batchsize"); + int in_c = inflags.GetValueInt("in_channels"); + int in_w = inflags.GetValueInt("in_w"); + int in_h = inflags.GetValueInt("in_h"); + int in_d = inflags.GetValueInt("in_d"); + + if((in_n != 0) && (in_c != 0) && (in_d != 0) && (in_h != 0) && (in_w != 0)) + { + return std::vector({in_n, in_c, in_d, in_h, in_w}); + } + else if((in_n != 0) && (in_c != 0) && (in_h != 0) && (in_w != 0)) + { + return std::vector({in_n, in_c, in_h, in_w}); + } + else if((in_n != 0) && (in_c != 0) && (in_w != 0)) + { + return std::vector({in_n, in_c, in_w}); + } + else if((in_n != 0) && (in_w != 0)) + { + return std::vector({in_n, in_w}); + } + else + { + std::cerr << "Error Input Tensor Lengths\n" << std::endl; + return std::vector({0}); + } +} + +template +int SumDriver::AllocateBuffersAndCopy() +{ + size_t in_sz = GetTensorSize(inputDesc); + size_t out_sz = GetTensorSize(outputDesc); + + miopenGetSumWorkspaceSize(GetHandle(), inputDesc, dim, outputDesc, &ws_sizeInBytes); + if(ws_sizeInBytes == static_cast(-1)) + return miopenStatusAllocFailed; + + uint32_t ctx = 0; + + in_dev = std::unique_ptr(new GPUMem(ctx, in_sz, sizeof(Tgpu))); + out_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(Tgpu))); + workspace_dev = std::unique_ptr(new GPUMem(ctx, ws_sizeInBytes, sizeof(std::byte))); + + in = std::vector(in_sz, static_cast(0)); + out = std::vector(out_sz, static_cast(0)); + outhost = std::vector(out_sz, static_cast(0)); + + for(int i = 0; i < in_sz; i++) + { + in[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + } + + if(in_dev->ToGPU(GetStream(), in.data()) != 0) + std::cerr << "Error copying (in) to GPU, size: " << in_dev->GetSize() << std::endl; + + if(out_dev->ToGPU(GetStream(), out.data()) != 0) + std::cerr << "Error copying (out) to GPU, size: " << out_dev->GetSize() << std::endl; + + return miopenStatusSuccess; +} + +template +int SumDriver::RunForwardGPU() +{ + float kernel_total_time = 0; + float kernel_first_time = 0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenSumForward(GetHandle(), + nanPropagation, + workspace_dev->GetMem(), + ws_sizeInBytes, + inputDesc, + in_dev->GetMem(), + dim, + outputDesc, + out_dev->GetMem()); + + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + printf("Wall-clock Time Forward Sum Elapsed: %f ms\n", t.gettime_ms() / iter); + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + printf("GPU Kernel Time Forward Sum Elapsed: %f ms\n", kernel_average_time); + } + + if(out_dev->FromGPU(GetStream(), out.data()) != 0) + std::cerr << "Error copying (out_dev) from GPU, size: " << out_dev->GetSize() << std::endl; + + return miopenStatusSuccess; +} + +template +int SumDriver::RunForwardCPU() +{ + mloSumForwardRunHost( + inputDesc, outputDesc, in.data(), outhost.data(), dim, nanPropagation); + + return miopenStatusSuccess; +} + +template +int SumDriver::RunBackwardGPU() +{ + return miopenStatusSuccess; +} + +template +Tref SumDriver::GetTolerance() +{ + // Computation error of fp16 is ~2^13 (=8192) bigger than + // the one of fp32 because mantissa is shorter by 13 bits. + auto tolerance = (sizeof(Tgpu) == 4 || sizeof(Tgpu) == 1) ? 1.5e-6 : 8.2e-3; + + // bf16 mantissa has 7 bits, by 3 bits shorter than fp16. + if(std::is_same::value) + tolerance *= 8.0; + return tolerance; +} + +template +int SumDriver::VerifyForward() +{ + RunForwardCPU(); + const Tref tolerance = GetTolerance(); + auto error = miopen::rms_range(outhost, out); + + if(!std::isfinite(error) || error > tolerance) + { + std::cout << "Forward Sum FAILED: " << error << " > " << tolerance << std::endl; + return EC_VerifyFwd; + } + else + { + std::cout << "Forward Sum Verifies OK on CPU reference (" << error << " < " << tolerance + << ')' << std::endl; + } + + return miopenStatusSuccess; +} + +template +int SumDriver::VerifyBackward() +{ + return miopenStatusSuccess; +} + +#endif // GUARD_MIOPEN_SUM_DRIVER_HPP diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index 9f091c835d..37654d4aac 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -64,6 +64,7 @@ * @defgroup LossFunction * @defgroup TensorReduce * @defgroup find2 + * @defgroup sum * */ @@ -5511,6 +5512,66 @@ MIOPEN_EXPORT miopenStatus_t miopenFuseProblems(miopenProblem_t problem1, miopen /** @} */ // CLOSEOUT find2 DOXYGEN GROUP +#ifdef MIOPEN_BETA_API + +/*! @ingroup sum + * @enum miopenSumNanPropagation_t + * Nan numbers propagation modes for sum + */ +typedef enum +{ + MIOPEN_SUM_NOT_PROPAGATE_NAN = 0, /*!< does not propagate Nan number */ + MIOPEN_SUM_PROPAGATE_NAN = 1, /*!< propagate the Nan number by the Reduction operation */ +} miopenSumNanPropagation_t; + +// Sum APIs +/** @addtogroup sum + * + * @{ + */ + +/*! @brief Helper function to query the minimum workspace size required by the ReduceTensor call + * + * @param handle MIOpen Handle (input) + * @param xDesc Tensor descriptor for data input tensor x (input) + * @param dim Dimensions to sum. (input) + * @param yDesc Tensor descriptor for output data tensor y (input) + * @param sizeInBytes Pointer to data to return the minimum workspace size + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenGetSumWorkspaceSize(miopenHandle_t handle, + const miopenTensorDescriptor_t xDesc, + const int32_t dim, + const miopenTensorDescriptor_t yDesc, + size_t* sizeInBytes); + +/*! @brief Execute a sum forward layer + * + * @param handle MIOpen handle (input) + * @param nanPropagation Nan number propagation mode (input) + * @param workspace Address of the allocated workspace data (input) + * @param workspaceSizeInBytes Size in bytes of the allocated workspace data (input) + * @param xDesc Tensor descriptor for data input tensor x (input) + * @param x Data tensor x (input) + * @param dim Dimensions to sum. (input) + * @param yDesc Tensor descriptor for output data tensor y (input) + * @param y Data tensor y (output) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenSumForward(miopenHandle_t handle, + miopenSumNanPropagation_t nanPropagation, + void* workspace, + size_t workspaceSizeInBytes, + const miopenTensorDescriptor_t xDesc, + const void* x, + const int32_t dim, + const miopenTensorDescriptor_t yDesc, + void* y); + +/** @} */ +// CLOSEOUT SUM DOXYGEN GROUP +#endif + #ifdef __cplusplus } #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c634e37319..dc4f9672f8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -138,6 +138,7 @@ set( MIOpen_Source readonlyramdb.cpp reducetensor.cpp reducetensor_api.cpp + reduce/problem_description.cpp rnn.cpp rnn_api.cpp rnn/rnn_util.cpp @@ -243,7 +244,9 @@ set( MIOpen_Source solver/pooling/forwardNd.cpp solver/pooling/backward2d.cpp solver/pooling/backwardNd.cpp + solver/reduce/forward_sum.cpp subbuffers.cpp + sum_api.cpp target_properties.cpp temp_file.cpp tensor.cpp @@ -430,6 +433,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenConv1x1J1.cl kernels/MIOpenConv1x1J1_stride.cl kernels/MIOpenSoftmax.cl + kernels/MIOpenSum.cpp kernels/MIOpenUtilKernels3.cl kernels/MIOpenUtilKernels4.cl kernels/MIOpenUtilKernels5.cl @@ -561,6 +565,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN pooling.cpp ocl/fusionopconvocl.cpp ocl/fusionopbiasbnactivocl.cpp + sum.cpp ${PROJECT_BINARY_DIR}/db_path.cpp ) diff --git a/src/include/miopen/layernorm.hpp b/src/include/miopen/layernorm.hpp index 64a3ea8339..7780e57cda 100644 --- a/src/include/miopen/layernorm.hpp +++ b/src/include/miopen/layernorm.hpp @@ -23,7 +23,6 @@ * SOFTWARE. * *******************************************************************************/ -#include #ifndef MIOPEN_LAYERNORM_HPP_ #define MIOPEN_LAYERNORM_HPP_ diff --git a/src/include/miopen/reduce/invoke_params.hpp b/src/include/miopen/reduce/invoke_params.hpp new file mode 100644 index 0000000000..6ad0884dfd --- /dev/null +++ b/src/include/miopen/reduce/invoke_params.hpp @@ -0,0 +1,55 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include + +namespace miopen { +namespace reduce { + +struct InvokeParams : public miopen::InvokeParams +{ + InvokeParams() = default; + + const TensorDescriptor* xDesc = nullptr; + const TensorDescriptor* yDesc = nullptr; + + ConstData_t x = nullptr; + Data_t y = nullptr; + Data_t workspace = nullptr; + std::size_t workspace_size = 0; + int32_t dim = 0; + miopenSumNanPropagation_t nanPropagation = MIOPEN_SUM_NOT_PROPAGATE_NAN; + + std::size_t GetWorkspaceSize() const { return workspace_size; } + Data_t GetWorkspace() const { return workspace; } +}; + +} // namespace reduce + +} // namespace miopen diff --git a/src/include/miopen/reduce/problem_description.hpp b/src/include/miopen/reduce/problem_description.hpp new file mode 100644 index 0000000000..131eea65bb --- /dev/null +++ b/src/include/miopen/reduce/problem_description.hpp @@ -0,0 +1,119 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace miopen { + +struct NetworkConfig; + +namespace reduce { + +struct ProblemDescription : ProblemDescriptionBase +{ + ProblemDescription(miopenSumNanPropagation_t nanPropagation_, + const TensorDescriptor& xDesc_, + const TensorDescriptor& yDesc_, + int32_t dim_) + : nanPropagation(nanPropagation_), xDesc(xDesc_), yDesc(yDesc_), dim(dim_) + { + } + + miopenSumNanPropagation_t GetNanPropagation_() const { return nanPropagation; } + const TensorDescriptor& GetXDesc() const { return xDesc; } + const TensorDescriptor& GetYDesc() const { return yDesc; } + int32_t GetDim() const { return dim; } + + bool IsSameType() const + { + if(xDesc.GetType() != yDesc.GetType()) + { + MIOPEN_THROW(miopenStatusBadParm, "SumForward: Tensor types do not match."); + } + return true; + } + + bool IsRightLength() const + { + int32_t posy = 0; + for(int32_t i = 0; i < xDesc.GetLengths().size(); i++) + { + if(i == dim) + continue; + + if(xDesc.GetLengths()[i] != yDesc.GetLengths()[posy]) + { + MIOPEN_THROW(miopenStatusBadParm, + "SumForward: Tensor dimension lengths do not match."); + } + + posy++; + } + return true; + } + + bool IsRightDim() const + { + if((dim < 0) || (dim > xDesc.GetLengths().size())) + { + MIOPEN_THROW( + miopenStatusBadParm, + "SumForward: is greater than 0 and less than or equal tensor dimension length."); + } + return true; + } + + bool IsAllPacked() const + { + if(!(xDesc.IsPacked() && yDesc.IsPacked())) + { + MIOPEN_THROW(miopenStatusBadParm, "SumForward: Unpacked tensors not supported."); + } + return true; + } + + NetworkConfig MakeNetworkConfig() const override; + +private: + miopenSumNanPropagation_t nanPropagation; + TensorDescriptor xDesc; + TensorDescriptor yDesc; + + int32_t dim; + + NetworkConfig MakeForwardNetworkConfig() const; +}; + +} // namespace reduce + +} // namespace miopen diff --git a/src/include/miopen/reduce/solvers.hpp b/src/include/miopen/reduce/solvers.hpp new file mode 100644 index 0000000000..ef3a345d01 --- /dev/null +++ b/src/include/miopen/reduce/solvers.hpp @@ -0,0 +1,59 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include + +#include + +namespace miopen { + +namespace solver { + +namespace reduce { + +using ReduceSolver = NonTunableSolverBase; + +struct SumForward final : ReduceSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::reduce::ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::reduce::ProblemDescription& problem) const override; + std::size_t GetWorkspaceSize(const ExecutionContext& context, + const miopen::reduce::ProblemDescription& problem) const override; + bool MayNeedWorkspace() const override { return true; } +}; + +} // namespace reduce + +} // namespace solver + +} // namespace miopen diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index 477053ed20..53f431ab6d 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -52,7 +52,8 @@ enum class Primitive Bias, Fusion, Pooling, - Normalization + Normalization, + Reduce }; struct Id diff --git a/src/include/miopen/sum.hpp b/src/include/miopen/sum.hpp new file mode 100644 index 0000000000..7e7c5f2b8f --- /dev/null +++ b/src/include/miopen/sum.hpp @@ -0,0 +1,52 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_SUM_HPP_ +#define MIOPEN_SUM_HPP_ + +#include + +namespace miopen { + +struct Handle; +struct TensorDescriptor; + +std::size_t GetSumWorkspaceSize(Handle& handle, + const TensorDescriptor& xDesc, + const TensorDescriptor& yDesc, + int32_t dim); + +miopenStatus_t SumForward(Handle& handle, + Data_t workspace, + size_t workspaceSizeInBytes, + const TensorDescriptor& xDesc, + ConstData_t x, + const TensorDescriptor& yDesc, + Data_t y, + miopenSumNanPropagation_t nanPropagation, + int32_t dim); + +} // namespace miopen +#endif // _MIOPEN_SUM_HPP_ diff --git a/src/kernels/MIOpenSum.cpp b/src/kernels/MIOpenSum.cpp new file mode 100644 index 0000000000..049e295636 --- /dev/null +++ b/src/kernels/MIOpenSum.cpp @@ -0,0 +1,104 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#endif + +#include "float_types.h" + +#if MIOPEN_USE_BFP16 == 1 +#define CVT_FLOAT2ACCUM(x) (bfloat16_to_float(x)) +#define CVT_ACCUM2FLOAT(x) (float_to_bfloat16(x)) +#define CVT_INTEGRAL2ACCUM(x) ((_FLOAT_ACCUM)(x)) +#define CVT_FP32_2FLOAT(x) (CVT_ACCUM2FLOAT(x)) +#define CVT_FP32_2ACCUM(x) (x) +#endif + +extern "C" __global__ void SumParallelFwdContiguous(const FLOAT* __restrict__ x, + FLOAT* __restrict__ y, + uint64_t output_numel, + uint64_t reduce_size, + uint64_t parallelism_size, + uint64_t inner_size, + bool nanPropagation) +{ + const uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; + if(gid >= parallelism_size * output_numel) + return; + + uint64_t n = inner_size * parallelism_size; + + uint64_t slice_id = gid / n; + uint64_t slice_local_id = gid % n; + + uint64_t input_idx = slice_id * inner_size * reduce_size + slice_local_id; + + uint64_t parallel_id = slice_local_id / inner_size; + + FLOAT_ACCUM sum = static_cast(0); + for(uint64_t k = parallel_id; k < reduce_size; k += parallelism_size) + { + FLOAT_ACCUM val = CVT_FLOAT2ACCUM(x[input_idx]); + if(nanPropagation && isnan(val)) + { + val = static_cast(0); + } + sum += val; + input_idx += inner_size * parallelism_size; + } + + y[gid] = CVT_ACCUM2FLOAT(sum); +} + +extern "C" __global__ void SumFwdContiguous(const FLOAT* __restrict__ x, + FLOAT* __restrict__ y, + uint64_t output_numel, + uint64_t reduce_size, + uint64_t inner_size, + int32_t dim, + bool nanPropagation) +{ + const uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; + if(gid >= output_numel) + return; + + uint64_t input_idx = (gid / inner_size) * inner_size * reduce_size + gid % inner_size; + + FLOAT_ACCUM sum = static_cast(0); + for(uint64_t k = 0; k < reduce_size; ++k) + { + FLOAT_ACCUM val = CVT_FLOAT2ACCUM(x[input_idx]); + if(nanPropagation && isnan(val)) + { + val = static_cast(0); + } + sum += val; + input_idx += inner_size; + } + + y[gid] = CVT_ACCUM2FLOAT(sum); +} diff --git a/src/layernorm_api.cpp b/src/layernorm_api.cpp index dc3dcb4a53..5ab7bb4ad2 100644 --- a/src/layernorm_api.cpp +++ b/src/layernorm_api.cpp @@ -48,10 +48,6 @@ LogCmdLayerNorm(const miopenTensorDescriptor_t xDesc, const miopenLayerNormMode_ { ss << "layernormbf16"; } - else if(dtype == miopenDouble) - { - ss << "layernormfp64"; - } int32_t size = {0}; miopenGetTensorDescriptorSize(xDesc, &size); diff --git a/src/reduce/problem_description.cpp b/src/reduce/problem_description.cpp new file mode 100644 index 0000000000..7f6dd1b4fd --- /dev/null +++ b/src/reduce/problem_description.cpp @@ -0,0 +1,58 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include + +#include + +namespace miopen { + +namespace reduce { + +NetworkConfig ProblemDescription::MakeNetworkConfig() const +{ + auto xlength = xDesc.GetLengths(); + auto ylength = yDesc.GetLengths(); + + auto reduce_size = xlength[dim]; + auto output_numel = std::accumulate( + ylength.begin(), ylength.end(), static_cast(1), std::multiplies()); + auto dtype = xDesc.GetType(); + + std::ostringstream ss; + + ss << "dtype" << dtype; + ss << "dim" << dim; + ss << "reduce_size" << reduce_size; + ss << "output_numel" << output_numel; + + return NetworkConfig{ss.str()}; +} + +} // namespace reduce + +} // namespace miopen diff --git a/src/solver.cpp b/src/solver.cpp index 66810637b1..43a77e06b0 100644 --- a/src/solver.cpp +++ b/src/solver.cpp @@ -31,6 +31,7 @@ #include #include #include +#include #include #include @@ -609,6 +610,7 @@ inline SolverRegistrar::SolverRegistrar(IdRegistryData& registry) Register(registry, ++id, Primitive::Normalization, norm::Layernorm2DCKForward{}.SolverDbId()); Register(registry, ++id, Primitive::Normalization, norm::Layernorm4DCKForward{}.SolverDbId()); Register(registry, ++id, Primitive::Normalization, norm::LayernormForward{}.SolverDbId()); + Register(registry, ++id, Primitive::Reduce, reduce::SumForward{}.SolverDbId()); // IMPORTANT: New solvers should be added to the end of the function! } diff --git a/src/solver/norm/forward_layernorm.cpp b/src/solver/norm/forward_layernorm.cpp index be258a5d22..aeef501f76 100644 --- a/src/solver/norm/forward_layernorm.cpp +++ b/src/solver/norm/forward_layernorm.cpp @@ -56,7 +56,17 @@ std::size_t sizeof_local_memory(const miopen::norm::ProblemDescription& problem) bool LayernormForward::IsApplicable(const ExecutionContext&, const miopen::norm::ProblemDescription& problem) const { - return (sizeof_local_memory(problem) <= TargetProperties::GetMaxLocalMemorySize()); + if(!problem.IsSameType()) + return false; + if(!problem.IsSameLength()) + return false; + if(!problem.IsAllPacked()) + return false; + if(!problem.IsRightNormDim()) + return false; + if(!(sizeof_local_memory(problem) <= TargetProperties::GetMaxLocalMemorySize())) + return false; + return true; } ConvSolution LayernormForward::GetSolution(const ExecutionContext& context, diff --git a/src/solver/reduce/forward_sum.cpp b/src/solver/reduce/forward_sum.cpp new file mode 100644 index 0000000000..e44de7ae58 --- /dev/null +++ b/src/solver/reduce/forward_sum.cpp @@ -0,0 +1,330 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 256 + +namespace miopen { + +namespace solver { + +namespace reduce { + +bool IsNotLastDim(const miopen::reduce::ProblemDescription& problem) +{ + if((problem.GetDim() == problem.GetXDesc().GetLengths().size() - 1)) + { + MIOPEN_THROW(miopenStatusBadParm, "SumForward: Reduce last dim not supported."); + } + return true; +} + +size_t get_reqd_work_item_cnt(const ExecutionContext& context) +{ + // At least 4 WGs per one CU + return static_cast(LOCAL_SIZE * context.GetStream().GetMaxComputeUnits() * 4); +} + +size_t get_reqd_work_item_cnt(const Handle& handle) +{ + // At least 4 WGs per one CU + return static_cast(LOCAL_SIZE * handle.GetMaxComputeUnits() * 4); +} + +size_t get_parallelism_size(size_t reqd_work_item_cnt, size_t output_numel, size_t reduce_size) +{ + size_t parallelism_size = 1ULL; + while(parallelism_size * output_numel < reqd_work_item_cnt && + parallelism_size < std::sqrt(reduce_size)) + { + parallelism_size *= 2ULL; + } + return parallelism_size; +} + +bool is_parallelism(size_t reqd_work_item_cnt, size_t output_numel, size_t reduce_size) +{ + return !(output_numel > reqd_work_item_cnt) && + (output_numel * reduce_size > reqd_work_item_cnt); +} + +bool IsImprovementOverROCm(const ExecutionContext& context, + const miopen::reduce::ProblemDescription& problem) +{ + auto xdims = problem.GetXDesc().GetLengths(); + auto ydims = problem.GetYDesc().GetLengths(); + auto dim = problem.GetDim(); + + auto reduce_size = xdims[dim]; + auto output_numel = + std::accumulate(ydims.begin(), ydims.end(), 1ULL, std::multiplies()); + + auto reqd_work_item_cnt = get_reqd_work_item_cnt(context); + + if(is_parallelism(reqd_work_item_cnt, output_numel, reduce_size)) + { + // It's large enough to parallelization, but calling the kernel twice is overhead. + // For cases smaller than this, ROCm pytorch performed better. + bool is_improvement_ROCm = (output_numel * reduce_size < reqd_work_item_cnt * 64); + // But the reduce size is small, MIOpen HIP performed better. + bool is_reduce_large = (reduce_size > 64); + + if(is_improvement_ROCm && is_reduce_large) + return false; + } + + return true; +} + +bool SumForward::IsApplicable(const ExecutionContext& context, + const miopen::reduce::ProblemDescription& problem) const +{ + if(!problem.IsSameType()) + return false; + if(!problem.IsRightDim()) + return false; + if(!problem.IsRightLength()) + return false; + if(!problem.IsAllPacked()) + return false; + if(!IsNotLastDim(problem)) + return false; + if(!IsImprovementOverROCm(context, problem)) + return false; + return true; +} + +ConvSolution SumForward::GetSolution(const ExecutionContext& context, + const miopen::reduce::ProblemDescription& problem) const +{ + std::ignore = context; + + auto result = ConvSolution{miopenStatusSuccess}; + + auto dtype = problem.GetXDesc().GetType(); + auto xdims = problem.GetXDesc().GetLengths(); + auto ydims = problem.GetYDesc().GetLengths(); + auto dim = problem.GetDim(); + + auto reduce_size = xdims[dim]; + auto output_numel = + std::accumulate(ydims.begin(), ydims.end(), 1ULL, std::multiplies()); + + auto reqd_work_item_cnt = get_reqd_work_item_cnt(context); + + if(is_parallelism(reqd_work_item_cnt, output_numel, reduce_size)) + { + auto parallelism_size = get_parallelism_size(reqd_work_item_cnt, output_numel, reduce_size); + + size_t xlocalsize = LOCAL_SIZE; + size_t xgridsize = AlignUp(parallelism_size * output_numel, xlocalsize); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + + kernel.kernel_file = "MIOpenSum.cpp"; + kernel.kernel_name = "SumParallelFwdContiguous"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + }; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + } + + { + size_t xlocalsize = LOCAL_SIZE; + size_t xgridsize = AlignUp(output_numel, xlocalsize); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + + kernel.kernel_file = "MIOpenSum.cpp"; + kernel.kernel_name = "SumFwdContiguous"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + }; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + } + + if(is_parallelism(reqd_work_item_cnt, output_numel, reduce_size)) + { + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) parallel_kernel = handle_.Run(kernels[0]); + decltype(auto) kernel = handle_.Run(kernels[1]); + decltype(auto) params = raw_params.CastTo(); + + auto xdims = params.xDesc->GetLengths(); + auto ydims = params.yDesc->GetLengths(); + auto dim = params.dim; + + auto reduce_size = xdims[dim]; + auto output_numel = + std::accumulate(ydims.begin(), ydims.end(), 1ULL, std::multiplies()); + + auto inner_size = 1ULL; + for(int32_t i = dim + 1; i < xdims.size(); i++) + { + inner_size *= xdims[i]; + } + + auto reqd_work_item_cnt = get_reqd_work_item_cnt(handle_); + auto parallelism_size = + get_parallelism_size(reqd_work_item_cnt, output_numel, reduce_size); + + auto elapsed = 0.f; + + parallel_kernel(params.x, + params.workspace, + output_numel, + reduce_size, + parallelism_size, + inner_size, + static_cast(params.nanPropagation)); + + if(handle_.IsProfilingEnabled()) + elapsed = handle_.GetKernelTime(); + + kernel(params.workspace, + params.y, + output_numel, + parallelism_size, + inner_size, + dim, + static_cast(params.nanPropagation)); + + if(handle_.IsProfilingEnabled()) + { + elapsed += handle_.GetKernelTime(); + handle_.ResetKernelTime(); + handle_.AccumKernelTime(elapsed); + }; + }; + }; + } + else + { + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + + auto xdims = params.xDesc->GetLengths(); + auto ydims = params.yDesc->GetLengths(); + auto dim = params.dim; + + auto reduce_size = xdims[dim]; + auto output_numel = + std::accumulate(ydims.begin(), ydims.end(), 1ULL, std::multiplies()); + + size_t inner_size = 1; + for(int32_t i = dim + 1; i < xdims.size(); i++) + { + inner_size *= xdims[i]; + } + + kernel(params.x, + params.y, + output_numel, + reduce_size, + inner_size, + dim, + static_cast(params.nanPropagation)); + }; + }; + } + return result; +} + +std::size_t SumForward::GetWorkspaceSize(const ExecutionContext& context, + const miopen::reduce::ProblemDescription& problem) const +{ + auto xdims = problem.GetXDesc().GetLengths(); + auto ydims = problem.GetYDesc().GetLengths(); + + auto reduce_size = xdims[problem.GetDim()]; + auto output_numel = + std::accumulate(ydims.begin(), ydims.end(), 1ULL, std::multiplies()); + + auto reqd_work_item_cnt = get_reqd_work_item_cnt(context); + + if(is_parallelism(reqd_work_item_cnt, output_numel, reduce_size)) + { + auto parallelism_size = get_parallelism_size(reqd_work_item_cnt, output_numel, reduce_size); + + return parallelism_size * output_numel * get_data_size(problem.GetXDesc().GetType()); + } + + return 0; +} + +} // namespace reduce + +} // namespace solver + +} // namespace miopen diff --git a/src/sum.cpp b/src/sum.cpp new file mode 100644 index 0000000000..7869891af4 --- /dev/null +++ b/src/sum.cpp @@ -0,0 +1,88 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include +#include +#include +#include +#include +#include + +namespace miopen { + +std::size_t GetSumWorkspaceSize(Handle& handle, + const TensorDescriptor& xDesc, + const TensorDescriptor& yDesc, + int32_t dim) +{ + auto ctx = ExecutionContext{&handle}; + const auto problem = + reduce::ProblemDescription{MIOPEN_SUM_NOT_PROPAGATE_NAN, xDesc, yDesc, dim}; + + const auto algo = AlgorithmName{"SumForward"}; + const auto solvers = solver::SolverContainer{}; + + auto pair_size_vector = solvers.GetWorkspaceSizes(ctx, problem); + + return pair_size_vector.empty() ? static_cast(-1) : pair_size_vector.front().second; +} + +miopenStatus_t SumForward(Handle& handle, + Data_t workspace, + size_t workspaceSizeInBytes, + const TensorDescriptor& xDesc, + ConstData_t x, + const TensorDescriptor& yDesc, + Data_t y, + miopenSumNanPropagation_t nanPropagation, + int32_t dim) +{ + const auto problem = reduce::ProblemDescription{nanPropagation, xDesc, yDesc, dim}; + + const auto invoke_params = [&]() { + auto tmp = reduce::InvokeParams{}; + tmp.type = InvokeType::Run; + tmp.xDesc = &xDesc; + tmp.yDesc = &yDesc; + tmp.x = x; + tmp.y = y; + tmp.workspace = workspace; + tmp.workspace_size = workspaceSizeInBytes; + tmp.nanPropagation = nanPropagation; + tmp.dim = dim; + return tmp; + }(); + + const auto algo = AlgorithmName{"SumForward"}; + const auto solvers = solver::SolverContainer{}; + + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +} // namespace miopen diff --git a/src/sum_api.cpp b/src/sum_api.cpp new file mode 100644 index 0000000000..de3744f306 --- /dev/null +++ b/src/sum_api.cpp @@ -0,0 +1,125 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include +#include +#include + +static void LogCmdSum(const miopenTensorDescriptor_t xDesc, + const miopenSumNanPropagation_t nanPropagation, + bool is_fwd) +{ + if(miopen::IsLoggingCmd()) + { + std::stringstream ss; + auto dtype = miopen::deref(xDesc).GetType(); + if(dtype == miopenHalf) + { + ss << "sumfp16"; + } + else if(dtype == miopenFloat) + { + ss << "sumfp32"; + } + else if(dtype == miopenBFloat16) + { + ss << "sumbf16"; + } + + int32_t size = {0}; + miopenGetTensorDescriptorSize(xDesc, &size); + ss << " -n " << miopen::deref(xDesc).GetLengths()[0]; + if(size == 5) + { + ss << " -c " << miopen::deref(xDesc).GetLengths()[1] << " -D " + << miopen::deref(xDesc).GetLengths()[2] << " -H " + << miopen::deref(xDesc).GetLengths()[3] << " -W " + << miopen::deref(xDesc).GetLengths()[4]; + } + else if(size == 4) + { + ss << " -c " << miopen::deref(xDesc).GetLengths()[1] << " -H " + << miopen::deref(xDesc).GetLengths()[2] << " -W " + << miopen::deref(xDesc).GetLengths()[3]; + } + else if(size == 3) + { + ss << " -c " << miopen::deref(xDesc).GetLengths()[1] << " -W " + << miopen::deref(xDesc).GetLengths()[2]; + } + else if(size == 2) + { + ss << " -c " << miopen::deref(xDesc).GetLengths()[1]; + } + + ss << " -F " << ((is_fwd) ? "1" : "2") << " -n " << nanPropagation; + + MIOPEN_LOG_DRIVER_CMD(ss.str()); + } +} + +extern "C" miopenStatus_t miopenGetSumWorkspaceSize(miopenHandle_t handle, + const miopenTensorDescriptor_t xDesc, + int32_t dim, + const miopenTensorDescriptor_t yDesc, + size_t* sizeInBytes) +{ + + MIOPEN_LOG_FUNCTION(handle, xDesc, dim, yDesc, sizeInBytes); + + return miopen::try_([&] { + miopen::deref(sizeInBytes) = miopen::GetSumWorkspaceSize( + miopen::deref(handle), miopen::deref(xDesc), miopen::deref(yDesc), dim); + }); +}; + +extern "C" miopenStatus_t miopenSumForward(miopenHandle_t handle, + miopenSumNanPropagation_t nanPropagation, + void* workspace, + size_t workspaceSizeInBytes, + const miopenTensorDescriptor_t xDesc, + const void* x, + const int32_t dim, + const miopenTensorDescriptor_t yDesc, + void* y) +{ + MIOPEN_LOG_FUNCTION( + handle, nanPropagation, workspace, workspaceSizeInBytes, xDesc, x, dim, yDesc, y); + + LogCmdSum(xDesc, nanPropagation, true); + return miopen::try_([&] { + miopen::SumForward(miopen::deref(handle), + DataCast(workspace), + workspaceSizeInBytes, + miopen::deref(xDesc), + DataCast(x), + miopen::deref(yDesc), + DataCast(y), + nanPropagation, + dim); + }); +} diff --git a/test/cpu_sum.hpp b/test/cpu_sum.hpp new file mode 100644 index 0000000000..8898a5654b --- /dev/null +++ b/test/cpu_sum.hpp @@ -0,0 +1,67 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_CPU_SUM_HPP +#define GUARD_CPU_SUM_HPP + +#include "tensor_holder.hpp" + +template +void cpu_sum_forward(tensor input, + tensor& ref_output, + int32_t dim, + miopenSumNanPropagation_t nanPropagation) +{ + auto input_dims = input.desc.GetLengths(); + auto output_dims = ref_output.desc.GetLengths(); + + auto reduce_size = input_dims[dim]; + auto output_numel = + std::accumulate(output_dims.begin(), output_dims.end(), 1L, std::multiplies()); + + auto inner_size = 1ULL; + for(int32_t i = dim + 1; i < input_dims.size(); i++) + { + inner_size *= input_dims[i]; + } + + par_ford(output_numel)([&](size_t o) { + size_t input_idx = (o / inner_size) * inner_size * reduce_size + o % inner_size; + T sum = 0.0f; + + ford(reduce_size)([&](size_t) { + T val = input[input_idx]; + if(nanPropagation && std::isnan(val)) + { + val = 0.0f; + } + sum += val; + input_idx += inner_size; + }); + + ref_output[o] = sum; + }); +} +#endif diff --git a/test/gtest/layernorm.hpp b/test/gtest/layernorm.hpp index 7740a20476..63158ef134 100644 --- a/test/gtest/layernorm.hpp +++ b/test/gtest/layernorm.hpp @@ -23,7 +23,6 @@ * SOFTWARE. * *******************************************************************************/ -#define MIOPEN_BETA_API 1 #include #include #include @@ -279,8 +278,8 @@ struct LayerNormTest : public ::testing::TestWithParam error = miopen::rms_range(ref_mean, mean); EXPECT_TRUE(miopen::range_distance(ref_mean) == miopen::range_distance(mean)); - EXPECT_TRUE(error < threshold * 20) - << "Error mean beyond tolerance Error:" << error << ", Threshold: " << threshold; + EXPECT_TRUE(error < threshold * 20) << "Error mean beyond tolerance Error:" << error + << ", Thresholdx20: " << threshold * 20; error = miopen::rms_range(ref_rstd, rstd); EXPECT_TRUE(miopen::range_distance(ref_rstd) == miopen::range_distance(rstd)); diff --git a/test/gtest/sum.cpp b/test/gtest/sum.cpp new file mode 100644 index 0000000000..747bb9ce28 --- /dev/null +++ b/test/gtest/sum.cpp @@ -0,0 +1,55 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "sum.hpp" + +std::string GetFloatArg() +{ + static const auto tmp = miopen::GetEnv("MIOPEN_TEST_FLOAT_ARG"); + if(tmp.empty()) + { + return ""; + } + return tmp.front(); +} + +struct SumTestFloat : SumTest +{ +}; + +TEST_P(SumTestFloat, SumTestFw) +{ + if(miopen::IsEnvvarValueEnabled("MIOPEN_TEST_ALL") && (GetFloatArg() == "--float")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +INSTANTIATE_TEST_SUITE_P(SumTestSet, SumTestFloat, testing::ValuesIn(SumTestConfigs())); diff --git a/test/gtest/sum.hpp b/test/gtest/sum.hpp new file mode 100644 index 0000000000..f38da15f96 --- /dev/null +++ b/test/gtest/sum.hpp @@ -0,0 +1,212 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2023 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include +#include +#include + +#include "tensor_holder.hpp" +#include "cpu_sum.hpp" +#include "get_handle.hpp" +#include "../driver/tensor_driver.hpp" +#include "verify.hpp" +#include + +struct SumTestCase +{ + size_t N; + size_t C; + size_t D; + size_t H; + size_t W; + int32_t dim; + miopenSumNanPropagation_t nanPropagation; + friend std::ostream& operator<<(std::ostream& os, const SumTestCase& tc) + { + return os << " N:" << tc.N << " C:" << tc.C << " D:" << tc.D << " H:" << tc.H + << " W:" << tc.W << " dim:" << tc.dim << " NanPropagation:" << tc.nanPropagation; + } + + std::vector GetInput() + { + if((N != 0) && (C != 0) && (D != 0) && (H != 0) && (W != 0)) + { + return std::vector({N, C, D, H, W}); + } + else if((N != 0) && (C != 0) && (H != 0) && (W != 0)) + { + return std::vector({N, C, H, W}); + } + else if((N != 0) && (C != 0) && (W != 0)) + { + return std::vector({N, C, W}); + } + else if((N != 0) && (W != 0)) + { + return std::vector({N, W}); + } + else if((N != 0)) + { + return std::vector({N}); + } + else + { + std::cout << "Error Input Tensor Lengths\n" << std::endl; + return std::vector({0}); + } + } +}; + +std::vector SumTestConfigs() +{ // n c d h w dim nanPropagation + // clang-format off + return { + { 8, 120, 0, 0, 1, 0 , MIOPEN_SUM_NOT_PROPAGATE_NAN}, //bart + { 8, 120, 0, 0, 1, 0 , MIOPEN_SUM_PROPAGATE_NAN}, + { 8, 1023, 0, 0, 1, 0 , MIOPEN_SUM_NOT_PROPAGATE_NAN}, //gpt_neo + { 8, 1024, 0, 0, 768, 0 , MIOPEN_SUM_NOT_PROPAGATE_NAN}, + { 8, 1023, 0, 0, 1, 0 , MIOPEN_SUM_PROPAGATE_NAN}, + { 8, 1024, 0, 0, 768, 0 , MIOPEN_SUM_PROPAGATE_NAN}, + { 16, 1024, 0, 0, 768, 0 , MIOPEN_SUM_NOT_PROPAGATE_NAN}, //gpt2 + { 16, 1024, 0, 0, 768, 0 , MIOPEN_SUM_PROPAGATE_NAN}, + { 48, 8, 0, 512, 512, 0 , MIOPEN_SUM_NOT_PROPAGATE_NAN}, //t5 + { 48, 8, 0, 512, 512, 0 , MIOPEN_SUM_PROPAGATE_NAN}, + { 16, 311, 0, 98, 512, 2 , MIOPEN_SUM_NOT_PROPAGATE_NAN}, //rnnt + { 16, 311, 0, 98, 512, 2 , MIOPEN_SUM_PROPAGATE_NAN} + }; + // clang-format on +} + +inline int32_t SetTensorLayout(miopen::TensorDescriptor& desc) +{ + std::vector lens = desc.GetLengths(); + std::vector int32_t_lens(lens.begin(), lens.end()); + + // set the strides for the tensor + return SetTensorNd(&desc, int32_t_lens, desc.GetType()); +} + +template +struct SumTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + auto&& handle = get_handle(); + sum_config = GetParam(); + std::mt19937 gen(0); + std::uniform_real_distribution<> d{-3, 3}; + auto gen_value = [&](auto...) { return d(gen); }; + + dim = sum_config.dim; + nanPropagation = sum_config.nanPropagation; + + auto in_dims = sum_config.GetInput(); + + input = tensor{in_dims}.generate(gen_value); + + std::vector out_dims; + + for(int i = 0; i < in_dims.size(); i++) + { + if(i != dim) + { + out_dims.push_back(in_dims[i]); + } + } + + SetTensorLayout(input.desc); + + output = tensor{out_dims}; + SetTensorLayout(output.desc); + std::fill(output.begin(), output.end(), std::numeric_limits::quiet_NaN()); + + ref_output = tensor{out_dims}; + std::fill(ref_output.begin(), ref_output.end(), std::numeric_limits::quiet_NaN()); + + std::vector workspace_dims; + ws_sizeInBytes = miopen::GetSumWorkspaceSize(handle, input.desc, output.desc, dim); + if(ws_sizeInBytes == static_cast(-1)) + GTEST_SKIP(); + + workspace_dims.push_back(ws_sizeInBytes / sizeof(T)); + if(ws_sizeInBytes != 0) + { + workspace = tensor{workspace_dims}; + std::fill(workspace.begin(), workspace.end(), std::numeric_limits::quiet_NaN()); + workspace_dev = handle.Write(workspace.data); + } + + input_dev = handle.Write(input.data); + output_dev = handle.Write(output.data); + } + void RunTest() + { + auto&& handle = get_handle(); + + cpu_sum_forward(input, ref_output, dim, nanPropagation); + miopenStatus_t status; + + status = miopen::SumForward(handle, + workspace_dev.get(), + ws_sizeInBytes, + input.desc, + input_dev.get(), + output.desc, + output_dev.get(), + nanPropagation, + dim); + + EXPECT_EQ(status, miopenStatusSuccess); + + output.data = handle.Read(output_dev, output.data.size()); + } + + void Verify() + { + double threshold = std::numeric_limits::epsilon(); + auto error = miopen::rms_range(ref_output, output); + + EXPECT_TRUE(miopen::range_distance(ref_output) == miopen::range_distance(output)); + EXPECT_TRUE(error < threshold * 10) << "Error output beyond tolerance Error:" << error + << ", Thresholdx10: " << threshold * 10; + } + SumTestCase sum_config; + + tensor input; + tensor output; + tensor workspace; + + tensor ref_output; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr output_dev; + miopen::Allocator::ManageDataPtr workspace_dev; + + size_t ws_sizeInBytes; + + int32_t dim; + miopenSumNanPropagation_t nanPropagation; +};