diff --git a/tools/pnnx/src/pass_level2/Tensor_to.cpp b/tools/pnnx/src/pass_level2/Tensor_to.cpp index 52a7047105b..8ab1f124960 100644 --- a/tools/pnnx/src/pass_level2/Tensor_to.cpp +++ b/tools/pnnx/src/pass_level2/Tensor_to.cpp @@ -55,12 +55,15 @@ pnnx.Output output 1 0 out op->params["copy"] = captured_params.at("copy"); - if (captured_params.at("memory_format").i == 0) - op->params["memory_format"] = "torch.contiguous_format"; - if (captured_params.at("memory_format").i == 1) - op->params["memory_format"] = "torch.preserve_format"; - if (captured_params.at("memory_format").i == 2) - op->params["memory_format"] = "torch.channels_last"; + if (captured_params.at("memory_format").type == 2) + { + if (captured_params.at("memory_format").i == 0) + op->params["memory_format"] = "torch.contiguous_format"; + if (captured_params.at("memory_format").i == 1) + op->params["memory_format"] = "torch.preserve_format"; + if (captured_params.at("memory_format").i == 2) + op->params["memory_format"] = "torch.channels_last"; + } } }; @@ -83,7 +86,29 @@ pnnx.Output output 1 0 out } }; +class Tensor_to_2 : public Tensor_to +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +10 9 +pnnx.Input input_0 0 1 input +prim::Constant op_0 0 1 dtype value=%dtype +prim::Constant op_1 0 1 layout value=* +prim::Constant op_2 0 1 device value=* +prim::Constant op_3 0 1 pin_memory value=* +prim::Constant op_4 0 1 non_blocking value=* +prim::Constant op_5 0 1 copy value=%copy +prim::Constant op_6 0 1 memory_format value=%memory_format +aten::to op_7 8 1 input dtype layout device pin_memory non_blocking copy memory_format out +pnnx.Output output 1 0 out +)PNNXIR"; + } +}; + REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to, 20) REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_1, 20) +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(Tensor_to_2, 20) } // namespace pnnx diff --git a/tools/pnnx/tests/test_Tensor_to.py b/tools/pnnx/tests/test_Tensor_to.py index 71c157cb341..c4ac834c44c 100644 --- a/tools/pnnx/tests/test_Tensor_to.py +++ b/tools/pnnx/tests/test_Tensor_to.py @@ -27,7 +27,8 @@ def forward(self, x, y): x = x.to(device='cpu', dtype=torch.int, copy=True) x = x + 1 y = y - 2 - return x, y + z = x.to(y.device) + return x, y, z def test(): net = Model()