Skip to content

Commit

Permalink
[ARM]fix group_norm compute error when compared with paddle (#5683)
Browse files Browse the repository at this point in the history
* fix group_norm compute error when compared with paddle. test=develop
  • Loading branch information
chenjiaoAngel authored Mar 12, 2021
1 parent af6805f commit d923435
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 63 deletions.
22 changes: 16 additions & 6 deletions lite/kernels/arm/group_norm_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,23 @@ void GroupNormCompute::Run() {
float epsilon = param.epsilon;
int groups = param.groups;
int channels = param.channels;
int n = param.x->dims()[0];
int c = param.x->dims()[1];
auto x_dims = param.x->dims();
int n = x_dims[0];
int c = x_dims[1];
if (channels == -1) {
CHECK_EQ(param.data_layout_str, "NCHW")
<< "it only support NCHW layout!, but recived layout is "
<< param.data_layout_str;
channels = c;
}
int height = x_dims[2];
int width = x_dims[3];
int ch_per_group = channels / groups;
int height = param.x->dims()[2];
int width = param.x->dims()[3];
int spatial_size = ch_per_group * height * width;
int ngroup = n * groups;
int cnt = spatial_size >> 4;
int remain = spatial_size % 16;
float* std_vec = new float[param.saved_variance->numel()];
// compute saved_mean and saved_variance
#pragma omp parallel for
for (int n = 0; n < ngroup; ++n) {
Expand Down Expand Up @@ -103,7 +111,8 @@ void GroupNormCompute::Run() {
float variance = (summ - mean * mean * spatial_size) / spatial_size;
float std = 1.f / sqrtf(variance + epsilon);
saved_mean[n] = mean;
saved_variance[n] = std;
saved_variance[n] = variance;
std_vec[n] = std;
}
int in_size = height * width;
cnt = in_size >> 4;
Expand All @@ -117,7 +126,7 @@ void GroupNormCompute::Run() {
numc *= ch_per_group;
for (int c = 0; c < ch_per_group; c++) {
int chin = numc + c;
const float sstd_val = scale[chin] * saved_variance[i];
const float sstd_val = scale[chin] * std_vec[i];
const float bias_val = bias[chin];
const float mean_val = saved_mean[i];
const float32x4_t vsstd = vdupq_n_f32(sstd_val);
Expand Down Expand Up @@ -158,6 +167,7 @@ void GroupNormCompute::Run() {
}
}
}
delete[] std_vec;
}

} // namespace arm
Expand Down
27 changes: 19 additions & 8 deletions lite/operators/group_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,27 +34,35 @@ bool GroupNormOp::CheckShape() const {
auto scale_dims = param_.scale->dims();
auto bias_dims = param_.bias->dims();
if (param_.channels == -1) {
param_.channels = x_dims[1];
param_.channels = (param_.data_layout_str == "NCHW")
? x_dims[1]
: x_dims[x_dims.size() - 1];
}
// only support NCHW
CHECK_EQ(param_.data_layout_str, "NCHW") << "data_layout must be NCHW";
CHECK(x_dims.size() >= 2 && x_dims.size() <= 5)
<< "Input X must have 2 to 5 dimensions.";
CHECK_EQ(scale_dims.size(), 1UL) << "Input Scale must have 1 dimensions.";
CHECK_EQ(bias_dims.size(), 1UL) << "Input Bias must have 1 dimensions.";
CHECK_GT(param_.epsilon, 0.f) << "epsilon should be greater than 0.f";
CHECK_LT(param_.epsilon, 0.01f) << "epsilon should be less than 0.01f";
CHECK_EQ(param_.channels, x_dims[1])
<< "Input channels must be equal input_shape[1]";
CHECK_EQ(param_.channels % param_.groups, 0)
<< "channels must be divide groups";
CHECK_LE(param_.groups, param_.channels)
<< "groups should be less than channels";
CHECK_GE(param_.groups, 1) << "groups should be greater than 1";
CHECK_EQ(param_.channels, scale_dims[0])
<< "The Input(Scale)'s first dimension size of Op(group_norm) must be "
"equal to the number of channels";
CHECK_EQ(param_.channels, bias_dims[0])
<< "The Input(Bias)'s first dimension size of Op(group_norm) must be "
"equal to the number of channels";
return true;
}

bool GroupNormOp::InferShapeImpl() const {
auto x_dims = param_.x->dims();
int64_t batch_size = x_dims[0];
int64_t num = param_.channels / param_.groups;
param_.saved_mean->Resize({batch_size * num});
param_.saved_variance->Resize({batch_size * num});
param_.saved_mean->Resize({batch_size, param_.groups});
param_.saved_variance->Resize({batch_size, param_.groups});
param_.out->Resize(x_dims);
return true;
}
Expand Down Expand Up @@ -82,6 +90,9 @@ bool GroupNormOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
}
param_.out =
scope->FindVar(op_desc.Output("Y").front())->GetMutable<Tensor>();
if (op_desc.HasAttr("data_layout")) {
param_.data_layout_str = op_desc.GetAttr<std::string>("data_layout");
}
param_.epsilon = op_desc.GetAttr<float>("epsilon");
param_.groups = op_desc.GetAttr<int>("groups");
if (op_desc.HasAttr("channels")) {
Expand Down
1 change: 1 addition & 0 deletions lite/operators/op_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -1797,6 +1797,7 @@ struct GroupNormParam : ParamBase {
lite::Tensor* scale{};
lite::Tensor* saved_mean{};
lite::Tensor* saved_variance{};
std::string data_layout_str{"NCHW"};
float epsilon;
int groups;
int channels;
Expand Down
131 changes: 82 additions & 49 deletions lite/tests/kernels/group_norm_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ class GroupNormComputeTest : public arena::TestCase {
DDim dims_{{4, 5, 19, 19}};
float epsilon_ = 1e-5f;
int groups_ = 1;
int channels_ = dims_[1];
std::string data_layout_str_ = "NCHW";

public:
GroupNormComputeTest(const Place& place,
const std::string& alias,
DDim dims,
float epsilon,
int groups,
int channels)
std::string data_layout_str)
: TestCase(place, alias),
dims_(dims),
epsilon_(epsilon),
groups_(groups),
channels_(channels) {}
data_layout_str_(data_layout_str) {}

void RunBaseline(Scope* scope) override {
auto x = scope->FindTensor(x_);
Expand All @@ -59,7 +59,7 @@ class GroupNormComputeTest : public arena::TestCase {
CHECK(y);
CHECK(saved_mean);
CHECK(saved_variance);
DDim saved_dim({dims_[0] * groups_});
DDim saved_dim({dims_[0], groups_});
y->Resize(dims_);
saved_mean->Resize(saved_dim);
saved_variance->Resize(saved_dim);
Expand All @@ -68,49 +68,82 @@ class GroupNormComputeTest : public arena::TestCase {
auto scale_data = scale->data<float>();
auto bias_data = bias->data<float>();
auto y_data = y->mutable_data<float>();
auto saved_mean_data = saved_mean->mutable_data<float>();
auto saved_variance_data = saved_variance->mutable_data<float>();

int n = x->dims()[0];
int ch_per_group = channels_ / groups_;
CHECK_EQ(x->dims()[1], channels_);
int spatial_size = ch_per_group * x->dims()[2] * x->dims()[3];
// compute mean
for (int i = 0; i < n * groups_; ++i) {
const float* x_ptr = x_data + i * spatial_size;
float sum = 0.f;
for (int j = 0; j < spatial_size; ++j) {
sum += x_ptr[j];
}
saved_mean_data[i] = sum / spatial_size;
}
// compute variance
for (int i = 0; i < n * groups_; ++i) {
const float* x_ptr = x_data + i * spatial_size;
float sum = 0.f;
for (int j = 0; j < spatial_size; ++j) {
sum +=
(x_ptr[j] - saved_mean_data[i]) * (x_ptr[j] - saved_mean_data[i]);
}
saved_variance_data[i] = 1.f / sqrtf(sum / spatial_size + epsilon_);
}
int in_size = x->dims()[2] * x->dims()[3];
// compute out
for (int i = 0; i < n * groups_; ++i) {
const float* x_ptr = x_data + i * spatial_size;
float* y_ptr = y_data + i * spatial_size;
int c_num = i % groups_;
for (int c = 0; c < ch_per_group; c++) {
int chin = c_num * ch_per_group + c;
float scale_val = scale_data[chin];
float bias_val = bias_data[chin];
const float* x_ch_ptr = x_ptr + c * in_size;
float* y_ch_ptr = y_ptr + c * in_size;
for (int j = 0; j < in_size; j++) {
y_ch_ptr[j] = scale_val * (x_ch_ptr[j] - saved_mean_data[i]) *
saved_variance_data[i] +
bias_val;
auto mean_data = saved_mean->mutable_data<float>();
auto var_data = saved_variance->mutable_data<float>();

auto x_dims = x->dims();
int groups = groups_;
int channels =
(data_layout_str_ == "NCHW") ? x_dims[1] : x_dims[x_dims.size() - 1];
int group_size = (channels - 1) / groups + 1;
int imsize = (data_layout_str_ == "NCHW") ? (x_dims[2] * x_dims[3])
: (x_dims[1] * x_dims[2]);

auto* iter_x_data = x_data;
auto* iter_y_data = y_data;
for (int bid = 0; bid < x_dims[0]; bid++) {
for (int gid = 0; gid < groups; gid++) {
float x_mean = 0;
float x_var = 0;
int number =
std::min(group_size, static_cast<int>(channels - gid * group_size));
auto* tmp_x = iter_x_data;
auto* x_src_data = iter_x_data;
auto* tmp_y = iter_y_data;
auto* y_src_data = iter_y_data;

if (data_layout_str_ == "NCHW") {
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize; imid++, iter_x_data++) {
x_mean += iter_x_data[0];
x_var += iter_x_data[0] * iter_x_data[0];
}
}
} else {
for (int cid = 0; cid < number; cid++) {
iter_x_data = tmp_x + cid;
for (int imid = 0; imid < imsize; imid++, iter_x_data += channels) {
x_mean += iter_x_data[0];
x_var += iter_x_data[0] * iter_x_data[0];
}
}
iter_x_data = tmp_x + group_size;
}

x_mean /= number * imsize;
x_var /= number * imsize;
x_var = x_var - x_mean * x_mean;
float var_inv = 1.0 / std::sqrt(x_var + epsilon_);
mean_data[bid * groups + gid] = x_mean;
var_data[bid * groups + gid] = x_var;

if (data_layout_str_ == "NCHW") {
for (int cid = 0; cid < number; cid++) {
for (int imid = 0; imid < imsize; imid++, tmp_x++, iter_y_data++) {
float val = (tmp_x[0] - x_mean) * var_inv;
if (scale_data) val *= scale_data[gid * group_size + cid];
if (bias_data) val += bias_data[gid * group_size + cid];
iter_y_data[0] = val;
}
}
} else {
for (int cid = 0; cid < number; cid++) {
tmp_x = x_src_data + cid;
iter_y_data = y_src_data + cid;
for (int imid = 0; imid < imsize;
imid++, tmp_x += channels, iter_y_data += channels) {
float val = (tmp_x[0] - x_mean) * var_inv;
if (scale_data) val *= scale_data[gid * group_size + cid];
if (bias_data) val += bias_data[gid * group_size + cid];
iter_y_data[0] = val;
}
}
iter_y_data = tmp_y + group_size;
}
}
if (data_layout_str_ == "NCHW") {
iter_x_data = x_data + (bid + 1) * channels * imsize;
iter_y_data = y_data + (bid + 1) * channels * imsize;
}
}
}
Expand All @@ -125,7 +158,7 @@ class GroupNormComputeTest : public arena::TestCase {
op_desc->SetOutput("Variance", {saved_variance_});
op_desc->SetAttr("epsilon", epsilon_);
op_desc->SetAttr("groups", groups_);
op_desc->SetAttr("channels", channels_);
op_desc->SetAttr("data_layout", data_layout_str_);
}

void PrepareData() override {
Expand All @@ -148,7 +181,7 @@ void TestGroupNorm(Place place,
float abs_error = 6e-5,
std::vector<std::string> ignored_outs = {}) {
for (auto& n : {1, 3, 16}) {
for (auto& c : {1}) {
for (auto& c : {1, 2}) {
for (auto& h : {1, 16, 33, 56}) {
for (auto& w : {1, 17, 55}) {
for (auto& groups : {1, 2, 4}) {
Expand All @@ -158,7 +191,7 @@ void TestGroupNorm(Place place,
DDim dim_in({n, c, h, w});
float epsilon = 1e-5f;
std::unique_ptr<arena::TestCase> tester(new GroupNormComputeTest(
place, "def", dim_in, epsilon, groups, c));
place, "def", dim_in, epsilon, groups, "NCHW"));
#ifdef LITE_WITH_ARM
if (place == TARGET(kARM)) {
auto& ctx = tester->context()->As<ARMContext>();
Expand Down

0 comments on commit d923435

Please sign in to comment.