Skip to content

Commit

Permalink
1. Add trans_TensorTypeAs2TensorTo pass in pass level 6
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jun 6, 2024
1 parent ca44aa6 commit 6770840
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 3 deletions.
5 changes: 4 additions & 1 deletion tools/pnnx/Releasenotes
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,7 @@ dev.1.0.15.20240603
2. Add trans_ReshapeAs2Reshape pass

dev.1.0.16.20240605
1. fix bug of Tensor.index with two inputs
1. fix bug of Tensor.index with two inputs

dev.1.0.17.20240606
1. Add trans_TensorTypeAs2TensorTo pass in pass level 7
49 changes: 48 additions & 1 deletion tools/pnnx/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ set(pnnx_pass_level6_SRCS
pass_level6/trans_expression2TupleConstruct.cpp
pass_level6/trans_Stack2Unsqueeze.cpp
pass_level6/trans_ReshapeAs2Reshape.cpp
pass_level6/trans_TensorTypeAs2TensorTo.cpp
)

set(pnnx_pass_ncnn_SRCS
Expand Down Expand Up @@ -700,7 +701,7 @@ set(pnnx_SRCS
${pnnx_pass_ncnn_SRCS}
)

# add_executable(pnnx ${pnnx_SRCS})

file(GLOB_RECURSE SRC_PARSE_FILES
${CMAKE_CURRENT_SOURCE_DIR}/parse/*.cpp
)
Expand Down Expand Up @@ -766,6 +767,52 @@ set_target_properties(ptx PROPERTIES MACOSX_RPATH TRUE)
include(GNUInstallDirs)
# install(TARGETS pnnx RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
install(TARGETS ptx LIBRARY DESTINATION ${CMAKE_INSTALL_BINDIR})

if(${CMAKE_BUILD_TYPE} STREQUAL Debug)
add_executable(pnnx ${pnnx_SRCS})

set_property(SOURCE main.cpp APPEND PROPERTY COMPILE_DEFINITIONS BUILD_TORCH2PNNX)
target_link_libraries(pnnx PRIVATE torch2pnnx)

if(TorchVision_FOUND)
target_link_libraries(pnnx PRIVATE ${TORCHVISION_LIBRARY})
endif()

if(WIN32)
target_link_libraries(pnnx PRIVATE ${TORCH_LIBRARIES})
else()
target_link_libraries(pnnx PRIVATE ${TORCH_LIBRARIES} pthread dl)
endif()

if(PROTOBUF_FOUND)
set_property(SOURCE main.cpp APPEND PROPERTY COMPILE_DEFINITIONS BUILD_PNNX2ONNX)
target_link_libraries(pnnx PRIVATE pnnx2onnx)
endif()

if(onnxruntime_FOUND)
set_property(SOURCE main.cpp APPEND PROPERTY COMPILE_DEFINITIONS BUILD_ONNX2PNNX)
target_link_libraries(pnnx PRIVATE onnx2pnnx)
endif()

if(PNNX_COVERAGE)
target_compile_options(pnnx PUBLIC -coverage -fprofile-arcs -ftest-coverage)
target_link_libraries(pnnx PUBLIC -coverage -lgcov)
endif()

# set_target_properties(pnnx PROPERTIES COMPILE_FLAGS -fsanitize=address)
# set_target_properties(pnnx PROPERTIES LINK_FLAGS -fsanitize=address)

if(APPLE)
set_target_properties(pnnx PROPERTIES INSTALL_RPATH "@executable_path/")
else()
set_target_properties(pnnx PROPERTIES INSTALL_RPATH "$ORIGIN/")
endif()
set_target_properties(pnnx PROPERTIES MACOSX_RPATH TRUE)

include(GNUInstallDirs)
install(TARGETS pnnx RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
endif()

if (WIN32)
file(GLOB TORCH_DLL "${TORCH_INSTALL_PREFIX}/lib/*.dll")
install(FILES ${TORCH_DLL} DESTINATION ${CMAKE_INSTALL_BINDIR})
Expand Down
3 changes: 3 additions & 0 deletions tools/pnnx/src/pass_level6.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include "pass_level6/trans_expression2TupleConstruct.h"
#include "pass_level6/trans_Stack2Unsqueeze.h"
#include "pass_level6/trans_ReshapeAs2Reshape.h"
#include "pass_level6/trans_TensorTypeAs2TensorTo.h"

namespace pnnx {

void pass_level6(Graph& g, const std::set<std::string>& foldable_constants, const std::string& foldable_constants_zippath)
Expand All @@ -26,6 +28,7 @@ void pass_level6(Graph& g, const std::set<std::string>& foldable_constants, cons
trans_expression2TupleConstruct(g);
trans_Stack2Unsqueeze(g);
trans_ReshapeAs2Reshape(g);
trans_TensorTypeAs2TensorTo(g);
}

} // namespace pnnx
121 changes: 121 additions & 0 deletions tools/pnnx/src/pass_level6/trans_TensorTypeAs2TensorTo.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "trans_TensorTypeAs2TensorTo.h"

#include <algorithm>
#include "pass_level2.h"

namespace pnnx {

void trans_TensorTypeAs2TensorTo(Graph& graph)
{
while (1)
{
bool matched = false;

for (size_t i = 0; i < graph.ops.size(); i++)
{
Operator* op = graph.ops[i];

if (op->type != "Tensor.type_as")
continue;
matched = true;
// get the input size of input1
Operand* input1 = op->inputs.at(1);
int tensor_type_index = input1->type;
if (tensor_type_index == 0) op->params["dtype"] = "torch.float";
if (tensor_type_index == 1) op->params["dtype"] = "torch.float";
if (tensor_type_index == 2) op->params["dtype"] = "torch.double";
if (tensor_type_index == 3) op->params["dtype"] = "torch.half";
if (tensor_type_index == 4) op->params["dtype"] = "torch.int";
if (tensor_type_index == 5) op->params["dtype"] = "torch.long";
if (tensor_type_index == 6) op->params["dtype"] = "torch.short";
if (tensor_type_index == 7) op->params["dtype"] = "torch.int8";
if (tensor_type_index == 8) op->params["dtype"] = "torch.uint8";
if (tensor_type_index == 9) op->params["dtype"] = "torch.bool";
if (tensor_type_index == 10) op->params["dtype"] = "torch.complex64";
if (tensor_type_index == 11) op->params["dtype"] = "torch.complex128";
if (tensor_type_index == 12) op->params["dtype"] = "torch.complex32";
// type_as 2 to
op->type = "Tensor.to";
op->params["copy"] = false;
op->inputs.pop_back();

std::vector<Operand*> delete_operands = {};
std::vector<Operator*> delete_ops = {};
std::vector<Operand*> operand_squence = {input1};
while(operand_squence.size() > 0)
{
Operand* cur_operand = operand_squence.front();
operand_squence.erase(operand_squence.begin());
if (cur_operand->consumers.size() == 1)
{
delete_operands.push_back(cur_operand);
Operator* pre_producer = cur_operand->producer;
if(pre_producer->outputs.size() == 1)
{
delete_ops.push_back(pre_producer);
for(auto cur_input: pre_producer->inputs)
{
operand_squence.push_back(cur_input);
}
}
else
{
for(auto out : pre_producer->outputs)
{
if (out->name == input1->name)
{
std::swap(out, pre_producer->outputs.back());
pre_producer->outputs.pop_back();
break;
}
}
}
}
else
{
for(int index = 0; index < cur_operand->consumers.size(); index++)
{
if(cur_operand->consumers[index]->name == op->name)
{
cur_operand->consumers.erase(cur_operand->consumers.begin() + index);
}
}
}
}

for(auto delete_op: delete_ops)
{
delete_op->inputs.clear();
delete_op->outputs.clear();
graph.ops.erase(std::find(graph.ops.begin(), graph.ops.end(), delete_op));
delete delete_op;
}
for(auto delete_operand: delete_operands)
{
graph.operands.erase(std::find(graph.operands.begin(), graph.operands.end(), delete_operand));
delete delete_operand;
}
break;

}

if (!matched)
break;
}
}

} // namespace pnnx
22 changes: 22 additions & 0 deletions tools/pnnx/src/pass_level6/trans_TensorTypeAs2TensorTo.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.

#include "ir.h"

namespace pnnx {

void trans_TensorTypeAs2TensorTo(Graph& graph);

} // namespace pnnx
2 changes: 1 addition & 1 deletion tools/pnnx/src/py_proj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// #include <torch/extension.h>
#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)
#define MYLIBRARY_VERSION "dev.1.0.16.20240605"
#define MYLIBRARY_VERSION "dev.1.0.17.20240606"
using namespace pnnx_graph;
using namespace pnnx_ir;
namespace py = pybind11;
Expand Down

0 comments on commit 6770840

Please sign in to comment.