Skip to content

Commit

Permalink
[TIR] Tighten up invariance of CopyOnWrite in recursive stmt visitor (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored and ylc committed Sep 29, 2021
1 parent d51ea60 commit 5f1ae26
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,12 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
*/
template <typename TNode>
ObjectPtr<TNode> CopyOnWrite(const TNode* node) {
static_assert(std::is_base_of<StmtNode, TNode>::value,
"StmtMutator:: CopyOnWrite requires us to track uniqueness of all parent "
"nodes during the recursion. Because the child classes do not necessarily "
"check the Array, Expr and other structures during the visit, it is only safe to "
"call this function with StmtNodes for now. "
"Please create a new node directly in other cases.");
if (allow_copy_on_write_) {
// return the old node.
return runtime::GetObjectPtr<TNode>(const_cast<TNode*>(node));
Expand Down

0 comments on commit 5f1ae26

Please sign in to comment.