Skip to content

Commit

Permalink
test Tensor fill
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Jul 21, 2023
1 parent 7347091 commit 8c7d0f4
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 1 deletion.
9 changes: 8 additions & 1 deletion tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2154,7 +2154,14 @@ int Graph::python(const std::string& pypath, const std::string& pnnxbinpath)

if (op->type.substr(0, 7) == "Tensor.")
{
fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str());
if (op->type == "Tensor.fill")
{
fprintf(pyfp, " = v_%s.fill_(", sanitize_identifier(op->inputs[0]->name).c_str());
}
else
{
fprintf(pyfp, " = v_%s.%s(", sanitize_identifier(op->inputs[0]->name).c_str(), op->type.substr(7).c_str());
}

if (op->inputnames.size() == op->inputs.size())
{
Expand Down
6 changes: 6 additions & 0 deletions tools/pnnx/src/pass_level5/fuse_select_to_unbind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ void fuse_select_to_unbind(Graph& graph)
if (input_rank == 0)
continue;

if (input_rank == 1)
{
// skip select scalar
continue;
}

int dim = op->params.at("dim").i;
const int select_dimsize = op_in->shape[dim];

Expand Down
1 change: 1 addition & 0 deletions tools/pnnx/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ pnnx_add_test(nn_ZeroPad2d)

pnnx_add_test(Tensor_contiguous)
pnnx_add_test(Tensor_expand)
pnnx_add_test(Tensor_fill)
pnnx_add_test(Tensor_index)
pnnx_add_test(Tensor_masked_fill)
pnnx_add_test(Tensor_new_empty)
Expand Down
57 changes: 57 additions & 0 deletions tools/pnnx/tests/test_Tensor_fill.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Tencent is pleased to support the open source community by making ncnn available.
#
# Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
#
# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# https://opensource.org/licenses/BSD-3-Clause
#
# Unless required by applicable law or agreed to in writing, software distributed
# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
# CONDITIONS OF ANY KIND, either express or implied. See the License for the
# specific language governing permissions and limitations under the License.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()

def forward(self, x, y, z):
x[:2,:].fill_(z[0])
y[:1,:].fill_(0.22)
return x + y.fill_(7)

def test():
net = Model()
net.eval()

torch.manual_seed(0)
x = torch.rand(6, 16)
y = torch.rand(6, 16)
z = torch.rand(1)

a = net(x, y, z)

# export torchscript
mod = torch.jit.trace(net, (x, y, z))
mod.save("test_Tensor_fill.pt")

# torchscript to pnnx
import os
os.system("../src/pnnx test_Tensor_fill.pt inputshape=[6,16],[6,16],[1]")

# pnnx inference
import test_Tensor_fill_pnnx
b = test_Tensor_fill_pnnx.test_inference()

return torch.equal(a, b)

if __name__ == "__main__":
if test():
exit(0)
else:
exit(1)

0 comments on commit 8c7d0f4

Please sign in to comment.