Skip to content

Commit

Permalink
pnnx update (#4870)
Browse files Browse the repository at this point in the history
Tensor.fill
Tensor.index_put
Tensor.to
Tensor.type_as
torch.topk
fmod
call Tensor member functions with inputnames
static shape_as_tensor
nn.Linear dynamic bias
eliminate noop type_as
convert two-dim nn.Linear to ncnn gemm
convert torch.stack to ncnn concat+reshape
ignore torch einsum path input
  • Loading branch information
nihui authored Jul 24, 2023
1 parent 5570970 commit 669ee2f
Show file tree
Hide file tree
Showing 33 changed files with 1,278 additions and 21 deletions.
7 changes: 7 additions & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,9 @@ set(pnnx_pass_level2_SRCS
pass_level2/Tensor_copy.cpp
pass_level2/Tensor_expand.cpp
pass_level2/Tensor_expand_as.cpp
pass_level2/Tensor_fill.cpp
pass_level2/Tensor_index.cpp
pass_level2/Tensor_index_put.cpp
pass_level2/Tensor_masked_fill.cpp
pass_level2/Tensor_new_empty.cpp
pass_level2/Tensor_new_ones.cpp
Expand All @@ -189,6 +191,8 @@ set(pnnx_pass_level2_SRCS
pass_level2/Tensor_reshape.cpp
pass_level2/Tensor_select.cpp
pass_level2/Tensor_slice.cpp
pass_level2/Tensor_to.cpp
pass_level2/Tensor_type_as.cpp
pass_level2/Tensor_view.cpp
pass_level2/torch_addmm.cpp
pass_level2/torch_amax.cpp
Expand Down Expand Up @@ -252,6 +256,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/torch_sum.cpp
pass_level2/torch_permute.cpp
pass_level2/torch_tensor_split.cpp
pass_level2/torch_topk.cpp
pass_level2/torch_transpose.cpp
pass_level2/torch_unbind.cpp
pass_level2/torch_unsqueeze.cpp
Expand Down Expand Up @@ -320,6 +325,7 @@ set(pnnx_pass_level5_SRCS
pass_level5/eliminate_noop_slice.cpp
pass_level5/eliminate_noop_view_reshape.cpp
pass_level5/eliminate_reshape_shape_expression.cpp
pass_level5/eliminate_type_as.cpp
pass_level5/eval_expression.cpp
pass_level5/fold_constants.cpp
pass_level5/fuse_adjacent_reshape.cpp
Expand Down Expand Up @@ -361,6 +367,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/convert_torch_chunk.cpp
pass_ncnn/convert_torch_einsum.cpp
pass_ncnn/convert_torch_split.cpp
pass_ncnn/convert_torch_stack.cpp
pass_ncnn/convert_torch_tensor_split.cpp
pass_ncnn/convert_torch_unbind.cpp
pass_ncnn/convert_Tensor_select.cpp
Expand Down
39 changes: 35 additions & 4 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1297,10 +1297,12 @@ static std::string expand_expression(const Operator* op)
exprstack.push(r);
}
else if (t == "atan2"
|| t == "fmod"
|| t == "pow")
{
std::string binaryop;
if (t == "atan2") binaryop = "torch.atan2";
if (t == "fmod") binaryop = "torch.fmod";
if (t == "pow") binaryop = "torch.pow";

std::string a = exprstack.top();
Expand All @@ -1311,14 +1313,15 @@ static std::string expand_expression(const Operator* op)
std::string r = binaryop + "(" + a + ", " + b + ")";
exprstack.push(r);
}
else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift")
else if (t == "add" || t == "sub" || t == "mul" || t == "div" || t == "floor_divide" || t == "remainder" || t == "and" || t == "or" || t == "xor" || t == "lshift" || t == "rshift")
{
std::string binaryop;
if (t == "add") binaryop = "+";
if (t == "sub") binaryop = "-";
if (t == "mul") binaryop = "*";
if (t == "div") binaryop = "/";
if (t == "floor_divide") binaryop = "//";
if (t == "remainder") binaryop = "%";
if (t == "and") binaryop = "&";
if (t == "or") binaryop = "|";
if (t == "xor") binaryop = "^";
Expand Down Expand Up @@ -2152,11 +2155,39 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)

if (op->type.substr(0, 7) == "Tensor.")
{
fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str());
if (op->type == "Tensor.fill")
{
fprintf(pyfp, " = v_%s.fill_(", sanitize_identifier(op->inputs[0]->name).c_str());
}
else
{
fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str());
}

if (op->inputnames.size() == op->inputs.size())
{
for (size_t i = 1; i < op->inputs.size(); i++)
{
if (!op->inputnames[i].empty())
continue;

for (size_t i = 1; i < op->inputs.size(); i++)
fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str());
}

for (size_t i = 1; i < op->inputs.size(); i++)
{
if (op->inputnames[i].empty())
continue;

fprintf(pyfp, "%s=v_%s, ", op->inputnames[i].c_str(), sanitize_identifier(op->inputs[i]->name).c_str());
}
}
else
{
fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str());
for (size_t i = 1; i < op->inputs.size(); i++)
{
fprintf(pyfp, "v_%s, ", sanitize_identifier(op->inputs[i]->name).c_str());
}
}
}
else
Expand Down
3 changes: 2 additions & 1 deletion tools/pnnx/src/pass_level0/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ static bool value_link_input(const torch::jit::Value* v, const std::vector<torch
|| optype == "aten::empty_like"
|| optype == "aten::full_like"
|| optype == "aten::ones_like"
|| optype == "aten::zeros_like")
|| optype == "aten::zeros_like"
|| optype == "aten::_shape_as_tensor")
return false;
}

Expand Down
4 changes: 2 additions & 2 deletions tools/pnnx/src/pass_level1/nn_Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ class Linear : public FuseModulePass

op->params["in_features"] = weight.size(1);
op->params["out_features"] = weight.size(0);
op->params["bias"] = mod.hasattr("bias");
op->params["bias"] = mod.hasattr("bias") && mod.attr("bias").isTensor();

op->attrs["weight"] = weight;
if (mod.hasattr("bias"))
if (mod.hasattr("bias") && mod.attr("bias").isTensor())
{
op->attrs["bias"] = mod.attr("bias").toTensor();
}
Expand Down
41 changes: 41 additions & 0 deletions tools/pnnx/src/pass_level2/Tensor_fill.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class Tensor_fill : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 value
aten::fill op_0 2 1 input value out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.fill";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_fill, 20)

} // namespace pnnx
43 changes: 43 additions & 0 deletions tools/pnnx/src/pass_level2/Tensor_index_put.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class Tensor_index_put : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
6 5
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 indices
pnnx.Input input_2 0 1 values
prim::Constant op_0 0 1 accumulate value=%accumulate
aten::index_put op_1 4 1 input indices values accumulate out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.index_put";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_index_put, 20)

} // namespace pnnx
2 changes: 1 addition & 1 deletion tools/pnnx/src/pass_level2/Tensor_masked_fill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Tensor_masked_fill : public GraphRewriterPass
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 mask
pnnx.Input input_2 0 1 value
aten::masked_fill op_1 3 1 input mask value out
aten::masked_fill op_0 3 1 input mask value out
pnnx.Output output 1 0 out
)PNNXIR";
}
Expand Down
89 changes: 89 additions & 0 deletions tools/pnnx/src/pass_level2/Tensor_to.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class Tensor_to : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input_0 0 1 input
prim::Constant op_0 0 1 dtype value=%dtype
prim::Constant op_1 0 1 non_blocking value=*
prim::Constant op_2 0 1 copy value=%copy
prim::Constant op_3 0 1 memory_format value=%memory_format
aten::to op_4 5 1 input dtype non_blocking copy memory_format out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.to";
}

void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const
{
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8";
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8";
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short";
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int";
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long";
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half";
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float";
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double";
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32";
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64";
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128";
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool";

op->params["copy"] = captured_params.at("copy");

if (captured_params.at("memory_format").i == 0)
op->params["memory_format"] = "torch.contiguous_format";
if (captured_params.at("memory_format").i == 1)
op->params["memory_format"] = "torch.preserve_format";
if (captured_params.at("memory_format").i == 2)
op->params["memory_format"] = "torch.channels_last";
}
};

class Tensor_to_1 : public Tensor_to
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
8 7
pnnx.Input input_0 0 1 input
prim::Constant op_0 0 1 device value=*
prim::Constant op_1 0 1 dtype value=%dtype
prim::Constant op_2 0 1 non_blocking value=*
prim::Constant op_3 0 1 copy value=%copy
prim::Constant op_4 0 1 memory_format value=%memory_format
aten::to op_5 6 1 input device dtype non_blocking copy memory_format out
pnnx.Output output 1 0 out
)PNNXIR";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to, 20)
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_1, 20)

} // namespace pnnx
41 changes: 41 additions & 0 deletions tools/pnnx/src/pass_level2/Tensor_type_as.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class Tensor_type_as : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
4 3
pnnx.Input input_0 0 1 input
pnnx.Input input_1 0 1 other
aten::type_as op_0 2 1 input other out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "Tensor.type_as";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_type_as, 20)

} // namespace pnnx
9 changes: 8 additions & 1 deletion tools/pnnx/src/pass_level2/torch_einsum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class torch_einsum_1 : public GraphRewriterPass
5 4
pnnx.Input input_0 0 1 equation
pnnx.Input input_1 0 1 operands
prim::Constant op_0 0 1 path value=None
pnnx.Input input_2 0 1 path
aten::einsum op_1 3 1 equation operands path out
pnnx.Output output 1 0 out
)PNNXIR";
Expand All @@ -57,6 +57,13 @@ pnnx.Output output 1 0 out
{
return "torch.einsum";
}

void write(Operator* op, const std::map<std::string, Parameter>& /*captured_params*/) const
{
// drop path input
op->inputs[2]->remove_consumer(op);
op->inputs.resize(2);
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(torch_einsum_1, 20)
Expand Down
Loading

0 comments on commit 669ee2f

Please sign in to comment.