diff --git a/docs/reference/index.rst b/docs/reference/index.rst index 02bcb88622..99da4c7de7 100644 --- a/docs/reference/index.rst +++ b/docs/reference/index.rst @@ -32,3 +32,4 @@ The MIOpen API library is structured as follows: * :doc:`GroupNorm <../doxygen/html/group__groupnorm>` (experimental) * :doc:`Cat <../doxygen/html/group__cat>` (experimental) * :doc:`Argmax<./argmax>` (experimental) + * :doc:`glu <../doxygen/html/group__glu>` (experimental) diff --git a/driver/CMakeLists.txt b/driver/CMakeLists.txt index 224e550fed..531a9e89fc 100644 --- a/driver/CMakeLists.txt +++ b/driver/CMakeLists.txt @@ -42,6 +42,7 @@ add_executable(MIOpenDriver dm_dropout.cpp dm_fusion.cpp dm_gemm.cpp + dm_glu.cpp dm_groupnorm.cpp dm_layernorm.cpp dm_lrn.cpp diff --git a/driver/dm_glu.cpp b/driver/dm_glu.cpp new file mode 100644 index 0000000000..666ccf03da --- /dev/null +++ b/driver/dm_glu.cpp @@ -0,0 +1,40 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "registry_driver_maker.hpp" +#include "glu_driver.hpp" + +static Driver* makeDriver(const std::string& base_arg) +{ + if(base_arg == "glu") + return new GLUDriver(); + if(base_arg == "glufp16") + return new GLUDriver(); + if(base_arg == "glubfp16") + return new GLUDriver(); + return nullptr; +} + +REGISTER_DRIVER_MAKER(makeDriver); diff --git a/driver/driver.hpp b/driver/driver.hpp index 4cfc2b544e..efcfec9503 100644 --- a/driver/driver.hpp +++ b/driver/driver.hpp @@ -151,7 +151,7 @@ inline void PadBufferSize(size_t& sz, int datatype_sz) "pool[fp16], lrn[fp16], " "activ[fp16], softmax[fp16], bnorm[fp16], rnn[fp16], gemm[fp16], ctc, dropout[fp16], " "tensorop[fp16], reduce[fp16|fp64], layernorm[bfp16|fp16], sum[bfp16|fp16], " - "argmax[bfp16|fp16], groupnorm[bfp16|fp16], cat[bfp16|fp16]\n"); + "argmax[bfp16|fp16], groupnorm[bfp16|fp16], cat[bfp16|fp16] glu[bfp16|fp16]\n"); exit(0); // NOLINT (concurrency-mt-unsafe) } @@ -176,7 +176,8 @@ inline std::string ParseBaseArg(int argc, char* argv[]) arg != "layernormfp16" && arg != "layernormbfp16" && arg != "sum" && arg != "sumfp16" && arg != "sumbfp16" && arg != "argmax" && arg != "argmaxfp16" && arg != "argmaxbfp16" && arg != "groupnorm" && arg != "groupnormfp16" && arg != "groupnormbfp16" && arg != "cat" && - arg != "catfp16" && arg != "catbfp16" && arg != "--version") + arg != "catfp16" && arg != "catbfp16" && arg != "glu" && arg != "glufp16" && + arg != "glubfp16" && arg != "--version") { printf("FAILED: Invalid Base Input Argument\n"); Usage(); diff --git a/driver/glu_driver.hpp b/driver/glu_driver.hpp new file mode 100644 index 0000000000..5a0ccb0590 --- /dev/null +++ b/driver/glu_driver.hpp @@ -0,0 +1,366 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_MIOPEN_GLU_DRIVER_HPP +#define GUARD_MIOPEN_GLU_DRIVER_HPP + +#include "InputFlags.hpp" +#include "driver.hpp" +#include "tensor_driver.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "random.hpp" +#include "timer.hpp" +#include "../test/verify.hpp" + +#ifndef MLO_GLUHOST_H_ +#define MLO_GLUHOST_H_ + +template +T sigmoid(T x) +{ + return 1.0f / (1.0f + exp(-x)); +} + +template +int32_t mloGLUForwardContiguousRunHost(miopenTensorDescriptor_t inputDesc, + Tgpu* input, + miopenTensorDescriptor_t outputDesc, + Tcheck* outputHost) +{ + auto output_dims = miopen::deref(outputDesc).GetLengths(); + + auto output_numel = + std::accumulate(output_dims.begin(), output_dims.end(), 1L, std::multiplies()); + auto inputFirstHalf = input; + auto inputSecondHalf = input + output_numel; + + int32_t ret = 0; + + for(size_t o = 0; o < output_numel; o++) + { + Tcheck valA = static_cast(inputFirstHalf[o]); + Tcheck valB = static_cast(inputSecondHalf[o]); + Tcheck val = valA * sigmoid(valB); + outputHost[o] = val; + } + + return ret; +} +#endif + +template +class GLUDriver : public Driver +{ +public: + GLUDriver() : Driver() + { + miopenCreateTensorDescriptor(&inputTensor); + miopenCreateTensorDescriptor(&outputTensor); + + data_type = miopen_type{}; + } + + int AddCmdLineArgs() override; + int ParseCmdLineArgs(int argc, char* argv[]) override; + InputFlags& GetInputFlags() override { return inflags; } + + int GetandSetData() override; + std::vector GetInputTensorLengthsFromCmdLine(); + + void splitInput(); + + int AllocateBuffersAndCopy() override; + + int RunForwardGPU() override; + int RunForwardCPU(); // Verify implements it + + int RunBackwardGPU() override; + int RunBackwardCPU(); // Verify implements it + + Tref GetTolerance(); + int VerifyBackward() override; + int VerifyForward() override; + ~GLUDriver() override + { + miopenDestroyTensorDescriptor(outputTensor); + miopenDestroyTensorDescriptor(inputTensor); + } + +private: + InputFlags inflags; + + miopenTensorDescriptor_t inputTensor; + miopenTensorDescriptor_t outputTensor; + + std::unique_ptr in_dev; + std::unique_ptr out_dev; + + std::vector in; + std::vector out; + std::vector outhost; + + int dim; +}; + +template +int GLUDriver::ParseCmdLineArgs(int argc, char* argv[]) +{ + inflags.Parse(argc, argv); + + if(inflags.GetValueInt("time") == 1) + { + miopenEnableProfiling(GetHandle(), true); + } + return miopenStatusSuccess; +} + +template +int GLUDriver::GetandSetData() +{ + std::vector in_len = GetInputTensorLengthsFromCmdLine(); + dim = inflags.GetValueInt("DimToSplit"); + + SetTensorNd(inputTensor, in_len, data_type); + + std::vector out_len; + + for(int i = 0; i < in_len.size(); i++) + { + if(i != dim) + { + out_len.push_back(in_len[i]); + } + else + { + out_len.push_back(in_len[i] / 2); + } + } + + if(out_len.empty()) + out_len.push_back(1); + + SetTensorNd(outputTensor, out_len, data_type); + + return (0); +} + +template +int GLUDriver::AddCmdLineArgs() +{ + inflags.AddInputFlag("forw", 'F', "0", "Run only Forward LRN Normalization (Default=0)", "int"); + inflags.AddInputFlag("batchsize", 'n', "100", "Mini-batch size (Default=100)", "int"); + inflags.AddInputFlag("in_channels", 'c', "3", "Number of Input Channels (Default=3)", "int"); + inflags.AddInputFlag("in_d", 'D', "0", "Input Depth (Default=0)", "int"); + inflags.AddInputFlag("in_h", 'H', "32", "Input Height (Default=32)", "int"); + inflags.AddInputFlag("in_w", 'W', "32", "Input Width (Default=32)", "int"); + + inflags.AddInputFlag( + "DimToSplit", 'R', "0", "The indice of the dimensions to be split half(Default=0)", "int"); + inflags.AddInputFlag("iter", 'i', "10", "Number of Iterations (Default=10)", "int"); + inflags.AddInputFlag("verify", 'V', "1", "Verify Each Layer (Default=1)", "int"); + inflags.AddInputFlag("time", 't', "0", "Time Each Layer (Default=0)", "int"); + inflags.AddInputFlag( + "wall", 'w', "0", "Wall-clock Time Each Layer, Requires time == 1 (Default=0)", "int"); + + return miopenStatusSuccess; +} + +template +std::vector GLUDriver::GetInputTensorLengthsFromCmdLine() +{ + int in_n = inflags.GetValueInt("batchsize"); + int in_c = inflags.GetValueInt("in_channels"); + int in_d = inflags.GetValueInt("in_d"); + int in_h = inflags.GetValueInt("in_h"); + int in_w = inflags.GetValueInt("in_w"); + + if((in_n != 0) && (in_c != 0) && (in_d != 0) && (in_h != 0) && (in_w != 0)) + { + return std::vector({in_n, in_c, in_d, in_h, in_w}); + } + else if((in_n != 0) && (in_c != 0) && (in_h != 0) && (in_w != 0)) + { + return std::vector({in_n, in_c, in_h, in_w}); + } + else if((in_n != 0) && (in_c != 0) && (in_w != 0)) + { + return std::vector({in_n, in_c, in_w}); + } + else if((in_n != 0) && (in_w != 0)) + { + return std::vector({in_n, in_w}); + } + else if(in_n != 0) + { + return std::vector({in_n}); + } + else + { + std::cerr << "Error Input Tensor Lengths\n" << std::endl; + return std::vector({0}); + } +} + +template +int GLUDriver::AllocateBuffersAndCopy() +{ + + size_t in_sz = GetTensorSpace(inputTensor); + size_t out_sz = GetTensorSpace(outputTensor); + + uint32_t ctx = 0; + + in_dev = std::unique_ptr(new GPUMem(ctx, in_sz, sizeof(Tgpu))); + out_dev = std::unique_ptr(new GPUMem(ctx, out_sz, sizeof(Tgpu))); + + in = std::vector(in_sz, static_cast(0)); + out = std::vector(out_sz, static_cast(0)); + outhost = std::vector(out_sz, static_cast(0)); + + for(int i = 0; i < in_sz; i++) + { + in[i] = prng::gen_A_to_B(static_cast(0.0), static_cast(1.0)); + } + + if(in_dev->ToGPU(GetStream(), in.data()) != 0) + std::cerr << "Error copying (second half in) to GPU, size: " << in_dev->GetSize() + << std::endl; + + if(out_dev->ToGPU(GetStream(), out.data()) != 0) + std::cerr << "Error copying (out) to GPU, size: " << out_dev->GetSize() << std::endl; + + return miopenStatusSuccess; +} + +template +int GLUDriver::RunForwardGPU() +{ + float kernel_total_time = 0; + float kernel_first_time = 0; + + Timer t; + START_TIME + + for(int i = 0; i < inflags.GetValueInt("iter"); i++) + { + miopenGLUForward( + GetHandle(), inputTensor, in_dev->GetMem(), dim, outputTensor, out_dev->GetMem()); + + float time = 0.0; + miopenGetKernelTime(GetHandle(), &time); + kernel_total_time += time; + if(i == 0) + kernel_first_time = time; + } + + if(inflags.GetValueInt("time") == 1) + { + STOP_TIME + int iter = inflags.GetValueInt("iter"); + if(WALL_CLOCK) + std::cout << "Wall-clock Time Forward Sum Elapsed: " << t.gettime_ms() / iter + << " ms\n"; + + float kernel_average_time = + iter > 1 ? (kernel_total_time - kernel_first_time) / (iter - 1) : kernel_first_time; + std::cout << "GPU Kernel Time Forward Sum Elapsed: " << kernel_average_time << " ms\n"; + } + + if(out_dev->FromGPU(GetStream(), out.data()) != 0) + std::cerr << "Error copying (out_dev) from GPU, size: " << out_dev->GetSize() << std::endl; + + return miopenStatusSuccess; +} + +template +int GLUDriver::RunForwardCPU() +{ + mloGLUForwardContiguousRunHost( + inputTensor, in.data(), outputTensor, outhost.data()); + + return miopenStatusSuccess; +} + +template +int GLUDriver::RunBackwardGPU() +{ + return miopenStatusSuccess; +} + +template +Tref GLUDriver::GetTolerance() +{ + // Computation error of fp16 is ~2^13 (=8192) bigger than + // the one of fp32 because mantissa is shorter by 13 bits. + auto tolerance = std::is_same::value ? 1.5e-6 : 8.2e-3; + + // bf16 mantissa has 7 bits, by 3 bits shorter than fp16. + if(std::is_same::value) + tolerance *= 8.0; + return tolerance; +} + +template +int GLUDriver::VerifyForward() +{ + RunForwardCPU(); + const Tref tolerance = GetTolerance(); + auto error = miopen::rms_range(outhost, out); + + if(!std::isfinite(error) || error > tolerance) + { + std::cout << "Forward GLU FAILED: " << error << " > " << tolerance << std::endl; + return EC_VerifyFwd; + } + else + { + std::cout << "Forward GLU Verifies OK on CPU reference (" << error << " < " << tolerance + << ')' << std::endl; + } + + return miopenStatusSuccess; +} + +template +int GLUDriver::RunBackwardCPU() +{ + return miopenStatusSuccess; +} + +template +int GLUDriver::VerifyBackward() +{ + return miopenStatusSuccess; +} + +#endif // GUARD_MIOPEN_GLU_DRIVER_HPP diff --git a/include/miopen/miopen.h b/include/miopen/miopen.h index e768c7b349..e812719d0c 100644 --- a/include/miopen/miopen.h +++ b/include/miopen/miopen.h @@ -68,6 +68,7 @@ * @defgroup argmax * @defgroup groupnorm * @defgroup cat + * @defgroup glu * */ @@ -6582,6 +6583,35 @@ MIOPEN_EXPORT miopenStatus_t miopenBackendInitialize(miopenBackendDescriptor_t d // CLOSEOUT BackendAPI DOXYGEN GROUP #endif // MIOPEN_BETA_API +#ifdef MIOPEN_BETA_API + +// GLU APIs +/** @addtogroup glu + * + * @{ + */ + +/*! @brief Execute a GLU forward contiguous layer + * + * @param handle MIOpen handle (input) + * @param inputDesc Tensor descriptor for data input tensor (input) + * @param input Input data tensor (input) + * @param dim Dimension to split the input (input) + * @param outputDesc Tensor descriptor for output data tensor (input) + * @param output Output data tensor (output) + * @return miopenStatus_t + */ +MIOPEN_EXPORT miopenStatus_t miopenGLUForward(miopenHandle_t handle, + const miopenTensorDescriptor_t inputDesc, + void* input, + const int32_t dim, + const miopenTensorDescriptor_t outputDesc, + void* output); + +/** @} */ +// CLOSEOUT BackendAPI DOXYGEN GROUP +#endif // MIOPEN_BETA_API + #ifdef __cplusplus } #endif diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9671eed03c..fd5ab7ea51 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -120,6 +120,9 @@ set( MIOpen_Source fused_api.cpp fusion.cpp fusion/problem_description.cpp + glu.cpp + glu_api.cpp + glu/problem_description.cpp generic_search.cpp graphapi/convolution.cpp graphapi/graphapi.cpp @@ -259,6 +262,7 @@ set( MIOpen_Source solver/gemm.cpp solver/gemm_bwd.cpp solver/gemm_wrw.cpp + solver/glu/forward_glu.cpp solver/groupnorm/forward_groupnorm.cpp solver/layernorm/forward_layernorm.cpp solver/layernorm/forward_layernorm2d_ck.cpp @@ -454,6 +458,7 @@ if( MIOPEN_BACKEND MATCHES "OpenCL" OR MIOPEN_BACKEND STREQUAL "HIPOC" OR MIOPEN kernels/MIOpenConvDirUni.cl kernels/MIOpenConvDirBatchNormActiv.cl kernels/MIOpenConvDirGenFwd.cl + kernels/MIOpenGLU.cpp kernels/MIOpenGroupNorm.cpp kernels/MIOpenLayerNorm.cpp kernels/MIOpenLRNBwd.cl diff --git a/src/glu.cpp b/src/glu.cpp new file mode 100644 index 0000000000..34ae67a52b --- /dev/null +++ b/src/glu.cpp @@ -0,0 +1,66 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace miopen { + +miopenStatus_t GLUForward(Handle& handle, + const TensorDescriptor& inputDesc, + Data_t input, + int32_t dim, + const TensorDescriptor& outputDesc, + Data_t output) +{ + const auto problem = glu::ProblemDescription{inputDesc, outputDesc, dim}; + + const auto invoke_params = [&]() { + auto tmp = glu::InvokeParams{}; + tmp.type = InvokeType::Run; + tmp.inputDesc = &inputDesc; + tmp.outputDesc = &outputDesc; + tmp.input = input; + tmp.output = output; + tmp.dim = dim; + return tmp; + }(); + + const auto algo = AlgorithmName{"GLUForward"}; + const auto solvers = solver::SolverContainer{}; + + solvers.ExecutePrimitive(handle, problem, algo, invoke_params); + + return miopenStatusSuccess; +} + +} // namespace miopen diff --git a/src/glu/problem_description.cpp b/src/glu/problem_description.cpp new file mode 100644 index 0000000000..269a517140 --- /dev/null +++ b/src/glu/problem_description.cpp @@ -0,0 +1,65 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "miopen/datatype.hpp" +#include +#include + +#include + +namespace miopen { + +namespace glu { + +NetworkConfig ProblemDescription::MakeNetworkConfig() const +{ + auto inputlength = inputDesc.GetLengths(); + auto outputlength = outputDesc.GetLengths(); + + auto splitdim_size = inputlength[dim]; + auto output_numel = std::accumulate(outputlength.begin(), + outputlength.end(), + static_cast(1), + std::multiplies()); + + auto input_dtype = miopen::GetDataType(inputDesc.GetType()); + auto output_dtype = miopen::GetDataType(outputDesc.GetType()); + + std::ostringstream ss; + + ss << "contiguous"; + ss << "input_dtype" << input_dtype; + ss << "output_dtype" << output_dtype; + ss << "dim" << dim; + ss << "splitDim_size" << splitdim_size; + ss << "output_numel" << output_numel; + + return NetworkConfig{ss.str()}; +} + +} // namespace glu + +} // namespace miopen diff --git a/src/glu_api.cpp b/src/glu_api.cpp new file mode 100644 index 0000000000..a167ae5aca --- /dev/null +++ b/src/glu_api.cpp @@ -0,0 +1,50 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include +#include +#include +#include +#include + +extern "C" miopenStatus_t miopenGLUForward(miopenHandle_t handle, + const miopenTensorDescriptor_t inputDesc, + void* input, + const int32_t dim, + const miopenTensorDescriptor_t outputDesc, + void* output) +{ + MIOPEN_LOG_FUNCTION(handle, inputDesc, input, dim, outputDesc, output); + + return miopen::try_([&] { + miopen::GLUForward(miopen::deref(handle), + miopen::deref(inputDesc), + DataCast(input), + dim, + miopen::deref(outputDesc), + DataCast(output)); + }); +} diff --git a/src/include/miopen/glu.hpp b/src/include/miopen/glu.hpp new file mode 100644 index 0000000000..e52b028ca8 --- /dev/null +++ b/src/include/miopen/glu.hpp @@ -0,0 +1,44 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_GLU_HPP_ +#define MIOPEN_GLU_HPP_ + +#include + +namespace miopen { + +struct Handle; +struct TensorDescriptor; + +miopenStatus_t GLUForward(Handle& handle, + const TensorDescriptor& inputDesc, + Data_t input, + int32_t dim, + const TensorDescriptor& outputDesc, + Data_t output); + +} // namespace miopen +#endif // _MIOPEN_GLU_HPP_ diff --git a/src/include/miopen/glu/invoke_params.hpp b/src/include/miopen/glu/invoke_params.hpp new file mode 100644 index 0000000000..53e76bffe9 --- /dev/null +++ b/src/include/miopen/glu/invoke_params.hpp @@ -0,0 +1,53 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include + +namespace miopen { +namespace glu { + +struct InvokeParams : public miopen::InvokeParams +{ + InvokeParams() = default; + + const TensorDescriptor* inputDesc = nullptr; + const TensorDescriptor* outputDesc = nullptr; + + ConstData_t input = nullptr; + Data_t output = nullptr; + int32_t dim = 0; + + std::size_t GetWorkspaceSize() const { return 0; } + Data_t GetWorkspace() const { return nullptr; } +}; + +} // namespace glu + +} // namespace miopen diff --git a/src/include/miopen/glu/problem_description.hpp b/src/include/miopen/glu/problem_description.hpp new file mode 100644 index 0000000000..37e5b01442 --- /dev/null +++ b/src/include/miopen/glu/problem_description.hpp @@ -0,0 +1,162 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +#include + +namespace miopen { + +struct NetworkConfig; + +namespace glu { + +enum class Direction +{ + Forward, + Backward, +}; + +struct ProblemDescription : ProblemDescriptionBase +{ + // Forward constructor + ProblemDescription(const TensorDescriptor& inputDesc_, + const TensorDescriptor& outputDesc_, + int32_t dim_) + : direction(Direction::Forward), inputDesc(inputDesc_), outputDesc(outputDesc_), dim(dim_) + { + if(inputDesc.GetLengths().size() != outputDesc.GetLengths().size()) + { + MIOPEN_THROW(miopenStatusBadParm, + "GLU::ProblemDescription: Number of tensor dimension do not match."); + } + if(inputDesc.GetLengths()[dim] % 2 != 0) + { + MIOPEN_THROW(miopenStatusBadParm, + "GLU::ProblemDescription: The split dimension size of input tensor should " + "be divisible by 2."); + } + } + + Direction GetDirection() const { return direction; } + const TensorDescriptor& GetInputDesc() const { return inputDesc; } + const TensorDescriptor& GetOutputDesc() const { return outputDesc; } + int32_t GetDim() const { return dim; } + + bool IsSameType() const + { + if(inputDesc.GetType() != outputDesc.GetType()) + { +#if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG + MIOPEN_THROW(miopenStatusBadParm, "GLU: Tensor types do not match."); +#else + return false; +#endif + } + return true; + } + + bool IsRightLength() const + { + for(int32_t i = 0; i < inputDesc.GetLengths().size(); i++) + { + if(i == dim) + { + if(inputDesc.GetLengths()[i] / 2 != outputDesc.GetLengths()[i]) + { + return false; + } + } + else + { + if(inputDesc.GetLengths()[i] != outputDesc.GetLengths()[i]) + { + return false; + } + } + } + return true; + } + + bool IsRightDim() const + { + if((dim < 0) || (dim > outputDesc.GetLengths().size())) + { +#if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG + MIOPEN_THROW( + miopenStatusBadParm, + "GLU: Dimension is greater than 0 and less than or equal tensor dimension length."); +#else + return false; +#endif + } + return true; + } + + bool IsAllPacked() const + { + if(!(inputDesc.IsPacked() && outputDesc.IsPacked())) + { +#if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG + MIOPEN_THROW(miopenStatusBadParm, "GLU: Unpacked tensors not supported."); +#else + return false; +#endif + } + return true; + } + + bool IsFirstDim() const + { + if(dim != 0) + { +#if MIOPEN_BUILD_DEV || !MIOPEN_NDEBUG + MIOPEN_THROW(miopenStatusBadParm, "GLU: Dimension is not 0."); +#else + return false; +#endif + } + return true; + } + + NetworkConfig MakeNetworkConfig() const override; + +private: + Direction direction; + TensorDescriptor inputDesc; + TensorDescriptor outputDesc; + + int32_t dim; +}; + +} // namespace glu + +} // namespace miopen diff --git a/src/include/miopen/glu/solvers.hpp b/src/include/miopen/glu/solvers.hpp new file mode 100644 index 0000000000..346155b36c --- /dev/null +++ b/src/include/miopen/glu/solvers.hpp @@ -0,0 +1,54 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#pragma once + +#include +#include +#include + +namespace miopen { + +namespace solver { + +namespace glu { + +using GLUSolver = NonTunableSolverBase; + +struct GLUForward final : GLUSolver +{ + const std::string& SolverDbId() const override { return GetSolverDbId(); } + + bool IsApplicable(const ExecutionContext& context, + const miopen::glu::ProblemDescription& problem) const override; + ConvSolution GetSolution(const ExecutionContext& context, + const miopen::glu::ProblemDescription& problem) const override; +}; + +} // namespace glu + +} // namespace solver + +} // namespace miopen diff --git a/src/include/miopen/solver_id.hpp b/src/include/miopen/solver_id.hpp index c52dc020ac..ddf70f7193 100644 --- a/src/include/miopen/solver_id.hpp +++ b/src/include/miopen/solver_id.hpp @@ -56,7 +56,8 @@ enum class Primitive Reduce, Cat, Mha, - Softmax + Softmax, + glu }; struct MIOPEN_EXPORT Id diff --git a/src/kernels/MIOpenGLU.cpp b/src/kernels/MIOpenGLU.cpp new file mode 100644 index 0000000000..e52bdc3c97 --- /dev/null +++ b/src/kernels/MIOpenGLU.cpp @@ -0,0 +1,55 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef MIOPEN_DONT_USE_HIP_RUNTIME_HEADERS +#include +#include +#endif + +#include "float_types.h" + +__device__ FLOAT_ACCUM sigmoid(FLOAT_ACCUM x) { return 1.0f / (1.0f + exp(-x)); } + +template +__device__ void +GLUFwdContiguousKernel(const TI* __restrict__ input, TO* __restrict__ output, long N) +{ + const TI* inputFirstHalf = input; + const TI* inputSecondHalf = input + N; + const size_t gid = blockIdx.x * blockDim.x + threadIdx.x; + if(gid >= N) + return; + + FLOAT_ACCUM val1 = CVT_FLOAT2ACCUM(inputFirstHalf[gid]); + FLOAT_ACCUM val2 = sigmoid(CVT_FLOAT2ACCUM(inputSecondHalf[gid])); + FLOAT_ACCUM val = val1 * val2; + output[gid] = CVT_ACCUM2FLOAT(val); +} + +extern "C" __global__ void +GLUFwdContiguous(const INPUT_TYPE* __restrict__ input, OUTPUT_TYPE* __restrict__ output, long N) +{ + GLUFwdContiguousKernel(input, output, N); +} diff --git a/src/solver/glu/forward_glu.cpp b/src/solver/glu/forward_glu.cpp new file mode 100644 index 0000000000..f03faf9caa --- /dev/null +++ b/src/solver/glu/forward_glu.cpp @@ -0,0 +1,130 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "miopen/kernel_info.hpp" +#include "miopen/mlo_internal.hpp" +#include +#include +#include +#include +#include +#include +#include + +#define LOCAL_SIZE 256 + +namespace miopen { + +namespace solver { + +namespace glu { + +bool GLUForward::IsApplicable(const ExecutionContext& context, + const miopen::glu::ProblemDescription& problem) const +{ + std::ignore = context; + + if(!problem.IsSameType()) + return false; + if(!problem.IsRightDim()) + return false; + if(!problem.IsRightLength()) + return false; + if(!problem.IsAllPacked()) + return false; + if(!problem.IsFirstDim()) + return false; + return true; +} + +ConvSolution GLUForward::GetSolution(const ExecutionContext& context, + const miopen::glu::ProblemDescription& problem) const +{ + std::ignore = context; + + auto result = ConvSolution{miopenStatusSuccess}; + + { + auto dtype = problem.GetInputDesc().GetType(); + auto input_dtype = miopen::GetDataType(problem.GetInputDesc().GetType()); + auto output_dtype = miopen::GetDataType(problem.GetOutputDesc().GetType()); + auto outputDims = problem.GetOutputDesc().GetLengths(); + auto output_numel = + std::accumulate(outputDims.begin(), outputDims.end(), 1ULL, std::multiplies()); + + size_t xlocalsize = LOCAL_SIZE; + size_t xgridsize = AlignUp(output_numel, xlocalsize); + size_t ylocalsize = 1; + size_t ygridsize = 1; + size_t zlocalsize = 1; + size_t zgridsize = 1; + + auto kernel = KernelInfo{}; + + kernel.kernel_file = "MIOpenGLU.cpp"; + kernel.kernel_name = "GLUFwdContiguous"; + + const auto build_params = KernelBuildParameters{ + {"MIOPEN_USE_FP16", static_cast(dtype == miopenHalf)}, + {"MIOPEN_USE_FP32", static_cast(dtype == miopenFloat)}, + {"MIOPEN_USE_FP64", static_cast(dtype == miopenDouble)}, + {"MIOPEN_USE_BFP16", static_cast(dtype == miopenBFloat16)}, + {"INPUT_TYPE", input_dtype == "bfloat16" ? "ushort" : input_dtype}, + {"OUTPUT_TYPE", output_dtype == "bfloat16" ? "ushort" : output_dtype}}; + + kernel.comp_options = build_params.GenerateFor(kbp::HIP{}); + + kernel.l_wk.push_back(xlocalsize); + kernel.l_wk.push_back(ylocalsize); + kernel.l_wk.push_back(zlocalsize); + + kernel.g_wk.push_back(xgridsize); + kernel.g_wk.push_back(ygridsize); + kernel.g_wk.push_back(zgridsize); + + result.construction_params.push_back(kernel); + } + + result.invoker_factory = [](const std::vector& kernels) { + return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { + decltype(auto) kernel = handle_.Run(kernels.front()); + decltype(auto) params = raw_params.CastTo(); + auto outputDims = params.outputDesc->GetLengths(); + auto output_numel = std::accumulate( + outputDims.begin(), outputDims.end(), 1ULL, std::multiplies()); + + kernel(params.input, params.output, output_numel); + }; + }; + + return result; +} + +} // namespace glu + +} // namespace solver + +} // namespace miopen diff --git a/test/cpu_glu.hpp b/test/cpu_glu.hpp new file mode 100644 index 0000000000..52b2e9abc3 --- /dev/null +++ b/test/cpu_glu.hpp @@ -0,0 +1,52 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef GUARD_CPU_GLU_HPP +#define GUARD_CPU_GLU_HPP + +#include "tensor_holder.hpp" + +template +T sigmoid(T x) +{ + return static_cast(1.0f / (1.0f + exp(-x))); +} + +template +void cpu_glu_forward(tensor input, tensor& ref_output) +{ + auto output_dims = ref_output.desc.GetLengths(); + + auto output_numel = + std::accumulate(output_dims.begin(), output_dims.end(), 1L, std::multiplies()); + + par_ford(output_numel)([&](size_t o) { + T valA = input[o]; + T valB = input[o + output_numel]; + T val = valA * sigmoid(valB); + ref_output[o] = val; + }); +} +#endif diff --git a/test/gtest/glu.cpp b/test/gtest/glu.cpp new file mode 100644 index 0000000000..f489f18cee --- /dev/null +++ b/test/gtest/glu.cpp @@ -0,0 +1,102 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "glu.hpp" +#include +using float16 = half_float::half; + +MIOPEN_DECLARE_ENV_VAR_STR(MIOPEN_TEST_FLOAT_ARG) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_TEST_ALL) + +namespace glu { + +std::string GetFloatArg() +{ + const auto& tmp = miopen::GetStringEnv(ENV(MIOPEN_TEST_FLOAT_ARG)); + if(tmp.empty()) + { + return ""; + } + return tmp; +} + +struct GLUTestFloat : GLUTest +{ +}; + +struct GLUTestFP16 : GLUTest +{ +}; + +struct GLUTestBFP16 : GLUTest +{ +}; + +} // namespace glu +using namespace glu; + +TEST_P(GLUTestFloat, GLUTestFw) +{ + if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--float")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(GLUTestFP16, GLUTestFw) +{ + if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--fp16")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +}; + +TEST_P(GLUTestBFP16, GLUTestFw) +{ + if(miopen::IsEnabled(ENV(MIOPEN_TEST_ALL)) && (GetFloatArg() == "--bfp16")) + { + RunTest(); + Verify(); + } + else + { + GTEST_SKIP(); + } +} + +INSTANTIATE_TEST_SUITE_P(GLUTestSet, GLUTestFloat, testing::ValuesIn(GLUTestConfigs())); +INSTANTIATE_TEST_SUITE_P(GLUTestSet, GLUTestFP16, testing::ValuesIn(GLUTestConfigs())); +INSTANTIATE_TEST_SUITE_P(GLUTestSet, GLUTestBFP16, testing::ValuesIn(GLUTestConfigs())); diff --git a/test/gtest/glu.hpp b/test/gtest/glu.hpp new file mode 100644 index 0000000000..ded9ee6cc4 --- /dev/null +++ b/test/gtest/glu.hpp @@ -0,0 +1,183 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2024 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include "../driver/tensor_driver.hpp" +#include "cpu_glu.hpp" +#include "get_handle.hpp" +#include "random.hpp" +#include "tensor_holder.hpp" +#include "verify.hpp" +#include +#include +#include +#include + +struct GLUTestCase +{ + size_t N; + size_t C; + size_t D; + size_t H; + size_t W; + int32_t dim; + friend std::ostream& operator<<(std::ostream& os, const GLUTestCase& tc) + { + return os << " N:" << tc.N << " C:" << tc.C << " D:" << tc.D << " H:" << tc.H + << " W:" << tc.W << " dim:" << tc.dim; + } + + std::vector GetInput() + { + if((N != 0) && (C != 0) && (D != 0) && (H != 0) && (W != 0)) + { + return std::vector({N, C, D, H, W}); + } + else if((N != 0) && (C != 0) && (H != 0) && (W != 0)) + { + return std::vector({N, C, H, W}); + } + else if((N != 0) && (C != 0) && (W != 0)) + { + return std::vector({N, C, W}); + } + else if((N != 0) && (W != 0)) + { + return std::vector({N, W}); + } + else if((N != 0)) + { + return std::vector({N}); + } + else + { + std::cout << "Error Input Tensor Lengths\n" << std::endl; + return std::vector({0}); + } + } +}; + +std::vector GLUTestConfigs() +{ // n c d h w dim + // clang-format off + return { + { 8, 120, 0, 0, 1, 0}, //bart + { 8, 1023, 0, 0, 1, 0}, //gpt_neo + { 8, 1024, 0, 0, 768, 0}, + { 16, 1024, 0, 0, 768, 0}, //gpt2 + { 48, 8, 0, 512, 512, 0}, //t5 + }; + // clang-format on +} + +template +struct GLUTest : public ::testing::TestWithParam +{ +protected: + void SetUp() override + { + + auto&& handle = get_handle(); + glu_config = GetParam(); + auto gen_value = [](auto...) { return prng::gen_descreet_uniform_sign(1e-2, 100); }; + + dim = glu_config.dim; + + auto in_dims = glu_config.GetInput(); + + input = tensor{in_dims}.generate(gen_value); + + std::vector out_dims; + + for(int i = 0; i < in_dims.size(); i++) + { + if(i != dim) + { + out_dims.push_back(in_dims[i]); + } + else + { + out_dims.push_back(in_dims[i] / 2); + } + } + + output = tensor{out_dims}; + std::fill(output.begin(), output.end(), std::numeric_limits::quiet_NaN()); + + ref_output = tensor{out_dims}; + std::fill(ref_output.begin(), ref_output.end(), std::numeric_limits::quiet_NaN()); + + input_dev = handle.Write(input.data); + output_dev = handle.Write(output.data); + } + + void RunTest() + { + auto&& handle = get_handle(); + + cpu_glu_forward(input, ref_output); + miopenStatus_t status; + + status = miopen::GLUForward( + handle, input.desc, input_dev.get(), dim, output.desc, output_dev.get()); + + EXPECT_EQ(status, miopenStatusSuccess); + + output.data = handle.Read(output_dev, output.data.size()); + } + + double GetTolerance() + { + // Computation error of fp16 is ~2^13 (=8192) bigger than + // the one of fp32 because mantissa is shorter by 13 bits. + double tolerance = std::is_same::value ? 1.5e-6 : 8.2e-3; + + // bf16 mantissa has 7 bits, by 3 bits shorter than fp16. + if(std::is_same::value) + tolerance *= 8.0; + return tolerance; + } + + void Verify() + { + double threshold = GetTolerance(); + auto error = miopen::rms_range(ref_output, output); + + EXPECT_TRUE(miopen::range_distance(ref_output) == miopen::range_distance(output)); + EXPECT_TRUE(error < threshold * 10) << "Error output beyond tolerance Error:" << error + << ", Thresholdx10: " << threshold * 10; + } + GLUTestCase glu_config; + + tensor input; + tensor output; + + tensor ref_output; + + miopen::Allocator::ManageDataPtr input_dev; + miopen::Allocator::ManageDataPtr output_dev; + + int32_t dim; +};