From 838de6849851694b3b1a9ef656ae205132c43f6e Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 17 Jan 2021 11:44:27 -0500 Subject: [PATCH] [TIR][REFACTOR] ForNode update - Remove deprecated device_api. - Add ThreadBinding for_type. - Add additional annotations. --- include/tvm/tir/stmt.h | 61 ++++++++++++++----- python/tvm/script/scope_handler.py | 8 +-- python/tvm/te/hybrid/parser.py | 2 +- python/tvm/tir/ir_builder.py | 2 +- python/tvm/tir/stmt.py | 35 +++++++++-- src/auto_scheduler/feature.cc | 2 +- src/autotvm/feature_visitor.cc | 4 ++ src/printer/tir_text_printer.cc | 3 + src/printer/tvmscript_printer.cc | 4 ++ src/target/llvm/codegen_cpu.cc | 5 +- src/te/operation/hybrid_op.cc | 13 ++-- src/te/operation/op_utils.cc | 4 +- ...hedule_postproc_rewrite_for_tensor_core.cc | 3 +- src/tir/ir/stmt.cc | 22 ++++--- src/tir/transforms/inject_double_buffer.cc | 3 +- src/tir/transforms/inject_virtual_thread.cc | 7 ++- src/tir/transforms/ir_utils.cc | 3 +- src/tir/transforms/loop_partition.cc | 3 +- src/tir/transforms/narrow_datatype.cc | 2 +- src/tir/transforms/storage_flatten.cc | 4 +- src/tir/transforms/storage_rewrite.cc | 4 +- src/tir/transforms/unroll_loop.cc | 4 +- src/tir/transforms/vectorize_loop.cc | 7 ++- .../unittest/test_arith_domain_touched.py | 6 +- .../test_runtime_module_based_interface.py | 3 +- .../unittest/test_runtime_module_load.py | 6 +- tests/python/unittest/test_tir_constructor.py | 2 +- tests/python/unittest/test_tir_nodes.py | 2 +- .../test_tir_transform_loop_partition.py | 1 + .../test_tir_transform_remove_no_op.py | 11 ++-- tutorials/dev/low_level_custom_pass.py | 4 +- vta/python/vta/transform.py | 12 +++- 32 files changed, 167 insertions(+), 85 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 2b7f1e67bda5..064f3283d275 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -752,16 +752,32 @@ class Evaluate : public Stmt { TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode); }; -/*! \brief Additional annotation of for loop. */ +/*! + * \brief The type of the loop. + * + * ForType can change the control flow semantics + * of the loop. So the for_type field needs to be considered + * in all TIR passes. + */ enum class ForType : int { - /*! \brief serial execution. */ + /*! \brief default semantics -- serial execution. */ Serial = 0, - /*! \brief parallel execution on CPU. */ + /*! \brief Parallel execution on CPU. */ Parallel = 1, - /*! \brief Vector SIMD loop annotaion. */ + /*! + * \brief Vector SIMD loop. + * The loop body will be vectorized. + */ Vectorized = 2, - /*! \brief Unroll annotation. */ - Unrolled = 3 + /*! \brief The loop body must be unrolled. */ + Unrolled = 3, + /*! + * \brief The loop variable is bound to a thread in + * an environment. In the final stage of lowering, + * the loop is simply removed and the loop variable is + * mapped to the corresponding context thread. + */ + ThreadBinding = 4 }; // Kevice api of for loop @@ -789,28 +805,39 @@ class ForNode : public StmtNode { PrimExpr extent; /*! \brief The type of the for loop. */ ForType for_type; - /*! - * \brief Deprecated, reserved for backward compatibility. - * Consider refactor and remove later. - */ - DeviceAPI device_api; /*! \brief The body of the for loop. */ Stmt body; + /*! + * \brief Only valid when for_type == ForType::ThreadBinding + * The context thread that this loop variable bounds to. + */ + Optional thread_binding; + /*! + * \brief Additional annotations about the loop. + * + * These annotations can be used as auxiliary hint + * to future transformations. An annotation should + * not change the control flow semantics of the loop + * and can be ignored in most passes. + */ + Map annotations; void VisitAttrs(AttrVisitor* v) { v->Visit("loop_var", &loop_var); v->Visit("min", &min); v->Visit("extent", &extent); v->Visit("for_type", &for_type); - v->Visit("device_api", &device_api); v->Visit("body", &body); + v->Visit("thread_binding", &thread_binding); + v->Visit("annotations", &annotations); v->Visit("span", &span); } bool SEqualReduce(const ForNode* other, SEqualReducer equal) const { return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) && equal(extent, other->extent) && equal(for_type, other->for_type) && - equal(device_api, other->device_api) && equal(body, other->body); + equal(body, other->body) && equal(thread_binding, other->thread_binding) && + equal(annotations, other->annotations); } void SHashReduce(SHashReducer hash_reduce) const { @@ -818,8 +845,9 @@ class ForNode : public StmtNode { hash_reduce(min); hash_reduce(extent); hash_reduce(for_type); - hash_reduce(device_api); hash_reduce(body); + hash_reduce(thread_binding); + hash_reduce(annotations); } static constexpr const char* _type_key = "tir.For"; @@ -832,8 +860,9 @@ class ForNode : public StmtNode { */ class For : public Stmt { public: - TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api, - Stmt body, Span span = Span()); + TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, Stmt body, + Optional thread_binding = NullOpt, + Map annotations = Map(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); }; diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index 21ed7f6e4682..9449cbdc156c 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -226,7 +226,7 @@ def serial(begin, end, span): self.context.report_error("Expect exact 1 loop var", span) ana = tvm.arith.Analyzer() extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(self.loop_vars[0], begin, extent, 0, 0, self.body, span=span) + return tvm.tir.For(self.loop_vars[0], begin, extent, 0, self.body, span=span) super().__init__(serial) @@ -241,7 +241,7 @@ def parallel(begin, end, span): self.context.report_error("Expect exact 1 loop var") ana = tvm.arith.Analyzer() extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(self.loop_vars[0], begin, extent, 1, 0, self.body, span=span) + return tvm.tir.For(self.loop_vars[0], begin, extent, 1, self.body, span=span) super().__init__(parallel) @@ -256,7 +256,7 @@ def vectorized(begin, end, span): self.context.report_error("Expect exact 1 loop var") ana = tvm.arith.Analyzer() extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(self.loop_vars[0], begin, extent, 2, 0, self.body, span=span) + return tvm.tir.For(self.loop_vars[0], begin, extent, 2, self.body, span=span) super().__init__(vectorized) @@ -271,6 +271,6 @@ def unroll(begin, end, span): self.context.report_error("Expect exact 1 loop var") ana = tvm.arith.Analyzer() extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(self.loop_vars[0], begin, extent, 3, 0, self.body, span=span) + return tvm.tir.For(self.loop_vars[0], begin, extent, 3, self.body, span=span) super().__init__(unroll) diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index d47b2ee879fc..8f87283356e7 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -532,7 +532,7 @@ def visit_For(self, node): _internal_assert( not isinstance(for_type, tuple), "Micro expansion should be handled before!" ) - res = tvm.tir.For(iter_var, tvm.runtime.const(0, "int32"), ext, for_type, 0, _body) + res = tvm.tir.For(iter_var, tvm.runtime.const(0, "int32"), ext, for_type, _body) self.symbols.pop(_name) return res diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 6dcc8580a221..d89a15b839fb 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -259,7 +259,7 @@ def _exit_cb(): for_type_id = 3 else: raise ValueError("Unknown for_type") - self.emit(_stmt.For(loop_var, begin, extent, for_type_id, 0, self._pop_seq())) + self.emit(_stmt.For(loop_var, begin, extent, for_type_id, self._pop_seq())) return WithScope(loop_var, _exit_cb) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 6857b68c261d..282da3b0deb2 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -100,12 +100,16 @@ class For(Stmt): for_type : int The for type. - device_api : int - The device api type. - body : Stmt The body statement. + thread_binding: Optional[tir.IterVar] + The thread this loop binds to. Only valid + if for_type is ThreadBinding + + annotations: tvm.ir.Map + Additional annotation hints. + span : Optional[Span] The location of this itervar in the source code. """ @@ -114,10 +118,29 @@ class For(Stmt): Parallel = 1 Vectorized = 2 Unrolled = 3 - - def __init__(self, loop_var, min_val, extent, for_type, device_api, body, span=None): + ThreadBiding = 4 + + def __init__( + self, + loop_var, + min_val, + extent, + for_type, + body, + thread_binding=None, + annotations=None, + span=None, + ): self.__init_handle_by_constructor__( - _ffi_api.For, loop_var, min_val, extent, for_type, device_api, body, span + _ffi_api.For, + loop_var, + min_val, + extent, + for_type, + body, + thread_binding, + annotations, + span, ) diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 1b10cd5f2601..460548d2bb32 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -618,7 +618,7 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { is_gpu_ = true; // make a fake for node for blockIdx.x or threadIdx.x - Stmt fake_for_node = For(var, 0, extent, ForType::Parallel, DeviceAPI::None, node->body); + Stmt fake_for_node = For(var, 0, extent, ForType::Parallel, node->body); outer_loop_prod_ *= extent; for_loop_stack_.push_back(fake_for_node.as()); diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc index 15e09755cee2..5253ea6ef0f4 100644 --- a/src/autotvm/feature_visitor.cc +++ b/src/autotvm/feature_visitor.cc @@ -47,6 +47,10 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) { case ForType::Serial: ann = kSerial; break; + case ForType::ThreadBinding: + LOG(FATAL) << "Loop ThreadBinding is reserved for future used and " + << "not yet supported in TIR"; + break; } if (EnterItervar_(op->loop_var, loop_extent, ann)) { diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 107817db29b3..76e1d0ab5a96 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -475,6 +475,9 @@ inline const char* ForType2String(ForType t) { return "vectorized"; case ForType::Unrolled: return "unroll"; + case ForType::ThreadBinding: + LOG(FATAL) << "Loop ThreadBinding is reserved for future used and " + << "not yet supported in TIR"; } LOG(FATAL) << "Unknown ForType"; return "Unknown"; diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 09f95e44b6d8..e25c7efd1c18 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -659,6 +659,10 @@ inline const char* ForType2String(ForType t) { return "vectorized"; case ForType::Unrolled: return "unroll"; + case ForType::ThreadBinding: + LOG(FATAL) << "Loop ThreadBinding is reserved for future used and " + << "not yet supported in TIR"; + return "threadbinding"; } LOG(FATAL) << "Unknown ForType"; return "Unknown"; diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 6143e7050495..71d4a0efb1cf 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -980,8 +980,9 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { CodeGenLLVM::VisitStmt_(op); } else if (op->for_type == ForType::Parallel) { if (parallel_env_.penv == nullptr) { - CreateParallelLaunch( - For(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body), 0); + CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->for_type, op->body, + op->thread_binding, op->annotations), + 0); } else { // already in parallel env. ICHECK(parallel_env_.task_id.defined()); diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 94e06d206ddb..53b26303132f 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -234,9 +234,9 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_mapextent - inner)); ret = IfThenElse(cond, ret); ret = For(inner->var, PrimExpr(0), inner->dom->extent, - IterVarTypeToForType(inner->iter_type), op->device_api, ret); + IterVarTypeToForType(inner->iter_type), ret); ret = For(outer->var, PrimExpr(0), outer->dom->extent, - IterVarTypeToForType(outer->iter_type), op->device_api, ret); + IterVarTypeToForType(outer->iter_type), ret); splitted = true; return ret; } @@ -277,8 +277,8 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_maploop_var.get()] = indexdiv(parent, extent); body = tir::Substitute(body, rmap); under_outer = false; - return For(parent->var, PrimExpr(0), extent * op->extent, op->for_type, op->device_api, - body); + return For(parent->var, PrimExpr(0), extent * op->extent, op->for_type, body, + op->thread_binding, op->annotations); } else if (under_outer) { Stmt body = this->VisitStmt(op->body); std::unordered_map rmap; @@ -332,7 +332,7 @@ Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_mapextent, body); } else { return For(op->loop_var, op->min, op->extent, IterVarTypeToForType(attr->iter_type), - op->device_api, op->body); + op->body, op->thread_binding, op->annotations); } } return StmtMutator::VisitStmt_(op); @@ -414,7 +414,8 @@ Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type); } const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second; - return For(target->var, range->min, range->extent, for_type, DeviceAPI::None, body); + return For(target->var, range->min, range->extent, for_type, body, op->thread_binding, + op->annotations); } }; diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc index f1991c181e67..bfce9b71b733 100644 --- a/src/te/operation/op_utils.cc +++ b/src/te/operation/op_utils.cc @@ -115,11 +115,11 @@ std::vector > MakeLoopNest(const Stage& stage, nest[i + 1].emplace_back(LetStmt(var, cast(var.dtype(), dom->min), no_op)); value_map[iv] = cast(var.dtype(), dom->min); } else if (is_zero(dom->min)) { - nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type, DeviceAPI::None, no_op)); + nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type, no_op)); value_map[iv] = var; } else { Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype()); - nest[i + 1].emplace_back(For(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op)); + nest[i + 1].emplace_back(For(idx, 0, dom->extent, for_type, no_op)); PrimExpr new_value = dom->min + idx; value_map[iv] = new_value; nest[i + 1].emplace_back(LetStmt(var, new_value, no_op)); diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index f81d72e0fe02..a543eaeaa096 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -968,7 +968,8 @@ class TensorCoreIRMutator : public StmtExprMutator { scaled_extent_value = ori_extent_value / scale_factor; } PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); - stmt = For(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api, op->body); + stmt = For(op->loop_var, op->min, scaled_extent, op->for_type, op->body, op->thread_binding, + op->annotations); } } return stmt; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index fd03046376f8..e18ae95fda6d 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -128,8 +128,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // For -For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api, - Stmt body, Span span) { +For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, Stmt body, + Optional thread_binding, Map annotations, Span span) { ICHECK(min.defined()); ICHECK(extent.defined()); ICHECK(min.dtype().is_scalar()); @@ -142,18 +142,19 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAP node->min = std::move(min); node->extent = std::move(extent); node->for_type = for_type; - node->device_api = device_api; node->body = std::move(body); + node->thread_binding = std::move(thread_binding); + node->annotations = std::move(annotations); node->span = std::move(span); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.For").set_body_typed([](Var loop_var, PrimExpr min, PrimExpr extent, - int for_type, int device_api, Stmt body, - Span span) { - return For(loop_var, min, extent, static_cast(for_type), - static_cast(device_api), body, span); -}); +TVM_REGISTER_GLOBAL("tir.For").set_body_typed( + [](Var loop_var, PrimExpr min, PrimExpr extent, int for_type, Stmt body, + Optional thread_binding, Optional> annotations, Span span) { + return For(loop_var, min, extent, static_cast(for_type), body, thread_binding, + annotations.value_or(Map()), span); + }); TVM_REGISTER_NODE_TYPE(ForNode); @@ -171,6 +172,9 @@ std::ostream& operator<<(std::ostream& out, ForType type) { // NOLINT(*) case ForType::Vectorized: out << "vectorized"; break; + case ForType::ThreadBinding: + out << "launch_thread"; + break; } return out; } diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 22a6ca23c24c..10e8bf5d457b 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -158,8 +158,7 @@ class DoubleBufferInjector : public StmtExprMutator { vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i); loop_seq.emplace_back(Substitute(old_loop->body, vmap)); } - Stmt loop = For(outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, - SeqStmt::Flatten(loop_seq)); + Stmt loop = For(outer_var, zero, outer_ext, old_loop->for_type, SeqStmt::Flatten(loop_seq)); // tail std::vector tail_seq; Stmt tail_body = StripDoubleBufferWrite()(old_loop->body); diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 5622d140a625..9c255272624d 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -303,7 +303,10 @@ class VTInjector : public StmtExprMutator { if (extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { - return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body); + auto n = CopyOnWrite(op); + n->extent = std::move(extent); + n->body = std::move(body); + return Stmt(n); } } // IfThenElse @@ -417,7 +420,7 @@ class VTInjector : public StmtExprMutator { Map values{{var_, idx}}; stmt = Substitute(stmt, values); return For(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_), - ForType::Serial, DeviceAPI::None, stmt); + ForType::Serial, stmt); } } diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 033a2e093a2a..1579f4f3ab53 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -149,7 +149,8 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return For(new_var, op->min, op->extent, op->for_type, op->device_api, op->body); + return For(new_var, op->min, op->extent, op->for_type, op->body, op->thread_binding, + op->annotations); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index a104dbb029eb..bc75b39535cb 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -607,8 +607,9 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); } else { + ICHECK(for_node->for_type == ForType::Serial || for_node->for_type == ForType::Unrolled); return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->for_type, - for_node->device_api, body); + body); } } diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 0b248959ec6e..bc8993156c31 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -221,7 +221,7 @@ class DataTypeRewriter : public StmtExprMutator { PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->for_type, - op->device_api, op->body); + op->body, op->thread_binding, op->annotations); } Stmt VisitStmt_(const AttrStmtNode* op) final { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index d392866b3694..2ef7370b57a8 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -318,14 +318,14 @@ class StorageFlattener : public StmtExprMutator { } for (int i = starts; i >= 0; --i) { if (i < starts) { - stmt = For(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt); + stmt = For(vars[i], 0, op->bounds[i]->extent, ForType::Serial, stmt); } else { PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); PrimExpr address = Call(DataType::Handle(), builtin::address_of(), {load}); PrimExpr prefetch = Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}); stmt = Evaluate(prefetch); PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; - stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); + stmt = For(vars[i], 0, extent, ForType::Serial, stmt); } } return stmt; diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index d4c5ca09650b..fc0ee1f72f1b 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -444,8 +444,8 @@ class StoragePlanRewriter : public StmtExprMutator { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - return For(op->loop_var, op->min, op->extent, op->for_type, op->device_api, - MakeAttach(svec, op->body)); + return For(op->loop_var, op->min, op->extent, op->for_type, MakeAttach(svec, op->body), + op->thread_binding, op->annotations); } else { return StmtExprMutator::VisitStmt_(op); } diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 71ad899273a6..7c77fd5ed8ab 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -125,8 +125,8 @@ class LoopUnroller : public StmtExprMutator { } else { if (auto_unroll) { if (op->for_type != ForType::Unrolled) { - return For(op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api, - op->body); + return For(op->loop_var, op->min, op->extent, ForType::Unrolled, op->body, + op->thread_binding, op->annotations); } } return stmt; diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 239f42266b83..5ba79ef6c2bd 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -365,7 +365,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorextent) && body.same_as(op->body)) { return GetRef(op); } else { - return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body); + return For(op->loop_var, op->min, extent, op->for_type, body, op->thread_binding, + op->annotations); } } // IfThenElse @@ -436,7 +437,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorname_hint + ".s", var_->dtype); Map values{{var_, idx}}; stmt = Substitute(stmt, values); - return For(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); + return For(idx, 0, var_lanes_, ForType::Serial, stmt); } // ProducerStore Stmt VisitStmt_(const ProducerStoreNode* op) final { @@ -546,7 +547,7 @@ class VectorizeSkipper : public StmtMutator { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); if (op->for_type == ForType::Vectorized) { - return For(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, op->body); + return For(op->loop_var, op->min, op->extent, ForType::Serial, op->body); } else { return stmt; } diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index ca5df4af6a71..e7d251655e14 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -31,14 +31,12 @@ def test_domain_touched(): i, 0, n, - 0, - 0, + tvm.tir.For.Serial, tvm.tir.For( j, 0, m, - 0, - 0, + tvm.tir.For.Serial, tvm.tir.BufferStore( a, tvm.tir.BufferLoad(b, [i - 1, j + 1]) + tvm.tir.BufferLoad(a, [i - 1, j - 1]), diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 64f87fb3c561..ea757a13bdf6 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -547,8 +547,7 @@ def make_func(symbol): i, 0, n - 1, - 0, - 0, + tvm.tir.For.Serial, tvm.tir.Store(Ab.data, tvm.tir.Load("float32", Ab.data, i) + 1, i + 1), ) return tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", symbol) diff --git a/tests/python/unittest/test_runtime_module_load.py b/tests/python/unittest/test_runtime_module_load.py index 7befed3bbcdd..0394074e008f 100644 --- a/tests/python/unittest/test_runtime_module_load.py +++ b/tests/python/unittest/test_runtime_module_load.py @@ -55,7 +55,11 @@ def save_object(names): i = te.var("i") # for i in 0 to n-1: stmt = tvm.tir.For( - i, 0, n - 1, 0, 0, tvm.tir.Store(Ab.data, tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1) + i, + 0, + n - 1, + tvm.tir.For.Serial, + tvm.tir.Store(Ab.data, tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1), ) mod = tvm.IRModule.from_expr( tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "main") diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 2bf4ba51937e..b490e85185bc 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -142,7 +142,7 @@ def test_stmt_constructor(): assert isinstance(x, tvm.tir.AssertStmt) assert x.body == nop - x = tvm.tir.For(te.var("x"), 0, 10, 0, 0, nop) + x = tvm.tir.For(te.var("x"), 0, 10, tvm.tir.For.Serial, nop) assert isinstance(x, tvm.tir.For) assert x.min.value == 0 assert x.extent.value == 10 diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 4d57ed8ec366..6e9753b2fcb9 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -129,7 +129,7 @@ def test_basic(): def test_stmt(): x = tvm.tir.Evaluate(0) - tvm.tir.For(te.var("i"), 0, 1, tvm.tir.For.Serial, 0, x) + tvm.tir.For(te.var("i"), 0, 1, tvm.tir.For.Serial, x) def test_dir(): diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index ecaff319441d..e3a2cec77be2 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -462,6 +462,7 @@ def test_multilevel_splitting_with_indivisble_factors(): def visit_stmt(op): return isinstance(op, tvm.tir.Max) + print(lowered_body) num_max = collect_visit(lowered_body, visit_stmt) assert num_max.count(True) == 10 diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py b/tests/python/unittest/test_tir_transform_remove_no_op.py index 2edb8cf980c2..4eef79540389 100644 --- a/tests/python/unittest/test_tir_transform_remove_no_op.py +++ b/tests/python/unittest/test_tir_transform_remove_no_op.py @@ -34,20 +34,17 @@ def test_remove_no_op(): i, 0, 4, - 0, - 0, + tvm.tir.For.Serial, tvm.tir.For( j, 0, n, - 0, - 0, + tvm.tir.For.Serial, tvm.tir.For( k, 0, m, - 0, - 0, + tvm.tir.For.Serial, tvm.tir.IfThenElse((i * m + j + k < n), tvm.tir.Evaluate(m), tvm.tir.Evaluate(n)), ), ), @@ -65,7 +62,7 @@ def test_remove_no_op(): assert ret == store # remove zero extent loop - stmt3 = tvm.tir.For(i, 0, 0, 0, 0, store) + stmt3 = tvm.tir.For(i, 0, 0, tvm.tir.For.Serial, store) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt3)) ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body assert isinstance(ret, tvm.tir.Evaluate) diff --git a/tutorials/dev/low_level_custom_pass.py b/tutorials/dev/low_level_custom_pass.py index 44fe59f99201..072b244aa867 100644 --- a/tutorials/dev/low_level_custom_pass.py +++ b/tutorials/dev/low_level_custom_pass.py @@ -116,8 +116,8 @@ def vectorize8(op): name = op.loop_var.name lo, li = te.var(name + ".outer"), te.var(name + ".inner") body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li}) - body = tvm.tir.For(li, 0, 8, tvm.tir.For.Vectorized, 0, body) - body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.For.Serial, 0, body) + body = tvm.tir.For(li, 0, 8, tvm.tir.For.Vectorized, body) + body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.For.Serial, body) return body return None diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index a485d2cfb7b8..8baf2a3e5d19 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -231,7 +231,13 @@ def _merge_block(slist, body): body = tvm.tir.AttrStmt(op.node, op.attr_key, op.value, body) elif isinstance(op, tvm.tir.For): body = tvm.tir.For( - op.loop_var, op.min, op.extent, op.for_type, op.device_api, body + op.loop_var, + op.min, + op.extent, + op.for_type, + body, + op.thread_binding, + op.annotations, ) else: raise RuntimeError("unexpected op") @@ -314,7 +320,9 @@ def _do_fold(stmt): if _match_pragma(stmt, "trim_loop"): op = stmt.body assert isinstance(op, tvm.tir.For) - return tvm.tir.For(op.loop_var, op.min, 2, op.for_type, op.device_api, op.body) + return tvm.tir.For( + op.loop_var, op.min, 2, op.for_type, op.body, op.thread_binding, op.annotations + ) return None return f.with_body(