Skip to content

Commit

Permalink
pnnx support torch-2.1 (#5074)
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Oct 10, 2023
1 parent b4f8fa6 commit bedbe59
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 0 deletions.
4 changes: 4 additions & 0 deletions .ci/pnnx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ jobs:
torchvision-version: 0.15.1
torchvision-cache-key: '0_15_1'

- torch-version: 2.1.0
torchvision-version: 0.16.0
torchvision-cache-key: '0_16_0'

runs-on:
pool-name: docker
container:
Expand Down
5 changes: 5 additions & 0 deletions tools/pnnx/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ if(Torch_VERSION VERSION_LESS "1.8")
message(FATAL_ERROR "pnnx only supports PyTorch >= 1.8")
endif()

if(Torch_VERSION VERSION_GREATER_EQUAL "2.1")
# c++17 is required for using torch 2.1+ headers
set(CMAKE_CXX_STANDARD 17)
endif()

if(TorchVision_FOUND)
message(STATUS "Building with TorchVision")
add_definitions(-DPNNX_TORCHVISION)
Expand Down
5 changes: 5 additions & 0 deletions tools/pnnx/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,11 @@ int main(int argc, char** argv)
fprintf(stderr, "\n");
}

#ifdef PNNX_TORCHVISION
// call some vision api to register vision ops :P
(void)vision::cuda_version();
#endif

for (auto m : customop_modules)
{
fprintf(stderr, "load custom module %s\n", m.c_str());
Expand Down
27 changes: 27 additions & 0 deletions tools/pnnx/src/pass_level2/F_scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,31 @@ pnnx.Output output 1 0 out

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention, 10)

class F_scaled_dot_product_attention_1 : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
9 8
pnnx.Input input_0 0 1 query
pnnx.Input input_1 0 1 key
pnnx.Input input_2 0 1 value
pnnx.Input input_3 0 1 attn_mask
prim::Constant op_0 0 1 dropout_p value=%dropout_p
prim::Constant op_1 0 1 is_causal value=%is_causal
prim::Constant op_2 0 1 scale value=%scale
aten::scaled_dot_product_attention op_3 7 1 query key value attn_mask dropout_p is_causal scale out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.scaled_dot_product_attention";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_scaled_dot_product_attention_1, 10)

} // namespace pnnx

0 comments on commit bedbe59

Please sign in to comment.