Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#52 from feifei-111/cinn-trivalop-fuse
Browse files Browse the repository at this point in the history
update
  • Loading branch information
feifei-111 authored Mar 11, 2024
2 parents efdd3e8 + f47ca40 commit 9661fb2
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions paddle/cinn/hlir/framework/pir/trivial_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -499,12 +499,12 @@ struct FusionGraph {
}

std::vector<ir::Expr> DoFusion(){
fuse_trivial_node();
return get_expr_results();
TrivialFusion();
return GetExprResults();
}

private:
FusionNode* find_trivial_node(){
FusionNode* FindTrivialFuseableNode(){
for (FusionNode* node: all_fusion_nodes_){
if (IsTrivialKind(node->op_pattern) && node->downstream.size() > 0){
CHECK(node->op_compute_body.size() == 1);
Expand All @@ -514,9 +514,9 @@ struct FusionGraph {
return nullptr;
}

void fuse_trivial_node(){
void TrivialFusion(){
FusionNode* upstream;
while((upstream = find_trivial_node()) != nullptr){
while((upstream = FindTrivialFuseableNode()) != nullptr){
std::unordered_map<FusionNode*, ::pir::Value> fusion_candidate = upstream->downstream;
upstream->downstream.clear();
for (const auto& pair_data : fusion_candidate) {
Expand All @@ -537,22 +537,22 @@ struct FusionGraph {
}

new_node->replace_topo_structure_of_fused_nodes(upstream, downstream);
append_fusion_node(new_node);
remove_fusion_node(downstream);
AppendNode(new_node);
RemoveNode(downstream);
}
remove_fusion_node(upstream);
RemoveNode(upstream);
}
}

std::vector<ir::Expr> get_expr_results() {
std::vector<ir::Expr> GetExprResults() {
std::vector<ir::Expr> output_exprs;
for (const auto& node : all_fusion_nodes_) {
output_exprs.insert(output_exprs.end(), node->op_compute_body.begin(), node->op_compute_body.end());
}
return output_exprs;
}

void remove_fusion_node(FusionNode* node){
void RemoveNode(FusionNode* node){
if (all_fusion_nodes_.find(node) != all_fusion_nodes_.end()){
all_fusion_nodes_.erase(node);
}
Expand All @@ -565,7 +565,7 @@ struct FusionGraph {
delete node;
}

void append_fusion_node(FusionNode* node){
void AppendNode(FusionNode* node){
all_fusion_nodes_.emplace(node);
if (node->upstream.size() == 0){
entrance_nodes_.emplace(node);
Expand Down

0 comments on commit 9661fb2

Please sign in to comment.