Skip to content

Commit

Permalink
1. Add extracting sub graph function
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Jun 17, 2024
1 parent 0c5ca2e commit e33c86e
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 9 deletions.
5 changes: 4 additions & 1 deletion tools/pnnx/Releasenotes
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,7 @@ dev.1.0.17.20240606
1. Add trans_TensorTypeAs2TensorTo pass in pass level 7

dev.1.0.18.20240613
1. Skip conv2d nodes of type NoneType
1. Skip conv2d nodes of type NoneType

dev.1.0.19.20240614
1. Add extracting sub graph function
279 changes: 279 additions & 0 deletions tools/pnnx/src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4427,5 +4427,284 @@ const Operand* Graph::get_operand(const std::string& name) const

return 0;
}
int Graph::extract_sub_graph(const std::vector<std::string>& start_nodes, const std::vector<std::string>& end_nodes)
{
if(start_nodes.size() == 0 && end_nodes.size() == 0)
{
fprintf(stderr, "############# not need extract sub graph\n");
}
else
{
std::vector<std::string> extract_start_nodes = start_nodes;
std::vector<std::string> extract_end_nodes = end_nodes;
if(extract_start_nodes.size() == 0)
{
// set input node name to start_nodes
for(auto node :ops)
{
if(node->type == "pnnx.Input")
{
std::string input_tensor_name = node->outputs[0]->name;
extract_start_nodes.push_back(input_tensor_name);
}
}
}
if(extract_end_nodes.size() == 0)
{
// set output node name to start_nodes
for(auto node :ops)
{
if(node->type == "pnnx.Output")
{
std::string output_tensor_name = node->inputs[0]->name;
extract_end_nodes.push_back(output_tensor_name);
}
}
}
std::vector<Operator*> new_input_ops;
std::vector<Operator*> new_output_ops;
//get exclude_node_names exclude_tensor_names
std::vector<std::string> exclude_node_names;

for(auto node: ops)
{
// check is input node or not
std::vector<Operand*> cur_inputs = node->inputs;
int input_num = 0;
for(auto cur_input: cur_inputs)
{
std::string cur_node_name = cur_input->name;
if(std::find(extract_start_nodes.begin(), extract_start_nodes.end(), cur_node_name) != extract_start_nodes.end())
{
bool is_new_tensor = true;
for(auto new_op: new_input_ops)
{
if(new_op->outputs[0]->name == cur_node_name)
{
is_new_tensor = false;
break;
}
}
if(is_new_tensor)
{
// create new input node
Operator* op = new Operator;
op->type = "pnnx.Input";
op->name = "pnnx_input_" + std::to_string(new_input_ops.size());
op->outputs.push_back(cur_input);
std::vector<int> shape = cur_input->shape;

new_input_ops.push_back(op);

// get pre node
Operator* pre_node = cur_input->producer;
if(std::find(exclude_node_names.begin(), exclude_node_names.end(), pre_node->name) == exclude_node_names.end())
{
exclude_node_names.push_back(pre_node->name);
std::list<Operator*> List;
List.push_back(pre_node);
while(!List.empty())
{
Operator* cur_node = List.front();
List.pop_front();
std::vector<Operand*> cur_node_inputs = cur_node->inputs;
for(auto cur_node_input: cur_node_inputs)
{
Operator* pre_node_producer = cur_node_input->producer;
if(std::find(exclude_node_names.begin(), exclude_node_names.end(), pre_node_producer->name) == exclude_node_names.end())
{
exclude_node_names.push_back(pre_node_producer->name);
List.push_back(pre_node_producer);
}

}

}
}


}

input_num++;
}

}
// is start node
if(input_num != 0)
{
if(input_num > extract_start_nodes.size())
{
fprintf(stderr, "############# please check your start nodes!\n");
return -1;
}
}

// check is output node or not
std::vector<Operand*> cur_outputs = node->outputs;
int output_num = 0;
for(auto cur_output: cur_outputs)
{
std::string cur_node_name = cur_output->name;
if(std::find(extract_end_nodes.begin(), extract_end_nodes.end(), cur_node_name) != extract_end_nodes.end())
{
bool is_new_tensor = true;
for(auto new_op: new_output_ops)
{
if(new_op->inputs[0]->name == cur_node_name)
{
is_new_tensor = false;
break;
}
}
if(is_new_tensor)
{
Operator* op = new Operator;
op->type = "pnnx.Output";
op->name = "pnnx_Output_" + std::to_string(new_output_ops.size());
op->inputs.push_back(cur_output);
std::vector<int> shape = cur_output->shape;

new_output_ops.push_back(op);


// get sink node
std::vector<Operator*> sink_nodes = cur_output->consumers;
for(auto sink_node: sink_nodes)
{
if(std::find(exclude_node_names.begin(), exclude_node_names.end(), sink_node->name) == exclude_node_names.end())
{
exclude_node_names.push_back(sink_node->name);
std::list<Operator*> sink_List;
sink_List.push_back(sink_node);
while(!sink_List.empty())
{
Operator* cur_sink_node = sink_List.front();
sink_List.pop_front();
std::vector<Operand*> cur_sink_node_outputs = cur_sink_node->outputs;
for(auto cur_sink_node_output: cur_sink_node_outputs)
{
std::vector<Operator*> sink_node_consumers = cur_sink_node_output->consumers;
for(auto sink_node_consumer: sink_node_consumers)
{
if(std::find(exclude_node_names.begin(), exclude_node_names.end(), sink_node_consumer->name) == exclude_node_names.end())
{
exclude_node_names.push_back(sink_node_consumer->name);
sink_List.push_back(sink_node_consumer);
}
}

}

}
}
}

}

output_num++;
}

}

// is end node
if(output_num != 0)
{
if(output_num > extract_end_nodes.size())
{
fprintf(stderr, "############# please check your end nodes!\n");
return -1;
}
}
}

// delect exclude_node_names
while (1)
{
bool matched = false;

for (size_t i = 0; i < ops.size(); i++)
{
Operator* op = ops[i];
if(std::find(exclude_node_names.begin(), exclude_node_names.end(),op->name) == exclude_node_names.end())
{
continue;
}
matched = true;
std::vector<Operand*> inputs = op->inputs;
std::vector<Operand*> outputs = op->outputs;
for(auto match_node_output: outputs)
{
if(std::find(extract_start_nodes.begin(), extract_start_nodes.end(), match_node_output->name) != extract_start_nodes.end())
{
for(auto new_input_op: new_input_ops)
{
for(auto new_input_op_output: new_input_op->outputs)
{
if(new_input_op_output->name == match_node_output->name)
{
match_node_output->producer = new_input_op;
}
}
}
}
else
{
match_node_output->producer = 0;
match_node_output->consumers.clear();
if(std::find(operands.begin(), operands.end(), match_node_output) != operands.end())
{
operands.erase(std::find(operands.begin(), operands.end(), match_node_output));
delete match_node_output;
}

}
}

for(auto match_node_input: inputs)
{
if(std::find(extract_end_nodes.begin(), extract_end_nodes.end(), match_node_input->name) != extract_end_nodes.end())
{
for(auto new_output_op: new_output_ops)
{
for(auto new_output_op_input: new_output_op->inputs)
{
if(new_output_op_input->name == match_node_input->name)
{
match_node_input->consumers.push_back(new_output_op);
}
}
}
}
else
{
match_node_input->producer = 0;
match_node_input->consumers.clear();
if(std::find(operands.begin(), operands.end(), match_node_input) != operands.end())
{
operands.erase(std::find(operands.begin(), operands.end(), match_node_input));
delete match_node_input;
}

}
}
op->inputs.clear();
op->outputs.clear();

ops.erase(ops.begin() + i);
delete op;
break;
}

if (!matched)
break;
}

// insert new input outout node
ops.insert(ops.end(), new_input_ops.begin(), new_input_ops.end());
ops.insert(ops.end(), new_output_ops.begin(), new_output_ops.end());

}
return 1;
}

} // namespace pnnx
2 changes: 2 additions & 0 deletions tools/pnnx/src/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,8 @@ class Graph
Operand* get_operand(const std::string& name);
const Operand* get_operand(const std::string& name) const;

int extract_sub_graph(const std::vector<std::string>& start_nodes, const std::vector<std::string>& end_nodes);

std::vector<Operator*> ops;
std::vector<Operand*> operands;

Expand Down
20 changes: 20 additions & 0 deletions tools/pnnx/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ int main(int argc, char** argv)
std::vector<std::string> module_operators;
// add by senli
std::string customop_infer_py = "None";
std::vector<std::string> start_nodes;
std::vector<std::string> end_nodes;

for (int i = 2; i < argc; i++)
{
// key=value
Expand Down Expand Up @@ -281,6 +284,10 @@ int main(int argc, char** argv)
// add by senli
if (strcmp(key, "customop_infer_py") == 0)
customop_infer_py = value;
if (strcmp(key, "start_nodes") == 0)
parse_string_list(value, start_nodes);
if (strcmp(key, "end_nodes") == 0)
parse_string_list(value, end_nodes);
}

// print options
Expand Down Expand Up @@ -313,6 +320,12 @@ int main(int argc, char** argv)
// add by senli
fprintf(stderr, "customop_infer_py = %s\n", customop_infer_py.c_str());
fprintf(stderr, "\n");
fprintf(stderr, "start_nodes = ");
print_string_list(start_nodes);
fprintf(stderr, "\n");
fprintf(stderr, "end_nodes = ");
print_string_list(end_nodes);
fprintf(stderr, "\n");
}

std::set<std::string> foldable_constants;
Expand Down Expand Up @@ -365,6 +378,13 @@ int main(int argc, char** argv)

// delete foldable_constants_zippath
remove(foldable_constants_zippath.c_str());

// extract_sub_graph
int extract_flag = pnnx_graph.extract_sub_graph(start_nodes, end_nodes);
if(extract_flag == -1)
{
fprintf(stderr, "############# failed to extract_sub_graph\n");
}

pnnx_graph.save(pnnxparampath, pnnxbinpath);

Expand Down
Loading

0 comments on commit e33c86e

Please sign in to comment.