From d1ef522b4427646ff033567d9c304d9321d0a344 Mon Sep 17 00:00:00 2001 From: Xiangyun Yang <96721969+mzmzm@users.noreply.github.com> Date: Tue, 11 Jan 2022 11:44:13 +0800 Subject: [PATCH] [opt] Accelerate whole_kernel_cse pass (#3957) * second step * third step * forth step * final? * final?1 * delete TI_AUTO_PROF * remove TI_AUTO_PROF * remove TI_AUTO_PROF * Auto Format * apply suggestion from code review * apply suggestion from code review * simplify code * Apply suggestions from code review Co-authored-by: Mingkuan Xu * skip GPS and LUS * Apply suggestions from code review Co-authored-by: Mingkuan Xu Co-authored-by: Taichi Gardener Co-authored-by: Mingkuan Xu --- taichi/transforms/whole_kernel_cse.cpp | 34 +++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/taichi/transforms/whole_kernel_cse.cpp b/taichi/transforms/whole_kernel_cse.cpp index c4c2144287e0b..1ecb6c1954032 100644 --- a/taichi/transforms/whole_kernel_cse.cpp +++ b/taichi/transforms/whole_kernel_cse.cpp @@ -47,7 +47,7 @@ class WholeKernelCSE : public BasicStmtVisitor { private: std::unordered_set visited_; // each scope corresponds to an unordered_set - std::vector>> + std::vector > > visible_stmts_; DelayedIRModifier modifier_; @@ -67,9 +67,31 @@ class WholeKernelCSE : public BasicStmtVisitor { visited_.insert(stmt->instance_id); } + static std::size_t operand_hash(const Stmt *stmt) { + std::size_t hash_code{0}; + auto hash_type = + std::hash{}(std::type_index(typeid(stmt))); + if (stmt->is() || stmt->is()) { + // special cases in common_statement_eliminable() + return hash_type; + } + auto op = stmt->get_operands(); + for (auto &x : op) { + if (x == nullptr) + continue; + // Hash the addresses of the operand pointers. + hash_code = + (hash_code * 33) ^ + (std::hash{}(reinterpret_cast(x))); + } + return hash_type ^ hash_code; + } + static bool common_statement_eliminable(Stmt *this_stmt, Stmt *prev_stmt) { // Is this_stmt eliminable given that prev_stmt appears before it and has // the same type with it? + if (this_stmt->type() != prev_stmt->type()) + return false; if (this_stmt->is()) { auto this_ptr = this_stmt->as(); auto prev_ptr = prev_stmt->as(); @@ -95,13 +117,17 @@ class WholeKernelCSE : public BasicStmtVisitor { void visit(Stmt *stmt) override { if (!stmt->common_statement_eliminable()) return; + // container_statement does not need to be CSE-ed + if (stmt->is_container_statement()) + return; // Generic visitor for all CSE-able statements. + std::size_t hash_value = operand_hash(stmt); if (is_done(stmt)) { - visible_stmts_.back()[std::type_index(typeid(*stmt))].insert(stmt); + visible_stmts_.back()[hash_value].insert(stmt); return; } for (auto &scope : visible_stmts_) { - for (auto &prev_stmt : scope[std::type_index(typeid(*stmt))]) { + for (auto &prev_stmt : scope[hash_value]) { if (common_statement_eliminable(stmt, prev_stmt)) { MarkUndone::run(&visited_, stmt); stmt->replace_usages_with(prev_stmt); @@ -110,7 +136,7 @@ class WholeKernelCSE : public BasicStmtVisitor { } } } - visible_stmts_.back()[std::type_index(typeid(*stmt))].insert(stmt); + visible_stmts_.back()[hash_value].insert(stmt); set_done(stmt); }