Skip to content

Commit

Permalink
[v1.x] backport apache#18500 - [Bug Fixed] Fix batch norm when grad_r…
Browse files Browse the repository at this point in the history
…eq is `add` (apache#18518)

* Fix batch norm when grad_req is

* fix

* remove softmax test

* fix
  • Loading branch information
wkcn authored and ChaiBapchya committed Aug 15, 2020
1 parent fe52b49 commit 2472d4a
Show file tree
Hide file tree
Showing 7 changed files with 385 additions and 83 deletions.
1 change: 1 addition & 0 deletions src/operator/nn/batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param,
const std::vector<TBlob> &outputs) {
CHECK_EQ(inputs.size(), 8U);
CHECK_EQ(outputs.size(), 3U);

std::vector<TBlob> out_grad(1);
std::vector<TBlob> out_data(3);
std::vector<TBlob> in_data(3);
Expand Down
83 changes: 62 additions & 21 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,31 @@ static inline void ForEachFast(const BNTensor3<DType1> &in_data,
}
}

template<typename DType1, typename DType2, typename DType3, typename OnData>
static inline void ForEachFast(const BNTensor3<DType1> &in_data,
const BNTensor3<DType2> &in_data2,
const BNTensor3<DType3> &out_data,
const size_t channel,
OnData onData) {
const size_t num = in_data.OuterSize();
const size_t matrixSize = in_data.InnerSize();
const size_t skipLength = in_data.SkipLengthToNextSameChannelData();
const size_t startOffset = in_data.StartOffset(channel);

DType1 *data = in_data.dptr_ + startOffset;
DType2 *data2 = in_data2.dptr_ + startOffset;
DType3 *odata = out_data.dptr_ + startOffset;

for (size_t outer = 0; outer < num; ++outer) {
for (size_t i = 0; i < matrixSize; ++i) {
onData(data++, data2++, odata++);
}
data += skipLength;
data2 += skipLength;
odata += skipLength;
}
}

} // namespace batchnorm

/*! \brief Forward CPU */
Expand Down Expand Up @@ -264,7 +289,7 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,
dotp += (*thisInputData - mean) * (*gradOut_data);
});

if (!gradIn.IsEmpty() && IsBNWriting(req[batchnorm::kData])) { // if there's a grad input
if (!gradIn.IsEmpty() && req[batchnorm::kData] != kNullOp) { // if there's a grad input
if (is_train_and_not_global_stats) {
// when in training mode
// Q(X) = X - E[x] ; i.e. input centered to zero mean
Expand All @@ -273,44 +298,60 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,

// projection of gradOutput on to output scaled by std
const AccReal k = dotp * invstd * invstd / itemCount;
ForEachFast(inputData, gradIn, static_cast<size_t>(channel),
[&mean, &k](const DType *inputDataPtr, DType *gradIn_data) {
*gradIn_data = (*inputDataPtr - mean) * k;
});

const AccReal iw = invstd * w;
const AccReal gradMean = sumGradOut / itemCount;
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
[iw, gradMean](const DType *gradOut_data, DType *gradIn_data) {
*gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw;
});
if (req[batchnorm::kData] != kAddTo) {
ForEachFast(inputData, gradIn, static_cast<size_t>(channel),
[&mean, &k](const DType *inputDataPtr, DType *gradIn_data) {
*gradIn_data = (*inputDataPtr - mean) * k;
});

ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
[iw, gradMean](const DType *gradOut_data, DType *gradIn_data) {
*gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw;
});
} else {
ForEachFast(inputData, gradOut, gradIn, static_cast<size_t>(channel),
[&mean, &k, iw, gradMean](const DType *inputDataPtr,
const DType *gradOut_data,
DType *gradIn_data) {
DType normal_val = (*inputDataPtr - mean) * k;
*gradIn_data += (*gradOut_data - gradMean -
normal_val) * iw;
});
}
} else {
// when in evaluation mode
// Q(X) = X - running_mean ; i.e. input centered to zero mean
// Y = Q(X) / running_std ; i.e. BN output before weight and bias
// dL/dX = w / running_std
const AccReal iw = invstd * w;
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
[iw](const DType *gradOut_data, DType *gradIn_data) {
*gradIn_data = *gradOut_data * iw;
});
if (req[batchnorm::kData] != kAddTo) {
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
[iw](const DType *gradOut_data, DType *gradIn_data) {
*gradIn_data = *gradOut_data * iw;
});
} else {
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
[iw](const DType *gradOut_data, DType *gradIn_data) {
*gradIn_data += *gradOut_data * iw;
});
}
}
}

// May want to make this a param eventually
const AccReal scale = 1.0f;

if (IsBNWriting(req[batchnorm::kGamma])) {
if (!param_.fix_gamma) {
gradWeightData[channel] = scale * dotp * invstd;
} else {
if (!param_.fix_gamma) {
KERNEL_ASSIGN(gradWeightData[channel], req[batchnorm::kGamma], scale * dotp * invstd);
} else {
if (IsBNWriting(req[batchnorm::kGamma])) {
gradWeightData[channel] = AccReal(0);
}
}

if (IsBNWriting(req[batchnorm::kBeta])) {
gradBiasData[channel] = scale * sumGradOut;
}
KERNEL_ASSIGN(gradBiasData[channel], req[batchnorm::kBeta], scale * sumGradOut);
}
}

Expand Down
68 changes: 52 additions & 16 deletions src/operator/nn/batch_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
#define FIX_GAMMA_FLAG 8
#define IS_TRAINING_FLAG 16
#define USE_GLOBAL_STATS_FLAG 32
#define ADDTO_DATA_FLAG (1 << 6)
#define ADDTO_GAMMA_FLAG (1 << 7)
#define ADDTO_BETA_FLAG (1 << 8)

#if MXNET_USE_CUDNN == 1
#include "./cudnn/cudnn_batch_norm-inl.h"
Expand Down Expand Up @@ -362,33 +365,60 @@ static __global__ void BatchNormalizationBackwardKernel(
* momentum + localVariance * (AccReal(1) - momentum);
}

if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) {
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
const DType gradOut = gradOutput.get_ref(batch, plane, x);
if (is_train_and_not_global_stats) {
const DType inp = input.get_ref(batch, plane, x);
const AccReal proj = (inp - mean) * projScale;
gradInput.get_ref(batch, plane, x) =
ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
} else {
gradInput.get_ref(batch, plane, x) = ScalarConvert<AccReal, DType>::to(
gradOut * gradScale);
if (gradInput.Size() > 0 && (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) != 0) {
const bool grad_write = flags & WRITE_DATA_FLAG;
if (grad_write) {
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
const DType gradOut = gradOutput.get_ref(batch, plane, x);
if (is_train_and_not_global_stats) {
const DType inp = input.get_ref(batch, plane, x);
const AccReal proj = (inp - mean) * projScale;
gradInput.get_ref(batch, plane, x) =
ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
} else {
gradInput.get_ref(batch, plane, x) = ScalarConvert<AccReal, DType>::to(
gradOut * gradScale);
}
}
}
} else {
// grad addto
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
const DType gradOut = gradOutput.get_ref(batch, plane, x);
if (is_train_and_not_global_stats) {
const DType inp = input.get_ref(batch, plane, x);
const AccReal proj = (inp - mean) * projScale;
gradInput.get_ref(batch, plane, x) +=
ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
} else {
gradInput.get_ref(batch, plane, x) += ScalarConvert<AccReal, DType>::to(
gradOut * gradScale);
}
}
}
}
}

if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) {
if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 &&
(flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) {
if ((flags & FIX_GAMMA_FLAG) == 0) {
tensors.gradWeight[plane] = ScalarConvert<AccReal, DType>::to(dotP * invstd);
if (flags & WRITE_GAMMA_FLAG)
tensors.gradWeight[plane] = ScalarConvert<AccReal, DType>::to(dotP * invstd);
else
tensors.gradWeight[plane] += ScalarConvert<AccReal, DType>::to(dotP * invstd);
} else {
tensors.gradWeight[plane] = DType(0);
}
}

if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) {
tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 &&
(flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) {
if (flags & WRITE_BETA_FLAG)
tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
else
tensors.gradBias[plane] += ScalarConvert<AccReal, DType>::to(gradOutputSum);
}
}

Expand Down Expand Up @@ -585,12 +615,18 @@ static inline uint32_t SetupFlags(const OpContext &ctx,
flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0;
if (IsBNWriting(req[batchnorm::kData])) {
flags |= WRITE_DATA_FLAG;
} else if (req[batchnorm::kData] == kAddTo) {
flags |= ADDTO_DATA_FLAG;
}
if (IsBNWriting(req[batchnorm::kGamma])) {
flags |= WRITE_GAMMA_FLAG;
} else if (req[batchnorm::kGamma] == kAddTo) {
flags |= ADDTO_GAMMA_FLAG;
}
if (IsBNWriting(req[batchnorm::kBeta])) {
flags |= WRITE_BETA_FLAG;
} else if (req[batchnorm::kBeta] == kAddTo) {
flags |= ADDTO_BETA_FLAG;
}
return flags;
}
Expand Down
15 changes: 13 additions & 2 deletions src/operator/nn/cudnn/cudnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,13 +220,24 @@ class CuDNNBatchNormOp {

if (param_.fix_gamma) gamma = 1.f;

bool grad_add_gamma_beta = (req[cudnnbatchnorm::kGamma] == kAddTo) ||
(req[cudnnbatchnorm::kBeta] == kAddTo);
if (grad_add_gamma_beta) {
if (IsBNWriting(req[cudnnbatchnorm::kGamma])) {
dgamma = 0.f;
}
if (IsBNWriting(req[cudnnbatchnorm::kBeta])) {
dbeta = 0.f;
}
}

CUDNN_CALL(cudnnBatchNormalizationBackward(
s->dnn_handle_,
mode,
&a,
&b,
req[cudnnbatchnorm::kData] == kAddTo ? &b_add : &b,
&a,
req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add,
grad_add_gamma_beta ? &b_add : &b, // gamma and beta
io_desc_,
x.dptr_,
io_desc_,
Expand Down
41 changes: 30 additions & 11 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,8 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
else if (diff.IsDefaultData())
diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc());
auto &bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_desc());
auto gradi_mem = CreateMKLDNNMem(const_cast<NDArray &>(gradIn),
bwd.pd.diff_src_desc(), req[batchnorm::kData]);

if (static_cast<int>(flags) & static_cast<int>(mkldnn::normalization_flags::use_scale_shift)) {
const NDArray &gamma = in_data[batchnorm::kGamma];
Expand All @@ -368,7 +369,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
}
mkldnn_args_map_t net_args;
net_args[MKLDNN_ARG_SRC] = *data_mem;
net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem;
net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem.second;
net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight();
net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw();
net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem;
Expand Down Expand Up @@ -401,28 +402,46 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
}
net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData());
net_args[MKLDNN_ARG_VARIANCE] = var_mem;
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
MKLDNNStream::Get()->Submit();
} else {
net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData());
net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData());
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
MKLDNNStream::Get()->Submit();
}
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
CommitOutput(gradIn, gradi_mem);
MKLDNNStream::Get()->Submit();

// copy data from gradw_mem to in_grad[1] and in_grad[2]
DType *gw_buf = reinterpret_cast<DType *>(bwd.GetGradw().get_data_handle());
DType *w_grad_1 = in_grad[1].data().dptr<DType>();
DType *w_grad_2 = in_grad[2].data().dptr<DType>();
DType *w_grad_1 = in_grad[batchnorm::kGamma].data().dptr<DType>();
DType *w_grad_2 = in_grad[batchnorm::kBeta].data().dptr<DType>();

// the gradient of gamma
if (!param.fix_gamma) {
memcpy(w_grad_1, gw_buf, copy_size);
memcpy(w_grad_2, &gw_buf[channels_], copy_size);
if (req[batchnorm::kGamma] != kNullOp) {
if (req[batchnorm::kGamma] != kAddTo) {
memcpy(w_grad_1, gw_buf, copy_size);
} else {
for (int i = 0; i < channels_; i++) {
w_grad_1[i] += gw_buf[i];
}
}
}
} else {
for (int i = 0; i < channels_; i++) {
(in_grad[1].data().dptr<DType>())[i] = 0.0f;
}
memcpy(w_grad_2, &gw_buf[channels_], copy_size);
}

// the gradient of beta
if (req[batchnorm::kBeta] != kNullOp) {
if (req[batchnorm::kBeta] != kAddTo) {
memcpy(w_grad_2, &gw_buf[channels_], copy_size);
} else {
DType *grad_beta = &gw_buf[channels_];
for (int i = 0; i < channels_; i++) {
w_grad_2[i] += grad_beta[i];
}
}
}
} else {
LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ...";
Expand Down
Loading

0 comments on commit 2472d4a

Please sign in to comment.