diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index d752c9adfdc2..396ad12150ed 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -266,6 +266,15 @@ Stmt CoProcSync(Stmt stmt); */ Stmt LiftAttrScope(Stmt stmt, std::string attr_key); +/*! + * \brief Lower attached storage access information. + * Do this pass after all storage access analysis finish. + * + * \param stmt The stmt to be trasnformed + * \return Transformed stmt. + */ +Stmt LowerStorageAccessInfo(Stmt stmt); + /*! * \brief Make an user callable API LoweredFunc. * diff --git a/include/tvm/target_info.h b/include/tvm/target_info.h index b1d11a8c0098..73dc2c5982ec 100644 --- a/include/tvm/target_info.h +++ b/include/tvm/target_info.h @@ -23,11 +23,17 @@ struct MemoryInfoNode : public Node { int max_num_bits; /*! \brief maximum number of bits to be used in simd op */ int max_simd_bits; + /*! + * \brief head address of the buffer, if visible to CPU + * This address can be None. + */ + Expr head_address; void VisitAttrs(AttrVisitor* v) final { v->Visit("unit_bits", &unit_bits); v->Visit("max_num_bits", &max_num_bits); v->Visit("max_simd_bits", &max_simd_bits); + v->Visit("head_address", &head_address); } static constexpr const char* _type_key = "MemoryInfo"; diff --git a/python/tvm/build_module.py b/python/tvm/build_module.py index dca3cb90aacb..8506f34729b5 100644 --- a/python/tvm/build_module.py +++ b/python/tvm/build_module.py @@ -197,7 +197,6 @@ def lower(sch, stmt = ir_pass.VectorizeLoop(stmt) stmt = ir_pass.InjectVirtualThread(stmt) stmt = ir_pass.StorageRewrite(stmt) - stmt = ir_pass.CoProcSync(stmt) cfg = BuildConfig.current stmt = ir_pass.UnrollLoop( stmt, @@ -210,6 +209,7 @@ def lower(sch, stmt = ir_pass.Simplify(stmt) if simple_mode: return stmt + stmt = ir_pass.LowerStorageAccessInfo(stmt) return ir_pass.MakeAPI(stmt, name, arg_list, 0, cfg.restricted_func) diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 4591e778ca31..fca67d3b32a5 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -95,6 +95,7 @@ REGISTER_PASS2(BindDeviceType); REGISTER_PASS1(SplitHostDevice); REGISTER_PASS1(StorageRewrite); REGISTER_PASS1(CoProcSync); +REGISTER_PASS1(LowerStorageAccessInfo); REGISTER_PASS1(InjectVirtualThread); REGISTER_PASS1(InjectPrefetch); REGISTER_PASS1(LoopPartition); diff --git a/src/pass/coproc_sync.cc b/src/pass/coproc_sync.cc new file mode 100644 index 000000000000..a1315b67b775 --- /dev/null +++ b/src/pass/coproc_sync.cc @@ -0,0 +1,409 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file coproc_sync.cc + */ +#include +#include +#include +#include +#include +#include +#include "./ir_util.h" +#include "./storage_access.h" + + +namespace tvm { +namespace ir { + +// Visitor to find touched set by co-processor scope. +class CoProcTouchedBuffer : public IRVisitor { + public: + void Visit_(const Load* op) final { + if (in_scope_) { + touched_[op->buffer_var.get()].coproc = true; + } else { + touched_[op->buffer_var.get()].normal = true; + } + IRVisitor::Visit_(op); + } + void Visit_(const Store* op) final { + if (in_scope_) { + touched_[op->buffer_var.get()].coproc = true; + } else { + touched_[op->buffer_var.get()].normal = true; + } + IRVisitor::Visit_(op); + } + void Visit_(const Call* op) final { + if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + const Variable* buffer = op->args[1].as(); + if (in_scope_) { + touched_[buffer].coproc = true; + } else { + touched_[buffer].normal = true; + } + } + IRVisitor::Visit_(op); + } + void Visit_(const AttrStmt* op) final { + if (op->attr_key == attr::coproc_scope && !in_scope_) { + in_scope_ = true; + IterVar iv(op->node.node_); + coproc_.insert(iv); + IRVisitor::Visit_(op); + in_scope_ = false; + } else { + IRVisitor::Visit_(op); + } + } + + // Touch Entry + struct TouchEntry { + bool normal{false}; + bool coproc{false}; + }; + std::unordered_map touched_; + std::unordered_set coproc_; + + private: + bool in_scope_{false}; +}; + +// Synchronization planning with co-processor. +class CoProcSyncPlanner : public StorageAccessVisitor { + public: + explicit CoProcSyncPlanner( + const std::unordered_set& touched, + const std::string& coproc_name) + : touched_(touched), coproc_name_(coproc_name) { + } + + void Plan(const Stmt& stmt) { + this->Visit(stmt); + PlanSync(scope_.back(), nullptr, true); + if (sync_.size() == 0) { + sync_[stmt.get()] = GetSync(coproc_name_ + ".coproc_sync"); + } + } + + // Write synchronization to be inserted before or after stmt. + std::unordered_map > sync_; + + protected: + bool Enabled(const Variable* buf, + const StorageScope& scope) const final { + return touched_.count(buf); + } + + // Plan the sync + std::vector Summarize( + std::vector seq, const For* loop) final { + return PlanSync(seq, loop, false); + } + + private: + // Plan write synchronization if write is not coherent + std::vector PlanSync( + std::vector seq, const For* loop, + bool force_sync_at_end) { + // detect write barriers + // access by the co-processor. + std::vector co_access; + bool contain_sync = false; + + auto find_conflict = [&](const AccessEntry& acc) { + for (const AccessEntry& x : co_access) { + if (x.buffer.same_as(acc.buffer) && + ((acc.type == kRead && x.type == kWrite) || + acc.type == kWrite)) { + return true; + } + } + return false; + }; + for (size_t i = 0; i < seq.size(); ++i) { + const StmtEntry& s = seq[i]; + bool sync_write = false; + for (const AccessEntry& acc : s.access) { + if (acc.threads.size() == 0 && find_conflict(acc)) { + sync_write = true; break; + } + if (acc.type == kSync) { + co_access.clear(); + contain_sync = true; + } + } + if (sync_write) { + CHECK_NE(i, 0U); + sync_[seq[i - 1].stmt] = GetSync(co_access); + co_access.clear(); + contain_sync = true; + } + for (const AccessEntry& acc : s.access) { + if (acc.threads.size() != 0) { + co_access.push_back(acc); + } + } + } + bool sync_at_end = force_sync_at_end; + if (loop != nullptr && !sync_at_end) { + // loop carray dependency + for (size_t i = 0; i < seq.size(); ++i) { + const StmtEntry& s = seq[i]; + for (const AccessEntry& acc : s.access) { + if (acc.threads.size() == 0 && find_conflict(acc)) { + sync_at_end = true; break; + } + } + if (sync_.count(s.stmt) || sync_at_end) break; + } + } + if (sync_at_end && co_access.size() != 0) { + CHECK_NE(seq.size(), 0); + contain_sync = true; + sync_[seq.back().stmt] = GetSync(co_access); + co_access.clear(); + } + if (contain_sync) { + AccessEntry e; + e.type = kSync; + co_access.insert(co_access.begin(), e); + } + return co_access; + } + // Add write Synchronization + std::vector GetSync(const std::vector& co_access) { + // Does not consider memory coherence, need runtime. + CHECK_NE(co_access.size(), 0U); + CHECK_EQ(co_access[0].threads.size(), 1U); + return GetSync(coproc_name_ + ".coproc_sync"); + } + + std::vector GetSync(std::string sync_name) { + return {Evaluate::make(Call::make( + Int(32), + sync_name, + {}, Call::Intrinsic))}; + } + + const std::unordered_set& touched_; + std::string coproc_name_; +}; + +// Detect memory barriers when coproc read/write memory +class CoProcBarrierDetector : public StorageAccessVisitor { + public: + explicit CoProcBarrierDetector( + const std::unordered_set& touched, + const std::string& coproc_name) + : touched_(touched) { + read_barrier_name_ = coproc_name + ".coproc_read_barrier"; + write_barrier_name_ = coproc_name + ".coproc_write_barrier"; + } + + void PlanReadBarrier(Stmt stmt) { + read_barrier_ = true; + this->Visit(stmt); + } + void PlanWriteBarrier(Stmt stmt) { + read_barrier_ = false; + this->Visit(stmt); + } + + std::unordered_map > barrier_before_; + std::unordered_map > barrier_after_; + + protected: + bool Enabled(const Variable* buf, + const StorageScope& scope) const final { + return touched_.count(buf); + } + + // Plan the sync + std::vector Summarize( + std::vector seq, const For* loop) final { + if (read_barrier_) { + return PlanReadBarrier(seq, loop); + } else { + return PlanWriteBarrier(seq, loop); + } + } + + private: + // Plan write barrier at Read after write point. + std::vector PlanWriteBarrier( + std::vector seq, const For* loop) { + std::vector read_seq; + std::unordered_map > write_set; + + auto fupdate = [&](size_t i, const AccessEntry& acc) { + auto it = write_set.find(acc.buffer.get()); + if (it != write_set.end()) { + CHECK_NE(i, 0U); + barrier_after_[seq[i - 1].stmt].push_back( + MakeBarrier(write_barrier_name_, it->second)); + write_set.erase(it); + } + }; + + for (size_t i = 0; i < seq.size(); ++i) { + const StmtEntry& s = seq[i]; + for (const AccessEntry& acc : s.access) { + if (acc.threads.size() == 0 && acc.type == kRead) { + fupdate(i, acc); + read_seq.push_back(acc); + } + } + for (const AccessEntry& acc : s.access) { + if (acc.threads.size() != 0 && acc.type == kWrite) { + write_set[acc.buffer.get()].push_back(acc); + } + } + } + // loop carry + if (loop != nullptr) { + for (const AccessEntry& acc : read_seq) { + fupdate(seq.size(), acc); + } + } + for (const auto &kv : write_set) { + read_seq.insert(read_seq.end(), kv.second.begin(), kv.second.end()); + } + return read_seq; + } + + std::vector PlanReadBarrier( + std::vector seq, const For* loop) { + std::vector write_seq; + std::unordered_map > read_set; + + auto fupdate = [&](size_t i, const AccessEntry& acc) { + auto it = read_set.find(acc.buffer.get()); + if (it != read_set.end()) { + CHECK_NE(i, seq.size()); + barrier_before_[seq[i].stmt].push_back( + MakeBarrier(read_barrier_name_, it->second)); + read_set.erase(it); + } + }; + + for (size_t i = seq.size(); i != 0; --i) { + const StmtEntry& s = seq[i - 1]; + for (const AccessEntry& acc : s.access) { + if (acc.threads.size() == 0 && acc.type == kWrite) { + CHECK_NE(i, seq.size()); + fupdate(i, acc); + write_seq.push_back(acc); + } + } + for (const AccessEntry& acc : s.access) { + if (acc.threads.size() != 0 && acc.type == kRead) { + read_set[acc.buffer.get()].push_back(acc); + } + } + } + // loop carry + if (loop != nullptr) { + for (const AccessEntry& acc : write_seq) { + fupdate(0, acc); + } + } + for (const auto &kv : read_set) { + write_seq.insert(write_seq.end(), kv.second.begin(), kv.second.end()); + } + return write_seq; + } + + Stmt MakeBarrier(const std::string& func, const std::vector& wvec) { + // insert write point + Array wset; + for (const AccessEntry& acc : wvec) { + CHECK(acc.dtype == wvec[0].dtype); + wset.push_back(acc.touched); + } + Range none; + Range r = arith::Union(wset).cover_range(none); + CHECK(r.defined()) + << "Cannot deduce write range of " << wvec[0].buffer; + Expr min = r->min; + Expr extent = r->extent; + return Evaluate::make(Call::make( + Int(32), func, + {wvec[0].buffer, wvec[0].dtype.bits(), r->min, r->extent}, Call::Intrinsic)); + } + // Write barrier name + bool read_barrier_{false}; + std::string read_barrier_name_; + std::string write_barrier_name_; + const std::unordered_set& touched_; +}; + + +class CoProcSyncInserter : public IRMutator { + public: + Stmt Insert(Stmt stmt) { + CoProcTouchedBuffer visitor; + visitor.Visit(stmt); + if (visitor.coproc_.size() == 0) return stmt; + std::unordered_set touched; + + for (const auto &kv : visitor.touched_) { + if (kv.second.normal && kv.second.coproc) { + touched.insert(kv.first); + } + } + CHECK_EQ(visitor.coproc_.size(), 1U); + std::string coproc_name = (*visitor.coproc_.begin())->var->name_hint; + // plan sync. + CoProcSyncPlanner sync_planner(touched, coproc_name); + sync_planner.Plan(stmt); + for (const auto& kv : sync_planner.sync_) { + auto& vec = insert_after_[kv.first]; + vec.insert(vec.end(), kv.second.begin(), kv.second.end()); + } + // Detect barrier + CoProcBarrierDetector barrier_detector(touched, coproc_name); + barrier_detector.PlanReadBarrier(stmt); + barrier_detector.PlanWriteBarrier(stmt); + for (const auto& kv : barrier_detector.barrier_before_) { + auto& vec = insert_before_[kv.first]; + vec.insert(vec.end(), kv.second.begin(), kv.second.end()); + } + for (const auto& kv : barrier_detector.barrier_after_) { + auto& vec = insert_after_[kv.first]; + vec.insert(vec.end(), kv.second.begin(), kv.second.end()); + } + return Mutate(stmt); + } + + Stmt Mutate(Stmt stmt) final { + Stmt before, after; + auto it = insert_before_.find(stmt.get()); + if (it != insert_before_.end()) { + before = MergeSeq(it->second); + } + it = insert_after_.find(stmt.get()); + if (it != insert_after_.end()) { + after = MergeSeq(it->second); + } + stmt = IRMutator::Mutate(stmt); + if (before.defined()) { + stmt = Block::make(before, stmt); + } + if (after.defined()) { + stmt = Block::make(stmt, after); + } + return stmt; + } + + private: + std::unordered_map > insert_before_; + std::unordered_map > insert_after_; +}; + +Stmt CoProcSync(Stmt stmt) { + return CoProcSyncInserter().Insert(stmt); +} + +} // namespace ir +} // namespace tvm diff --git a/src/pass/lower_tvm_builtin.cc b/src/pass/lower_tvm_builtin.cc index c3650cda52bf..a7f78f85f6ce 100644 --- a/src/pass/lower_tvm_builtin.cc +++ b/src/pass/lower_tvm_builtin.cc @@ -6,16 +6,13 @@ #include #include #include -#include #include #include "./ir_util.h" #include "../arithmetic/compute_expr.h" -#include "../runtime/thread_storage_scope.h" + namespace tvm { namespace ir { -using runtime::StorageScope; - inline Expr ConstInt32(size_t index) { CHECK_LE(index, std::numeric_limits::max()); return make_const(Int(32), static_cast(index)); @@ -69,14 +66,7 @@ class BuiltinLower : public IRMutator { // Lower allocate to device allocate when needed. Stmt stmt = IRMutator::Mutate_(op, s); op = stmt.as(); - // For special memory, remove allocate. - auto it = storage_info_.find(op->buffer_var.get()); - if (it != storage_info_.end() && it->second.scope.tag.length() != 0) { - ++it->second.alloc_count; - CHECK_LE(it->second.alloc_count, 1) - << "Double allocation of " << it->second.scope.to_string(); - return op->body; - } + if (op->new_expr.defined()) return stmt; // Get constant allocation bound. int64_t dev_type; int64_t nbytes = GetVectorBytes(op->type); @@ -139,25 +129,12 @@ class BuiltinLower : public IRMutator { CHECK(!device_type_.defined()); device_type_ = op->value; return Mutate(op->body); - } else if (op->attr_key == attr::storage_scope) { - const Variable* buf = op->node.as(); - StorageScope scope = StorageScope::make(op->value.as()->value); - StorageEntry e; - e.scope = scope; - if (scope.tag.length() != 0) { - e.info = GetMemoryInfo(op->value.as()->value); - CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); - } - storage_info_[buf] = e; - return IRMutator::Mutate_(op, s); } else { return IRMutator::Mutate_(op, s); } } Expr Mutate_(const Call* op, const Expr &e) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { - return MakeAccessPtr(op, e); - } else if (op->is_intrinsic(intrinsic::tvm_call_packed)) { + if (op->is_intrinsic(intrinsic::tvm_call_packed)) { return MakeCallPacked(op, e); } else if (op->is_intrinsic(intrinsic::tvm_stack_make_shape)) { return MakeShape(op, e); @@ -167,14 +144,6 @@ class BuiltinLower : public IRMutator { return IRMutator::Mutate_(op, e); } } - - Expr Convert(Type t, Expr e) { - if (e.type() != t) { - return Cast::make(t, e); - } else { - return e; - } - } // call shape Expr MakeShape(const Call* op, const Expr& e) { size_t stack_begin = run_shape_stack_; @@ -183,7 +152,7 @@ class BuiltinLower : public IRMutator { op = expr.as(); for (size_t i = 0; i < op->args.size(); ++i) { prep_seq_.emplace_back( - Store::make(stack_shape_, Convert(Int(64), op->args[i]), + Store::make(stack_shape_, cast(Int(64), op->args[i]), ConstInt32(stack_begin +i), const_true(1))); } return AddressOffset(stack_shape_, Int(64), stack_begin); @@ -224,15 +193,15 @@ class BuiltinLower : public IRMutator { } prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrByteOffset, - Convert(UInt(64), byte_offset))); + cast(UInt(64), byte_offset))); CHECK(device_type_.defined()) << "Unknown device type in current IR"; CHECK(device_id_.defined()) << "Unknown device id in current IR"; prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceId, - Convert(Int(32), device_id_))); + cast(Int(32), device_id_))); prep_seq_.emplace_back( TVMStructSet(stack_array_, idx, intrinsic::kArrDeviceType, - Convert(Int(32), device_type_))); + cast(Int(32), device_type_))); return TVMStructGet(Handle(), stack_array_, idx, intrinsic::kArrAddr); } // call packled. @@ -280,33 +249,6 @@ class BuiltinLower : public IRMutator { Int(32), intrinsic::tvm_call_packed_lowered, packed_args, Call::Intrinsic); } - // tvm_access_ptr - Expr MakeAccessPtr(const Call* op, const Expr& e) { - // Specially handle the buffer packed intrinsic - Expr expr = IRMutator::Mutate_(op, e); - op = expr.as(); - CHECK_EQ(op->args.size(), 5U); - Type dtype = op->args[0].type(); - const Variable* buffer = op->args[1].as(); - Expr offset = op->args[2]; - auto it = storage_info_.find(buffer); - if (it != storage_info_.end() && it->second.scope.tag.length() != 0) { - return MakeTaggedAccessPtr( - op->type, dtype, offset, - it->second.info.defined() ? it->second.info->unit_bits : 8); - } - CHECK(op->type.is_handle()); - // Change to address_of - return AddressOffset(Var(op->args[1].node_), dtype, offset); - } - - Expr MakeTaggedAccessPtr(Type ptr_type, Type dtype, - Expr offset, int unit_bits) { - int dtype_bits = dtype.bits() * dtype.lanes(); - CHECK_EQ(unit_bits % dtype_bits, 0); - return Convert(ptr_type, - ir::Simplify(offset / make_const(offset.type(), unit_bits / dtype_bits))); - } private: bool IsArrayHandle(const Expr& arg) { @@ -337,17 +279,6 @@ class BuiltinLower : public IRMutator { uint64_t max_shape_stack_{0}; uint64_t max_array_stack_{0}; uint64_t max_arg_stack_{0}; - // The storage entry. - struct StorageEntry { - // Whether it is tagged memory. - StorageScope scope; - // The memory info if any. - MemoryInfo info; - // Allocation counter - int alloc_count{0}; - }; - // The storage scope of each buffer - std::unordered_map storage_info_; }; LoweredFunc LowerTVMBuiltin(LoweredFunc f) { diff --git a/src/pass/storage_access.cc b/src/pass/storage_access.cc index 80a768b174bc..399d92133f74 100644 --- a/src/pass/storage_access.cc +++ b/src/pass/storage_access.cc @@ -2,7 +2,12 @@ * Copyright (c) 2017 by Contributors * \file storage_access.cc */ +#include +#include +#include +#include "./ir_util.h" #include "./storage_access.h" +#include "../arithmetic/compute_expr.h" namespace tvm { namespace ir { @@ -191,5 +196,110 @@ StorageScope StorageAccessVisitor::GetScope(const Variable* buf) const { if (it == storage_scope_.end()) return s; return it->second; } + +class StorageAccessInfoLower : public IRMutator { + public: + Stmt Mutate_(const Allocate* op, const Stmt& s) final { + // Lower allocate to device allocate when needed. + Stmt stmt = IRMutator::Mutate_(op, s); + op = stmt.as(); + // For special memory, remove allocate, or use head expr + auto it = storage_info_.find(op->buffer_var.get()); + if (it != storage_info_.end() && it->second.info.defined()) { + const MemoryInfo& info = it->second.info; + ++it->second.alloc_count; + CHECK_LE(it->second.alloc_count, 1) + << "Double allocation of " << it->second.scope.to_string(); + if (info->head_address.defined()) { + return Allocate::make( + op->buffer_var, op->type, op->extents, op->condition, + op->body, info->head_address, "nop"); + } + return op->body; + } else { + return stmt; + } + } + Stmt Mutate_(const AttrStmt* op, const Stmt& s) final { + if (op->attr_key == attr::storage_scope) { + const Variable* buf = op->node.as(); + StorageScope scope = StorageScope::make(op->value.as()->value); + StorageEntry e; + e.scope = scope; + if (scope.tag.length() != 0) { + e.info = GetMemoryInfo(op->value.as()->value); + CHECK(e.info.defined()) << "Cannot find memory info of " << scope.to_string(); + } + storage_info_[buf] = e; + return IRMutator::Mutate_(op, s); + + } else { + return IRMutator::Mutate_(op, s); + } + } + + Expr Mutate_(const Call* op, const Expr &e) final { + if (op->is_intrinsic(intrinsic::tvm_access_ptr)) { + return MakeAccessPtr(op, e); + } else { + return IRMutator::Mutate_(op, e); + } + } + + private: + // tvm_access_ptr + Expr MakeAccessPtr(const Call* op, const Expr& e) { + // Specially handle the buffer packed intrinsic + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as(); + CHECK_EQ(op->args.size(), 5U); + Type dtype = op->args[0].type(); + const Variable* buffer = op->args[1].as(); + Var buffer_var(op->args[1].node_); + Expr offset = op->args[2]; + auto it = storage_info_.find(buffer); + if (it != storage_info_.end() && it->second.info.defined()) { + return MakeTaggedAccessPtr( + op->type, buffer_var, dtype, offset, + it->second.info); + } + CHECK(op->type.is_handle()); + // Change to address_of + return AddressOffset(buffer_var, dtype, offset); + } + + Expr MakeTaggedAccessPtr(Type ptr_type, + Var buffer_var, + Type dtype, + Expr offset, + const MemoryInfo& info) { + if (ptr_type.is_handle()) { + CHECK(info->head_address.defined()) + << buffer_var << " is not adddressable."; + return AddressOffset(buffer_var, dtype, offset); + } + int dtype_bits = dtype.bits() * dtype.lanes(); + CHECK_EQ(info->unit_bits % dtype_bits, 0); + return cast(ptr_type, + ir::Simplify(offset / make_const( + offset.type(), info->unit_bits / dtype_bits))); + } + // The storage entry. + struct StorageEntry { + // Whether it is tagged memory. + StorageScope scope; + // The memory info if any. + MemoryInfo info; + // Allocation counter + int alloc_count{0}; + }; + // The storage scope of each buffer + std::unordered_map storage_info_; +}; + +Stmt LowerStorageAccessInfo(Stmt stmt) { + return StorageAccessInfoLower().Mutate(stmt); +} + } // namespace ir } // namespace tvm diff --git a/src/pass/storage_flatten.cc b/src/pass/storage_flatten.cc index 0776388cbb09..543286efc1fa 100644 --- a/src/pass/storage_flatten.cc +++ b/src/pass/storage_flatten.cc @@ -86,7 +86,6 @@ class StorageFlattener : public IRMutator { return this->Mutate(op->body); } else { // create a buffer entry - // TODO(tqchen) allow permutation and inference of index dimension. BufferEntry e; e.bounds = op->bounds; Array shape; diff --git a/src/pass/storage_sync.cc b/src/pass/storage_sync.cc index aac60b993554..13773321edac 100644 --- a/src/pass/storage_sync.cc +++ b/src/pass/storage_sync.cc @@ -153,7 +153,6 @@ class ThreadSyncInserter : public IRMutator { Stmt Mutate(Stmt stmt) final { if (syncs_.size() == 0) return stmt; - stmt = IRMutator::Mutate(stmt); if (syncs_.count(stmt.get())) { Stmt barrier; if (sync_scope_.rank == 0) { @@ -164,7 +163,11 @@ class ThreadSyncInserter : public IRMutator { {StringImm::make(sync_scope_.to_string())}, Call::Intrinsic)); } + // Mutate after query, to avoid stmt change. + stmt = IRMutator::Mutate(stmt); stmt = Block::make(barrier, stmt); + } else { + stmt = IRMutator::Mutate(stmt); } return stmt; } @@ -296,201 +299,5 @@ LoweredFunc ThreadSync(LoweredFunc f, std::string storage_scope) { return LoweredFunc(n); } -// Visitor to find touched set by co-processor scope. -class CoProcTouchedBuffer : public IRVisitor { - public: - void Visit_(const Load* op) final { - if (in_scope_) { - touched_.insert(op->buffer_var.get()); - } - IRVisitor::Visit_(op); - } - void Visit_(const Store* op) final { - if (in_scope_) { - touched_.insert(op->buffer_var.get()); - } - IRVisitor::Visit_(op); - } - void Visit_(const Call* op) final { - if (op->is_intrinsic(intrinsic::tvm_access_ptr) && in_scope_) { - const Variable* buffer = op->args[1].as(); - touched_.insert(buffer); - } - IRVisitor::Visit_(op); - } - void Visit_(const AttrStmt* op) final { - if (op->attr_key == attr::coproc_scope && !in_scope_) { - in_scope_ = true; - IterVar iv(op->node.node_); - coproc_.insert(iv); - IRVisitor::Visit_(op); - in_scope_ = false; - } else { - IRVisitor::Visit_(op); - } - } - - std::unordered_set touched_; - std::unordered_set coproc_; - - private: - bool in_scope_{false}; -}; - -// Synchronization planning with co-processor. -class CoProcSyncPlanner : public StorageAccessVisitor { - public: - void Plan(const Stmt& stmt) { - CoProcTouchedBuffer visitor; - visitor.Visit(stmt); - touched_ = std::move(visitor.touched_); - if (!touched_.empty()) { - this->Visit(stmt); - PlanWriteSync(scope_.back(), nullptr, true); - CHECK_EQ(visitor.coproc_.size(), 1U); - if (write_sync_.size() == 0) { - write_sync_[stmt.get()] = GetWriteSync( - (*visitor.coproc_.begin())->var->name_hint + ".coproc_sync"); - } - } - } - - // Write synchronization to be inserted before or after stmt. - std::unordered_map > write_sync_; - - protected: - bool Enabled(const Variable* buf, - const StorageScope& scope) const final { - return touched_.count(buf) && scope == global_scope_; - } - - // Plan the sync - std::vector Summarize( - std::vector seq, const For* loop) final { - return PlanWriteSync(seq, loop, false); - } - - private: - // Plan write synchronization if write is not coherent - std::vector PlanWriteSync( - std::vector seq, const For* loop, - bool force_sync_at_end) { - // detect write barriers - // access by the co-processor. - std::vector co_access; - bool contain_sync = false; - - auto find_conflict = [&](const AccessEntry& acc) { - for (const AccessEntry& x : co_access) { - if (x.buffer.same_as(acc.buffer) && - ((acc.type == kRead && x.type == kWrite) || - acc.type == kWrite)) { - return true; - } - } - return false; - }; - for (size_t i = 0; i < seq.size(); ++i) { - const StmtEntry& s = seq[i]; - bool sync_write = false; - for (const AccessEntry& acc : s.access) { - if (acc.threads.size() == 0 && find_conflict(acc)) { - sync_write = true; break; - } - if (acc.type == kSync) { - co_access.clear(); - contain_sync = true; - } - } - if (sync_write) { - CHECK_NE(i, 0U); - write_sync_[seq[i - 1].stmt] = GetWriteSync(co_access); - co_access.clear(); - contain_sync = true; - } - for (const AccessEntry& acc : s.access) { - if (acc.threads.size() != 0) { - co_access.push_back(acc); - } - } - } - bool sync_at_end = force_sync_at_end; - if (loop != nullptr && !sync_at_end) { - // loop carray dependency - for (size_t i = 0; i < seq.size(); ++i) { - const StmtEntry& s = seq[i]; - for (const AccessEntry& acc : s.access) { - if (acc.threads.size() == 0 && find_conflict(acc)) { - sync_at_end = true; break; - } - } - if (write_sync_.count(s.stmt) || sync_at_end) break; - } - } - if (sync_at_end && co_access.size() != 0) { - CHECK_NE(seq.size(), 0); - contain_sync = true; - write_sync_[seq.back().stmt] = GetWriteSync(co_access); - co_access.clear(); - } - if (contain_sync) { - AccessEntry e; - e.type = kSync; - e.scope = global_scope_; - co_access.insert(co_access.begin(), e); - } - return co_access; - } - // Add write Synchronization - std::vector GetWriteSync(const std::vector& co_access) { - // Does not consider memory coherence, need runtime. - CHECK_NE(co_access.size(), 0U); - CHECK_EQ(co_access[0].threads.size(), 1U); - return GetWriteSync(co_access[0].threads[0]->var->name_hint + ".coproc_sync"); - } - - std::vector GetWriteSync(std::string sync_name) { - std::vector stmts; - stmts.emplace_back( - Evaluate::make(Call::make( - Int(32), - sync_name, - {}, Call::Intrinsic))); - return stmts; - } - - std::unordered_set touched_; - StorageScope global_scope_ = StorageScope::make("global"); -}; - -class CoProcSyncInserter : public IRMutator { - public: - explicit CoProcSyncInserter( - const std::unordered_map >& write_sync) - : write_sync_(write_sync) {} - - Stmt Mutate(Stmt stmt) final { - stmt = IRMutator::Mutate(stmt); - auto it = write_sync_.find(stmt.get()); - if (it != write_sync_.end()) { - stmt = Block::make(stmt, MergeSeq(it->second)); - } - return stmt; - } - - private: - const std::unordered_map >& write_sync_; -}; - -Stmt CoProcSync(Stmt stmt) { - CoProcSyncPlanner planner; - planner.Plan(stmt); - if (planner.write_sync_.size() != 0) { - return CoProcSyncInserter(planner.write_sync_).Mutate(stmt); - } else { - return stmt; - } -} - } // namespace ir } // namespace tvm diff --git a/tests/python/unittest/test_pass_storage_sync.py b/tests/python/unittest/test_pass_storage_sync.py index 417419fe15b0..8734da696a4a 100644 --- a/tests/python/unittest/test_pass_storage_sync.py +++ b/tests/python/unittest/test_pass_storage_sync.py @@ -32,16 +32,31 @@ def test_coproc_sync(): ib = tvm.ir_builder.create() n = tvm.var("n") cp = tvm.thread_axis((0, 1), "cop") - A = ib.allocate("float32", n, name="A", scope="global") + + @tvm.register_func("tvm.info.mem.global.cache") + def meminfo_cache(): + return tvm.make.node( + "MemoryInfo", + unit_bits=8, + max_simd_bits=32, + max_num_bits=128, + head_address=tvm.call_extern("handle", "global_cache")) + A = ib.allocate("float32", 128, name="A", scope="global.cache") with ib.for_range(0, n, name="i") as i: A[i] = A[i] + 1 - with ib.for_range(0, 10, name="j") as j: - ib.scope_attr(cp, "coproc_scope", 1) - A[j] = A[j] + 2 - body = ib.get() - body = tvm.ir_pass.CoProcSync(body) - body = body.body.body.body - assert(tvm.make.stmt_list(body)[-1].value.name == "cop.coproc_sync") + with ib.for_range(0, 8, name="k") as k: + with ib.for_range(0, 10, name="j") as j: + ib.scope_attr(cp, "coproc_scope", 1) + A[j] = A[j + k * 10] + 2 + stmt = ib.get() + stmt = tvm.ir_pass.CoProcSync(stmt) + body = stmt.body.body.body + blist = tvm.make.stmt_list(body) + assert(blist[1].value.name == "cop.coproc_read_barrier") + assert(blist[1].value.args[3].value == 80) + assert(blist[-2].value.name == "cop.coproc_sync") + assert(blist[-1].value.name == "cop.coproc_write_barrier") + assert(blist[-1].value.args[3].value == 10) if __name__ == "__main__":