Skip to content

Commit

Permalink
add F_pairwise_distance to pnnx and ncnn (#4942)
Browse files Browse the repository at this point in the history
  • Loading branch information
Marsyule authored Sep 22, 2023
1 parent 1d7720e commit 69d6051
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 0 deletions.
1 change: 1 addition & 0 deletions tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ set(pnnx_pass_level2_SRCS
pass_level2/F_mish.cpp
pass_level2/F_normalize.cpp
pass_level2/F_pad.cpp
pass_level2/F_pairwise_distance.cpp
pass_level2/F_pixel_shuffle.cpp
pass_level2/F_pixel_unshuffle.cpp
pass_level2/F_prelu.cpp
Expand Down
44 changes: 44 additions & 0 deletions tools/pnnx/src/pass_level2/F_pairwise_distance.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "pass_level2.h"

namespace pnnx {

class F_pairwise_distance : public GraphRewriterPass
{
public:
const char* match_pattern_graph() const
{
return R"PNNXIR(7767517
7 6
pnnx.Input input_0 0 1 x1
pnnx.Input input_1 0 1 x2
prim::Constant op_0 0 1 p value=%p
prim::Constant op_1 0 1 eps value=%eps
prim::Constant op_2 0 1 keepdim value=%keepdim
aten::pairwise_distance op_3 5 1 x1 x2 p eps keepdim out
pnnx.Output output 1 0 out
)PNNXIR";
}

const char* type_str() const
{
return "F.pairwise_distance";
}
};

REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS(F_pairwise_distance, 10)

} // namespace pnnx
1 change: 1 addition & 0 deletions tools/pnnx/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ pnnx_add_test(F_max_pool2d)
pnnx_add_test(F_max_pool3d)
pnnx_add_test(F_normalize)
pnnx_add_test(F_pad)
pnnx_add_test(F_pairwise_distance)
pnnx_add_test(F_pixel_shuffle)
pnnx_add_test(F_pixel_unshuffle)
pnnx_add_test(F_prelu)
Expand Down
58 changes: 58 additions & 0 deletions tools/pnnx/tests/test_F_pairwise_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y):
z1 = F.pairwise_distance(x,y,p=1,keepdim=False)
z2 = F.pairwise_distance(x,y,p=2,keepdim=True)
z3 = F.pairwise_distance(x,y)
z4 = F.pairwise_distance(x,y,eps = 1e-3)
return z1,z2,z3,z4

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(12, 128, 128)
y = torch.rand(12, 128, 128)

a0,a1,a2,a3 = net(x, y)

# export torchscript
mod = torch.jit.trace(net, (x, y))
mod.save("test_F_pairwise_distance.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_F_pairwise_distance.pt inputshape=[12,128,128],[12,128,128]")

# pnnx inference
import test_F_pairwise_distance_pnnx
b0,b1,b2,b3 = test_F_pairwise_distance_pnnx.test_inference()

return torch.equal(a0,b0) and torch.equal(a1,b1) and torch.equal(a2,b2) and torch.equal(a3,b3)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

0 comments on commit 69d6051

Please sign in to comment.