Skip to content

Commit

Permalink
1. Support load input tensor to export
Browse files Browse the repository at this point in the history
  • Loading branch information
sen.li committed Jun 20, 2024
1 parent d85c333 commit 4f310c4
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 24 deletions.
5 changes: 4 additions & 1 deletion tools/pnnx/Releasenotes
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,7 @@ dev.1.0.20.20240617
1. Add loop op parse function

dev.1.0.21.20240619
1. Support export sub_model
1. Support export sub_model

dev.1.0.22.20240620
1. Support load input tensor to export
114 changes: 94 additions & 20 deletions tools/pnnx/src/load_torchscript.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
#else
#include <dlfcn.h>
#endif

#include <iostream>
#include <cstdio>
#include <torch/script.h>
#include <torch/csrc/api/include/torch/version.h>
#ifdef PNNX_TORCHVISION
Expand All @@ -33,6 +34,16 @@ int64_t cuda_version();

namespace pnnx {

static bool fileExists(const std::string& path) {
FILE* file = fopen(path.c_str(), "r");
if (file) {
fclose(file);
return true;
} else {
return false;
}
}

static int get_at_tensor_type(const at::ScalarType& st)
{
if (st == c10::ScalarType::Float) return 1;
Expand Down Expand Up @@ -429,7 +440,8 @@ const torch::jit::Node* find_node_by_kind(const std::shared_ptr<torch::jit::Grap
return 0;
}

int load_torchscript(const std::string& ptpath, \
int load_torchscript(const std::string& ptpath,
const std::string& save_dir,
std::unordered_map<std::string, std::shared_ptr<pnnx::Graph>>& pnnx_graph_map,
const std::string& device,
const std::vector<std::vector<int64_t> >& input_shapes,
Expand Down Expand Up @@ -463,32 +475,94 @@ int load_torchscript(const std::string& ptpath, \
}
#endif
}

std::vector<at::Tensor> input_tensors;
for (size_t i = 0; i < input_shapes.size(); i++)
std::vector<at::Tensor> input_tensors2;
// load input data
std::string input_data_path = save_dir + "/input_tensor_container.pt";
bool load_input_data_flag = false;
if(fileExists(input_data_path))
{
const std::vector<int64_t>& shape = input_shapes[i];
const std::string& type = input_types[i];
try
{
torch::jit::script::Module tensors = torch::jit::load(input_data_path);
bool has_input_tensors = tensors.hasattr("input_tensors");
if(!has_input_tensors)
{
fprintf(stderr, "############# %s exist, but there are not input_tensors, still creating tensor based on shape\n",input_data_path.c_str());
}
else
{
c10::IValue input_values = tensors.attr("input_tensors");
bool is_tensor_vector = input_values.isTensorList();
if(!is_tensor_vector)
{
fprintf(stderr, "############# input_tensors is not a tensor list, still creating tensor based on shape\n");
}
else
{
bool has_input_tensors2 = tensors.hasattr("input_tensors2");
if(has_input_tensors2)
{
c10::IValue input_values2 = tensors.attr("input_tensors2");
bool is_tensor2_vector = input_values2.isTensorList();
if(!is_tensor2_vector)
{
fprintf(stderr, "############# input_tensors2 is not a tensor list, still creating tensor based on shape\n");
}
else
{
input_tensors2 = input_values2.toTensorVector();
input_tensors = input_values.toTensorVector();
load_input_data_flag = true;
}

}
else
{
input_tensors = input_values.toTensorVector();
load_input_data_flag = true;
}

}
}



}
catch (const c10::Error& e)
{
fprintf(stderr, "############# Failed to load input_tensor_container, still creating tensor based on shape\n");
}
}

if(!load_input_data_flag)
{
for (size_t i = 0; i < input_shapes.size(); i++)
{
const std::vector<int64_t>& shape = input_shapes[i];
const std::string& type = input_types[i];

at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();
at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();

input_tensors.push_back(t);
}
input_tensors.push_back(t);
}

std::vector<at::Tensor> input_tensors2;
for (size_t i = 0; i < input_shapes2.size(); i++)
{
const std::vector<int64_t>& shape = input_shapes2[i];
const std::string& type = input_types2[i];
for (size_t i = 0; i < input_shapes2.size(); i++)
{
const std::vector<int64_t>& shape = input_shapes2[i];
const std::string& type = input_types2[i];

at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();
at::Tensor t = torch::ones(shape, input_type_to_c10_ScalarType(type));
if (device == "gpu")
t = t.cuda();

input_tensors2.push_back(t);
input_tensors2.push_back(t);
}
}


torch::jit::Module mod;

Expand Down
3 changes: 2 additions & 1 deletion tools/pnnx/src/load_torchscript.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

namespace pnnx {

int load_torchscript(const std::string& ptpath, \
int load_torchscript(const std::string& ptpath,
const std::string& save_dir,
std::unordered_map<std::string, std::shared_ptr<pnnx::Graph>>& pnnx_graph_map,
const std::string& device,
const std::vector<std::vector<int64_t> >& input_shapes,
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ int main(int argc, char** argv)
dynamic_network = true;
}
std::unordered_map<std::string, std::shared_ptr<pnnx::Graph>> pnnx_graph_map;
load_torchscript(ptpath, pnnx_graph_map,
load_torchscript(ptpath, save_dir, pnnx_graph_map,
device, input_shapes, input_types,
input_shapes2, input_types2,
customop_modules, module_operators,
Expand Down
2 changes: 1 addition & 1 deletion tools/pnnx/src/py_proj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
// #include <torch/extension.h>
#define STRINGIFY(x) #x
#define MACRO_STRINGIFY(x) STRINGIFY(x)
#define MYLIBRARY_VERSION "dev.1.0.21.20240619"
#define MYLIBRARY_VERSION "dev.1.0.22.20240620"
using namespace pnnx_graph;
using namespace pnnx_ir;
namespace py = pybind11;
Expand Down

0 comments on commit 4f310c4

Please sign in to comment.