Skip to content

Commit

Permalink
Merge all bn ck nchw branches (#3332)
Browse files Browse the repository at this point in the history
  • Loading branch information
bghimireamd authored Oct 27, 2024
1 parent 15bb826 commit 7bef289
Show file tree
Hide file tree
Showing 12 changed files with 359 additions and 301 deletions.
5 changes: 3 additions & 2 deletions driver/bn_driver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
#define ERRTOL_FP32 1e-4
#define ERRTOL_FP16 0.5e-3
#define RMSTOL_FP32 1e-4
#define RMSTOL_FP16 0.5e-3
#define RMSTOL_FP16 2e-3

#define MIO_DRIVER_BN_REFERENCE_COMPUTE_3D_AS_2D 1 // Resolves issue #1974

Expand Down Expand Up @@ -1298,7 +1298,8 @@ int BatchNormDriver<Tgpu, Tref, Tmix>::VerifyForward()

out.CopyFromDeviceToHost(GetStream());

maxval = static_cast<Tref>(0.0);
maxval = static_cast<Tref>(0.0);

auto errorOut = miopen::rms_range(out_ref.data, out.GetVector());
if(!std::isfinite(errorOut) || errorOut > maxrms)
{
Expand Down
20 changes: 15 additions & 5 deletions src/include/miopen/batchnorm/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,11 +179,7 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase,
return scaleDesc;
}

const TensorDescriptor& GetScaleBiasDiffDesc() const
{
assert(direction == Direction::Backward);
return scaleDesc;
}
const TensorDescriptor& GetScaleBiasDiffDesc() const { return scaleDesc; }

bool GetResultSave() const
{
Expand Down Expand Up @@ -217,6 +213,20 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase,
: ((in_layout == "NDHWC") && (out_layout == "NDHWC"));
}

bool IsLayoutNCHW() const
{
if(direction == Direction::Backward)
{
return xDesc.GetLengths().size() == 4
? ((in_layout == "NCHW") && (out_layout == "NCHW") && (din_layout == "NCHW"))
: ((in_layout == "NCDHW") && (out_layout == "NCDHW") &&
(din_layout == "NCDHW"));
}

return xDesc.GetLengths().size() == 4 ? ((in_layout == "NCHW") && (out_layout == "NCHW"))
: ((in_layout == "NCDHW") && (out_layout == "NCDHW"));
}

bool Is2D() const { return xDesc.GetLengths().size() == 4; }
bool Is3D() const { return xDesc.GetLengths().size() == 5; }

Expand Down
4 changes: 2 additions & 2 deletions src/ocl/batchnormocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ void BatchNormForwardInference(Handle& handle,
}();

const auto algo = AlgorithmName{"miopenBatchNormalizationForwardInference"};
const auto solvers = solver::SolverContainer<solver::batchnorm::BnFwdInference,
solver::batchnorm::BnCKFwdInference>{};
const auto solvers = solver::SolverContainer<solver::batchnorm::BnCKFwdInference,
solver::batchnorm::BnFwdInference>{};

solvers.ExecutePrimitive(handle, problem, algo, invoke_params);
}
Expand Down
49 changes: 30 additions & 19 deletions src/solver/batchnorm/backward_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,21 @@ struct CKArgsBNormBwd

// prep for CK
std::sort(in_strides.begin(), in_strides.end(), std::greater<>());
std::rotate(lens.begin() + 1, lens.begin() + 2, lens.end());

if(problem.IsLayoutNHWC())
{
std::rotate(lens.begin() + 1, lens.begin() + 2, lens.end());
reduceDims = {0, 1, 2};
}
else if(problem.IsLayoutNCHW())
{
reduceDims = {0, 2, 3};
}
else
{
MIOPEN_THROW(miopenStatusInternalError,
"BnCKBwd operation does not support this data layout");
}
}

CKArgsBNormBwd(const CKArgsBNormBwd&) = default;
Expand Down Expand Up @@ -133,7 +147,7 @@ struct CKArgsBNormBwd
std::array<index_t, Rank - NumBatchNormReduceDim> arrScaleBiasMeanVarStrides;

double epsilon = 1e-5;
std::array<int, NumBatchNormReduceDim> reduceDims{0, 1, 2};
std::array<int, NumBatchNormReduceDim> reduceDims;
};

template <typename XDataType,
Expand Down Expand Up @@ -345,14 +359,20 @@ bool BnCKBwdBackward::IsApplicable(
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
if(env::disabled(MIOPEN_DEBUG_CK_BN_BACK))
return false;
if(!bn_problem.IsLayoutNHWC())
if(!bn_problem.IsLayoutNHWC() && !bn_problem.IsLayoutNCHW())
return false;
if(!ck_utility::is_ck_supported_hardware(context.GetStream()))
return false;
if(!bn_problem.Is2D())
return false;
if(bn_problem.GetDirection() != miopen::batchnorm::Direction::Backward)
return false;
if(bn_problem.GetXDesc().GetType() != bn_problem.GetScaleBiasDiffDesc().GetType())
return false;
if(bn_problem.GetMode() != miopenBNSpatial)
return false;
if(!bn_problem.Is2D())
return false;

switch(bn_problem.GetXDesc().GetType())
{
Expand All @@ -376,24 +396,15 @@ ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription&
InvokerFactoryMakerNHWC&& invoker_factory_maker_nhwc)
{
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
if(problem.IsLayoutNHWC())
{
switch(problem.GetXDesc().GetType())
{
case miopenFloat: return invoker_factory_maker_nhwc(F32{});
case miopenDouble: return invoker_factory_maker_nhwc(F64{});
case miopenHalf: return invoker_factory_maker_nhwc(F16{});
case miopenBFloat16: return invoker_factory_maker_nhwc(BF16{});
default:
MIOPEN_THROW(miopenStatusInternalError,
"BnCKBwdBackward operation does not support this data type");
}
}
// Todo: problem.IsLayoutDefault()
else
switch(problem.GetXDesc().GetType())
{
case miopenFloat: return invoker_factory_maker_nhwc(F32{});
case miopenDouble: return invoker_factory_maker_nhwc(F64{});
case miopenHalf: return invoker_factory_maker_nhwc(F16{});
case miopenBFloat16: return invoker_factory_maker_nhwc(BF16{});
default:
MIOPEN_THROW(miopenStatusInternalError,
"BnCKBwdBackward operation does not support this data layout");
"BnCKBwdBackward operation does not support this data type");
}
#else
return {};
Expand Down
64 changes: 39 additions & 25 deletions src/solver/batchnorm/forward_inference_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ struct CKArgsBNormFwd
{
CKArgsBNormFwd(const miopen::batchnorm::ProblemDescription& problem)
{

std::copy(problem.GetXDesc().GetLengths().begin(),
problem.GetXDesc().GetLengths().end(),
xyLengths.begin());
Expand All @@ -78,21 +79,39 @@ struct CKArgsBNormFwd
xyStrides.begin());
// prep for CK
std::sort(xyStrides.begin(), xyStrides.end(), std::greater<>());
std::rotate(xyLengths.begin() + 1, xyLengths.begin() + 2, xyLengths.end());

aligned_scaleBiasMeanVarStrides[0] = 0;
aligned_scaleBiasMeanVarStrides[1] = 0;
aligned_scaleBiasMeanVarStrides[2] = 0;
aligned_scaleBiasMeanVarStrides[3] = 1;
if(problem.IsLayoutNHWC())
{
std::rotate(xyLengths.begin() + 1, xyLengths.begin() + 2, xyLengths.end());
reduceDims = {0, 1, 2};
aligned_scaleBiasMeanVarStrides[0] = 0;
aligned_scaleBiasMeanVarStrides[1] = 0;
aligned_scaleBiasMeanVarStrides[2] = 0;
aligned_scaleBiasMeanVarStrides[3] = 1;
}
else if(problem.IsLayoutNCHW())
{
reduceDims = {0, 2, 3};
aligned_scaleBiasMeanVarStrides[0] = 0;
aligned_scaleBiasMeanVarStrides[1] = 1;
aligned_scaleBiasMeanVarStrides[2] = 0;
aligned_scaleBiasMeanVarStrides[3] = 0;
}
else
{
MIOPEN_THROW(miopenStatusInternalError,
"BnCKFwdInference operation does not support this data layout");
}
}

std::array<ck::index_t, Rank> xyLengths;
std::array<ck::index_t, Rank> xyStrides;
std::vector<int> invariantDims;

std::array<index_t, Rank> aligned_scaleBiasMeanVarStrides{3};
std::array<index_t, Rank - NumBatchNormReduceDim> arrScaleBiasMeanVarStrides;

std::array<int, NumBatchNormReduceDim> reduceDims{0, 1, 2};
std::array<int, NumBatchNormReduceDim> reduceDims;

template <typename InvokerPtr, typename InvokerParams>
auto MakeArgPtr(const InvokerPtr& invoker_ptr, const InvokerParams& data_ctx) const
Expand Down Expand Up @@ -305,14 +324,18 @@ bool BnCKFwdInference::IsApplicable(
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
if(env::disabled(MIOPEN_DEBUG_CK_BN_INFER))
return false;
if(!bn_problem.IsLayoutNHWC())
if(!bn_problem.IsLayoutNHWC() && !bn_problem.IsLayoutNCHW())
return false;
if(!ck_utility::is_ck_supported_hardware(context.GetStream()))
return false;
if(!bn_problem.Is2D())
return false;
if(bn_problem.GetDirection() != miopen::batchnorm::Direction::ForwardInference)
return false;
if(bn_problem.GetMode() != miopenBNSpatial)
return false;
if(bn_problem.GetXDesc().GetType() != bn_problem.GetScaleBiasDiffDesc().GetType())
return false;

switch(bn_problem.GetXDesc().GetType())
{
Expand All @@ -330,29 +353,20 @@ bool BnCKFwdInference::IsApplicable(
return false;
}

template <typename InvokerFactoryMakerNHWC>
template <typename InvokerFactoryMaker>
ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& problem,
InvokerFactoryMakerNHWC&& invoker_factory_maker_nhwc)
InvokerFactoryMaker&& invoker_factory_maker)
{
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
if(problem.IsLayoutNHWC())
{
switch(problem.GetXDesc().GetType())
{
case miopenFloat: return invoker_factory_maker_nhwc(F32{});
case miopenDouble: return invoker_factory_maker_nhwc(F64{});
case miopenHalf: return invoker_factory_maker_nhwc(F16{});
case miopenBFloat16: return invoker_factory_maker_nhwc(BF16{});
default:
MIOPEN_THROW(miopenStatusInternalError,
"BnCKFwdInference operation does not support this data type");
}
}
// Todo: problem.IsLayoutDefault()
else
switch(problem.GetXDesc().GetType())
{
case miopenFloat: return invoker_factory_maker(F32{});
case miopenDouble: return invoker_factory_maker(F64{});
case miopenHalf: return invoker_factory_maker(F16{});
case miopenBFloat16: return invoker_factory_maker(BF16{});
default:
MIOPEN_THROW(miopenStatusInternalError,
"BnCKFwdInference operation does not support this data layout");
"BnCKFwdInference operation does not support this data type");
}
#else
return {};
Expand Down
47 changes: 28 additions & 19 deletions src/solver/batchnorm/forward_training_ck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,21 @@ struct CKArgsBNormFwdTraining

// prep for CK
std::sort(xyStrides.begin(), xyStrides.end(), std::greater<>());
std::rotate(xyLengths.begin() + 1, xyLengths.begin() + 2, xyLengths.end());

if(problem.IsLayoutNHWC())
{
std::rotate(xyLengths.begin() + 1, xyLengths.begin() + 2, xyLengths.end());
reduceDims = {0, 1, 2};
}
else if(problem.IsLayoutNCHW())
{
reduceDims = {0, 2, 3};
}
else
{
MIOPEN_THROW(miopenStatusInternalError,
"BnCKFwdTraining operation does not support this data layout");
}
}

CKArgsBNormFwdTraining(const CKArgsBNormFwdTraining&) = default;
Expand Down Expand Up @@ -131,7 +145,7 @@ struct CKArgsBNormFwdTraining
std::array<index_t, Rank - NumBatchNormReduceDim> arrScaleBiasMeanVarLengths;
std::array<index_t, Rank - NumBatchNormReduceDim> arrScaleBiasMeanVarStrides;

std::array<int, NumBatchNormReduceDim> reduceDims{0, 1, 2};
std::array<int, NumBatchNormReduceDim> reduceDims;
};

template <typename XDataType,
Expand Down Expand Up @@ -337,14 +351,18 @@ bool BnCKFwdTraining::IsApplicable(
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
if(env::disabled(MIOPEN_DEBUG_CK_BN_FWD_TRAINING))
return false;
if(!bn_problem.IsLayoutNHWC())
if(!bn_problem.IsLayoutNHWC() && !bn_problem.IsLayoutNCHW())
return false;
if(!ck_utility::is_ck_supported_hardware(context.GetStream()))
return false;
if(!bn_problem.Is2D())
return false;
if(bn_problem.GetDirection() != miopen::batchnorm::Direction::ForwardTraining)
return false;
if(bn_problem.GetMode() != miopenBNSpatial)
return false;
if(bn_problem.GetXDesc().GetType() != bn_problem.GetScaleBiasDiffDesc().GetType())
return false;

switch(bn_problem.GetXDesc().GetType())
{
Expand All @@ -367,24 +385,15 @@ ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription&
InvokerFactoryMakerNHWC&& invoker_factory_maker_nhwc)
{
#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL
if(problem.IsLayoutNHWC())
{
switch(problem.GetXDesc().GetType())
{
case miopenFloat: return invoker_factory_maker_nhwc(F32{});
case miopenDouble: return invoker_factory_maker_nhwc(F64{});
case miopenHalf: return invoker_factory_maker_nhwc(F16{});
case miopenBFloat16: return invoker_factory_maker_nhwc(BF16{});
default:
MIOPEN_THROW(miopenStatusInternalError,
"BnCKFwdTraining operation does not support this data type");
}
}
// Todo: problem.IsLayoutDefault()
else
switch(problem.GetXDesc().GetType())
{
case miopenFloat: return invoker_factory_maker_nhwc(F32{});
case miopenDouble: return invoker_factory_maker_nhwc(F64{});
case miopenHalf: return invoker_factory_maker_nhwc(F16{});
case miopenBFloat16: return invoker_factory_maker_nhwc(BF16{});
default:
MIOPEN_THROW(miopenStatusInternalError,
"BnCKFwdTraining operation does not support this data layout");
"BnCKFwdTraining operation does not support this data type");
}
#else
return {};
Expand Down
36 changes: 11 additions & 25 deletions test/bn_3d_spatial_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,8 @@ struct batch_norm_3d_spatial_driver : test_driver
batch_norm_3d_spatial_driver()
{
this->batch_factor = 4;
this->tolerance =
4e-3 / std::numeric_limits<T>::epsilon(); // ck solver has tolerance of 4e-3
add(input,
"input",
get_3d_bn_spatial_input_tensor(
Expand Down Expand Up @@ -1233,34 +1235,18 @@ struct batch_norm_3d_spatial_driver : test_driver
miopen::DeriveBNTensorDescriptor(derivedBnDesc, input.desc, miopenBNSpatial);
std::tie(ssn, ssc, ssd, ssh, ssw) = miopen::tien<5>(derivedBnDesc.GetLengths());

if(input.desc.GetType() == miopenFloat)
{
scale =
tensor<PREC_TYPE>{ssn, ssc, ssd, ssh, ssw}.generate(tensor_elem_gen_integer{17});
shift =
tensor<PREC_TYPE>{ssn, ssc, ssd, ssh, ssw}.generate(tensor_elem_gen_integer{17});
scale = tensor<PREC_TYPE>{ssn, ssc, ssd, ssh, ssw};
shift = tensor<PREC_TYPE>{ssn, ssc, ssd, ssh, ssw};
const double Data_scale = 1e-4;

if(d * h * w < 3072)
{
std::cout << "Choosing smaller input values for low dims" << std::endl;
input = tensor<T>{n, c, d, h, w}.generate(tensor_elem_gen_integer{7});
}
for(std::size_t i = 0; i < scale.desc.GetElementSize(); i++)
{
scale[i] = prng::gen_descreet_uniform_sign<PREC_TYPE>(Data_scale, 100);
shift[i] = prng::gen_descreet_uniform_sign<PREC_TYPE>(Data_scale, 100);
}
else
for(std::size_t i = 0; i < input.desc.GetElementSize(); i++)
{
scale = tensor<PREC_TYPE>{ssn, ssc, ssd, ssh, ssw};
shift = tensor<PREC_TYPE>{ssn, ssc, ssd, ssh, ssw};

const double Data_scale = 1e-4;
for(std::size_t i = 0; i < scale.desc.GetElementSize(); i++)
{
scale[i] = prng::gen_descreet_uniform_sign<PREC_TYPE>(Data_scale, 100);
shift[i] = prng::gen_descreet_uniform_sign<PREC_TYPE>(Data_scale, 100);
}
for(std::size_t i = 0; i < input.desc.GetElementSize(); i++)
{
input[i] = prng::gen_descreet_uniform_sign<T>(1e-5, 100);
}
input[i] = prng::gen_descreet_uniform_sign<T>(1e-5, 100);
}

// train
Expand Down
Loading

0 comments on commit 7bef289

Please sign in to comment.