diff --git a/src/kernels/MIOpenPadConstantFwd.cpp b/src/kernels/MIOpenPadConstantFwd.cpp index 7cc7c8e753..c7b964950b 100644 --- a/src/kernels/MIOpenPadConstantFwd.cpp +++ b/src/kernels/MIOpenPadConstantFwd.cpp @@ -40,16 +40,16 @@ __device__ T inline get5DValueAt( w * x_strides[4]]; } -template -__device__ void padconstantfwdcontiguous(const TI* __restrict__ x, - TO* __restrict__ y, +template +__device__ void padconstantfwdcontiguous(const T* __restrict__ x, + T* __restrict__ y, const tensor_view_5d_t x_tv, const tensor_view_5d_t y_tv, const padding_5d_t padding, const size_t output_size, - TO value) + T value) { - TO padding_value = CVT_ACCUM2FLOAT(value); + T padding_value = CVT_ACCUM2FLOAT(value); const uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x; if(gid >= output_size) @@ -69,14 +69,13 @@ __device__ void padconstantfwdcontiguous(const TI* __restrict__ x, y[gid] = flag ? get5DValueAt(x, x_tv.stride, o[0], o[1], o[2], o[3], o[4]) : padding_value; } -extern "C" __global__ void PadConstantFwdContiguous(const INPUT_TYPE* __restrict__ x, - OUTPUT_TYPE* __restrict__ y, +extern "C" __global__ void PadConstantFwdContiguous(const DTYPE* __restrict__ x, + DTYPE* __restrict__ y, const tensor_view_5d_t x_tv, const tensor_view_5d_t y_tv, const padding_5d_t padding, const size_t output_size, FLOAT_ACCUM value) { - padconstantfwdcontiguous( - x, y, x_tv, y_tv, padding, output_size, value); + padconstantfwdcontiguous(x, y, x_tv, y_tv, padding, output_size, value); } diff --git a/src/solver/pad_constant/pad_constant_fwd_contiguous.cpp b/src/solver/pad_constant/pad_constant_fwd_contiguous.cpp index d05b55b75f..d133be369f 100644 --- a/src/solver/pad_constant/pad_constant_fwd_contiguous.cpp +++ b/src/solver/pad_constant/pad_constant_fwd_contiguous.cpp @@ -56,8 +56,7 @@ ConvSolution PadConstantFwdContiguous::GetSolution( auto result = ConvSolution{miopenStatusSuccess}; auto ydims = problem.GetYDesc().GetLengths(); - auto input_dtype = miopen::GetDataType(problem.GetXDesc().GetType()); - auto output_dtype = miopen::GetDataType(problem.GetYDesc().GetType()); + auto dtype = miopen::GetDataType(problem.GetXDesc().GetType()); size_t output_size = problem.GetYDesc().GetElementSize(); @@ -73,12 +72,11 @@ ConvSolution PadConstantFwdContiguous::GetSolution( // TODO: Actually understand how to use this properly const auto build_params = KernelBuildParameters{ - {"INPUT_TYPE", input_dtype == "bfloat16" ? "ushort" : input_dtype}, - {"OUTPUT_TYPE", output_dtype == "bfloat16" ? "ushort" : output_dtype}, - {"MIOPEN_USE_FP64", static_cast(output_dtype == "double")}, - {"MIOPEN_USE_FP32", static_cast(output_dtype == "float")}, - {"MIOPEN_USE_FP16", static_cast(output_dtype == "half")}, - {"MIOPEN_USE_BFP16", static_cast(output_dtype == "bfloat16")}, + {"DTYPE", dtype == "bfloat16" ? "ushort" : dtype}, + {"MIOPEN_USE_FP64", static_cast(dtype == "double")}, + {"MIOPEN_USE_FP32", static_cast(dtype == "float")}, + {"MIOPEN_USE_FP16", static_cast(dtype == "half")}, + {"MIOPEN_USE_BFP16", static_cast(dtype == "bfloat16")}, }; kernel.comp_options = build_params.GenerateFor(kbp::HIP{});