-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Tensor.fill Tensor.index_put Tensor.to Tensor.type_as torch.topk fmod call Tensor member functions with inputnames static shape_as_tensor nn.Linear dynamic bias eliminate noop type_as convert two-dim nn.Linear to ncnn gemm convert torch.stack to ncnn concat+reshape ignore torch einsum path input
- Loading branch information
Showing
33 changed files
with
1,278 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Tencent is pleased to support the open source community by making ncnn available. | ||
// | ||
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | ||
// | ||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | ||
// in compliance with the License. You may obtain a copy of the License at | ||
// | ||
// https://opensource.org/licenses/BSD-3-Clause | ||
// | ||
// Unless required by applicable law or agreed to in writing, software distributed | ||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | ||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations under the License. | ||
|
||
#include "pass_level2.h" | ||
|
||
namespace pnnx { | ||
|
||
class Tensor_fill : public GraphRewriterPass | ||
{ | ||
public: | ||
const char* match_pattern_graph() const | ||
{ | ||
return R"PNNXIR(7767517 | ||
4 3 | ||
pnnx.Input input_0 0 1 input | ||
pnnx.Input input_1 0 1 value | ||
aten::fill op_0 2 1 input value out | ||
pnnx.Output output 1 0 out | ||
)PNNXIR"; | ||
} | ||
|
||
const char* type_str() const | ||
{ | ||
return "Tensor.fill"; | ||
} | ||
}; | ||
|
||
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_fill, 20) | ||
|
||
} // namespace pnnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
// Tencent is pleased to support the open source community by making ncnn available. | ||
// | ||
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | ||
// | ||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | ||
// in compliance with the License. You may obtain a copy of the License at | ||
// | ||
// https://opensource.org/licenses/BSD-3-Clause | ||
// | ||
// Unless required by applicable law or agreed to in writing, software distributed | ||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | ||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations under the License. | ||
|
||
#include "pass_level2.h" | ||
|
||
namespace pnnx { | ||
|
||
class Tensor_index_put : public GraphRewriterPass | ||
{ | ||
public: | ||
const char* match_pattern_graph() const | ||
{ | ||
return R"PNNXIR(7767517 | ||
6 5 | ||
pnnx.Input input_0 0 1 input | ||
pnnx.Input input_1 0 1 indices | ||
pnnx.Input input_2 0 1 values | ||
prim::Constant op_0 0 1 accumulate value=%accumulate | ||
aten::index_put op_1 4 1 input indices values accumulate out | ||
pnnx.Output output 1 0 out | ||
)PNNXIR"; | ||
} | ||
|
||
const char* type_str() const | ||
{ | ||
return "Tensor.index_put"; | ||
} | ||
}; | ||
|
||
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_index_put, 20) | ||
|
||
} // namespace pnnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
// Tencent is pleased to support the open source community by making ncnn available. | ||
// | ||
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | ||
// | ||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | ||
// in compliance with the License. You may obtain a copy of the License at | ||
// | ||
// https://opensource.org/licenses/BSD-3-Clause | ||
// | ||
// Unless required by applicable law or agreed to in writing, software distributed | ||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | ||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations under the License. | ||
|
||
#include "pass_level2.h" | ||
|
||
namespace pnnx { | ||
|
||
class Tensor_to : public GraphRewriterPass | ||
{ | ||
public: | ||
const char* match_pattern_graph() const | ||
{ | ||
return R"PNNXIR(7767517 | ||
7 6 | ||
pnnx.Input input_0 0 1 input | ||
prim::Constant op_0 0 1 dtype value=%dtype | ||
prim::Constant op_1 0 1 non_blocking value=* | ||
prim::Constant op_2 0 1 copy value=%copy | ||
prim::Constant op_3 0 1 memory_format value=%memory_format | ||
aten::to op_4 5 1 input dtype non_blocking copy memory_format out | ||
pnnx.Output output 1 0 out | ||
)PNNXIR"; | ||
} | ||
|
||
const char* type_str() const | ||
{ | ||
return "Tensor.to"; | ||
} | ||
|
||
void write(Operator* op, const std::map<std::string, Parameter>& captured_params) const | ||
{ | ||
if (captured_params.at("dtype").i == 0) op->params["dtype"] = "torch.uint8"; | ||
if (captured_params.at("dtype").i == 1) op->params["dtype"] = "torch.int8"; | ||
if (captured_params.at("dtype").i == 2) op->params["dtype"] = "torch.short"; | ||
if (captured_params.at("dtype").i == 3) op->params["dtype"] = "torch.int"; | ||
if (captured_params.at("dtype").i == 4) op->params["dtype"] = "torch.long"; | ||
if (captured_params.at("dtype").i == 5) op->params["dtype"] = "torch.half"; | ||
if (captured_params.at("dtype").i == 6) op->params["dtype"] = "torch.float"; | ||
if (captured_params.at("dtype").i == 7) op->params["dtype"] = "torch.double"; | ||
if (captured_params.at("dtype").i == 8) op->params["dtype"] = "torch.complex32"; | ||
if (captured_params.at("dtype").i == 9) op->params["dtype"] = "torch.complex64"; | ||
if (captured_params.at("dtype").i == 10) op->params["dtype"] = "torch.complex128"; | ||
if (captured_params.at("dtype").i == 11) op->params["dtype"] = "torch.bool"; | ||
|
||
op->params["copy"] = captured_params.at("copy"); | ||
|
||
if (captured_params.at("memory_format").i == 0) | ||
op->params["memory_format"] = "torch.contiguous_format"; | ||
if (captured_params.at("memory_format").i == 1) | ||
op->params["memory_format"] = "torch.preserve_format"; | ||
if (captured_params.at("memory_format").i == 2) | ||
op->params["memory_format"] = "torch.channels_last"; | ||
} | ||
}; | ||
|
||
class Tensor_to_1 : public Tensor_to | ||
{ | ||
public: | ||
const char* match_pattern_graph() const | ||
{ | ||
return R"PNNXIR(7767517 | ||
8 7 | ||
pnnx.Input input_0 0 1 input | ||
prim::Constant op_0 0 1 device value=* | ||
prim::Constant op_1 0 1 dtype value=%dtype | ||
prim::Constant op_2 0 1 non_blocking value=* | ||
prim::Constant op_3 0 1 copy value=%copy | ||
prim::Constant op_4 0 1 memory_format value=%memory_format | ||
aten::to op_5 6 1 input device dtype non_blocking copy memory_format out | ||
pnnx.Output output 1 0 out | ||
)PNNXIR"; | ||
} | ||
}; | ||
|
||
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to, 20) | ||
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_1, 20) | ||
|
||
} // namespace pnnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Tencent is pleased to support the open source community by making ncnn available. | ||
// | ||
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved. | ||
// | ||
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except | ||
// in compliance with the License. You may obtain a copy of the License at | ||
// | ||
// https://opensource.org/licenses/BSD-3-Clause | ||
// | ||
// Unless required by applicable law or agreed to in writing, software distributed | ||
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR | ||
// CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations under the License. | ||
|
||
#include "pass_level2.h" | ||
|
||
namespace pnnx { | ||
|
||
class Tensor_type_as : public GraphRewriterPass | ||
{ | ||
public: | ||
const char* match_pattern_graph() const | ||
{ | ||
return R"PNNXIR(7767517 | ||
4 3 | ||
pnnx.Input input_0 0 1 input | ||
pnnx.Input input_1 0 1 other | ||
aten::type_as op_0 2 1 input other out | ||
pnnx.Output output 1 0 out | ||
)PNNXIR"; | ||
} | ||
|
||
const char* type_str() const | ||
{ | ||
return "Tensor.type_as"; | ||
} | ||
}; | ||
|
||
REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_type_as, 20) | ||
|
||
} // namespace pnnx |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.