Skip to content

Commit

Permalink
add new branch for test map
Browse files Browse the repository at this point in the history
  • Loading branch information
sen.li committed May 20, 2024
1 parent e5509d9 commit ac7d29a
Show file tree
Hide file tree
Showing 4 changed files with 252 additions and 64 deletions.
142 changes: 108 additions & 34 deletions tools/pnnx/src/parse/pnnx_graph_parse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,26 +65,35 @@ bool PnnxGraph::getNvpPnnxModel(const std::string& pt_path, const std::string& i
return false;
}

bool PnnxGraph::loadModel(const std::string& param_path, const std::string& bin_path)
bool PnnxGraph::loadModel(const std::string& param_path, const std::string& bin_path, const std::string& key)
{
if (this->graph_ != nullptr)
{
this->graph_.reset();
}

this->graph_ = std::make_unique<Graph>();

int32_t load_result = this->graph_->load(param_path, bin_path);
// this->graph_ = std::make_unique<Graph>();

auto it = this->graph_map_.find(key);
if (it != this->graph_map_.end())
{
std::cout << "your input key: " << key << "has been registered"<< std::endl;
return false;
}

std::unique_ptr<Graph> graph_;
graph_ = std::make_unique<Graph>();
this->graph_map_[key] = std::move(graph_);
int32_t load_result = graph_->load(param_path, bin_path);
if (load_result != 0)
{
std::cout << "Can not find the param path or bin path: " << param_path << " " << bin_path << std::endl;
return false;
}

std::cout << "123" << bin_path << std::endl;
//parse all operator
this->operators_.clear();
this->input_ops_.clear();
this->output_ops_.clear();
std::vector<Operator*> operators = this->graph_->ops;
std::vector<Operator> operators_;
std::vector<Operand> operands_;
std::vector<Operator> input_ops_;
std::vector<Operator> output_ops_;
std::vector<Operator*> operators = graph_->ops;

if (operators.empty())
{
Expand All @@ -109,20 +118,19 @@ bool PnnxGraph::loadModel(const std::string& param_path, const std::string& bin_
// py::vector<char> vec_data = py::vector<char>(data1.size(), data1.data());
attr.b_data = py::bytes(data1.data(), data1.size());
}
this->operators_.push_back(*op);
operators_.push_back(*op);
if (op->inputs.empty())
{
this->input_ops_.push_back(*op);
input_ops_.push_back(*op);
}
if (op->outputs.empty())
{
this->output_ops_.push_back(*op);
output_ops_.push_back(*op);
}
}
}
//parse all operand
this->operands_.clear();
std::vector<Operand*> operands = this->graph_->operands;
std::vector<Operand*> operands = graph_->operands;

if (operands.empty())
{
Expand All @@ -138,42 +146,108 @@ bool PnnxGraph::loadModel(const std::string& param_path, const std::string& bin_
}
else
{
this->operands_.push_back(*blob);
operands_.push_back(*blob);
}
}

this->operators_map_[key] = operators_;
this->operands_map_[key] = operands_;
this->input_ops_map_[key] = input_ops_;
this->output_ops_map_[key] = output_ops_;

return true;
}

std::vector<Operator> PnnxGraph::getOperators() const
std::vector<Operator> PnnxGraph::getOperators(const std::string& key) const
{
return this->operators_;
auto it = this->operators_map_.find(key);
std::vector<Operator> operators;
if (it != this->operators_map_.end())
{
operators = this->operators_map_.at(key);
}
else
{
std::cout << "your input key: " << key << "has not register"<< std::endl;

}
return operators;

}

std::vector<Operand> PnnxGraph::getOperands() const
std::vector<Operand> PnnxGraph::getOperands(const std::string& key) const
{
return this->operands_;

auto it = this->operands_map_.find(key);
std::vector<Operand> operands;
if (it != this->operands_map_.end())
{
operands = this->operands_map_.at(key);
}
else
{
std::cout << "your input key: " << key << "has not register"<< std::endl;

}
return operands;
}

std::vector<Operator> PnnxGraph::getInputOps() const
std::vector<Operator> PnnxGraph::getInputOps(const std::string& key) const
{
return this->input_ops_;

auto it = this->input_ops_map_.find(key);
std::vector<Operator> operators;
if (it != this->input_ops_map_.end())
{
operators = this->input_ops_map_.at(key);
}
else
{
std::cout << "your input key: " << key << "has not register"<< std::endl;

}
return operators;
}

std::vector<Operator> PnnxGraph::getOutputOps() const
std::vector<Operator> PnnxGraph::getOutputOps(const std::string& key) const
{
return this->output_ops_;
auto it = this->output_ops_map_.find(key);
std::vector<Operator> operators;

if (it != this->output_ops_map_.end())
{
operators = this->output_ops_map_.at(key);
}
else
{
std::cout << "your input key: " << key << "has not register"<< std::endl;

}
return operators;
}


bool PnnxGraph::saveModel(const std::string& parampath, const std::vector<Operator>& operators, const std::vector<Operand>& operands)
{
int32_t save_result = this->graph_->save_param(parampath, operators, operands);
if (save_result != 0)
bool PnnxGraph::saveModel(const std::string& parampath, const std::vector<Operator>& operators, const std::vector<Operand>& operands, const std::string& key)
{

auto it = graph_map_.find(key);
if (it != graph_map_.end())
{
// int32_t save_result = this->graph_->save_param(parampath, operators, operands);
Graph* graphPtr = graph_map_[key].get();
int32_t save_result = graphPtr->save_param(parampath, operators, operands);
if (save_result != 0)
{
std::cout << "Can not save params to param path: " << parampath << std::endl;
return false;
}
return true;
}else
{
std::cout << "Can not save params to param path: " << parampath << std::endl;
return false;
}
return true;
std::cout << "Please input a src model" << std::endl;
}
return false;

}

} // namespace pnnx_graph
24 changes: 13 additions & 11 deletions tools/pnnx/src/parse/pnnx_graph_parse.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <string>
#include <memory>
#include <iostream>
#include <unordered_map>
#include "pnnx_ir_parse.h"
using namespace pnnx_ir;
namespace pnnx_graph {
Expand Down Expand Up @@ -30,7 +31,7 @@ class PnnxGraph
* @return true
* @return false
*/
bool loadModel(const std::string& param_path, const std::string& bin_path);
bool loadModel(const std::string& param_path, const std::string& bin_path, const std::string& key);

/**
* @brief
Expand All @@ -41,49 +42,50 @@ class PnnxGraph
* @return true
* @return false
*/
bool saveModel(const std::string& parampath, const std::vector<Operator>& operators, const std::vector<Operand>& operands);
bool saveModel(const std::string& parampath, const std::vector<Operator>& operators, const std::vector<Operand>& operands, const std::string& key);

/**
* @brief Get the Operator object
*
* @return std::vector<std::shared_ptr<pnnx::Operator>>
*/
std::vector<Operator> getOperators() const;
std::vector<Operator> getOperators(const std::string& key) const;
/**
* @brief Get the Operands object
*
* @return std::vector<std::shared_ptr<pnnx::Operand>>
*/
std::vector<Operand> getOperands() const;
std::vector<Operand> getOperands(const std::string& key) const;

/**
* @brief Get the Input Ops object
*
* @return std::vector<std::shared_ptr<pnnx::Operator>>
*/

std::vector<Operator> getInputOps() const;
std::vector<Operator> getInputOps(const std::string& key) const;

/**
* @brief Get the Output Ops object
*
* @return std::vector<std::shared_ptr<pnnx::Operator>>
*/
std::vector<Operator> getOutputOps() const;
std::vector<Operator> getOutputOps(const std::string& key) const;




private:
/// @brief load pnnx graph
std::unique_ptr<Graph> graph_;
// std::unique_ptr<Graph> graph_;
std::unordered_map<std::string, std::unique_ptr<Graph>> graph_map_;
/// @brief all operator
std::vector<Operator> operators_;
std::unordered_map<std::string, std::vector<Operator>> operators_map_;
/// @brief all operand
std::vector<Operand> operands_;
std::unordered_map<std::string, std::vector<Operand>> operands_map_;
/// @brief all input operator
std::vector<Operator> input_ops_;
std::unordered_map<std::string, std::vector<Operator>> input_ops_map_;
/// @brief all output operator
std::vector<Operator> output_ops_;
std::unordered_map<std::string, std::vector<Operator>> output_ops_map_;
};
} // namespace pnnx_graph
Loading

0 comments on commit ac7d29a

Please sign in to comment.