From dcd720eec40c216000b518c7ac0d6776b898ca52 Mon Sep 17 00:00:00 2001 From: Xiangyun Yang Date: Sat, 1 Jan 2022 15:15:39 +0800 Subject: [PATCH 1/4] call recursively generic_visit --- taichi/transforms/offload.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index e4501edc83767..b19ffe4c2b35d 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -464,7 +464,6 @@ class PromoteIntermediateToGlobalTmp : public BasicStmtVisitor { auto ptr = stmt->insert_after_me( Stmt::make(offset, stmt->ret_type)); ptr->insert_after_me(Stmt::make(ptr, stmt)); - throw IRModified(); } } @@ -577,7 +576,6 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { stmt->parent->replace_with(stmt, std::move(replacement), false); // To deal with the same offloaded visit_operand() stmt_to_offloaded_[stmt] = nullptr; - throw IRModified(); } // Replace local LD/ST with global LD/ST @@ -591,7 +589,6 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { auto global_load = replacement.push_back(ptr); stmt_to_offloaded_[global_load] = stmt_to_offloaded_[stmt]; stmt->parent->replace_with(stmt, std::move(replacement)); - throw IRModified(); } } @@ -605,7 +602,6 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { replacement.push_back(ptr, stmt->val); stmt_to_offloaded_[global_store] = stmt_to_offloaded_[stmt]; stmt->parent->replace_with(stmt, std::move(replacement)); - throw IRModified(); } } @@ -623,10 +619,12 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { if (op->is()) { auto copy = op->clone(); + auto pcopy = copy.get(); copy->as()->activate = false; stmt_to_offloaded_[copy.get()] = offloaded; stmt->set_operand(index, copy.get()); stmt->insert_before_me(std::move(copy)); + generic_visit(pcopy); return true; } @@ -638,9 +636,11 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { "{} is not allowed here.", op->type()); // For cases like ConstStmt auto copy = op->clone(); + auto pcopy = copy.get(); stmt_to_offloaded_[copy.get()] = offloaded; stmt->set_operand(index, copy.get()); stmt->insert_before_me(std::move(copy)); + generic_visit(pcopy); } else { auto global_temporary = Stmt::make( local_to_global_offset_[op], op->ret_type); @@ -653,10 +653,12 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { } else { // For other cases like ArgLoadStmt UnaryOpStmt which needs to load. auto load = Stmt::make(global_temporary.get()); + auto pload = load.get(); stmt_to_offloaded_[load.get()] = offloaded; stmt->set_operand(index, load.get()); stmt->insert_before_me(std::move(global_temporary)); stmt->insert_before_me(std::move(load)); + generic_visit(pload); } } return true; @@ -664,13 +666,9 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { void generic_visit(Stmt *stmt) { int n_op = stmt->num_operands(); - bool modified = false; for (int i = 0; i < n_op; i++) { - if (visit_operand(stmt, i)) - modified = true; + visit_operand(stmt, i); } - if (modified) - throw IRModified(); } void visit(Stmt *stmt) override { From dc9e640b5047dcaa02feb773cf688a4165095c7b Mon Sep 17 00:00:00 2001 From: Xiangyun Yang <96721969+mzmzm@users.noreply.github.com> Date: Sat, 1 Jan 2022 16:10:01 +0800 Subject: [PATCH 2/4] remove redundant catch code --- taichi/transforms/offload.cpp | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index b19ffe4c2b35d..62ab4b1ac882e 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -469,14 +469,7 @@ class PromoteIntermediateToGlobalTmp : public BasicStmtVisitor { static void run(IRNode *root, const StmtToOffsetMap &local_to_global_offset) { PromoteIntermediateToGlobalTmp pass(local_to_global_offset); - while (true) { - try { - root->accept(&pass); - } catch (IRModified) { - continue; - } - break; - } + root->accept(&pass); } private: From fa31729ee374fa2d1c10e6f059c4a4249f8b2238 Mon Sep 17 00:00:00 2001 From: Xiangyun Yang <96721969+mzmzm@users.noreply.github.com> Date: Sun, 2 Jan 2022 14:20:55 +0800 Subject: [PATCH 3/4] remove another `catch` --- taichi/transforms/offload.cpp | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index 62ab4b1ac882e..0d197f118ae37 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -681,14 +681,7 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { OffloadedRanges *offloaded_ranges) { FixCrossOffloadReferences pass(config, local_to_global_offset, stmt_to_offloaded, offloaded_ranges); - while (true) { - try { - root->accept(&pass); - } catch (IRModified) { - continue; - } - break; - } + root->accept(&pass); } private: From f70f31d164a046a7473565cdf5f5e13509fa7b2c Mon Sep 17 00:00:00 2001 From: Xiangyun Yang <96721969+mzmzm@users.noreply.github.com> Date: Sun, 2 Jan 2022 16:43:01 +0800 Subject: [PATCH 4/4] fix --- taichi/transforms/offload.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index 0d197f118ae37..4718d991bbf38 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -646,12 +646,10 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { } else { // For other cases like ArgLoadStmt UnaryOpStmt which needs to load. auto load = Stmt::make(global_temporary.get()); - auto pload = load.get(); stmt_to_offloaded_[load.get()] = offloaded; stmt->set_operand(index, load.get()); stmt->insert_before_me(std::move(global_temporary)); stmt->insert_before_me(std::move(load)); - generic_visit(pload); } } return true;