Skip to content

Commit

Permalink
pnnx fuse conv3d-bn and deconv3d-bn (#5045)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Sep 21, 2023
1 parent 019176c commit b8d5a5d
Show file tree
Hide file tree
Showing 9 changed files with 529 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,10 @@ set(pnnx_pass_level5_SRCS
pass_level5/fuse_constant_expression.cpp
pass_level5/fuse_conv1d_batchnorm1d.cpp
pass_level5/fuse_conv2d_batchnorm2d.cpp
pass_level5/fuse_conv3d_batchnorm3d.cpp
pass_level5/fuse_convtranspose1d_batchnorm1d.cpp
pass_level5/fuse_convtranspose2d_batchnorm2d.cpp
pass_level5/fuse_convtranspose3d_batchnorm3d.cpp
pass_level5/fuse_contiguous_view.cpp
pass_level5/fuse_linear_batchnorm1d.cpp
pass_level5/fuse_pad_conv1d.cpp
Expand Down
4 changes: 4 additions & 0 deletions tools/pnnx/src/pass_level5.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
#include "pass_level5/fuse_constant_expression.h"
#include "pass_level5/fuse_conv1d_batchnorm1d.h"
#include "pass_level5/fuse_conv2d_batchnorm2d.h"
#include "pass_level5/fuse_conv3d_batchnorm3d.h"
#include "pass_level5/fuse_convtranspose1d_batchnorm1d.h"
#include "pass_level5/fuse_convtranspose2d_batchnorm2d.h"
#include "pass_level5/fuse_convtranspose3d_batchnorm3d.h"
#include "pass_level5/fuse_contiguous_view.h"
#include "pass_level5/fuse_layernorm.h"
#include "pass_level5/fuse_linear_batchnorm1d.h"
Expand Down Expand Up @@ -101,8 +103,10 @@ void pass_level5(Graph& g, const std::set<std::string>& foldable_constants, cons

fuse_conv1d_batchnorm1d(g);
fuse_conv2d_batchnorm2d(g);
fuse_conv3d_batchnorm3d(g);
fuse_convtranspose1d_batchnorm1d(g);
fuse_convtranspose2d_batchnorm2d(g);
fuse_convtranspose3d_batchnorm3d(g);
fuse_linear_batchnorm1d(g);

fuse_pad_conv1d(g);
Expand Down
138 changes: 138 additions & 0 deletions tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_conv3d_batchnorm3d.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

namespace pnnx {

class fuse_conv3d_batchnorm3d_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
nn.Conv3d op_0 1 1 input a in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias
nn.BatchNorm3d op_1 1 1 a out num_features=%num_features eps=%eps affine=%affine @running_mean @running_var @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.Conv3d";
}

const char* name_str() const
{
return "convbn3d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["padding_mode"] = captured_params.at("padding_mode");
op->params["stride"] = captured_params.at("stride");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = true;

// resolve merged conv3d weight and bias
int channels = captured_params.at("num_features").i;
float bn_eps = captured_params.at("eps").f;
bool has_bn_affine = captured_params.at("affine").b;
bool has_conv_bias = captured_params.at("bias").b;

auto bn_running_mean = captured_attrs.at("op_1.running_mean").get_float32_data();
auto bn_running_var = captured_attrs.at("op_1.running_var").get_float32_data();
auto bn_weight = has_bn_affine ? captured_attrs.at("op_1.weight").get_float32_data() : std::vector<float>();
auto bn_bias = has_bn_affine ? captured_attrs.at("op_1.bias").get_float32_data() : std::vector<float>();

// a = bias - slope * mean / sqrt(var + eps)
// b = slope / sqrt(var + eps)
// value = value * b + a

std::vector<float> a(channels);
std::vector<float> b(channels);
for (int i = 0; i < channels; i++)
{
double sqrt_var = sqrt(bn_running_var[i] + bn_eps);

if (has_bn_affine)
{
a[i] = (float)(bn_bias[i] - bn_weight[i] * bn_running_mean[i] / sqrt_var);
b[i] = (float)(bn_weight[i] / sqrt_var);
}
else
{
a[i] = (float)(-bn_running_mean[i] / sqrt_var);
b[i] = (float)(1.f / sqrt_var);
}
}

op->attrs["weight"] = captured_attrs.at("op_0.weight");

if (has_conv_bias)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
else
{
// init bias as zero
op->attrs["bias"] = Attribute();
op->attrs["bias"].type = op->attrs["weight"].type;
op->attrs["bias"].shape = {channels};
op->attrs["bias"].set_float32_data(std::vector<float>(channels, 0.f));
}

auto conv_weight = op->attrs["weight"].get_float32_data();
auto conv_bias = op->attrs["bias"].get_float32_data();

const int outch = captured_params.at("out_channels").i;
const int weight_per_outch = op->attrs["weight"].elemcount() / outch;

for (int i = 0; i < channels; i++)
{
float* conv_weight_outch = conv_weight.data() + weight_per_outch * i;
for (int j = 0; j < weight_per_outch; j++)
{
conv_weight_outch[j] *= b[i];
}

conv_bias[i] = conv_bias[i] * b[i] + a[i];
}

op->attrs["weight"].set_float32_data(conv_weight);
op->attrs["bias"].set_float32_data(conv_bias);
}
};

void fuse_conv3d_batchnorm3d(Graph& graph)
{
fuse_conv3d_batchnorm3d_pass a;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
}

} // namespace pnnx
21 changes: 21 additions & 0 deletions tools/pnnx/src/pass_level5/fuse_conv3d_batchnorm3d.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_conv3d_batchnorm3d(Graph& graph);

} // namespace pnnx
156 changes: 156 additions & 0 deletions tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "fuse_convtranspose3d_batchnorm3d.h"

#include "pass_level2.h"

#include <math.h>
#include <string.h>

namespace pnnx {

class fuse_convtranspose3d_batchnorm3d_pass : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
nn.ConvTranspose3d op_0 1 1 input a in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride output_padding=%output_padding padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias
nn.BatchNorm3d op_1 1 1 a out num_features=%num_features eps=%eps affine=%affine @running_mean @running_var @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "nn.ConvTranspose3d";
}

const char* name_str() const
{
return "convtransposebn3d";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
op->params["in_channels"] = captured_params.at("in_channels");
op->params["out_channels"] = captured_params.at("out_channels");
op->params["kernel_size"] = captured_params.at("kernel_size");
op->params["stride"] = captured_params.at("stride");
op->params["output_padding"] = captured_params.at("output_padding");
op->params["padding"] = captured_params.at("padding");
op->params["dilation"] = captured_params.at("dilation");
op->params["groups"] = captured_params.at("groups");
op->params["bias"] = true;

// resolve merged convtranspose3d weight and bias
int channels = captured_params.at("num_features").i;
float bn_eps = captured_params.at("eps").f;
bool has_bn_affine = captured_params.at("affine").b;
bool has_convtranspose_bias = captured_params.at("bias").b;

auto bn_running_mean = captured_attrs.at("op_1.running_mean").get_float32_data();
auto bn_running_var = captured_attrs.at("op_1.running_var").get_float32_data();
auto bn_weight = has_bn_affine ? captured_attrs.at("op_1.weight").get_float32_data() : std::vector<float>();
auto bn_bias = has_bn_affine ? captured_attrs.at("op_1.bias").get_float32_data() : std::vector<float>();

// a = bias - slope * mean / sqrt(var + eps)
// b = slope / sqrt(var + eps)
// value = value * b + a

std::vector<float> a(channels);
std::vector<float> b(channels);
for (int i = 0; i < channels; i++)
{
double sqrt_var = sqrt(bn_running_var[i] + bn_eps);

if (has_bn_affine)
{
a[i] = (float)(bn_bias[i] - bn_weight[i] * bn_running_mean[i] / sqrt_var);
b[i] = (float)(bn_weight[i] / sqrt_var);
}
else
{
a[i] = (float)(-bn_running_mean[i] / sqrt_var);
b[i] = (float)(1.f / sqrt_var);
}
}

op->attrs["weight"] = captured_attrs.at("op_0.weight");

if (has_convtranspose_bias)
{
op->attrs["bias"] = captured_attrs.at("op_0.bias");
}
else
{
// init bias as zero
op->attrs["bias"] = Attribute();
op->attrs["bias"].type = op->attrs["weight"].type;
op->attrs["bias"].shape = {channels};
op->attrs["bias"].set_float32_data(std::vector<float>(channels, 0.f));
}

auto conv_weight = op->attrs["weight"].get_float32_data();
auto conv_bias = op->attrs["bias"].get_float32_data();

// group-inch/group-outch/group-kh-kw
const int inch = captured_params.at("in_channels").i;
const int outch = captured_params.at("out_channels").i;
const int groups = captured_params.at("groups").i;
const int kd = captured_params.at("kernel_size").ai[0];
const int kh = captured_params.at("kernel_size").ai[1];
const int kw = captured_params.at("kernel_size").ai[2];

const int outch_g = outch / groups;
const int inch_g = inch / groups;
const int maxk = kd * kh * kw;

for (int g = 0; g < groups; g++)
{
float* wg = (float*)conv_weight.data() + g * inch_g * outch_g * maxk;
for (int i = 0; i < inch_g; i++)
{
for (int j = 0; j < outch_g; j++)
{
for (int k = 0; k < maxk; k++)
{
wg[(i * outch_g + j) * maxk + k] *= b[g * outch_g + j];
}
}
}
}

for (int i = 0; i < channels; i++)
{
conv_bias[i] = conv_bias[i] * b[i] + a[i];
}

op->attrs["weight"].set_float32_data(conv_weight);
op->attrs["bias"].set_float32_data(conv_bias);
}
};

void fuse_convtranspose3d_batchnorm3d(Graph& graph)
{
fuse_convtranspose3d_batchnorm3d_pass a;
int opindex = 0;

pnnx_graph_rewrite(graph, &a, opindex);
}

} // namespace pnnx
21 changes: 21 additions & 0 deletions tools/pnnx/src/pass_level5/fuse_convtranspose3d_batchnorm3d.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void fuse_convtranspose3d_batchnorm3d(Graph& graph);

} // namespace pnnx
2 changes: 2 additions & 0 deletions tools/pnnx/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,10 @@ pnnx_add_test(pnnx_expression)
pnnx_add_test(pnnx_fold_constant)
pnnx_add_test(pnnx_fuse_conv1d_batchnorm1d)
pnnx_add_test(pnnx_fuse_conv2d_batchnorm2d)
pnnx_add_test(pnnx_fuse_conv3d_batchnorm3d)
pnnx_add_test(pnnx_fuse_convtranspose1d_batchnorm1d)
pnnx_add_test(pnnx_fuse_convtranspose2d_batchnorm2d)
pnnx_add_test(pnnx_fuse_convtranspose3d_batchnorm3d)
pnnx_add_test(pnnx_fuse_input_unpack)
pnnx_add_test(pnnx_fuse_layernorm)
pnnx_add_test(pnnx_fuse_linear_batchnorm1d)
Expand Down
Loading

0 comments on commit b8d5a5d

Please sign in to comment.