Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Bug Fixed] Fix batch norm when grad_req is add #18500

Merged
merged 23 commits into from
Jun 8, 2020
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/operator/nn/batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ void BatchNormBackward(const OpContext &ctx, const BatchNormParam& param,
const std::vector<TBlob> &outputs) {
CHECK_EQ(inputs.size(), 8U);
CHECK_EQ(outputs.size(), 3U);

std::vector<TBlob> out_grad(1);
std::vector<TBlob> out_data(3);
std::vector<TBlob> in_data(3);
Expand Down
83 changes: 62 additions & 21 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,31 @@ static inline void ForEachFast(const BNTensor3<DType1> &in_data,
}
}

template<typename DType1, typename DType2, typename DType3, typename OnData>
static inline void ForEachFast(const BNTensor3<DType1> &in_data,
const BNTensor3<DType2> &in_data2,
const BNTensor3<DType3> &out_data,
const size_t channel,
OnData onData) {
const size_t num = in_data.OuterSize();
const size_t matrixSize = in_data.InnerSize();
const size_t skipLength = in_data.SkipLengthToNextSameChannelData();
const size_t startOffset = in_data.StartOffset(channel);

DType1 *data = in_data.dptr_ + startOffset;
DType2 *data2 = in_data2.dptr_ + startOffset;
DType3 *odata = out_data.dptr_ + startOffset;

for (size_t outer = 0; outer < num; ++outer) {
for (size_t i = 0; i < matrixSize; ++i) {
onData(data++, data2++, odata++);
}
data += skipLength;
data2 += skipLength;
odata += skipLength;
}
}

} // namespace batchnorm

/*! \brief Forward CPU */
Expand Down Expand Up @@ -263,7 +288,7 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,
dotp += (*thisInputData - mean) * (*gradOut_data);
});

if (!gradIn.IsEmpty() && IsBNWriting(req[batchnorm::kData])) { // if there's a grad input
if (!gradIn.IsEmpty() && req[batchnorm::kData] != kNullOp) { // if there's a grad input
if (is_train_and_not_global_stats) {
// when in training mode
// Q(X) = X - E[x] ; i.e. input centered to zero mean
Expand All @@ -272,44 +297,60 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,

// projection of gradOutput on to output scaled by std
const AccReal k = dotp * invstd * invstd / itemCount;
ForEachFast(inputData, gradIn, static_cast<size_t>(channel),
[&mean, &k](const DType *inputDataPtr, DType *gradIn_data) {
*gradIn_data = (*inputDataPtr - mean) * k;
});

const AccReal iw = invstd * w;
const AccReal gradMean = sumGradOut / itemCount;
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
[iw, gradMean](const DType *gradOut_data, DType *gradIn_data) {
*gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw;
});
if (req[batchnorm::kData] != kAddTo) {
ForEachFast(inputData, gradIn, static_cast<size_t>(channel),
[&mean, &k](const DType *inputDataPtr, DType *gradIn_data) {
*gradIn_data = (*inputDataPtr - mean) * k;
});

ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
[iw, gradMean](const DType *gradOut_data, DType *gradIn_data) {
*gradIn_data = (*gradOut_data - gradMean - *gradIn_data) * iw;
});
} else {
ForEachFast(inputData, gradOut, gradIn, static_cast<size_t>(channel),
[&mean, &k, iw, gradMean](const DType *inputDataPtr,
const DType *gradOut_data,
DType *gradIn_data) {
DType normal_val = (*inputDataPtr - mean) * k;
*gradIn_data += (*gradOut_data - gradMean -
normal_val) * iw;
});
}
} else {
// when in evaluation mode
// Q(X) = X - running_mean ; i.e. input centered to zero mean
// Y = Q(X) / running_std ; i.e. BN output before weight and bias
// dL/dX = w / running_std
const AccReal iw = invstd * w;
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
[iw](const DType *gradOut_data, DType *gradIn_data) {
*gradIn_data = *gradOut_data * iw;
});
if (req[batchnorm::kData] != kAddTo) {
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
[iw](const DType *gradOut_data, DType *gradIn_data) {
*gradIn_data = *gradOut_data * iw;
});
} else {
ForEachFast(gradOut, gradIn, static_cast<size_t>(channel),
[iw](const DType *gradOut_data, DType *gradIn_data) {
*gradIn_data += *gradOut_data * iw;
});
}
}
}

// May want to make this a param eventually
const AccReal scale = 1.0f;

if (IsBNWriting(req[batchnorm::kGamma])) {
if (!param_.fix_gamma) {
gradWeightData[channel] = scale * dotp * invstd;
} else {
if (!param_.fix_gamma) {
KERNEL_ASSIGN(gradWeightData[channel], req[batchnorm::kGamma], scale * dotp * invstd);
} else {
if (IsBNWriting(req[batchnorm::kGamma])) {
gradWeightData[channel] = AccReal(0);
}
}

if (IsBNWriting(req[batchnorm::kBeta])) {
gradBiasData[channel] = scale * sumGradOut;
}
KERNEL_ASSIGN(gradBiasData[channel], req[batchnorm::kBeta], scale * sumGradOut);
}
}

Expand Down
68 changes: 52 additions & 16 deletions src/operator/nn/batch_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
#define FIX_GAMMA_FLAG 8
#define IS_TRAINING_FLAG 16
#define USE_GLOBAL_STATS_FLAG 32
#define ADDTO_DATA_FLAG (1 << 6)
#define ADDTO_GAMMA_FLAG (1 << 7)
#define ADDTO_BETA_FLAG (1 << 8)

#if MXNET_USE_CUDNN == 1
#include "./cudnn/cudnn_batch_norm-inl.h"
Expand Down Expand Up @@ -361,33 +364,60 @@ static __global__ void BatchNormalizationBackwardKernel(
* momentum + localVariance * (AccReal(1) - momentum);
}

if (gradInput.Size() > 0 && (flags & WRITE_DATA_FLAG) != 0) {
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
const DType gradOut = gradOutput.get_ref(batch, plane, x);
if (is_train_and_not_global_stats) {
const DType inp = input.get_ref(batch, plane, x);
const AccReal proj = (inp - mean) * projScale;
gradInput.get_ref(batch, plane, x) =
ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
} else {
gradInput.get_ref(batch, plane, x) = ScalarConvert<AccReal, DType>::to(
gradOut * gradScale);
if (gradInput.Size() > 0 && (flags & (WRITE_DATA_FLAG | ADDTO_DATA_FLAG)) != 0) {
const bool grad_write = flags & WRITE_DATA_FLAG;
if (grad_write) {
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
const DType gradOut = gradOutput.get_ref(batch, plane, x);
if (is_train_and_not_global_stats) {
const DType inp = input.get_ref(batch, plane, x);
const AccReal proj = (inp - mean) * projScale;
gradInput.get_ref(batch, plane, x) =
ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
} else {
gradInput.get_ref(batch, plane, x) = ScalarConvert<AccReal, DType>::to(
gradOut * gradScale);
}
}
}
} else {
// grad addto
for (int batch = 0, nbatch = gradOutput.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = gradOutput.InnerSize(); x < nx; x += blockDim.x) {
const DType gradOut = gradOutput.get_ref(batch, plane, x);
if (is_train_and_not_global_stats) {
const DType inp = input.get_ref(batch, plane, x);
const AccReal proj = (inp - mean) * projScale;
gradInput.get_ref(batch, plane, x) +=
ScalarConvert<AccReal, DType>::to((gradOut - proj - gradMean) * gradScale);
} else {
gradInput.get_ref(batch, plane, x) += ScalarConvert<AccReal, DType>::to(
gradOut * gradScale);
}
}
}
}
}

if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_GAMMA_FLAG) != 0) {
if (tensors.gradWeight.numElements() > 0 && threadIdx.x == 0 &&
(flags & (WRITE_GAMMA_FLAG | ADDTO_GAMMA_FLAG)) != 0) {
if ((flags & FIX_GAMMA_FLAG) == 0) {
tensors.gradWeight[plane] = ScalarConvert<AccReal, DType>::to(dotP * invstd);
if (flags & WRITE_GAMMA_FLAG)
tensors.gradWeight[plane] = ScalarConvert<AccReal, DType>::to(dotP * invstd);
else
tensors.gradWeight[plane] += ScalarConvert<AccReal, DType>::to(dotP * invstd);
} else {
tensors.gradWeight[plane] = DType(0);
}
}

if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) {
tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 &&
(flags & (WRITE_BETA_FLAG | ADDTO_BETA_FLAG)) != 0) {
if (flags & WRITE_BETA_FLAG)
tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
else
tensors.gradBias[plane] += ScalarConvert<AccReal, DType>::to(gradOutputSum);
}
}

Expand Down Expand Up @@ -579,12 +609,18 @@ static inline uint32_t SetupFlags(const OpContext &ctx,
flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0;
if (IsBNWriting(req[batchnorm::kData])) {
flags |= WRITE_DATA_FLAG;
} else if (req[batchnorm::kData] == kAddTo) {
flags |= ADDTO_DATA_FLAG;
}
if (IsBNWriting(req[batchnorm::kGamma])) {
flags |= WRITE_GAMMA_FLAG;
} else if (req[batchnorm::kGamma] == kAddTo) {
flags |= ADDTO_GAMMA_FLAG;
}
if (IsBNWriting(req[batchnorm::kBeta])) {
flags |= WRITE_BETA_FLAG;
} else if (req[batchnorm::kBeta] == kAddTo) {
flags |= ADDTO_BETA_FLAG;
}
return flags;
}
Expand Down
15 changes: 13 additions & 2 deletions src/operator/nn/cudnn/cudnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,24 @@ class CuDNNBatchNormOp {

if (param_.fix_gamma) gamma = 1.f;

bool grad_add_gamma_beta = (req[cudnnbatchnorm::kGamma] == kAddTo) ||
(req[cudnnbatchnorm::kBeta] == kAddTo);
if (grad_add_gamma_beta) {
if (IsBNWriting(req[cudnnbatchnorm::kGamma])) {
dgamma = 0.f;
}
if (IsBNWriting(req[cudnnbatchnorm::kBeta])) {
dbeta = 0.f;
}
}

CUDNN_CALL(cudnnBatchNormalizationBackward(
s->dnn_handle_,
mode,
&a,
&b,
req[cudnnbatchnorm::kData] == kAddTo ? &b_add : &b,
&a,
req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add,
grad_add_gamma_beta ? &b_add : &b, // gamma and beta
io_desc_,
x.dptr_,
io_desc_,
Expand Down
41 changes: 30 additions & 11 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,8 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
else if (diff.IsDefaultData())
diff_mem = diff.GetMKLDNNDataReorder(data_mem->get_desc());
auto &bwd = GetBNBackward<DType>(param, ctx, data, *data_mem, diff, *diff_mem, flags);
auto gradi_mem = const_cast<NDArray &>(gradIn).CreateMKLDNNData(data_mem->get_desc());
auto gradi_mem = CreateMKLDNNMem(const_cast<NDArray &>(gradIn),
bwd.pd.diff_src_desc(), req[batchnorm::kData]);

if (static_cast<int>(flags) & static_cast<int>(mkldnn::normalization_flags::use_scale_shift)) {
const NDArray &gamma = in_data[batchnorm::kGamma];
Expand All @@ -368,7 +369,7 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
}
mkldnn_args_map_t net_args;
net_args[MKLDNN_ARG_SRC] = *data_mem;
net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem;
net_args[MKLDNN_ARG_DIFF_SRC] = *gradi_mem.second;
net_args[MKLDNN_ARG_SCALE_SHIFT] = bwd.GetWeight();
net_args[MKLDNN_ARG_DIFF_SCALE_SHIFT] = bwd.GetGradw();
net_args[MKLDNN_ARG_DIFF_DST] = *diff_mem;
Expand Down Expand Up @@ -401,28 +402,46 @@ void MKLDNNBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
}
net_args[MKLDNN_ARG_MEAN] = *(out_mean.GetMKLDNNData());
net_args[MKLDNN_ARG_VARIANCE] = var_mem;
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
MKLDNNStream::Get()->Submit();
} else {
net_args[MKLDNN_ARG_MEAN] = *(moving_mean.GetMKLDNNData());
net_args[MKLDNN_ARG_VARIANCE] = *(moving_var.GetMKLDNNData());
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
MKLDNNStream::Get()->Submit();
}
MKLDNNStream::Get()->RegisterPrimArgs(bwd.GetBwd(), net_args);
CommitOutput(gradIn, gradi_mem);
MKLDNNStream::Get()->Submit();

// copy data from gradw_mem to in_grad[1] and in_grad[2]
DType *gw_buf = reinterpret_cast<DType *>(bwd.GetGradw().get_data_handle());
DType *w_grad_1 = in_grad[1].data().dptr<DType>();
DType *w_grad_2 = in_grad[2].data().dptr<DType>();
DType *w_grad_1 = in_grad[batchnorm::kGamma].data().dptr<DType>();
DType *w_grad_2 = in_grad[batchnorm::kBeta].data().dptr<DType>();

// the gradient of gamma
if (!param.fix_gamma) {
memcpy(w_grad_1, gw_buf, copy_size);
memcpy(w_grad_2, &gw_buf[channels_], copy_size);
if (req[batchnorm::kGamma] != kNullOp) {
if (req[batchnorm::kGamma] != kAddTo) {
memcpy(w_grad_1, gw_buf, copy_size);
} else {
for (int i = 0; i < channels_; i++) {
w_grad_1[i] += gw_buf[i];
}
}
}
} else {
for (int i = 0; i < channels_; i++) {
(in_grad[1].data().dptr<DType>())[i] = 0.0f;
}
memcpy(w_grad_2, &gw_buf[channels_], copy_size);
}

// the gradient of beta
if (req[batchnorm::kBeta] != kNullOp) {
if (req[batchnorm::kBeta] != kAddTo) {
memcpy(w_grad_2, &gw_buf[channels_], copy_size);
} else {
DType *grad_beta = &gw_buf[channels_];
for (int i = 0; i < channels_; i++) {
w_grad_2[i] += grad_beta[i];
}
}
}
} else {
LOG(FATAL) << "MKLDNN batch normalization backward: should not reach here ...";
Expand Down
Loading