From ce3d2ab52bc10e859f364b4dc20723e831fecbdf Mon Sep 17 00:00:00 2001 From: nihuini Date: Mon, 9 Oct 2023 16:58:06 +0800 Subject: [PATCH] pnnx support torch-2.1 --- .ci/pnnx.yml | 4 +++ tools/pnnx/CMakeLists.txt | 5 ++++ .../F_scaled_dot_product_attention.cpp | 27 +++++++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/.ci/pnnx.yml b/.ci/pnnx.yml index 267a0afa289..3f116a4fa2e 100644 --- a/.ci/pnnx.yml +++ b/.ci/pnnx.yml @@ -48,6 +48,10 @@ jobs: torchvision-version: 0.15.1 torchvision-cache-key: '0_15_1' + - torch-version: 2.1.0 + torchvision-version: 0.16.0 + torchvision-cache-key: '0_16_0' + runs-on: pool-name: docker container: diff --git a/tools/pnnx/CMakeLists.txt b/tools/pnnx/CMakeLists.txt index 3a08cbc249e..0c8326fc942 100644 --- a/tools/pnnx/CMakeLists.txt +++ b/tools/pnnx/CMakeLists.txt @@ -70,6 +70,11 @@ if(Torch_VERSION VERSION_LESS "1.8") message(FATAL_ERROR "pnnx only supports PyTorch >= 1.8") endif() +if(Torch_VERSION VERSION_GREATER_EQUAL "2.1") + # c++17 is required for using torch 2.1+ headers + set(CMAKE_CXX_STANDARD 17) +endif() + if(TorchVision_FOUND) message(STATUS "Building with TorchVision") add_definitions(-DPNNX_TORCHVISION) diff --git a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp index e7ca7bbf824..8dcfafaf12b 100644 --- a/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp +++ b/tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp @@ -42,4 +42,31 @@ pnnx.Output output 1 0 out REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention, 10) +class F_scaled_dot_product_attention_1 : public GraphRewriterPass +{ +public: + const char* match_pattern_graph() const + { + return R"PNNXIR(7767517 +9 8 +pnnx.Input input_0 0 1 query +pnnx.Input input_1 0 1 key +pnnx.Input input_2 0 1 value +pnnx.Input input_3 0 1 attn_mask +prim::Constant op_0 0 1 dropout_p value=%dropout_p +prim::Constant op_1 0 1 is_causal value=%is_causal +prim::Constant op_2 0 1 scale value=%scale +aten::scaled_dot_product_attention op_3 7 1 query key value attn_mask dropout_p is_causal scale out +pnnx.Output output 1 0 out +)PNNXIR"; + } + + const char* type_str() const + { + return "F.scaled_dot_product_attention"; + } +}; + +REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10) + } // namespace pnnx