Skip to content

Commit

Permalink
pnnx convert conv with non-zero padding mode (#4849)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Jul 10, 2023
1 parent f1943fd commit a87be24
Show file tree
Hide file tree
Showing 8 changed files with 557 additions and 101 deletions.
102 changes: 61 additions & 41 deletions tools/pnnx/src/pass_level1/nn_Conv3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ class Conv3d : public FuseModulePass

const torch::jit::Node* convolution = find_node_by_kind(graph, "aten::_convolution");
const torch::jit::Node* convolution_mode = find_node_by_kind(graph, "aten::_convolution_mode");
// const torch::jit::Node* reflection_pad3d = find_node_by_kind(graph, "aten::reflection_pad3d");
// const torch::jit::Node* replication_pad3d = find_node_by_kind(graph, "aten::replication_pad3d");
const torch::jit::Node* pad = find_node_by_kind(graph, "aten::pad");
const torch::jit::Node* reflection_pad3d = find_node_by_kind(graph, "aten::reflection_pad3d");
const torch::jit::Node* replication_pad3d = find_node_by_kind(graph, "aten::replication_pad3d");

if (convolution_mode)
{
Expand All @@ -62,45 +63,64 @@ class Conv3d : public FuseModulePass
op->params["out_channels"] = weight.size(0);
op->params["kernel_size"] = Parameter{weight.size(2), weight.size(3), weight.size(4)};
op->params["stride"] = convolution->namedInput("stride");
// if (reflection_pad3d)
// {
// op->params["padding_mode"] = "reflect";
// op->params["padding"] = reflection_pad3d->namedInput("padding");
// std::vector<int>& padding = op->params["padding"].ai;
// if (padding.size() == 6)
// {
// // Conv3d only accepts tuple of three integers
// if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3] && padding[3] == padding[4] && padding[4] == padding[5])
// {
// padding.resize(3);
// }
// else if (padding[0] == padding[3] && padding[1] == padding[4] && padding[2] == padding[5] && padding[0] != padding[1] && padding[1] != padding[2])
// {
// padding.resize(0);
// op->params["padding"].s = "same";
// }
// }
// }
// else if (replication_pad3d)
// {
// op->params["padding_mode"] = "replicate";
// op->params["padding"] = replication_pad3d->namedInput("padding");
// std::vector<int>& padding = op->params["padding"].ai;
// if (padding.size() == 6)
// {
// // Conv3d only accepts tuple of three integers
// if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3] && padding[3] == padding[4] && padding[4] == padding[5])
// {
// padding.resize(3);
// }
// else if (padding[0] == padding[3] && padding[1] == padding[4] && padding[2] == padding[5] && padding[0] != padding[1] && padding[1] != padding[2])
// {
// padding.resize(0);
// op->params["padding"].s = "same";
// }
// }
// }
// else
if (pad)
{
op->params["padding_mode"] = pad->namedInput("mode");
op->params["padding"] = pad->namedInput("pad");
std::vector<int>& padding = op->params["padding"].ai;
if (padding.size() == 6)
{
// Conv3d only accepts tuple of three integers
if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3] && padding[3] == padding[4] && padding[4] == padding[5])
{
padding.resize(3);
}
else if (padding[0] == padding[3] && padding[1] == padding[4] && padding[2] == padding[5] && padding[0] != padding[1] && padding[1] != padding[2])
{
padding.resize(0);
op->params["padding"].s = "same";
}
}
}
else if (reflection_pad3d)
{
op->params["padding_mode"] = "reflect";
op->params["padding"] = reflection_pad3d->namedInput("padding");
std::vector<int>& padding = op->params["padding"].ai;
if (padding.size() == 6)
{
// Conv3d only accepts tuple of three integers
if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3] && padding[3] == padding[4] && padding[4] == padding[5])
{
padding.resize(3);
}
else if (padding[0] == padding[3] && padding[1] == padding[4] && padding[2] == padding[5] && padding[0] != padding[1] && padding[1] != padding[2])
{
padding.resize(0);
op->params["padding"].s = "same";
}
}
}
else if (replication_pad3d)
{
op->params["padding_mode"] = "replicate";
op->params["padding"] = replication_pad3d->namedInput("padding");
std::vector<int>& padding = op->params["padding"].ai;
if (padding.size() == 6)
{
// Conv3d only accepts tuple of three integers
if (padding[0] == padding[1] && padding[1] == padding[2] && padding[2] == padding[3] && padding[3] == padding[4] && padding[4] == padding[5])
{
padding.resize(3);
}
else if (padding[0] == padding[3] && padding[1] == padding[4] && padding[2] == padding[5] && padding[0] != padding[1] && padding[1] != padding[2])
{
padding.resize(0);
op->params["padding"].s = "same";
}
}
}
else
{
op->params["padding_mode"] = "zeros";
op->params["padding"] = convolution->namedInput("padding");
Expand Down
159 changes: 145 additions & 14 deletions tools/pnnx/src/pass_ncnn/nn_Conv1d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class nn_Conv1d : public GraphRewriterPass
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias
nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -43,12 +43,6 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
std::string padding_mode = captured_params.at("padding_mode").s;
if (padding_mode != "zeros")
{
fprintf(stderr, "unsupported padding_mode %s\n", padding_mode.c_str());
}

op->params["0"] = captured_params.at("out_channels");
op->params["1"] = captured_params.at("kernel_size").ai[0];
op->params["2"] = captured_params.at("dilation").ai[0];
Expand Down Expand Up @@ -83,7 +77,7 @@ class nn_Conv1d_1 : public GraphRewriterPass
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias
nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=zeros padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand All @@ -100,12 +94,6 @@ pnnx.Output output 1 0 out

void write(Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
std::string padding_mode = captured_params.at("padding_mode").s;
if (padding_mode != "zeros")
{
fprintf(stderr, "unsupported padding_mode %s\n", padding_mode.c_str());
}

op->params["0"] = captured_params.at("out_channels");
op->params["1"] = captured_params.at("kernel_size").ai[0];
op->params["2"] = captured_params.at("dilation").ai[0];
Expand Down Expand Up @@ -133,8 +121,151 @@ pnnx.Output output 1 0 out
}
};

class nn_Conv1d_2 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=1 bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
Padding pad 1 1 input a
Convolution1D conv 1 1 a out
pnnx.Output output 1 0 out
)PNNXIR";
}

bool match(const std::map<std::string, Parameter>& captured_params) const
{
const std::string& padding_mode = captured_params.at("padding_mode").s;
if (padding_mode == "zeros")
return false;

return true;
}

bool match(const std::map<std::string, const Operator*>& matched_operators) const
{
const Operator* conv = matched_operators.at("op_0");
if (conv->params.at("padding").type == 4 && conv->params.at("padding").s == "same")
{
const std::vector<int> input_shape = conv->inputs[0]->shape;
if (input_shape.size() != 2 && input_shape.size() != 3)
{
fprintf(stderr, "can not resolve pads without shape\n");
return false;
}
}

return true;
}

void write(const std::map<std::string, Operator*>& ops, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
{
std::vector<int> padding;
if (captured_params.at("padding").type == 4)
{
if (captured_params.at("padding").s == "same")
{
// resolve pads
const std::vector<int> input_shape = ops.at("pad")->inputs[0]->shape;
const int w = input_shape[input_shape.size() - 1];
const int kernel_w = captured_params.at("kernel_size").ai[0];
const int dilation_w = captured_params.at("dilation").ai[0];
const int stride_w = captured_params.at("stride").ai[0];

const int kernel_extent_w = dilation_w * (kernel_w - 1) + 1;

int wpad = kernel_extent_w + (w - 1) / stride_w * stride_w - w;

padding = std::vector<int>{wpad / 2, wpad - wpad / 2};
}
else if (captured_params.at("padding").s == "valid")
{
padding = std::vector<int>{0, 0};
}
}
else
{
int wpad = captured_params.at("padding").ai[0];
padding = std::vector<int>{wpad, wpad};
}

ops.at("pad")->params["0"] = 0;
ops.at("pad")->params["1"] = 0;
ops.at("pad")->params["2"] = padding[0];
ops.at("pad")->params["3"] = padding[1];

std::string padding_mode = captured_params.at("padding_mode").s;
if (padding_mode == "reflect")
{
ops.at("pad")->params["4"] = 2; // type=reflect
}
else if (padding_mode == "replicate")
{
ops.at("pad")->params["4"] = 1; // type=replicate
}
else
{
fprintf(stderr, "unsupported padding_mode %s\n", padding_mode.c_str());
}

ops.at("conv")->params["0"] = captured_params.at("out_channels");
ops.at("conv")->params["1"] = captured_params.at("kernel_size").ai[0];
ops.at("conv")->params["2"] = captured_params.at("dilation").ai[0];
ops.at("conv")->params["3"] = captured_params.at("stride").ai[0];
ops.at("conv")->params["4"] = 0;
ops.at("conv")->params["5"] = captured_params.at("bias").b ? 1 : 0;
ops.at("conv")->params["6"] = captured_attrs.at("op_0.weight").elemcount();
ops.at("conv")->params["7"] = captured_params.find("groups") != captured_params.end() ? captured_params.at("groups") : 1;

ops.at("conv")->attrs["0"] = Attribute();
ops.at("conv")->attrs["0"].data = {0, 0, 0, 0};
ops.at("conv")->attrs["1"] = captured_attrs.at("op_0.weight");
if (captured_params.at("bias").b)
ops.at("conv")->attrs["2"] = captured_attrs.at("op_0.bias");
}
};

class nn_Conv1d_3 : public nn_Conv1d_2
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
3 2
pnnx.Input input 0 1 input
nn.Conv1d op_0 1 1 input out in_channels=%in_channels out_channels=%out_channels kernel_size=%kernel_size stride=%stride padding_mode=%padding_mode padding=%padding dilation=%dilation groups=%groups bias=%bias @weight @bias
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* replace_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input 0 1 input
Padding pad 1 1 input a
ConvolutionDepthWise1D conv 1 1 a out
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv1d, 20)
REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv1d_1, 21)
REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv1d_2, 22)
REGISTER_GLOBAL_PNNX_NCNN_GRAPH_REWRITER_PASS(nn_Conv1d_3, 23)

} // namespace ncnn

Expand Down
Loading

0 comments on commit a87be24

Please sign in to comment.