Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NHWC] Enable batch norm by refactoring OCL kernel #1244

Merged
merged 13 commits into from
Dec 23, 2021
28 changes: 25 additions & 3 deletions src/include/miopen/batchnorm/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ struct ProblemDescription
resultsave(resultsave_),
resultrunning(resultrunning_)
{
in_layout = xDesc.GetLayout(xDesc.GetLengths().size() == 4 ? "NCHW" : "NCDHW");
out_layout = yOrDyDesc.GetLayout(yOrDyDesc.GetLengths().size() == 4 ? "NCHW" : "NCDHW");
}

// Forward
Expand Down Expand Up @@ -100,6 +102,9 @@ struct ProblemDescription
epsilon(epsilon_),
useSaved(useSaved_)
{
in_layout = xDesc.GetLayout(xDesc.GetLengths().size() == 4 ? "NCHW" : "NCDHW");
out_layout = yOrDyDesc.GetLayout(yOrDyDesc.GetLengths().size() == 4 ? "NCHW" : "NCDHW");
din_layout = dxDesc.GetLayout(dxDesc.GetLengths().size() == 4 ? "NCHW" : "NCDHW");
}

Direction GetDirection() const { return direction; }
Expand Down Expand Up @@ -154,6 +159,20 @@ struct ProblemDescription
return useSaved;
}

bool IsLayoutNHWC() const
{
if(direction == Direction::Backward)
{
return xDesc.GetLengths().size() == 4
? ((in_layout == "NHWC") && (out_layout == "NHWC") && (din_layout == "NHWC"))
: ((in_layout == "NDHWC") && (out_layout == "NDHWC") &&
(din_layout == "NDHWC"));
}

return xDesc.GetLengths().size() == 4 ? ((in_layout == "NHWC") && (out_layout == "NHWC"))
: ((in_layout == "NDHWC") && (out_layout == "NDHWC"));
}

NetworkConfig MakeNetworkConfig() const;

void Serialize(std::ostream& stream) const;
Expand All @@ -173,9 +192,12 @@ struct ProblemDescription
TensorDescriptor scaleBiasDesc;
double expAvgFactor = 0;
double epsilon;
bool resultsave = false;
bool resultrunning = false;
bool useSaved = false;
bool resultsave = false;
bool resultrunning = false;
bool useSaved = false;
std::string in_layout = "NCHW";
std::string out_layout = "NCHW";
std::string din_layout = "NCHW";

NetworkConfig MakeForwardTrainingNetworkConfig() const;
NetworkConfig MakeForwardInferenceNetworkConfig() const;
Expand Down
105 changes: 87 additions & 18 deletions src/kernels/MIOpenBatchNormBwdSpatial.cl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@
#include "batchnorm_functions.h"
#include "reduction_functions.h"

#ifndef MIO_LAYOUT_NHWC
#define MIO_LAYOUT_NHWC 0
#endif

#if (MIO_LAYOUT_NHWC != 0) && (MIO_LAYOUT_NHWC != 1)
#error "MIO_LAYOUT_NHWC must be 0 or 1"
#endif

#if(MIO_BN_VARIANT == 0)

#define MIO_BN_SEGTMP_1 (MIO_BN_GRP0 / MIO_BN_HW)
Expand Down Expand Up @@ -144,7 +152,7 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
{
variance = 0;
}
invVariance = rsqrt(variance + epsilon);
invVariance = rsqrt(variance + epsilon);
#endif // end -- Recalc mean and variance
//-------------------------------------------

Expand Down Expand Up @@ -229,9 +237,15 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,

#elif(MIO_BN_VARIANT == 1)

#if MIO_LAYOUT_NHWC
#define MIO_MAX_READ 1
#define RD_BLK 1
#define GRPRD (MIO_BN_GRP0 * RD_BLK)
#else
#define MIO_MAX_READ 2
#define RD_BLK 1
#define GRPRD (MIO_BN_GRP0 * RD_BLK * 4)
#endif
#define MIO_BN_REM4 (MIO_BN_NHW - ((MIO_BN_NHW / GRPRD) * GRPRD))
#define MIO_BN_LESS4 (MIO_BN_NHW - MIO_BN_REM4)
#define MIO_BN_CHUNK4 (MIO_MAX_READ * GRPRD)
Expand Down Expand Up @@ -278,7 +292,9 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
unsigned int index = 0;
unsigned int lid = get_local_id(0);
unsigned int grpid = get_group_id(0);
#if !MIO_LAYOUT_NHWC
unsigned int chwid = grpid * MIO_BN_HW;
#endif
unsigned int nidx = 0;
unsigned int hwidx = 0;

Expand All @@ -295,7 +311,7 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
#if(MIO_BN_USESAVED == 0)
//==== CALC MEAN and VARIANCE ONCE AGAIN =======================
_FLOAT_PREC variance = (_FLOAT_PREC)0.;
#if(MIO_BN_HW >= 4096)
#if !MIO_LAYOUT_NHWC && MIO_BN_HW >= 4096
_FLOAT4 read4;
#if(MIO_BN_N > MIO_BN_LOOP_UNROLL_MAXN)
__attribute__((opencl_unroll_hint(4))) for(unsigned int k = lid << 2; k < MIO_BN_LESS4;
Expand Down Expand Up @@ -350,7 +366,11 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
{
nidx = k / MIO_BN_HW;
hwidx = k - (nidx * MIO_BN_HW);
#if MIO_LAYOUT_NHWC
index = nidx * MIO_BN_CHW + hwidx * MIO_BN_C + grpid;
#else
index = nidx * MIO_BN_CHW + chwid + hwidx;
#endif
_FLOAT_PREC in = (_FLOAT_PREC)(*(x_in + index));
mean += in;
variance = mad(in, in, variance);
Expand All @@ -361,7 +381,11 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
unsigned int remkey = lid + MIO_BN_LESS;
nidx = remkey / MIO_BN_HW;
hwidx = remkey - (nidx * MIO_BN_HW);
#if MIO_LAYOUT_NHWC
index = nidx * MIO_BN_CHW + hwidx * MIO_BN_C + grpid;
#else
index = nidx * MIO_BN_CHW + chwid + hwidx;
#endif
_FLOAT_PREC in = (index < MIO_BN_NCHW) ? (_FLOAT_PREC)(*(x_in + index)) : (_FLOAT_PREC)0.;
mean += in;
variance = mad(in, in, variance);
Expand Down Expand Up @@ -396,19 +420,35 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,

#endif

#if MIO_LAYOUT_NHWC
_FLOAT dyRead;
_FLOAT xread;
_FLOAT_PREC xhat_tmp;
#else
_FLOAT4 dyRead4;
_FLOAT4 xread4;
_FLOAT_PREC4 xhat4;
#endif
#if(MIO_BN_N > MIO_BN_LOOP_UNROLL_MAXN)
__attribute__((opencl_unroll_hint(4))) for(unsigned int k = lid << 2; k < MIO_BN_LESS4;
__attribute__((opencl_unroll_hint(4))) for(unsigned int k = lid << 2*(1 - MIO_LAYOUT_NHWC);
k < MIO_BN_LESS4;
k += GRPRD)
#else
__attribute__((opencl_unroll_hint(2))) for(unsigned int k = lid << 2; k < MIO_BN_LESS4;
__attribute__((opencl_unroll_hint(2))) for(unsigned int k = lid << 2*(1 - MIO_LAYOUT_NHWC);
k < MIO_BN_LESS4;
k += GRPRD)
#endif
{
nidx = k / MIO_BN_HW;
hwidx = k - (nidx * MIO_BN_HW);
nidx = k / MIO_BN_HW;
hwidx = k - (nidx * MIO_BN_HW);
#if MIO_LAYOUT_NHWC
index = nidx * MIO_BN_CHW + hwidx * MIO_BN_C + grpid;
xread = *((const global _FLOAT*)(x_in + index));
dyRead = *((const global _FLOAT*)(dy_in + index));
xhat_tmp = ((_FLOAT_PREC)xread - mean) * invVariance;
db += (_FLOAT_PREC)dyRead;
ds = mad(xhat_tmp, (_FLOAT_PREC)dyRead, ds);
#else
index = nidx * MIO_BN_CHW + chwid + hwidx;
xread4 = *((const global _FLOAT4*)(x_in + index));
dyRead4 = *((const global _FLOAT4*)(dy_in + index));
Expand All @@ -424,13 +464,25 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
ds = mad(xhat4.y, (_FLOAT_PREC)dyRead4.y, ds);
ds = mad(xhat4.z, (_FLOAT_PREC)dyRead4.z, ds);
ds = mad(xhat4.w, (_FLOAT_PREC)dyRead4.w, ds);
#endif
}

#if(MIO_BN_REM4)
unsigned int remkey = (lid << 2) + MIO_BN_LESS4;
nidx = remkey / MIO_BN_HW;
hwidx = remkey - (nidx * MIO_BN_HW);
index = nidx * MIO_BN_CHW + chwid + hwidx;
unsigned int remkey = (lid << 2*(1 - MIO_LAYOUT_NHWC)) + MIO_BN_LESS4;
nidx = remkey / MIO_BN_HW;
hwidx = remkey - (nidx * MIO_BN_HW);
index = nidx * MIO_BN_CHW +
#if MIO_LAYOUT_NHWC
hwidx * MIO_BN_C + grpid;
if(index < MIO_BN_NCHW)
{
xread = *((const global _FLOAT*)(x_in + index));
dyRead = *((const global _FLOAT*)(dy_in + index));
xhat_tmp = ((_FLOAT_PREC)xread - mean) * invVariance;
db += (_FLOAT_PREC)dyRead;
ds = mad(xhat_tmp, (_FLOAT_PREC)dyRead, ds);
#else
chwid + hwidx;
if(index < (MIO_BN_NCHW - 3))
{
xread4 = *((const global _FLOAT4*)(x_in + index));
Expand All @@ -447,6 +499,7 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
ds = mad(xhat4.y, (_FLOAT_PREC)dyRead4.y, ds);
ds = mad(xhat4.z, (_FLOAT_PREC)dyRead4.z, ds);
ds = mad(xhat4.w, (_FLOAT_PREC)dyRead4.w, ds);
#endif
}

#endif
Expand Down Expand Up @@ -491,12 +544,16 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
for(unsigned int j = 0; j < MIO_MAX_READ; j++)
#endif
{
unsigned int l = k + j;
nidx = l / MIO_BN_HW;
hwidx = l - (nidx * MIO_BN_HW);
index = nidx * MIO_BN_CHW + chwid + hwidx;
dyvalue = (_FLOAT_PREC)(*(dy_in + index));
xhat = ((_FLOAT_PREC)(*(x_in + index)) - mean) * invVariance;
unsigned int l = k + j;
nidx = l / MIO_BN_HW;
hwidx = l - (nidx * MIO_BN_HW);
#if MIO_LAYOUT_NHWC
index = nidx * MIO_BN_CHW + hwidx * MIO_BN_C + grpid;
#else
index = nidx * MIO_BN_CHW + chwid + hwidx;
#endif
dyvalue = (_FLOAT_PREC)(*(dy_in + index));
xhat = ((_FLOAT_PREC)(*(x_in + index)) - mean) * invVariance;
#if MIOPEN_USE_FP16 == 1
float temp_tmp1 = mad((float)NHW, (float)dyvalue, -temp_db);
float temp_tmp2 = -((float)xhat) * temp_ds;
Expand All @@ -518,7 +575,11 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
unsigned int l = k + j;
nidx = l / MIO_BN_HW;
hwidx = l - (nidx * MIO_BN_HW);
#if MIO_LAYOUT_NHWC
index = nidx * MIO_BN_CHW + hwidx * MIO_BN_C + grpid;
#else
index = nidx * MIO_BN_CHW + chwid + hwidx;
#endif
*(dx_out + index) = (_FLOAT)vals[j];
}
}
Expand All @@ -534,7 +595,11 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
unsigned int l = remkeyout + j;
nidx = l / MIO_BN_HW;
hwidx = l - (nidx * MIO_BN_HW);
#if MIO_LAYOUT_NHWC
index = nidx * MIO_BN_CHW + hwidx * MIO_BN_C + grpid;
#else
index = nidx * MIO_BN_CHW + chwid + hwidx;
#endif
if(index < MIO_BN_NCHW)
{
dyvalue = (_FLOAT_PREC)(*(dy_in + index));
Expand All @@ -554,7 +619,11 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
unsigned int l = remkeyout + j;
nidx = l / MIO_BN_HW;
hwidx = l - (nidx * MIO_BN_HW);
#if MIO_LAYOUT_NHWC
index = nidx * MIO_BN_CHW + hwidx * MIO_BN_C + grpid;
#else
index = nidx * MIO_BN_CHW + chwid + hwidx;
#endif
if(index < MIO_BN_NCHW)
{
*(dx_out + index) = (_FLOAT_PREC)vals[j];
Expand Down Expand Up @@ -681,7 +750,7 @@ MIOpenBatchNormBwdSpatialDScaleDBias(const __global _FLOAT* x_in,
const __global _FLOAT* savedMean,
const __global _FLOAT* savedInvVariance
#endif
)
)
{

unsigned int xgid = get_global_id(0);
Expand Down Expand Up @@ -995,7 +1064,7 @@ MIOpenBatchNormBwdSpatial(const __global _FLOAT* __restrict x_in,
#else // maxn
db += (_FLOAT_PREC)(*(dy_in + index));
_FLOAT_PREC xhat = (((_FLOAT_PREC)(*(x_in + index)) - mean) * invVariance);
ds = mad(xhat, (_FLOAT_PREC)(*(dy_in + index)), ds);
ds = mad(xhat, (_FLOAT_PREC)(*(dy_in + index)), ds);
#endif
}
}
Expand Down
Loading