Skip to content

Commit

Permalink
fix: condense input and output types.
Browse files Browse the repository at this point in the history
  • Loading branch information
o2buzzle committed Apr 23, 2024
1 parent 7e69239 commit 98bba90
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 17 deletions.
17 changes: 8 additions & 9 deletions src/kernels/MIOpenPadConstantFwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,16 @@ __device__ T inline get5DValueAt(
w * x_strides[4]];
}

template <typename TI, typename TO>
__device__ void padconstantfwdcontiguous(const TI* __restrict__ x,
TO* __restrict__ y,
template <typename T>
__device__ void padconstantfwdcontiguous(const T* __restrict__ x,
T* __restrict__ y,
const tensor_view_5d_t x_tv,
const tensor_view_5d_t y_tv,
const padding_5d_t padding,
const size_t output_size,
TO value)
T value)
{
TO padding_value = CVT_ACCUM2FLOAT(value);
T padding_value = CVT_ACCUM2FLOAT(value);

const uint64_t gid = threadIdx.x + blockIdx.x * blockDim.x;
if(gid >= output_size)
Expand All @@ -69,14 +69,13 @@ __device__ void padconstantfwdcontiguous(const TI* __restrict__ x,
y[gid] = flag ? get5DValueAt(x, x_tv.stride, o[0], o[1], o[2], o[3], o[4]) : padding_value;
}

extern "C" __global__ void PadConstantFwdContiguous(const INPUT_TYPE* __restrict__ x,
OUTPUT_TYPE* __restrict__ y,
extern "C" __global__ void PadConstantFwdContiguous(const DTYPE* __restrict__ x,
DTYPE* __restrict__ y,
const tensor_view_5d_t x_tv,
const tensor_view_5d_t y_tv,
const padding_5d_t padding,
const size_t output_size,
FLOAT_ACCUM value)
{
padconstantfwdcontiguous<INPUT_TYPE, OUTPUT_TYPE>(
x, y, x_tv, y_tv, padding, output_size, value);
padconstantfwdcontiguous<DTYPE>(x, y, x_tv, y_tv, padding, output_size, value);
}
14 changes: 6 additions & 8 deletions src/solver/pad_constant/pad_constant_fwd_contiguous.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ ConvSolution PadConstantFwdContiguous::GetSolution(
auto result = ConvSolution{miopenStatusSuccess};
auto ydims = problem.GetYDesc().GetLengths();

auto input_dtype = miopen::GetDataType(problem.GetXDesc().GetType());
auto output_dtype = miopen::GetDataType(problem.GetYDesc().GetType());
auto dtype = miopen::GetDataType(problem.GetXDesc().GetType());

size_t output_size = problem.GetYDesc().GetElementSize();

Expand All @@ -73,12 +72,11 @@ ConvSolution PadConstantFwdContiguous::GetSolution(

// TODO: Actually understand how to use this properly
const auto build_params = KernelBuildParameters{
{"INPUT_TYPE", input_dtype == "bfloat16" ? "ushort" : input_dtype},
{"OUTPUT_TYPE", output_dtype == "bfloat16" ? "ushort" : output_dtype},
{"MIOPEN_USE_FP64", static_cast<int>(output_dtype == "double")},
{"MIOPEN_USE_FP32", static_cast<int>(output_dtype == "float")},
{"MIOPEN_USE_FP16", static_cast<int>(output_dtype == "half")},
{"MIOPEN_USE_BFP16", static_cast<int>(output_dtype == "bfloat16")},
{"DTYPE", dtype == "bfloat16" ? "ushort" : dtype},
{"MIOPEN_USE_FP64", static_cast<int>(dtype == "double")},
{"MIOPEN_USE_FP32", static_cast<int>(dtype == "float")},
{"MIOPEN_USE_FP16", static_cast<int>(dtype == "half")},
{"MIOPEN_USE_BFP16", static_cast<int>(dtype == "bfloat16")},
};

kernel.comp_options = build_params.GenerateFor(kbp::HIP{});
Expand Down

0 comments on commit 98bba90

Please sign in to comment.