From 2337e86fdc9aaa359b939dac5c5302598e287ec9 Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 10 Jul 2023 17:28:24 +0800 Subject: [PATCH 1/2] pnnx convert conv with non-zero padding mode --- tools/pnnx/src/pass_level1/nn_Conv3d.cpp | 22 ++- tools/pnnx/src/pass_ncnn/nn_Conv1d.cpp | 159 ++++++++++++++++++-- tools/pnnx/src/pass_ncnn/nn_Conv2d.cpp | 170 +++++++++++++++++++-- tools/pnnx/src/pass_ncnn/nn_Conv3d.cpp | 183 +++++++++++++++++++++-- tools/pnnx/tests/ncnn/test_nn_Conv1d.py | 8 +- tools/pnnx/tests/ncnn/test_nn_Conv2d.py | 8 +- tools/pnnx/tests/ncnn/test_nn_Conv3d.py | 8 +- 7 files changed, 503 insertions(+), 55 deletions(-) diff --git a/tools/pnnx/src/pass_level1/nn_Conv3d.cpp b/tools/pnnx/src/pass_level1/nn_Conv3d.cpp index ac87f4c921f..9e18718a7bd 100644 --- a/tools/pnnx/src/pass_level1/nn_Conv3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_Conv3d.cpp @@ -47,6 +47,7 @@ class Conv3d : public FuseModulePass const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode"); + const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); // const torch::jit::Node* reflection_pad3d = find_node_by_kind(graph, "aten::reflection_pad3d"); // const torch::jit::Node* replication_pad3d = find_node_by_kind(graph, "aten::replication_pad3d"); @@ -62,6 +63,25 @@ class Conv3d : public FuseModulePass op->params["out_channels"] = weight.size(0); op->params["kernel_size"] = Parameter{weight.size(2), weight.size(3), weight.size(4)}; op->params["stride"] = convolution->namedInput("stride"); + if (pad) + { + op->params["padding_mode"] = pad->namedInput("mode"); + op->params["padding"] = pad->namedInput("pad"); + std::vector& padding = op->params["padding"].ai; + if (padding.size() == 6) + { + // Conv3d only accepts tuple of three integers + if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3] && padding[3] == padding[4] && padding[4] == padding[5]) + { + padding.resize(3); + } + else if (padding[0] == padding[3] && padding[1] == padding[4] && padding[2] == padding[5] && padding[0] != padding[1] && padding[1] != padding[2]) + { + padding.resize(0); + op->params["padding"].s = "same"; + } + } + } // if (reflection_pad3d) // { // op->params["padding_mode"] = "reflect"; @@ -100,7 +120,7 @@ class Conv3d : public FuseModulePass // } // } // } - // else + else { op->params["padding_mode"] = "zeros"; op->params["padding"] = convolution->namedInput("padding"); diff --git a/tools/pnnx/src/pass_ncnn/nn_Conv1d.cpp b/tools/pnnx/src/pass_ncnn/nn_Conv1d.cpp index 8d531ef8aac..fae2365e049 100644 --- a/tools/pnnx/src/pass_ncnn/nn_Conv1d.cpp +++ b/tools/pnnx/src/pass_ncnn/nn_Conv1d.cpp @@ -26,7 +26,7 @@ class nn_Conv1d : public GraphRewriterPass return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias +nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias pnnx.Output output 1 0 out )PNNXIR"; } @@ -43,12 +43,6 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { - std::string padding_mode = captured_params.at("padding_mode").s; - if (padding_mode != "zeros") - { - fprintf(stderr, "unsupported padding_mode %s\n", padding_mode.c_str()); - } - op->params["0"] = captured_params.at("out_channels"); op->params["1"] = captured_params.at("kernel_size").ai[0]; op->params["2"] = captured_params.at("dilation").ai[0]; @@ -83,7 +77,7 @@ class nn_Conv1d_1 : public GraphRewriterPass return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias pnnx.Output output 1 0 out )PNNXIR"; } @@ -100,12 +94,6 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { - std::string padding_mode = captured_params.at("padding_mode").s; - if (padding_mode != "zeros") - { - fprintf(stderr, "unsupported padding_mode %s\n", padding_mode.c_str()); - } - op->params["0"] = captured_params.at("out_channels"); op->params["1"] = captured_params.at("kernel_size").ai[0]; op->params["2"] = captured_params.at("dilation").ai[0]; @@ -133,8 +121,151 @@ pnnx.Output output 1 0 out } }; +class nn_Conv1d_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Padding pad 1 1 input a +Convolution1D conv 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + bool match(const std::map& captured_params) const + { + const std::string& padding_mode = captured_params.at("padding_mode").s; + if (padding_mode == "zeros") + return false; + + return true; + } + + bool match(const std::map& matched_operators) const + { + const Operator* conv = matched_operators.at("op_0"); + if (conv->params.at("padding").type == 4 && conv->params.at("padding").s == "same") + { + const std::vector input_shape = conv->inputs[0]->shape; + if (input_shape.size() != 2 && input_shape.size() != 3) + { + fprintf(stderr, "can not resolve pads without shape\n"); + return false; + } + } + + return true; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + std::vector padding; + if (captured_params.at("padding").type == 4) + { + if (captured_params.at("padding").s == "same") + { + // resolve pads + const std::vector input_shape = ops.at("pad")->inputs[0]->shape; + const int w = input_shape[input_shape.size() - 1]; + const int kernel_w = captured_params.at("kernel_size").ai[0]; + const int dilation_w = captured_params.at("dilation").ai[0]; + const int stride_w = captured_params.at("stride").ai[0]; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + + int wpad = kernel_extent_w + (w - 1) / stride_w * stride_w - w; + + padding = std::vector{wpad / 2, wpad - wpad / 2}; + } + else if (captured_params.at("padding").s == "valid") + { + padding = std::vector{0, 0}; + } + } + else + { + int wpad = captured_params.at("padding").ai[0]; + padding = std::vector{wpad, wpad}; + } + + ops.at("pad")->params["0"] = 0; + ops.at("pad")->params["1"] = 0; + ops.at("pad")->params["2"] = padding[0]; + ops.at("pad")->params["3"] = padding[1]; + + std::string padding_mode = captured_params.at("padding_mode").s; + if (padding_mode == "reflect") + { + ops.at("pad")->params["4"] = 2; // type=reflect + } + else if (padding_mode == "replicate") + { + ops.at("pad")->params["4"] = 1; // type=replicate + } + else + { + fprintf(stderr, "unsupported padding_mode %s\n", padding_mode.c_str()); + } + + ops.at("conv")->params["0"] = captured_params.at("out_channels"); + ops.at("conv")->params["1"] = captured_params.at("kernel_size").ai[0]; + ops.at("conv")->params["2"] = captured_params.at("dilation").ai[0]; + ops.at("conv")->params["3"] = captured_params.at("stride").ai[0]; + ops.at("conv")->params["4"] = 0; + ops.at("conv")->params["5"] = captured_params.at("bias").b ? 1 : 0; + ops.at("conv")->params["6"] = captured_attrs.at("op_0.weight").elemcount(); + ops.at("conv")->params["7"] = captured_params.find("groups") != captured_params.end() ? captured_params.at("groups") : 1; + + ops.at("conv")->attrs["0"] = Attribute(); + ops.at("conv")->attrs["0"].data = {0, 0, 0, 0}; + ops.at("conv")->attrs["1"] = captured_attrs.at("op_0.weight"); + if (captured_params.at("bias").b) + ops.at("conv")->attrs["2"] = captured_attrs.at("op_0.bias"); + } +}; + +class nn_Conv1d_3 : public nn_Conv1d_2 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Padding pad 1 1 input a +ConvolutionDepthWise1D conv 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv1d, 20) REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv1d_1, 21) +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv1d_2, 22) +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv1d_3, 23) } // namespace ncnn diff --git a/tools/pnnx/src/pass_ncnn/nn_Conv2d.cpp b/tools/pnnx/src/pass_ncnn/nn_Conv2d.cpp index 74aee445a30..a033d7cd213 100644 --- a/tools/pnnx/src/pass_ncnn/nn_Conv2d.cpp +++ b/tools/pnnx/src/pass_ncnn/nn_Conv2d.cpp @@ -26,7 +26,7 @@ class nn_Conv2d : public GraphRewriterPass return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -nn.Conv2d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias +nn.Conv2d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias pnnx.Output output 1 0 out )PNNXIR"; } @@ -43,12 +43,6 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { - std::string padding_mode = captured_params.at("padding_mode").s; - if (padding_mode != "zeros") - { - fprintf(stderr, "unsupported padding_mode %s\n", padding_mode.c_str()); - } - op->params["0"] = captured_params.at("out_channels"); op->params["1"] = captured_params.at("kernel_size").ai[1]; op->params["11"] = captured_params.at("kernel_size").ai[0]; @@ -87,7 +81,7 @@ class nn_Conv2d_1 : public GraphRewriterPass return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -nn.Conv2d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +nn.Conv2d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias pnnx.Output output 1 0 out )PNNXIR"; } @@ -104,12 +98,6 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { - std::string padding_mode = captured_params.at("padding_mode").s; - if (padding_mode != "zeros") - { - fprintf(stderr, "unsupported padding_mode %s\n", padding_mode.c_str()); - } - op->params["0"] = captured_params.at("out_channels"); op->params["1"] = captured_params.at("kernel_size").ai[1]; op->params["11"] = captured_params.at("kernel_size").ai[0]; @@ -141,8 +129,162 @@ pnnx.Output output 1 0 out } }; +class nn_Conv2d_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv2d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Padding pad 1 1 input a +Convolution conv 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + bool match(const std::map& captured_params) const + { + const std::string& padding_mode = captured_params.at("padding_mode").s; + if (padding_mode == "zeros") + return false; + + return true; + } + + bool match(const std::map& matched_operators) const + { + const Operator* conv = matched_operators.at("op_0"); + if (conv->params.at("padding").type == 4 && conv->params.at("padding").s == "same") + { + const std::vector input_shape = conv->inputs[0]->shape; + if (input_shape.size() != 3 && input_shape.size() != 4) + { + fprintf(stderr, "can not resolve pads without shape\n"); + return false; + } + } + + return true; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + std::vector padding; + if (captured_params.at("padding").type == 4) + { + if (captured_params.at("padding").s == "same") + { + // resolve pads + const std::vector input_shape = ops.at("pad")->inputs[0]->shape; + const int w = input_shape[input_shape.size() - 1]; + const int h = input_shape[input_shape.size() - 2]; + const int kernel_w = captured_params.at("kernel_size").ai[1]; + const int kernel_h = captured_params.at("kernel_size").ai[0]; + const int dilation_w = captured_params.at("dilation").ai[1]; + const int dilation_h = captured_params.at("dilation").ai[0]; + const int stride_w = captured_params.at("stride").ai[1]; + const int stride_h = captured_params.at("stride").ai[0]; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + const int kernel_extent_h = dilation_h * (kernel_h - 1) + 1; + + int wpad = kernel_extent_w + (w - 1) / stride_w * stride_w - w; + int hpad = kernel_extent_h + (h - 1) / stride_h * stride_h - h; + + padding = std::vector{hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2}; + } + else if (captured_params.at("padding").s == "valid") + { + padding = std::vector{0, 0, 0, 0}; + } + } + else + { + int hpad = captured_params.at("padding").ai[0]; + int wpad = captured_params.at("padding").ai[1]; + padding = std::vector{hpad, hpad, wpad, wpad}; + } + + ops.at("pad")->params["0"] = padding[0]; + ops.at("pad")->params["1"] = padding[1]; + ops.at("pad")->params["2"] = padding[2]; + ops.at("pad")->params["3"] = padding[3]; + + std::string padding_mode = captured_params.at("padding_mode").s; + if (padding_mode == "reflect") + { + ops.at("pad")->params["4"] = 2; // type=reflect + } + else if (padding_mode == "replicate") + { + ops.at("pad")->params["4"] = 1; // type=replicate + } + else + { + fprintf(stderr, "unsupported padding_mode %s\n", padding_mode.c_str()); + } + + ops.at("conv")->params["0"] = captured_params.at("out_channels"); + ops.at("conv")->params["1"] = captured_params.at("kernel_size").ai[1]; + ops.at("conv")->params["11"] = captured_params.at("kernel_size").ai[0]; + ops.at("conv")->params["2"] = captured_params.at("dilation").ai[1]; + ops.at("conv")->params["12"] = captured_params.at("dilation").ai[0]; + ops.at("conv")->params["3"] = captured_params.at("stride").ai[1]; + ops.at("conv")->params["13"] = captured_params.at("stride").ai[0]; + ops.at("conv")->params["4"] = 0; + ops.at("conv")->params["14"] = 0; + ops.at("conv")->params["5"] = captured_params.at("bias").b ? 1 : 0; + ops.at("conv")->params["6"] = captured_attrs.at("op_0.weight").elemcount(); + ops.at("conv")->params["7"] = captured_params.find("groups") != captured_params.end() ? captured_params.at("groups") : 1; + + ops.at("conv")->attrs["0"] = Attribute(); + ops.at("conv")->attrs["0"].data = {0, 0, 0, 0}; + ops.at("conv")->attrs["1"] = captured_attrs.at("op_0.weight"); + if (captured_params.at("bias").b) + ops.at("conv")->attrs["2"] = captured_attrs.at("op_0.bias"); + } +}; + +class nn_Conv2d_3 : public nn_Conv2d_2 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv2d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Padding pad 1 1 input a +ConvolutionDepthWise conv 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv2d, 20) REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv2d_1, 21) +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv2d_2, 22) +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv2d_3, 23) } // namespace ncnn diff --git a/tools/pnnx/src/pass_ncnn/nn_Conv3d.cpp b/tools/pnnx/src/pass_ncnn/nn_Conv3d.cpp index 0f33cb0d819..b39e47d26be 100644 --- a/tools/pnnx/src/pass_ncnn/nn_Conv3d.cpp +++ b/tools/pnnx/src/pass_ncnn/nn_Conv3d.cpp @@ -26,7 +26,7 @@ class nn_Conv3d : public GraphRewriterPass return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -nn.Conv3d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias +nn.Conv3d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias pnnx.Output output 1 0 out )PNNXIR"; } @@ -43,12 +43,6 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { - std::string padding_mode = captured_params.at("padding_mode").s; - if (padding_mode != "zeros") - { - fprintf(stderr, "unsupported padding_mode %s\n", padding_mode.c_str()); - } - op->params["0"] = captured_params.at("out_channels"); op->params["1"] = captured_params.at("kernel_size").ai[2]; op->params["11"] = captured_params.at("kernel_size").ai[1]; @@ -91,7 +85,7 @@ class nn_Conv3d_1 : public GraphRewriterPass return R"PNNXIR(7767517 3 2 pnnx.Input input 0 1 input -nn.Conv3d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +nn.Conv3d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias pnnx.Output output 1 0 out )PNNXIR"; } @@ -108,12 +102,6 @@ pnnx.Output output 1 0 out void write(Operator* op, const std::map& captured_params, const std::map& captured_attrs) const { - std::string padding_mode = captured_params.at("padding_mode").s; - if (padding_mode != "zeros") - { - fprintf(stderr, "unsupported padding_mode %s\n", padding_mode.c_str()); - } - op->params["0"] = captured_params.at("out_channels"); op->params["1"] = captured_params.at("kernel_size").ai[2]; op->params["11"] = captured_params.at("kernel_size").ai[1]; @@ -149,8 +137,175 @@ pnnx.Output output 1 0 out } }; +class nn_Conv3d_2 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv3d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Padding pad 1 1 input a +Convolution3D conv 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + bool match(const std::map& captured_params) const + { + const std::string& padding_mode = captured_params.at("padding_mode").s; + if (padding_mode == "zeros") + return false; + + return true; + } + + bool match(const std::map& matched_operators) const + { + const Operator* conv = matched_operators.at("op_0"); + if (conv->params.at("padding").type == 4 && conv->params.at("padding").s == "same") + { + const std::vector input_shape = conv->inputs[0]->shape; + if (input_shape.size() != 4 && input_shape.size() != 5) + { + fprintf(stderr, "can not resolve pads without shape\n"); + return false; + } + } + + return true; + } + + void write(const std::map& ops, const std::map& captured_params, const std::map& captured_attrs) const + { + std::vector padding; + if (captured_params.at("padding").type == 4) + { + if (captured_params.at("padding").s == "same") + { + // resolve pads + const std::vector input_shape = ops.at("pad")->inputs[0]->shape; + const int w = input_shape[input_shape.size() - 1]; + const int h = input_shape[input_shape.size() - 2]; + const int d = input_shape[input_shape.size() - 3]; + const int kernel_w = captured_params.at("kernel_size").ai[2]; + const int kernel_h = captured_params.at("kernel_size").ai[1]; + const int kernel_d = captured_params.at("kernel_size").ai[0]; + const int dilation_w = captured_params.at("dilation").ai[2]; + const int dilation_h = captured_params.at("dilation").ai[1]; + const int dilation_d = captured_params.at("dilation").ai[0]; + const int stride_w = captured_params.at("stride").ai[2]; + const int stride_h = captured_params.at("stride").ai[1]; + const int stride_d = captured_params.at("stride").ai[0]; + + const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1; + const int kernel_extent_h = dilation_h * (kernel_h - 1) + 1; + const int kernel_extent_d = dilation_d * (kernel_d - 1) + 1; + + int wpad = kernel_extent_w + (w - 1) / stride_w * stride_w - w; + int hpad = kernel_extent_h + (h - 1) / stride_h * stride_h - h; + int dpad = kernel_extent_d + (d - 1) / stride_d * stride_d - d; + + padding = std::vector{hpad / 2, hpad - hpad / 2, wpad / 2, wpad - wpad / 2, dpad / 2, dpad - dpad / 2}; + } + else if (captured_params.at("padding").s == "valid") + { + padding = std::vector{0, 0, 0, 0, 0, 0}; + } + } + else + { + int dpad = captured_params.at("padding").ai[0]; + int hpad = captured_params.at("padding").ai[1]; + int wpad = captured_params.at("padding").ai[2]; + padding = std::vector{hpad, hpad, wpad, wpad, dpad, dpad}; + } + + ops.at("pad")->params["0"] = padding[0]; + ops.at("pad")->params["1"] = padding[1]; + ops.at("pad")->params["2"] = padding[2]; + ops.at("pad")->params["3"] = padding[3]; + ops.at("pad")->params["7"] = padding[4]; + ops.at("pad")->params["8"] = padding[5]; + + std::string padding_mode = captured_params.at("padding_mode").s; + if (padding_mode == "reflect") + { + ops.at("pad")->params["4"] = 2; // type=reflect + } + else if (padding_mode == "replicate") + { + ops.at("pad")->params["4"] = 1; // type=replicate + } + else + { + fprintf(stderr, "unsupported padding_mode %s\n", padding_mode.c_str()); + } + + ops.at("conv")->params["0"] = captured_params.at("out_channels"); + ops.at("conv")->params["1"] = captured_params.at("kernel_size").ai[2]; + ops.at("conv")->params["11"] = captured_params.at("kernel_size").ai[1]; + ops.at("conv")->params["21"] = captured_params.at("kernel_size").ai[0]; + ops.at("conv")->params["2"] = captured_params.at("dilation").ai[2]; + ops.at("conv")->params["12"] = captured_params.at("dilation").ai[1]; + ops.at("conv")->params["22"] = captured_params.at("dilation").ai[0]; + ops.at("conv")->params["3"] = captured_params.at("stride").ai[2]; + ops.at("conv")->params["13"] = captured_params.at("stride").ai[1]; + ops.at("conv")->params["23"] = captured_params.at("stride").ai[0]; + ops.at("conv")->params["4"] = 0; + ops.at("conv")->params["14"] = 0; + ops.at("conv")->params["24"] = 0; + ops.at("conv")->params["5"] = captured_params.at("bias").b ? 1 : 0; + ops.at("conv")->params["6"] = captured_attrs.at("op_0.weight").elemcount(); + ops.at("conv")->params["7"] = captured_params.find("groups") != captured_params.end() ? captured_params.at("groups") : 1; + + ops.at("conv")->attrs["0"] = Attribute(); + ops.at("conv")->attrs["0"].data = {0, 0, 0, 0}; + ops.at("conv")->attrs["1"] = captured_attrs.at("op_0.weight"); + if (captured_params.at("bias").b) + ops.at("conv")->attrs["2"] = captured_attrs.at("op_0.bias"); + } +}; + +class nn_Conv3d_3 : public nn_Conv3d_2 +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +3 2 +pnnx.Input input 0 1 input +nn.Conv3d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* replace_pattern_graph() const + { + return R"PNNXIR(7767517 +4 3 +pnnx.Input input 0 1 input +Padding pad 1 1 input a +ConvolutionDepthWise3D conv 1 1 a out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv3d, 20) REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv3d_1, 21) +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv3d_2, 22) +REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv3d_3, 23) } // namespace ncnn diff --git a/tools/pnnx/tests/ncnn/test_nn_Conv1d.py b/tools/pnnx/tests/ncnn/test_nn_Conv1d.py index ce187d2512f..18aa1a9de49 100644 --- a/tools/pnnx/tests/ncnn/test_nn_Conv1d.py +++ b/tools/pnnx/tests/ncnn/test_nn_Conv1d.py @@ -30,8 +30,8 @@ def __init__(self): else: self.conv_3 = nn.Conv1d(in_channels=24, out_channels=28, kernel_size=5, stride=1, padding='valid', dilation=1, groups=4, bias=True) self.conv_4 = nn.Conv1d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding='same', dilation=2, groups=2, bias=False, padding_mode='zeros') - #self.conv_5 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') - #self.conv_6 = nn.Conv1d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') + self.conv_5 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') + self.conv_6 = nn.Conv1d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') def forward(self, x): x = self.conv_0(x) @@ -39,8 +39,8 @@ def forward(self, x): x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) - #x = self.conv_5(x) - #x = self.conv_6(x) + x = self.conv_5(x) + x = self.conv_6(x) return x diff --git a/tools/pnnx/tests/ncnn/test_nn_Conv2d.py b/tools/pnnx/tests/ncnn/test_nn_Conv2d.py index 19aa5fe9645..5cde9e1b530 100644 --- a/tools/pnnx/tests/ncnn/test_nn_Conv2d.py +++ b/tools/pnnx/tests/ncnn/test_nn_Conv2d.py @@ -30,8 +30,8 @@ def __init__(self): else: self.conv_3 = nn.Conv2d(in_channels=24, out_channels=28, kernel_size=(5,4), stride=1, padding='valid', dilation=1, groups=4, bias=True) self.conv_4 = nn.Conv2d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding='same', dilation=(1,2), groups=2, bias=False, padding_mode='zeros') - #self.conv_5 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') - #self.conv_6 = nn.Conv2d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') + self.conv_5 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') + self.conv_6 = nn.Conv2d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') def forward(self, x): x = self.conv_0(x) @@ -39,8 +39,8 @@ def forward(self, x): x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) - #x = self.conv_5(x) - #x = self.conv_6(x) + x = self.conv_5(x) + x = self.conv_6(x) return x diff --git a/tools/pnnx/tests/ncnn/test_nn_Conv3d.py b/tools/pnnx/tests/ncnn/test_nn_Conv3d.py index cfe9382d645..3fa339ecfad 100644 --- a/tools/pnnx/tests/ncnn/test_nn_Conv3d.py +++ b/tools/pnnx/tests/ncnn/test_nn_Conv3d.py @@ -30,8 +30,8 @@ def __init__(self): else: self.conv_3 = nn.Conv3d(in_channels=24, out_channels=28, kernel_size=(5,4,3), stride=1, padding='valid', dilation=1, groups=4, bias=True) self.conv_4 = nn.Conv3d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding='same', dilation=(1,2,2), groups=2, bias=False, padding_mode='zeros') - #self.conv_5 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') - #self.conv_6 = nn.Conv3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') + self.conv_5 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') + self.conv_6 = nn.Conv3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') #self.conv_7 = nn.Conv3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), dilation=2, groups=1, bias=True, padding_mode='circular') def forward(self, x): @@ -40,8 +40,8 @@ def forward(self, x): x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) - #x = self.conv_5(x) - #x = self.conv_6(x) + x = self.conv_5(x) + x = self.conv_6(x) #x = self.conv_7(x) return x From 10ed8b209dfee2f62cd9a641c3a86cfaf31a4812 Mon Sep 17 00:00:00 2001 From: nihui Date: Mon, 10 Jul 2023 22:45:44 +0800 Subject: [PATCH 2/2] fix --- tools/pnnx/src/pass_level1/nn_Conv3d.cpp | 80 ++++++++++++------------ tools/pnnx/tests/ncnn/test_nn_Conv3d.py | 10 ++- tools/pnnx/tests/test_nn_Conv3d.py | 14 +++-- 3 files changed, 56 insertions(+), 48 deletions(-) diff --git a/tools/pnnx/src/pass_level1/nn_Conv3d.cpp b/tools/pnnx/src/pass_level1/nn_Conv3d.cpp index 9e18718a7bd..271c024d90b 100644 --- a/tools/pnnx/src/pass_level1/nn_Conv3d.cpp +++ b/tools/pnnx/src/pass_level1/nn_Conv3d.cpp @@ -48,8 +48,8 @@ class Conv3d : public FuseModulePass const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution"); const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode"); const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad"); - // const torch::jit::Node* reflection_pad3d = find_node_by_kind(graph, "aten::reflection_pad3d"); - // const torch::jit::Node* replication_pad3d = find_node_by_kind(graph, "aten::replication_pad3d"); + const torch::jit::Node* reflection_pad3d = find_node_by_kind(graph, "aten::reflection_pad3d"); + const torch::jit::Node* replication_pad3d = find_node_by_kind(graph, "aten::replication_pad3d"); if (convolution_mode) { @@ -82,44 +82,44 @@ class Conv3d : public FuseModulePass } } } - // if (reflection_pad3d) - // { - // op->params["padding_mode"] = "reflect"; - // op->params["padding"] = reflection_pad3d->namedInput("padding"); - // std::vector& padding = op->params["padding"].ai; - // if (padding.size() == 6) - // { - // // Conv3d only accepts tuple of three integers - // if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3] && padding[3] == padding[4] && padding[4] == padding[5]) - // { - // padding.resize(3); - // } - // else if (padding[0] == padding[3] && padding[1] == padding[4] && padding[2] == padding[5] && padding[0] != padding[1] && padding[1] != padding[2]) - // { - // padding.resize(0); - // op->params["padding"].s = "same"; - // } - // } - // } - // else if (replication_pad3d) - // { - // op->params["padding_mode"] = "replicate"; - // op->params["padding"] = replication_pad3d->namedInput("padding"); - // std::vector& padding = op->params["padding"].ai; - // if (padding.size() == 6) - // { - // // Conv3d only accepts tuple of three integers - // if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3] && padding[3] == padding[4] && padding[4] == padding[5]) - // { - // padding.resize(3); - // } - // else if (padding[0] == padding[3] && padding[1] == padding[4] && padding[2] == padding[5] && padding[0] != padding[1] && padding[1] != padding[2]) - // { - // padding.resize(0); - // op->params["padding"].s = "same"; - // } - // } - // } + else if (reflection_pad3d) + { + op->params["padding_mode"] = "reflect"; + op->params["padding"] = reflection_pad3d->namedInput("padding"); + std::vector& padding = op->params["padding"].ai; + if (padding.size() == 6) + { + // Conv3d only accepts tuple of three integers + if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3] && padding[3] == padding[4] && padding[4] == padding[5]) + { + padding.resize(3); + } + else if (padding[0] == padding[3] && padding[1] == padding[4] && padding[2] == padding[5] && padding[0] != padding[1] && padding[1] != padding[2]) + { + padding.resize(0); + op->params["padding"].s = "same"; + } + } + } + else if (replication_pad3d) + { + op->params["padding_mode"] = "replicate"; + op->params["padding"] = replication_pad3d->namedInput("padding"); + std::vector& padding = op->params["padding"].ai; + if (padding.size() == 6) + { + // Conv3d only accepts tuple of three integers + if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3] && padding[3] == padding[4] && padding[4] == padding[5]) + { + padding.resize(3); + } + else if (padding[0] == padding[3] && padding[1] == padding[4] && padding[2] == padding[5] && padding[0] != padding[1] && padding[1] != padding[2]) + { + padding.resize(0); + op->params["padding"].s = "same"; + } + } + } else { op->params["padding_mode"] = "zeros"; diff --git a/tools/pnnx/tests/ncnn/test_nn_Conv3d.py b/tools/pnnx/tests/ncnn/test_nn_Conv3d.py index 3fa339ecfad..b68a937af9d 100644 --- a/tools/pnnx/tests/ncnn/test_nn_Conv3d.py +++ b/tools/pnnx/tests/ncnn/test_nn_Conv3d.py @@ -30,9 +30,10 @@ def __init__(self): else: self.conv_3 = nn.Conv3d(in_channels=24, out_channels=28, kernel_size=(5,4,3), stride=1, padding='valid', dilation=1, groups=4, bias=True) self.conv_4 = nn.Conv3d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding='same', dilation=(1,2,2), groups=2, bias=False, padding_mode='zeros') - self.conv_5 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') - self.conv_6 = nn.Conv3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') - #self.conv_7 = nn.Conv3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), dilation=2, groups=1, bias=True, padding_mode='circular') + if version.parse(torch.__version__) >= version.parse('1.10'): + self.conv_5 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') + self.conv_6 = nn.Conv3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') + # self.conv_7 = nn.Conv3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), dilation=2, groups=1, bias=True, padding_mode='circular') def forward(self, x): x = self.conv_0(x) @@ -40,6 +41,9 @@ def forward(self, x): x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) + if version.parse(torch.__version__) < version.parse('1.10'): + return x + x = self.conv_5(x) x = self.conv_6(x) #x = self.conv_7(x) diff --git a/tools/pnnx/tests/test_nn_Conv3d.py b/tools/pnnx/tests/test_nn_Conv3d.py index 250d496020a..76221d9bca4 100644 --- a/tools/pnnx/tests/test_nn_Conv3d.py +++ b/tools/pnnx/tests/test_nn_Conv3d.py @@ -30,9 +30,10 @@ def __init__(self): else: self.conv_3 = nn.Conv3d(in_channels=24, out_channels=28, kernel_size=(5,4,3), stride=1, padding='valid', dilation=1, groups=4, bias=True) self.conv_4 = nn.Conv3d(in_channels=28, out_channels=32, kernel_size=3, stride=1, padding='same', dilation=(1,2,2), groups=2, bias=False, padding_mode='zeros') - #self.conv_5 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') - #self.conv_6 = nn.Conv3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') - #self.conv_7 = nn.Conv3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), dilation=2, groups=1, bias=True, padding_mode='circular') + if version.parse(torch.__version__) >= version.parse('1.10'): + self.conv_5 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=2, stride=2, padding=3, dilation=1, groups=32, bias=True, padding_mode='reflect') + self.conv_6 = nn.Conv3d(in_channels=32, out_channels=28, kernel_size=2, stride=1, padding=2, dilation=1, groups=1, bias=False, padding_mode='replicate') + # self.conv_7 = nn.Conv3d(in_channels=28, out_channels=24, kernel_size=3, stride=2, padding=(5,6), dilation=2, groups=1, bias=True, padding_mode='circular') def forward(self, x): x = self.conv_0(x) @@ -40,8 +41,11 @@ def forward(self, x): x = self.conv_2(x) x = self.conv_3(x) x = self.conv_4(x) - #x = self.conv_5(x) - #x = self.conv_6(x) + if version.parse(torch.__version__) < version.parse('1.10'): + return x + + x = self.conv_5(x) + x = self.conv_6(x) #x = self.conv_7(x) return x