Skip to content

Commit

Permalink
[opt] Accelerate whole_kernel_cse pass (#3957)
Browse files Browse the repository at this point in the history
* second step

* third step

* forth step

* final?

* final?1

* delete TI_AUTO_PROF

* remove TI_AUTO_PROF

* remove TI_AUTO_PROF

* Auto Format

* apply suggestion from code review

* apply suggestion from code review

* simplify code

* Apply suggestions from code review

Co-authored-by: Mingkuan Xu <xumingkuan0721@126.com>

* skip GPS and LUS

* Apply suggestions from code review

Co-authored-by: Mingkuan Xu <xumingkuan0721@126.com>

Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
Co-authored-by: Mingkuan Xu <xumingkuan0721@126.com>
  • Loading branch information
3 people committed Jan 11, 2022
1 parent 13e6320 commit d1ef522
Showing 1 changed file with 30 additions and 4 deletions.
34 changes: 30 additions & 4 deletions taichi/transforms/whole_kernel_cse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class WholeKernelCSE : public BasicStmtVisitor {
private:
std::unordered_set<int> visited_;
// each scope corresponds to an unordered_set
std::vector<std::unordered_map<std::type_index, std::unordered_set<Stmt *>>>
std::vector<std::unordered_map<std::size_t, std::unordered_set<Stmt *> > >
visible_stmts_;
DelayedIRModifier modifier_;

Expand All @@ -67,9 +67,31 @@ class WholeKernelCSE : public BasicStmtVisitor {
visited_.insert(stmt->instance_id);
}

static std::size_t operand_hash(const Stmt *stmt) {
std::size_t hash_code{0};
auto hash_type =
std::hash<std::type_index>{}(std::type_index(typeid(stmt)));
if (stmt->is<GlobalPtrStmt>() || stmt->is<LoopUniqueStmt>()) {
// special cases in common_statement_eliminable()
return hash_type;
}
auto op = stmt->get_operands();
for (auto &x : op) {
if (x == nullptr)
continue;
// Hash the addresses of the operand pointers.
hash_code =
(hash_code * 33) ^
(std::hash<unsigned long>{}(reinterpret_cast<unsigned long>(x)));
}
return hash_type ^ hash_code;
}

static bool common_statement_eliminable(Stmt *this_stmt, Stmt *prev_stmt) {
// Is this_stmt eliminable given that prev_stmt appears before it and has
// the same type with it?
if (this_stmt->type() != prev_stmt->type())
return false;
if (this_stmt->is<GlobalPtrStmt>()) {
auto this_ptr = this_stmt->as<GlobalPtrStmt>();
auto prev_ptr = prev_stmt->as<GlobalPtrStmt>();
Expand All @@ -95,13 +117,17 @@ class WholeKernelCSE : public BasicStmtVisitor {
void visit(Stmt *stmt) override {
if (!stmt->common_statement_eliminable())
return;
// container_statement does not need to be CSE-ed
if (stmt->is_container_statement())
return;
// Generic visitor for all CSE-able statements.
std::size_t hash_value = operand_hash(stmt);
if (is_done(stmt)) {
visible_stmts_.back()[std::type_index(typeid(*stmt))].insert(stmt);
visible_stmts_.back()[hash_value].insert(stmt);
return;
}
for (auto &scope : visible_stmts_) {
for (auto &prev_stmt : scope[std::type_index(typeid(*stmt))]) {
for (auto &prev_stmt : scope[hash_value]) {
if (common_statement_eliminable(stmt, prev_stmt)) {
MarkUndone::run(&visited_, stmt);
stmt->replace_usages_with(prev_stmt);
Expand All @@ -110,7 +136,7 @@ class WholeKernelCSE : public BasicStmtVisitor {
}
}
}
visible_stmts_.back()[std::type_index(typeid(*stmt))].insert(stmt);
visible_stmts_.back()[hash_value].insert(stmt);
set_done(stmt);
}

Expand Down

0 comments on commit d1ef522

Please sign in to comment.