diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 2b7f1e67bda5..093d49ca2dd4 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -752,23 +752,34 @@ class Evaluate : public Stmt { TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode); }; -/*! \brief Additional annotation of for loop. */ -enum class ForType : int { - /*! \brief serial execution. */ - Serial = 0, - /*! \brief parallel execution on CPU. */ - Parallel = 1, - /*! \brief Vector SIMD loop annotaion. */ - Vectorized = 2, - /*! \brief Unroll annotation. */ - Unrolled = 3 +/*! + * \brief The kind of the loop. + * + * ForKind can change the control flow semantics + * of the loop. So the kind field needs to be considered + * in all TIR passes. + */ +enum class ForKind : int { + /*! \brief default semantics -- serial execution. */ + kSerial = 0, + /*! \brief Parallel execution on CPU. */ + kParallel = 1, + /*! + * \brief Vector SIMD loop. + * The loop body will be vectorized. + */ + kVectorized = 2, + /*! \brief The loop body must be unrolled. */ + kUnrolled = 3, + /*! + * \brief The loop variable is bound to a thread in + * an environment. In the final stage of lowering, + * the loop is simply removed and the loop variable is + * mapped to the corresponding context thread. + */ + kThreadBinding = 4 }; -// Kevice api of for loop -// kept for backward compatibility -// consider refactor and remove later. -enum class DeviceAPI : int { None = 0 }; - /*! * \brief A for loop, with poissible type annotations. * @@ -787,39 +798,50 @@ class ForNode : public StmtNode { PrimExpr min; /*! \brief The extent of the iteration. */ PrimExpr extent; - /*! \brief The type of the for loop. */ - ForType for_type; - /*! - * \brief Deprecated, reserved for backward compatibility. - * Consider refactor and remove later. - */ - DeviceAPI device_api; + /*! \brief The kind of the for loop. */ + ForKind kind; /*! \brief The body of the for loop. */ Stmt body; + /*! + * \brief Only valid when kind == ForKind::kThreadBinding + * The context thread that this loop variable bounds to. + */ + Optional thread_binding; + /*! + * \brief Additional annotations about the loop. + * + * These annotations can be used as auxiliary hint + * to future transformations. An annotation should + * not change the control flow semantics of the loop + * and can be ignored in most passes. + */ + Map annotations; void VisitAttrs(AttrVisitor* v) { v->Visit("loop_var", &loop_var); v->Visit("min", &min); v->Visit("extent", &extent); - v->Visit("for_type", &for_type); - v->Visit("device_api", &device_api); + v->Visit("kind", &kind); v->Visit("body", &body); + v->Visit("thread_binding", &thread_binding); + v->Visit("annotations", &annotations); v->Visit("span", &span); } bool SEqualReduce(const ForNode* other, SEqualReducer equal) const { return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) && - equal(extent, other->extent) && equal(for_type, other->for_type) && - equal(device_api, other->device_api) && equal(body, other->body); + equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) && + equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations); } void SHashReduce(SHashReducer hash_reduce) const { hash_reduce.DefHash(loop_var); hash_reduce(min); hash_reduce(extent); - hash_reduce(for_type); - hash_reduce(device_api); + hash_reduce(kind); hash_reduce(body); + hash_reduce(thread_binding); + hash_reduce(annotations); } static constexpr const char* _type_key = "tir.For"; @@ -832,8 +854,9 @@ class ForNode : public StmtNode { */ class For : public Stmt { public: - TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api, - Stmt body, Span span = Span()); + TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, + Optional thread_binding = NullOpt, + Map annotations = Map(), Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode); }; @@ -1015,7 +1038,7 @@ inline bool IsPragmaKey(const std::string& attr_key) { TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span()); // overload printing of for type. -TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type); +TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind); } // namespace tir } // namespace tvm diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index 21ed7f6e4682..9449cbdc156c 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -226,7 +226,7 @@ def serial(begin, end, span): self.context.report_error("Expect exact 1 loop var", span) ana = tvm.arith.Analyzer() extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(self.loop_vars[0], begin, extent, 0, 0, self.body, span=span) + return tvm.tir.For(self.loop_vars[0], begin, extent, 0, self.body, span=span) super().__init__(serial) @@ -241,7 +241,7 @@ def parallel(begin, end, span): self.context.report_error("Expect exact 1 loop var") ana = tvm.arith.Analyzer() extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(self.loop_vars[0], begin, extent, 1, 0, self.body, span=span) + return tvm.tir.For(self.loop_vars[0], begin, extent, 1, self.body, span=span) super().__init__(parallel) @@ -256,7 +256,7 @@ def vectorized(begin, end, span): self.context.report_error("Expect exact 1 loop var") ana = tvm.arith.Analyzer() extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(self.loop_vars[0], begin, extent, 2, 0, self.body, span=span) + return tvm.tir.For(self.loop_vars[0], begin, extent, 2, self.body, span=span) super().__init__(vectorized) @@ -271,6 +271,6 @@ def unroll(begin, end, span): self.context.report_error("Expect exact 1 loop var") ana = tvm.arith.Analyzer() extent = end if begin == 0 else ana.simplify(end - begin) - return tvm.tir.For(self.loop_vars[0], begin, extent, 3, 0, self.body, span=span) + return tvm.tir.For(self.loop_vars[0], begin, extent, 3, self.body, span=span) super().__init__(unroll) diff --git a/python/tvm/te/hybrid/calls.py b/python/tvm/te/hybrid/calls.py index 761189115050..6785457c3bd7 100644 --- a/python/tvm/te/hybrid/calls.py +++ b/python/tvm/te/hybrid/calls.py @@ -23,18 +23,18 @@ from tvm.target import Target from tvm.tir import expr as _expr from tvm.tir import call_intrin -from tvm.tir.stmt import For +from tvm.tir.stmt import ForKind from .utils import _internal_assert # pylint: disable=redefined-builtin,invalid-name LOOP_INTRIN = { - "range": For.Serial, - "unroll": For.Unrolled, - "parallel": For.Parallel, - "vectorize": For.Vectorized, - "const_range": (For.Unrolled,), + "range": ForKind.SERIAL, + "unroll": ForKind.UNROLLED, + "parallel": ForKind.PARALLEL, + "vectorize": ForKind.VECTORIZED, + "const_range": (ForKind.UNROLLED,), } @@ -48,9 +48,9 @@ def _range(annotation, args): low, ext = args[0], args[1] if not tvm.tir.analysis.expr_deep_equal(low, const(0, dtype="int32")): ext = ext - low - for_type = LOOP_INTRIN[annotation] + kind = LOOP_INTRIN[annotation] iter_var = None - return iter_var, low, ext, for_type + return iter_var, low, ext, kind range = unroll = vectorize = parallel = const_range = _range # pylint: disable=invalid-name @@ -63,8 +63,8 @@ def bind(func_id, args): _internal_assert(isinstance(args[0], str), "A loop bind's first argument should be a string!") low, ext = const(0, "int32"), args[1] iter_var = tvm.te.thread_axis((low, ext), args[0]) - for_type = None - return iter_var, low, ext, for_type + kind = None + return iter_var, low, ext, kind def _math_intrin(func_id, args): diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index d47b2ee879fc..7bb85e3da83c 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -480,14 +480,14 @@ def visit_Call(self, node): return op def visit_For(self, node): - iter_var, low, ext, for_type = self.visit(node.iter) + iter_var, low, ext, kind = self.visit(node.iter) _internal_assert( isinstance(node.target, ast.Name), "The loop iterator should be a variable!" ) _name = node.target.id - if isinstance(for_type, tuple): + if isinstance(kind, tuple): low = self.analyzer.simplify(low) ext = self.analyzer.simplify(ext) _internal_assert( @@ -511,14 +511,14 @@ def visit_For(self, node): return concat_list_to_block(bodies) if iter_var is None: - _internal_assert(for_type is not None, "The loop iterating function parse error!") + _internal_assert(kind is not None, "The loop iterating function parse error!") offset = iter_var = tvm.te.var(_name) if not tvm.tir.analysis.expr_deep_equal(low, tvm.runtime.const(0, "int32")): offset = iter_var + low self.add_symbol(_name, Symbol.LoopVar, offset) _body = visit_list_to_block(self.visit, node.body) else: - _internal_assert(for_type is None, "The loop bind function parse error!") + _internal_assert(kind is None, "The loop bind function parse error!") self.add_symbol(_name, Symbol.ThreadBind, iter_var) self.device += 1 _body = visit_list_to_block(self.visit, node.body) @@ -526,13 +526,13 @@ def visit_For(self, node): _body = self.wrap_up_realize(node, _body) - if for_type is None: + if kind is None: res = _body else: _internal_assert( - not isinstance(for_type, tuple), "Micro expansion should be handled before!" + not isinstance(kind, tuple), "Micro expansion should be handled before!" ) - res = tvm.tir.For(iter_var, tvm.runtime.const(0, "int32"), ext, for_type, 0, _body) + res = tvm.tir.For(iter_var, tvm.runtime.const(0, "int32"), ext, kind, _body) self.symbols.pop(_name) return res diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 901c89ed9106..324c4daf19ba 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -27,7 +27,7 @@ from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle from .expr import Call, CallEffectKind, Let, IterVar, Any -from .stmt import Stmt, LetStmt, AssertStmt, For +from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt from .stmt import ProducerRealize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 6dcc8580a221..437e8f6610f4 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -206,7 +206,7 @@ def scope_attr(self, node, attr_key, value): value = op.max(1, value) self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x)) - def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"): + def for_range(self, begin, end, name="i", dtype="int32", kind="serial"): """Create a for iteration scope. Parameters @@ -224,7 +224,7 @@ def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"): dtype : str, optional The data type of iteration variable. - for_type : str, optional + kind : str, optional The special tag on the for loop. Returns @@ -249,17 +249,17 @@ def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"): extent = end if begin == 0 else (end - begin) def _exit_cb(): - if for_type == "serial": - for_type_id = 0 - elif for_type == "parallel": - for_type_id = 1 - elif for_type == "vectorize": - for_type_id = 2 - elif for_type == "unroll": - for_type_id = 3 + if kind == "serial": + kind_id = _stmt.ForKind.SERIAL + elif kind == "parallel": + kind_id = _stmt.ForKind.PARALLEL + elif kind == "vectorize": + kind_id = _stmt.ForKind.VECTORIZED + elif kind == "unroll": + kind_id = _stmt.ForKind.UNROLLED else: - raise ValueError("Unknown for_type") - self.emit(_stmt.For(loop_var, begin, extent, for_type_id, 0, self._pop_seq())) + raise ValueError("Unknown kind") + self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq())) return WithScope(loop_var, _exit_cb) diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 6857b68c261d..9e1ef56cca58 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -26,6 +26,7 @@ assert isinstance(st, tvm.tir.stmt.Store) assert(st.buffer_var == a) """ +from enum import IntEnum import tvm._ffi from tvm.runtime import Object @@ -82,6 +83,22 @@ def __init__(self, condition, message, body, span=None): self.__init_handle_by_constructor__(_ffi_api.AssertStmt, condition, message, body, span) +class ForKind(IntEnum): + """The kind of the for loop. + + note + ---- + ForKind can change the control flow semantics + of the loop and need to be considered in all TIR passes. + """ + + SERIAL = 0 + PARALLEL = 1 + VECTORIZED = 2 + UNROLLED = 3 + THREAD_BINDING = 4 + + @tvm._ffi.register_object("tir.For") class For(Stmt): """For node. @@ -97,27 +114,44 @@ class For(Stmt): extent : PrimExpr The length of the loop. - for_type : int - The for type. - - device_api : int - The device api type. + kind : ForKind + The type of the for. body : Stmt The body statement. + thread_binding: Optional[tir.IterVar] + The thread this loop binds to. Only valid + if kind is ThreadBinding + + annotations: tvm.ir.Map + Additional annotation hints. + span : Optional[Span] The location of this itervar in the source code. """ - Serial = 0 - Parallel = 1 - Vectorized = 2 - Unrolled = 3 - - def __init__(self, loop_var, min_val, extent, for_type, device_api, body, span=None): + def __init__( + self, + loop_var, + min_val, + extent, + kind, + body, + thread_binding=None, + annotations=None, + span=None, + ): self.__init_handle_by_constructor__( - _ffi_api.For, loop_var, min_val, extent, for_type, device_api, body, span + _ffi_api.For, + loop_var, + min_val, + extent, + kind, + body, + thread_binding, + annotations, + span, ) diff --git a/python/tvm/topi/cuda/nms.py b/python/tvm/topi/cuda/nms.py index 6f3ed789ffc1..0c01cc9fbbdf 100644 --- a/python/tvm/topi/cuda/nms.py +++ b/python/tvm/topi/cuda/nms.py @@ -580,7 +580,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): j = bx * max_threads + tx with ib.if_scope(j < nkeep): src_idx = base_src_idx + sorted_index[i * num_anchors + j] * box_data_length - with ib.for_range(0, 4, for_type="unroll") as k: + with ib.for_range(0, 4, kind="unroll") as k: out_bboxes[(base_bbox_idx + j * 4 + k)] = data[src_idx + coord_start + k] out_scores[i * num_anchors + j] = data[src_idx + score_index] @@ -593,7 +593,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): # Only needed for return_indices = False case if return_indices is False: with ib.if_scope(j < num_anchors): - with ib.for_range(0, 4, for_type="unroll") as k: + with ib.for_range(0, 4, kind="unroll") as k: out_bboxes[(base_bbox_idx + j * 4 + k)] = -1.0 out_scores[i, j] = -1.0 @@ -609,7 +609,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): with ib.if_scope(j < valid_count[i]): src_offset = base_src_idx + j * box_data_length - with ib.for_range(0, 4, for_type="unroll") as k: + with ib.for_range(0, 4, kind="unroll") as k: out_bboxes[base_bbox_idx + j * 4 + k] = data[src_offset + coord_start + k] out_scores[i * num_anchors + j] = data[src_offset + score_index] @@ -855,7 +855,7 @@ def ir(out_bboxes, out_scores, out_class_ids, out): i = by with ib.if_scope(tid < num_anchors): - with ib.for_range(0, 4, for_type="unroll") as j: + with ib.for_range(0, 4, kind="unroll") as j: out[i, tid, coord_start + j] = out_bboxes[i, tid, j] out[i, tid, score_index] = out_scores[i, tid] if id_index >= 0: diff --git a/python/tvm/topi/cuda/rcnn/proposal.py b/python/tvm/topi/cuda/rcnn/proposal.py index 5b7884c7363b..e5e83b4911a3 100644 --- a/python/tvm/topi/cuda/rcnn/proposal.py +++ b/python/tvm/topi/cuda/rcnn/proposal.py @@ -181,7 +181,7 @@ def argsort_ir(data_buf, out_index_buf): idxm = tvm.tir.indexmod - with ib.for_range(0, batch, for_type="unroll") as b: + with ib.for_range(0, batch, kind="unroll") as b: start = b * num_bbox for i in range(2): bbox_id = tid * 2 + i @@ -259,7 +259,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): ib.scope_attr(tx, "thread_extent", nthread_tx) ib.scope_attr(bx, "thread_extent", nthread_bx) i = bx * max_threads + tx - with ib.for_range(0, batch, for_type="unroll", name="n") as b: + with ib.for_range(0, batch, kind="unroll", name="n") as b: base_idx = b * num_bbox with ib.if_scope(i < num_bbox): p_out[base_idx + i] = False @@ -323,7 +323,7 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf): tvm.tir.all(i[0] < rpn_post_nms_top_n, p_remove[(b * num_bbox + j)] == False) ): p_out[offset_i] = tvm.tir.Cast("float32", b) - with ib.for_range(0, 4, for_type="unroll") as k: + with ib.for_range(0, 4, kind="unroll") as k: p_out[offset_i + k + 1] = p_sorted_bbox[offset_j + k] i[0] = i[0] + 1 diff --git a/python/tvm/topi/cuda/sparse.py b/python/tvm/topi/cuda/sparse.py index f2cecacbc618..cb61d9686919 100644 --- a/python/tvm/topi/cuda/sparse.py +++ b/python/tvm/topi/cuda/sparse.py @@ -228,8 +228,8 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): ) # zero block - with ib.for_range(0, bs_m, name="x", for_type="unroll") as x: - with ib.for_range(0, bs_n, name="y", for_type="unroll") as y: + with ib.for_range(0, bs_m, name="x", kind="unroll") as x: + with ib.for_range(0, bs_n, name="y", kind="unroll") as y: block[x, y] = 0.0 # compute into thread local storage using warp_size chunks with ib.for_range(0, rowlength_bo, name="bb") as bb: @@ -240,26 +240,26 @@ def gen_ir(data, w_data, w_indices, w_indptr, out): # each thread has a row # TODO: ideally we could vectorize this with ib.for_range(0, rowlength_bi, name="bi") as bi: - with ib.for_range(0, bs_m, name="x", for_type="unroll") as x: - with ib.for_range(0, bs_k, name="z", for_type="unroll") as z: + with ib.for_range(0, bs_m, name="x", kind="unroll") as x: + with ib.for_range(0, bs_k, name="z", kind="unroll") as z: # This memory acces should be out of bounds when # m_index >= mb (which occurs when the dense matrix # rows % 32 != 0), but it seems to work just fine... data_cache[bi, x, z] = data_ptr[indices[bi] * bs_k + z, m_index * bs_m + x] # cache w_data elem_idx = bb * rowlength_bi + tx - with ib.for_range(0, bs_n, name="y", for_type="unroll") as y: - with ib.for_range(0, bs_k, name="z", for_type="unroll") as z: + with ib.for_range(0, bs_n, name="y", kind="unroll") as y: + with ib.for_range(0, bs_k, name="z", kind="unroll") as z: w_data_cache[tx, y, z] = w_data_ptr[row_start + elem_idx, y, z] with ib.for_range(0, mi, name="i") as i: # thread local block matmul - with ib.for_range(0, bs_m, name="x", for_type="unroll") as x: - with ib.for_range(0, bs_n, name="y", for_type="unroll") as y: - with ib.for_range(0, bs_k, name="z", for_type="unroll") as z: + with ib.for_range(0, bs_m, name="x", kind="unroll") as x: + with ib.for_range(0, bs_n, name="y", kind="unroll") as y: + with ib.for_range(0, bs_k, name="z", kind="unroll") as z: block[x, y] += data_cache[i, x, z] * w_data_cache[i, y, z] # store results - with ib.for_range(0, bs_m, name="x", for_type="unroll") as x: - with ib.for_range(0, bs_n, name="y", for_type="unroll") as y: + with ib.for_range(0, bs_m, name="x", kind="unroll") as x: + with ib.for_range(0, bs_n, name="y", kind="unroll") as y: with ib.if_scope(m_index < mb): with ib.if_scope(n_index < nb): # It doesn't seem like we would be getting coelesced diff --git a/python/tvm/topi/nn/sparse.py b/python/tvm/topi/nn/sparse.py index cdccc80bb5f8..8145ed80af47 100644 --- a/python/tvm/topi/nn/sparse.py +++ b/python/tvm/topi/nn/sparse.py @@ -294,26 +294,26 @@ def _csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr): n = get_const_tuple(indptr.shape)[0] - 1 nnz = get_const_tuple(data.shape)[0] - with irb.for_range(0, n, for_type="parallel", name="col") as col: + with irb.for_range(0, n, kind="parallel", name="col") as col: out_indptr_ptr[col] = 0 - with irb.for_range(0, nnz, for_type="serial", name="nz_idx") as nz_idx: + with irb.for_range(0, nnz, kind="serial", name="nz_idx") as nz_idx: out_indptr_ptr[indices_ptr[nz_idx]] += 1 cumsum = irb.allocate("int32", (1,), name="cumsum", scope="local") temp = irb.allocate("int32", (1,), name="temp", scope="local") cumsum[0] = 0 - with irb.for_range(0, n, for_type="serial", name="col") as col: + with irb.for_range(0, n, kind="serial", name="col") as col: temp[0] = out_indptr_ptr[col] out_indptr_ptr[col] = cumsum[0] cumsum[0] += temp[0] out_indptr_ptr[n] = nnz - with irb.for_range(0, n, for_type="serial", name="row") as row: + with irb.for_range(0, n, kind="serial", name="row") as row: offset = indptr_ptr[row] diff = indptr_ptr[row + 1] - indptr_ptr[row] - with irb.for_range(0, diff, for_type="serial", name="idx") as idx: + with irb.for_range(0, diff, kind="serial", name="idx") as idx: real_idx = offset + idx col = indices_ptr[real_idx] dest = out_indptr_ptr[col] @@ -325,7 +325,7 @@ def _csr_transpose_ir(data, indices, indptr, out_data, out_indices, out_indptr): last = irb.allocate("int32", (1,), name="last", scope="local") temp2 = irb.allocate("int32", (1,), name="temp2", scope="local") last[0] = 0 - with irb.for_range(0, n, for_type="serial", name="col") as col: + with irb.for_range(0, n, kind="serial", name="col") as col: temp2[0] = out_indptr_ptr[col] out_indptr_ptr[col] = last[0] last[0] = temp2[0] diff --git a/python/tvm/topi/sparse/csrmm.py b/python/tvm/topi/sparse/csrmm.py index f578e6001351..39ba3332fc72 100644 --- a/python/tvm/topi/sparse/csrmm.py +++ b/python/tvm/topi/sparse/csrmm.py @@ -72,8 +72,8 @@ def csrmm_default_ir(data, indices, indptr, weight, out): out_ptr = irb.buffer_ptr(out) M = simplify(indptr.shape[0] - 1) _, N = weight.shape - with irb.for_range(0, N, for_type="vectorize", name="n") as n: - with irb.for_range(0, M, for_type="parallel", name="row") as row: + with irb.for_range(0, N, kind="vectorize", name="n") as n: + with irb.for_range(0, M, kind="parallel", name="row") as row: dot = irb.allocate("float32", (1,), name="dot", scope="local") out_ptr[row * N + n] = 0.0 dot[0] = 0.0 diff --git a/python/tvm/topi/sparse/csrmv.py b/python/tvm/topi/sparse/csrmv.py index afe3bc76d121..a2d22afe01e0 100644 --- a/python/tvm/topi/sparse/csrmv.py +++ b/python/tvm/topi/sparse/csrmv.py @@ -63,7 +63,7 @@ def csrmv_default_ir(data, indices, indptr, weight, out): weight_ptr = irb.buffer_ptr(weight) out_ptr = irb.buffer_ptr(out) num_rows = indptr.shape[0] - 1 - with irb.for_range(0, num_rows, for_type="parallel", name="row") as row: + with irb.for_range(0, num_rows, kind="parallel", name="row") as row: dot = irb.allocate("float32", (1,), name="dot", scope="local") out_ptr[row] = 0.0 dot[0] = 0.0 diff --git a/python/tvm/topi/sparse/dense.py b/python/tvm/topi/sparse/dense.py index d1516d0c20fc..5c63e44f691a 100644 --- a/python/tvm/topi/sparse/dense.py +++ b/python/tvm/topi/sparse/dense.py @@ -74,8 +74,8 @@ def dense_default_ir(data, indices, indptr, weight, out): out_ptr = irb.buffer_ptr(out) M = simplify(indptr.shape[0] - 1) N, K = weight.shape - with irb.for_range(0, N, for_type="vectorize", name="n") as n: - with irb.for_range(0, M, for_type="parallel", name="m") as m: + with irb.for_range(0, N, kind="vectorize", name="n") as n: + with irb.for_range(0, M, kind="parallel", name="m") as m: dot = irb.allocate(dtype, (1,), name="dot", scope="local") out_ptr[m * N + n] = tvm.tir.const(0, dtype) dot[0] = tvm.tir.const(0, dtype) @@ -153,8 +153,8 @@ def dense_default_ir(data, w_data, w_indices, w_indptr, out): out_ptr = irb.buffer_ptr(out) M, K = data.shape N = simplify(w_indptr.shape[0] - 1) - with irb.for_range(0, M, for_type="vectorize", name="m") as m: - with irb.for_range(0, N, for_type="parallel", name="n") as n: + with irb.for_range(0, M, kind="vectorize", name="m") as m: + with irb.for_range(0, N, kind="parallel", name="n") as n: dot = irb.allocate(dtype, (1,), name="dot", scope="local") out_ptr[m * N + n] = tvm.tir.const(0, dtype) dot[0] = tvm.tir.const(0, dtype) diff --git a/python/tvm/topi/vision/rcnn/proposal.py b/python/tvm/topi/vision/rcnn/proposal.py index 89726efd5d0e..e15ba8cd27c7 100644 --- a/python/tvm/topi/vision/rcnn/proposal.py +++ b/python/tvm/topi/vision/rcnn/proposal.py @@ -208,7 +208,7 @@ def argsort_ir(data_buf, out_index_buf): temp_data = ib.allocate("float32", (1,), name="temp_data", scope="local") temp_index = ib.allocate("int32", (1,), name="temp_index", scope="local") idxm = tvm.tir.indexmod - with ib.for_range(0, batch, for_type="unroll") as b: + with ib.for_range(0, batch, kind="unroll") as b: start = b * num_bbox for i in range(2): with ib.for_range(0, (num_bbox + 1) // 2) as tid: @@ -279,7 +279,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx): ib = tvm.tir.ir_builder.create() p_data = ib.buffer_ptr(sorted_bbox_buf) p_out = ib.buffer_ptr(out_buf) - with ib.for_range(0, batch, for_type="unroll", name="n") as b: + with ib.for_range(0, batch, kind="unroll", name="n") as b: base_idx = b * num_bbox for i in range(num_bbox): p_out[base_idx + i] = False @@ -345,7 +345,7 @@ def prepare_output_ir(sorted_bbox_buf, remove_mask_buf, out_buf): ) ): p_out[offset_i] = tvm.tir.Cast("float32", b) - with ib.for_range(0, 4, for_type="unroll") as k: + with ib.for_range(0, 4, kind="unroll") as k: p_out[offset_i + k + 1] = p_sorted_bbox[offset_j + k] i[b] = i[b] + 1 diff --git a/python/tvm/topi/x86/scatter.py b/python/tvm/topi/x86/scatter.py index 8147d3a00135..8bb3f57e82e4 100644 --- a/python/tvm/topi/x86/scatter.py +++ b/python/tvm/topi/x86/scatter.py @@ -84,7 +84,7 @@ def gen_ir(data_ptr, indices_ptr, out_ptr): out[i] = tvm.tir.Cast(data_ptr.dtype, 0) with ib.for_range(0, fused_indices_dimension) as i: - with ib.for_range(0, fused_data_dimension, for_type="parallel") as j: + with ib.for_range(0, fused_data_dimension, kind="parallel") as j: offset = fused_data_dimension index = j # This is x_M, .. x_{N-1} part of the index into out. # Build up the indices[0, y_0, .. y_{K-1}], .. indices[M-1, y_0, .. y_{K-1}] part diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 1b10cd5f2601..cf516d8452e2 100755 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -618,7 +618,7 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { is_gpu_ = true; // make a fake for node for blockIdx.x or threadIdx.x - Stmt fake_for_node = For(var, 0, extent, ForType::Parallel, DeviceAPI::None, node->body); + Stmt fake_for_node = For(var, 0, extent, ForKind::kParallel, node->body); outer_loop_prod_ *= extent; for_loop_stack_.push_back(fake_for_node.as()); @@ -642,11 +642,11 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { void VisitStmt_(const ForNode* node) final { int64_t loop_extent = GetLoopExtent(node); - if (node->for_type == ForType::Vectorized) { + if (node->kind == ForKind::kVectorized) { vec_for_stack_.push_back(node); - } else if (node->for_type == ForType::Unrolled) { + } else if (node->kind == ForKind::kUnrolled) { unroll_for_stack_.push_back(node); - } else if (node->for_type == ForType::Parallel) { + } else if (node->kind == ForKind::kParallel) { parallel_for_stack_.push_back(node); } @@ -656,11 +656,11 @@ class PerStoreFeatureExtractor : public StmtExprVisitor { for_loop_stack_.pop_back(); outer_loop_prod_ /= loop_extent; - if (node->for_type == ForType::Vectorized) { + if (node->kind == ForKind::kVectorized) { vec_for_stack_.pop_back(); - } else if (node->for_type == ForType::Unrolled) { + } else if (node->kind == ForKind::kUnrolled) { unroll_for_stack_.pop_back(); - } else if (node->for_type == ForType::Parallel) { + } else if (node->kind == ForKind::kParallel) { parallel_for_stack_.pop_back(); } } diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc index 15e09755cee2..59cac9cc9827 100644 --- a/src/autotvm/feature_visitor.cc +++ b/src/autotvm/feature_visitor.cc @@ -34,19 +34,23 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) { int64_t loop_extent = -1; if (extent != nullptr) loop_extent = extent->value; AnnotationType ann = kSerial; - switch (op->for_type) { - case ForType ::Parallel: + switch (op->kind) { + case ForKind ::kParallel: ann = kParallel; break; - case ForType::Unrolled: + case ForKind::kUnrolled: ann = kUnrolled; break; - case ForType::Vectorized: + case ForKind::kVectorized: ann = kVectorized; break; - case ForType::Serial: + case ForKind::kSerial: ann = kSerial; break; + case ForKind::kThreadBinding: + LOG(FATAL) << "Loop ThreadBinding is reserved for future used and " + << "not yet supported in TIR"; + break; } if (EnterItervar_(op->loop_var, loop_extent, ann)) { diff --git a/src/printer/tir_text_printer.cc b/src/printer/tir_text_printer.cc index 107817db29b3..4b0871ae2ce6 100644 --- a/src/printer/tir_text_printer.cc +++ b/src/printer/tir_text_printer.cc @@ -465,18 +465,21 @@ Doc TIRTextPrinter::VisitStmt_(const EvaluateNode* op) { return doc; } -inline const char* ForType2String(ForType t) { +inline const char* ForKind2String(ForKind t) { switch (t) { - case ForType::Serial: + case ForKind::kSerial: return "serial"; - case ForType::Parallel: + case ForKind::kParallel: return "parallel"; - case ForType::Vectorized: + case ForKind::kVectorized: return "vectorized"; - case ForType::Unrolled: + case ForKind::kUnrolled: return "unroll"; + case ForKind::kThreadBinding: + LOG(FATAL) << "Loop ThreadBinding is reserved for future used and " + << "not yet supported in TIR"; } - LOG(FATAL) << "Unknown ForType"; + LOG(FATAL) << "Unknown ForKind"; return "Unknown"; } @@ -484,8 +487,8 @@ Doc TIRTextPrinter::VisitStmt_(const ForNode* op) { Doc doc; doc << "for (" << Print(op->loop_var) << ", " << Print(op->min) << ", " << Print(op->min + op->extent) << ")"; - if (op->for_type != ForType::Serial) { - doc << " " << Doc::StrLiteral(ForType2String(op->for_type)); + if (op->kind != ForKind::kSerial) { + doc << " " << Doc::StrLiteral(ForKind2String(op->kind)); } doc << PrintBody(op->body); return doc; diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 09f95e44b6d8..86b175e1676c 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -649,27 +649,30 @@ Doc TVMScriptPrinter::VisitStmt_(const EvaluateNode* op) { return doc; } -inline const char* ForType2String(ForType t) { +inline const char* ForKind2String(ForKind t) { switch (t) { - case ForType::Serial: + case ForKind::kSerial: return "serial"; - case ForType::Parallel: + case ForKind::kParallel: return "parallel"; - case ForType::Vectorized: + case ForKind::kVectorized: return "vectorized"; - case ForType::Unrolled: + case ForKind::kUnrolled: return "unroll"; + case ForKind::kThreadBinding: + LOG(FATAL) << "Loop ThreadBinding is reserved for future used and " + << "not yet supported in TIR"; + return "threadbinding"; } - LOG(FATAL) << "Unknown ForType"; + LOG(FATAL) << "Unknown ForKind"; return "Unknown"; } Doc TVMScriptPrinter::VisitStmt_(const ForNode* op) { Doc doc; var_not_in_headers.insert(op->loop_var.get()); - doc << "for " << Print(op->loop_var) - << " in tir." + std::string(ForType2String(op->for_type)) + "(" << Print(op->min) << ", " - << Print(op->min + op->extent) + doc << "for " << Print(op->loop_var) << " in tir." + std::string(ForKind2String(op->kind)) + "(" + << Print(op->min) << ", " << Print(op->min + op->extent) << "):" << Doc::Indent(4, Doc::NewLine() << PrintBody(op->body)); return doc; } diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index 6143e7050495..e2a8553199f0 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -976,12 +976,13 @@ void CodeGenCPU::VisitStmt_(const AttrStmtNode* op) { void CodeGenCPU::VisitStmt_(const ForNode* op) { ICHECK(is_zero(op->min)); - if (op->for_type == ForType::Serial || op->for_type == ForType::Unrolled) { + if (op->kind == ForKind::kSerial || op->kind == ForKind::kUnrolled) { CodeGenLLVM::VisitStmt_(op); - } else if (op->for_type == ForType::Parallel) { + } else if (op->kind == ForKind::kParallel) { if (parallel_env_.penv == nullptr) { - CreateParallelLaunch( - For(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body), 0); + CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->kind, op->body, + op->thread_binding, op->annotations), + 0); } else { // already in parallel env. ICHECK(parallel_env_.task_id.defined()); @@ -1007,7 +1008,7 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) { ++parallel_env_.parallel_loop_count; } } else { - LOG(FATAL) << "cannot handle for type " << op->for_type; + LOG(FATAL) << "cannot handle for type " << op->kind; } } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 34f3897cce88..1dd76f6b9d51 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1318,11 +1318,11 @@ void CodeGenLLVM::VisitStmt_(const StoreNode* op) { void CodeGenLLVM::VisitStmt_(const ForNode* op) { ICHECK(is_zero(op->min)); analyzer_->Bind(op->loop_var, Range::FromMinExtent(op->min, op->extent)); - if (op->for_type == ForType::Unrolled) { + if (op->kind == ForKind::kUnrolled) { LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, " << " consider set unroll_explicit=True"; } else { - ICHECK(op->for_type == ForType::Serial); + ICHECK(op->kind == ForKind::kSerial); } CreateSerialFor(MakeValue(op->min), MakeValue(op->extent), llvm::ConstantInt::getSigned(GetLLVMType(op->extent), 1), op->loop_var, op->body); diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 6c73716edc18..e5547315613f 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -84,7 +84,7 @@ std::string CodeGenCUDA::Finish() { void CodeGenCUDA::VisitStmt_(const tir::ForNode* op) { ICHECK(is_const_int(op->min, 0)); - if (op->for_type == tir::ForType::Unrolled) { + if (op->kind == tir::ForKind::kUnrolled) { PrintIndent(); stream << "#pragma unroll\n"; } diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index c3b12ab943c6..51d136d5510e 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -492,7 +492,7 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { loop_var.SetIncoming(0, init_value, init_label); spirv::Value loop_cond = builder_->LT(loop_var, extent_value); uint32_t control = - (op->for_type == ForType::Unrolled ? spv::LoopControlUnrollMask : spv::LoopControlMaskNone); + (op->kind == ForKind::kUnrolled ? spv::LoopControlUnrollMask : spv::LoopControlMaskNone); builder_->MakeInst(spv::OpLoopMerge, merge_label, continue_label, control); builder_->MakeInst(spv::OpBranchConditional, loop_cond, body_label, merge_label, weight_likely_branch_, 1); diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 94e06d206ddb..65b8660ca1fb 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -234,9 +234,9 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_mapextent - inner)); ret = IfThenElse(cond, ret); ret = For(inner->var, PrimExpr(0), inner->dom->extent, - IterVarTypeToForType(inner->iter_type), op->device_api, ret); + IterVarTypeToForKind(inner->iter_type), ret); ret = For(outer->var, PrimExpr(0), outer->dom->extent, - IterVarTypeToForType(outer->iter_type), op->device_api, ret); + IterVarTypeToForKind(outer->iter_type), ret); splitted = true; return ret; } @@ -277,8 +277,8 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_maploop_var.get()] = indexdiv(parent, extent); body = tir::Substitute(body, rmap); under_outer = false; - return For(parent->var, PrimExpr(0), extent * op->extent, op->for_type, op->device_api, - body); + return For(parent->var, PrimExpr(0), extent * op->extent, op->kind, body, + op->thread_binding, op->annotations); } else if (under_outer) { Stmt body = this->VisitStmt(op->body); std::unordered_map rmap; @@ -331,8 +331,8 @@ Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_mapbody, rmap); return AttrStmt(iter_var, "thread_extent", op->extent, body); } else { - return For(op->loop_var, op->min, op->extent, IterVarTypeToForType(attr->iter_type), - op->device_api, op->body); + return For(op->loop_var, op->min, op->extent, IterVarTypeToForKind(attr->iter_type), + op->body, op->thread_binding, op->annotations); } } return StmtMutator::VisitStmt_(op); @@ -345,18 +345,18 @@ Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_mapsecond : iter_var; const VarNode* var = actual->var.get(); - ForType expected = IterVarTypeToForType(iter_var->iter_type); + ForKind expected = IterVarTypeToForKind(iter_var->iter_type); IterVarAttr attr; if (stage->iter_var_attrs.count(iter_var)) { attr = stage->iter_var_attrs[iter_var]; - expected = IterVarTypeToForType(attr->iter_type); + expected = IterVarTypeToForKind(attr->iter_type); } PostOrderVisit(stmt, [&found, &var, &attr, &expected, &need_change](const ObjectRef& node) { if (const ForNode* op = node.as()) { if (op->loop_var.get() == var) { ++found; - need_change = expected != op->for_type || (attr.defined() && attr->bind_thread.defined()); + need_change = expected != op->kind || (attr.defined() && attr->bind_thread.defined()); } } }); @@ -409,12 +409,13 @@ Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map if (body_.same_as(op->body) && op->loop_var.get() == target->var.get()) return GetRef(op); const Stmt& body = op->body.same_as(body_) ? op->body : body_; - ForType for_type = IterVarTypeToForType(target->iter_type); + ForKind kind = IterVarTypeToForKind(target->iter_type); if (stage->iter_var_attrs.count(target)) { - for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type); + kind = IterVarTypeToForKind(stage->iter_var_attrs[target]->iter_type); } const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second; - return For(target->var, range->min, range->extent, for_type, DeviceAPI::None, body); + return For(target->var, range->min, range->extent, kind, body, op->thread_binding, + op->annotations); } }; @@ -448,7 +449,7 @@ std::vector GatherLoopVars(Stmt stmt) { if (const ForNode* op = node.as()) { Var loop_var(op->loop_var); Range dom = Range::FromMinExtent(op->min, op->extent); - res_.push_back(IterVar(dom, loop_var, ForTypeToIterVarType(op->for_type))); + res_.push_back(IterVar(dom, loop_var, ForKindToIterVarType(op->kind))); } }); std::reverse(res_.begin(), res_.end()); diff --git a/src/te/operation/op_utils.cc b/src/te/operation/op_utils.cc index f1991c181e67..32ffccbbec1f 100644 --- a/src/te/operation/op_utils.cc +++ b/src/te/operation/op_utils.cc @@ -77,7 +77,7 @@ std::vector > MakeLoopNest(const Stage& stage, var = Var(iv->var->name_hint + ".init", bind_iv->var.dtype()); } - ForType for_type = ForType::Serial; + ForKind kind = ForKind::kSerial; IterVarAttr it_attr; if (stage->iter_var_attrs.count(iv)) { it_attr = stage->iter_var_attrs[iv]; @@ -85,13 +85,13 @@ std::vector > MakeLoopNest(const Stage& stage, if (it_attr.defined()) { switch (it_attr->iter_type) { case kUnrolled: - for_type = ForType::Unrolled; + kind = ForKind::kUnrolled; break; case kVectorized: - for_type = ForType::Vectorized; + kind = ForKind::kVectorized; break; case kParallelized: - for_type = ForType::Parallel; + kind = ForKind::kParallel; break; case kDataPar: break; @@ -115,11 +115,11 @@ std::vector > MakeLoopNest(const Stage& stage, nest[i + 1].emplace_back(LetStmt(var, cast(var.dtype(), dom->min), no_op)); value_map[iv] = cast(var.dtype(), dom->min); } else if (is_zero(dom->min)) { - nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type, DeviceAPI::None, no_op)); + nest[i + 1].emplace_back(For(var, 0, dom->extent, kind, no_op)); value_map[iv] = var; } else { Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype()); - nest[i + 1].emplace_back(For(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op)); + nest[i + 1].emplace_back(For(idx, 0, dom->extent, kind, no_op)); PrimExpr new_value = dom->min + idx; value_map[iv] = new_value; nest[i + 1].emplace_back(LetStmt(var, new_value, no_op)); @@ -243,33 +243,33 @@ Stmt Substitute(Stmt s, const std::unordered_map& value_map) return tir::Substitute(s, init); } -IterVarType ForTypeToIterVarType(tir::ForType for_type) { - switch (for_type) { - case ForType::Serial: +IterVarType ForKindToIterVarType(tir::ForKind kind) { + switch (kind) { + case ForKind::kSerial: return kDataPar; - case ForType::Parallel: + case ForKind::kParallel: return kParallelized; - case ForType::Vectorized: + case ForKind::kVectorized: return kVectorized; - case ForType::Unrolled: + case ForKind::kUnrolled: return kUnrolled; default: return kDataPar; } } -tir::ForType IterVarTypeToForType(IterVarType iter_type) { +tir::ForKind IterVarTypeToForKind(IterVarType iter_type) { switch (iter_type) { case kDataPar: - return ForType::Serial; + return ForKind::kSerial; case kParallelized: - return ForType::Parallel; + return ForKind::kParallel; case kVectorized: - return ForType::Vectorized; + return ForKind::kVectorized; case kUnrolled: - return ForType::Unrolled; + return ForKind::kUnrolled; default: - return ForType::Serial; + return ForKind::kSerial; } } diff --git a/src/te/operation/op_utils.h b/src/te/operation/op_utils.h index 16f7d96cfa77..e6bf2caae6e0 100644 --- a/src/te/operation/op_utils.h +++ b/src/te/operation/op_utils.h @@ -88,16 +88,16 @@ PrimExpr ReplaceTensor(PrimExpr expr, const std::unordered_map& Stmt Substitute(Stmt stmt, const std::unordered_map& value_map); /*! - * \brief Converts Halide ForType to its corresponding IterVarType - * \param for_type The ForType to be converted + * \brief Converts Halide ForKind to its corresponding IterVarType + * \param kind The ForKind to be converted */ -IterVarType ForTypeToIterVarType(tir::ForType for_type); +IterVarType ForKindToIterVarType(tir::ForKind kind); /*! - * \brief Converts IterVarType to its corresponding Halide ForType + * \brief Converts IterVarType to its corresponding Halide ForKind * \param iter_type The IterVarType to be converted */ -tir::ForType IterVarTypeToForType(IterVarType iter_type); +tir::ForKind IterVarTypeToForKind(IterVarType iter_type); } // namespace te } // namespace tvm diff --git a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc index f81d72e0fe02..74d1a19d2cfe 100644 --- a/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc +++ b/src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc @@ -968,7 +968,8 @@ class TensorCoreIRMutator : public StmtExprMutator { scaled_extent_value = ori_extent_value / scale_factor; } PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value); - stmt = For(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api, op->body); + stmt = For(op->loop_var, op->min, scaled_extent, op->kind, op->body, op->thread_binding, + op->annotations); } } return stmt; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index fd03046376f8..92dc38797544 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -128,8 +128,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); // For -For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api, - Stmt body, Span span) { +For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body, + Optional thread_binding, Map annotations, Span span) { ICHECK(min.defined()); ICHECK(extent.defined()); ICHECK(min.dtype().is_scalar()); @@ -141,36 +141,40 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAP node->loop_var = std::move(loop_var); node->min = std::move(min); node->extent = std::move(extent); - node->for_type = for_type; - node->device_api = device_api; + node->kind = kind; node->body = std::move(body); + node->thread_binding = std::move(thread_binding); + node->annotations = std::move(annotations); node->span = std::move(span); data_ = std::move(node); } -TVM_REGISTER_GLOBAL("tir.For").set_body_typed([](Var loop_var, PrimExpr min, PrimExpr extent, - int for_type, int device_api, Stmt body, - Span span) { - return For(loop_var, min, extent, static_cast(for_type), - static_cast(device_api), body, span); -}); +TVM_REGISTER_GLOBAL("tir.For").set_body_typed( + [](Var loop_var, PrimExpr min, PrimExpr extent, int kind, Stmt body, + Optional thread_binding, Optional> annotations, Span span) { + return For(loop_var, min, extent, static_cast(kind), body, thread_binding, + annotations.value_or(Map()), span); + }); TVM_REGISTER_NODE_TYPE(ForNode); -std::ostream& operator<<(std::ostream& out, ForType type) { // NOLINT(*) +std::ostream& operator<<(std::ostream& out, ForKind type) { // NOLINT(*) switch (type) { - case ForType::Serial: + case ForKind::kSerial: out << "for"; break; - case ForType::Parallel: + case ForKind::kParallel: out << "parallel"; break; - case ForType::Unrolled: + case ForKind::kUnrolled: out << "unrolled"; break; - case ForType::Vectorized: + case ForKind::kVectorized: out << "vectorized"; break; + case ForKind::kThreadBinding: + out << "launch_thread"; + break; } return out; } @@ -179,7 +183,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { auto* op = static_cast(node.get()); p->PrintIndent(); - p->stream << op->for_type << " (" << op->loop_var << ", "; + p->stream << op->kind << " (" << op->loop_var << ", "; p->Print(op->min); p->stream << ", "; p->Print(op->extent); diff --git a/src/tir/transforms/combine_context_call.cc b/src/tir/transforms/combine_context_call.cc index 03a0d5e751cf..4a3986460b15 100644 --- a/src/tir/transforms/combine_context_call.cc +++ b/src/tir/transforms/combine_context_call.cc @@ -72,7 +72,7 @@ class ContextCallCombiner final : public StmtExprMutator { } Stmt VisitStmt_(const ForNode* op) final { - if (op->for_type == ForType::Parallel) { + if (op->kind == ForKind::kParallel) { // Map of comparison expression to variable std::unordered_map temp; std::swap(temp, ctx_map_); diff --git a/src/tir/transforms/inject_double_buffer.cc b/src/tir/transforms/inject_double_buffer.cc index 22a6ca23c24c..7a16c06d8058 100644 --- a/src/tir/transforms/inject_double_buffer.cc +++ b/src/tir/transforms/inject_double_buffer.cc @@ -158,8 +158,7 @@ class DoubleBufferInjector : public StmtExprMutator { vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i); loop_seq.emplace_back(Substitute(old_loop->body, vmap)); } - Stmt loop = For(outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api, - SeqStmt::Flatten(loop_seq)); + Stmt loop = For(outer_var, zero, outer_ext, old_loop->kind, SeqStmt::Flatten(loop_seq)); // tail std::vector tail_seq; Stmt tail_body = StripDoubleBufferWrite()(old_loop->body); diff --git a/src/tir/transforms/inject_prefetch.cc b/src/tir/transforms/inject_prefetch.cc index b5c4cf5ec582..4ce9c7639b77 100644 --- a/src/tir/transforms/inject_prefetch.cc +++ b/src/tir/transforms/inject_prefetch.cc @@ -71,11 +71,11 @@ class PrefetchInjector : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { auto& var = op->loop_var; loop_nest_.push_back(var); - if (op->for_type == ForType::Vectorized) { + if (op->kind == ForKind::kVectorized) { vectorized_[var.get()] = IntSet::Interval(op->min, (op->min + op->extent) - 1); } Stmt ret = StmtMutator::VisitStmt_(op); - if (op->for_type == ForType::Vectorized) { + if (op->kind == ForKind::kVectorized) { vectorized_.erase(var.get()); } loop_nest_.pop_back(); diff --git a/src/tir/transforms/inject_virtual_thread.cc b/src/tir/transforms/inject_virtual_thread.cc index 5622d140a625..b24a0e95cd53 100644 --- a/src/tir/transforms/inject_virtual_thread.cc +++ b/src/tir/transforms/inject_virtual_thread.cc @@ -303,7 +303,10 @@ class VTInjector : public StmtExprMutator { if (extent.same_as(op->extent) && body.same_as(op->body)) { return GetRef(op); } else { - return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body); + auto n = CopyOnWrite(op); + n->extent = std::move(extent); + n->body = std::move(body); + return Stmt(n); } } // IfThenElse @@ -417,7 +420,7 @@ class VTInjector : public StmtExprMutator { Map values{{var_, idx}}; stmt = Substitute(stmt, values); return For(idx, make_zero(idx.dtype()), make_const(idx.dtype(), num_threads_), - ForType::Serial, DeviceAPI::None, stmt); + ForKind::kSerial, stmt); } } diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 033a2e093a2a..cbae3f95ec68 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -149,7 +149,8 @@ class IRConvertSSA final : public StmtExprMutator { Stmt stmt = StmtExprMutator::VisitStmt_(op); scope_[v.get()].pop_back(); op = stmt.as(); - return For(new_var, op->min, op->extent, op->for_type, op->device_api, op->body); + return For(new_var, op->min, op->extent, op->kind, op->body, op->thread_binding, + op->annotations); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/transforms/loop_partition.cc b/src/tir/transforms/loop_partition.cc index a104dbb029eb..f1d816f0baef 100644 --- a/src/tir/transforms/loop_partition.cc +++ b/src/tir/transforms/loop_partition.cc @@ -607,8 +607,8 @@ inline Stmt LoopPartitioner::MakeFor(const Object* node, PrimExpr extent, Stmt b // If the loop extent is 1, do not create the loop anymore return Substitute(body, {{Var{for_node->loop_var}, make_const(DataType::Int(32), 0)}}); } else { - return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->for_type, - for_node->device_api, body); + ICHECK(for_node->kind != ForKind::kThreadBinding); + return For(for_node->loop_var, IntImm(for_node->min.dtype(), 0), extent, for_node->kind, body); } } diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index adbe78a6d627..0946af6f640a 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -46,9 +46,9 @@ class ReturnRewriter : public StmtMutator { explicit ReturnRewriter(Var ret_var, Var ret_tcode) : ret_var_(ret_var), ret_tcode_(ret_tcode) {} Stmt VisitStmt_(const ForNode* node) override { - if (node->for_type == ForType::Parallel) in_parallel_ += 1; + if (node->kind == ForKind::kParallel) in_parallel_ += 1; Stmt ret = StmtMutator::VisitStmt_(node); - if (node->for_type == ForType::Parallel) in_parallel_ -= 1; + if (node->kind == ForKind::kParallel) in_parallel_ -= 1; return ret; } diff --git a/src/tir/transforms/narrow_datatype.cc b/src/tir/transforms/narrow_datatype.cc index 0b248959ec6e..dc34626205a1 100644 --- a/src/tir/transforms/narrow_datatype.cc +++ b/src/tir/transforms/narrow_datatype.cc @@ -220,8 +220,8 @@ class DataTypeRewriter : public StmtExprMutator { << ", but get " << s->GetTypeKey(); PrimExpr e = VisitExpr(op->loop_var); Var var = Downcast(e); - return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->for_type, - op->device_api, op->body); + return For(var, cast(var.dtype(), op->min), cast(var.dtype(), op->extent), op->kind, op->body, + op->thread_binding, op->annotations); } Stmt VisitStmt_(const AttrStmtNode* op) final { diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index d392866b3694..43fc1f1ec53f 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -318,14 +318,14 @@ class StorageFlattener : public StmtExprMutator { } for (int i = starts; i >= 0; --i) { if (i < starts) { - stmt = For(vars[i], 0, op->bounds[i]->extent, ForType::Serial, DeviceAPI::None, stmt); + stmt = For(vars[i], 0, op->bounds[i]->extent, ForKind::kSerial, stmt); } else { PrimExpr load = e.buffer.vload(e.RelIndex(args), e.buffer->dtype); PrimExpr address = Call(DataType::Handle(), builtin::address_of(), {load}); PrimExpr prefetch = Call(op->buffer->dtype, builtin::prefetch(), {address, 0, 3, 1}); stmt = Evaluate(prefetch); PrimExpr extent = (op->bounds[i]->extent - 1) / stride + 1; - stmt = For(vars[i], 0, extent, ForType::Serial, DeviceAPI::None, stmt); + stmt = For(vars[i], 0, extent, ForKind::kSerial, stmt); } } return stmt; diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index d4c5ca09650b..0b1429ca7efa 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -438,14 +438,14 @@ class StoragePlanRewriter : public StmtExprMutator { } } Stmt VisitStmt_(const ForNode* op) final { - ICHECK(op->for_type != ForType::Vectorized) << "VectorizeLoop before LiftStorageAlloc"; + ICHECK(op->kind != ForKind::kVectorized) << "VectorizeLoop before LiftStorageAlloc"; // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - return For(op->loop_var, op->min, op->extent, op->for_type, op->device_api, - MakeAttach(svec, op->body)); + return For(op->loop_var, op->min, op->extent, op->kind, MakeAttach(svec, op->body), + op->thread_binding, op->annotations); } else { return StmtExprMutator::VisitStmt_(op); } @@ -765,7 +765,7 @@ class StoragePlanRewriter : public StmtExprMutator { } } else if (s.stmt->IsInstance()) { const auto* op = static_cast(s.stmt); - if (op->for_type == ForType::Parallel) { + if (op->kind == ForKind::kParallel) { if (thread_scope_ == nullptr || thread_scope_ == op) { PlanNewScope(op); } diff --git a/src/tir/transforms/unroll_loop.cc b/src/tir/transforms/unroll_loop.cc index 71ad899273a6..c6e0b5c5f41e 100644 --- a/src/tir/transforms/unroll_loop.cc +++ b/src/tir/transforms/unroll_loop.cc @@ -100,13 +100,13 @@ class LoopUnroller : public StmtExprMutator { op = stmt.as(); int value = GetExtent(op); // condition for auto unroll - bool auto_unroll = (op->for_type == ForType::Serial && value >= 0 && normal_loop_depth_ == 0 && + bool auto_unroll = (op->kind == ForKind::kSerial && value >= 0 && normal_loop_depth_ == 0 && unroll_depth_ <= auto_max_depth_); auto_unroll = auto_unroll && (value * step_count_ <= auto_max_step_ || value <= auto_max_extent_); - if (op->for_type == ForType::Unrolled) { + if (op->kind == ForKind::kUnrolled) { ICHECK_GE(value, 0) << "Cannot unroll non-constant loop"; auto_unroll = true; } @@ -124,9 +124,9 @@ class LoopUnroller : public StmtExprMutator { return Unroll(op); } else { if (auto_unroll) { - if (op->for_type != ForType::Unrolled) { - return For(op->loop_var, op->min, op->extent, ForType::Unrolled, op->device_api, - op->body); + if (op->kind != ForKind::kUnrolled) { + return For(op->loop_var, op->min, op->extent, ForKind::kUnrolled, op->body, + op->thread_binding, op->annotations); } } return stmt; diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 239f42266b83..66f4ae329f69 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -352,7 +352,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorfor_type == ForType::Vectorized) { + if (op->kind == ForKind::kVectorized) { LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring..."; } ICHECK(is_zero(op->min)); @@ -365,7 +365,8 @@ class Vectorizer : public StmtMutator, public ExprFunctorextent) && body.same_as(op->body)) { return GetRef(op); } else { - return For(op->loop_var, op->min, extent, op->for_type, op->device_api, body); + return For(op->loop_var, op->min, extent, op->kind, body, op->thread_binding, + op->annotations); } } // IfThenElse @@ -436,7 +437,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorname_hint + ".s", var_->dtype); Map values{{var_, idx}}; stmt = Substitute(stmt, values); - return For(idx, 0, var_lanes_, ForType::Serial, DeviceAPI::None, stmt); + return For(idx, 0, var_lanes_, ForKind::kSerial, stmt); } // ProducerStore Stmt VisitStmt_(const ProducerStoreNode* op) final { @@ -525,7 +526,7 @@ class Vectorizer : public StmtMutator, public ExprFunctorfor_type == ForType::Vectorized) { + if (op->kind == ForKind::kVectorized) { ICHECK(is_zero(op->min)); auto* extent_as_int = op->extent.as(); if (!extent_as_int || extent_as_int->value < 1) { @@ -545,8 +546,8 @@ class VectorizeSkipper : public StmtMutator { Stmt VisitStmt_(const ForNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); - if (op->for_type == ForType::Vectorized) { - return For(op->loop_var, op->min, op->extent, ForType::Serial, op->device_api, op->body); + if (op->kind == ForKind::kVectorized) { + return For(op->loop_var, op->min, op->extent, ForKind::kSerial, op->body); } else { return stmt; } diff --git a/tests/python/unittest/test_arith_domain_touched.py b/tests/python/unittest/test_arith_domain_touched.py index ca5df4af6a71..af06a038e1f7 100644 --- a/tests/python/unittest/test_arith_domain_touched.py +++ b/tests/python/unittest/test_arith_domain_touched.py @@ -31,14 +31,12 @@ def test_domain_touched(): i, 0, n, - 0, - 0, + tvm.tir.ForKind.SERIAL, tvm.tir.For( j, 0, m, - 0, - 0, + tvm.tir.ForKind.SERIAL, tvm.tir.BufferStore( a, tvm.tir.BufferLoad(b, [i - 1, j + 1]) + tvm.tir.BufferLoad(a, [i - 1, j - 1]), diff --git a/tests/python/unittest/test_runtime_module_based_interface.py b/tests/python/unittest/test_runtime_module_based_interface.py index 64f87fb3c561..51a587242ae3 100644 --- a/tests/python/unittest/test_runtime_module_based_interface.py +++ b/tests/python/unittest/test_runtime_module_based_interface.py @@ -547,8 +547,7 @@ def make_func(symbol): i, 0, n - 1, - 0, - 0, + tvm.tir.ForKind.SERIAL, tvm.tir.Store(Ab.data, tvm.tir.Load("float32", Ab.data, i) + 1, i + 1), ) return tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", symbol) diff --git a/tests/python/unittest/test_runtime_module_load.py b/tests/python/unittest/test_runtime_module_load.py index 7befed3bbcdd..38800e8de6ad 100644 --- a/tests/python/unittest/test_runtime_module_load.py +++ b/tests/python/unittest/test_runtime_module_load.py @@ -55,7 +55,11 @@ def save_object(names): i = te.var("i") # for i in 0 to n-1: stmt = tvm.tir.For( - i, 0, n - 1, 0, 0, tvm.tir.Store(Ab.data, tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1) + i, + 0, + n - 1, + tvm.tir.ForKind.SERIAL, + tvm.tir.Store(Ab.data, tvm.tir.Load(dtype, Ab.data, i) + 1, i + 1), ) mod = tvm.IRModule.from_expr( tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "main") diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py index e87767475ab2..a22fe10c1321 100644 --- a/tests/python/unittest/test_target_codegen_cuda.py +++ b/tests/python/unittest/test_target_codegen_cuda.py @@ -200,7 +200,7 @@ def test_cuda_shuffle(): def MyVectorize(): def vectorizer(op): - if op.for_type == tvm.tir.For.Vectorized: + if op.kind == tvm.tir.ForKind.VECTORIZED: four = tvm.tir.const(4, "int32") idx = tvm.tir.Ramp(thrx.var * four, tvm.tir.const(1, "int32"), 4) all_ones = tvm.tir.const(1, "int32x4") diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py index 4b67752367db..67c1f6bff429 100644 --- a/tests/python/unittest/test_target_codegen_llvm.py +++ b/tests/python/unittest/test_target_codegen_llvm.py @@ -761,7 +761,7 @@ def do_atomic_add(A): atomic_add_return = ib.allocate(A.dtype, (1,), name="atomic_add_return", scope="local") one = tvm.tir.const(1, A.dtype) A_ptr = ib.buffer_ptr(A) - with ib.for_range(0, n, name="i", for_type="parallel") as i: + with ib.for_range(0, n, name="i", kind="parallel") as i: atomic_add_return[0] = atomic_add( tvm.tir.call_intrin("handle", "tir.address_of", A_ptr[0]), one ) diff --git a/tests/python/unittest/test_target_codegen_static_init.py b/tests/python/unittest/test_target_codegen_static_init.py index 179e302984cc..b0c19dfcffeb 100644 --- a/tests/python/unittest/test_target_codegen_static_init.py +++ b/tests/python/unittest/test_target_codegen_static_init.py @@ -30,7 +30,7 @@ def test_static_callback(): cp = te.thread_axis((0, 1), "cop") finit = tvm.tir.StringImm("TVMBackendRunOnce") ib.scope_attr(cp, "coproc_uop_scope", finit) - with ib.for_range(0, n, "i", for_type="parallel") as i: + with ib.for_range(0, n, "i", kind="parallel") as i: A[i] = A[i] + 1 stmt = ib.get() diff --git a/tests/python/unittest/test_target_codegen_vm_basic.py b/tests/python/unittest/test_target_codegen_vm_basic.py index 26f1493c4ec1..9bbee76e2736 100644 --- a/tests/python/unittest/test_target_codegen_vm_basic.py +++ b/tests/python/unittest/test_target_codegen_vm_basic.py @@ -109,7 +109,7 @@ def test_vm_parallel(): i = te.size_var("i") ib = tvm.tir.ir_builder.create() A = ib.buffer_ptr(Ab) - with ib.for_range(0, n, "i", for_type="parallel") as i: + with ib.for_range(0, n, "i", kind="parallel") as i: A[i] = A[i] + 1 stmt = ib.get() mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt).with_attr("global_symbol", "test")) diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index 06d409933f1f..be9956529dcc 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -267,9 +267,9 @@ def looptype(a, b, c): iloop = ir[0] jloop = ir[1] kloop = ir[2] - assert iloop.for_type == tvm.tir.For.Parallel - assert jloop.for_type == tvm.tir.For.Vectorized - assert kloop.for_type == tvm.tir.For.Unrolled + assert iloop.kind == tvm.tir.ForKind.PARALLEL + assert jloop.kind == tvm.tir.ForKind.VECTORIZED + assert kloop.kind == tvm.tir.ForKind.UNROLLED func, ins, outs = run_and_check(looptype, [a, b, c]) run_and_check(func, ins, outs=outs) diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 2bf4ba51937e..2cc21dbce91d 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -142,7 +142,7 @@ def test_stmt_constructor(): assert isinstance(x, tvm.tir.AssertStmt) assert x.body == nop - x = tvm.tir.For(te.var("x"), 0, 10, 0, 0, nop) + x = tvm.tir.For(te.var("x"), 0, 10, tvm.tir.ForKind.SERIAL, nop) assert isinstance(x, tvm.tir.For) assert x.min.value == 0 assert x.extent.value == 10 diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index 4d57ed8ec366..bff60f70f53b 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -129,7 +129,7 @@ def test_basic(): def test_stmt(): x = tvm.tir.Evaluate(0) - tvm.tir.For(te.var("i"), 0, 1, tvm.tir.For.Serial, 0, x) + tvm.tir.For(te.var("i"), 0, 1, tvm.tir.ForKind.SERIAL, x) def test_dir(): diff --git a/tests/python/unittest/test_tir_transform_remove_no_op.py b/tests/python/unittest/test_tir_transform_remove_no_op.py index 2edb8cf980c2..8b7a16952af9 100644 --- a/tests/python/unittest/test_tir_transform_remove_no_op.py +++ b/tests/python/unittest/test_tir_transform_remove_no_op.py @@ -34,20 +34,17 @@ def test_remove_no_op(): i, 0, 4, - 0, - 0, + tvm.tir.ForKind.SERIAL, tvm.tir.For( j, 0, n, - 0, - 0, + tvm.tir.ForKind.SERIAL, tvm.tir.For( k, 0, m, - 0, - 0, + tvm.tir.ForKind.SERIAL, tvm.tir.IfThenElse((i * m + j + k < n), tvm.tir.Evaluate(m), tvm.tir.Evaluate(n)), ), ), @@ -65,7 +62,7 @@ def test_remove_no_op(): assert ret == store # remove zero extent loop - stmt3 = tvm.tir.For(i, 0, 0, 0, 0, store) + stmt3 = tvm.tir.For(i, 0, 0, tvm.tir.ForKind.SERIAL, store) mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([Ab], stmt3)) ret = tvm.tir.transform.RemoveNoOp()(mod)["main"].body assert isinstance(ret, tvm.tir.Evaluate) diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index cc2b4273a5e3..49adcfb568a7 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -269,7 +269,7 @@ def verify(n): def test_parallel_alloc(): ib = tvm.tir.ir_builder.create() n = te.var("n") - with ib.for_range(0, n, name="i", for_type="parallel") as i: + with ib.for_range(0, n, name="i", kind="parallel") as i: with ib.for_range(0, 10, name="j") as j: A = ib.allocate("float32", n, name="A", scope="global") A[j] = A[j] + 2 @@ -286,7 +286,7 @@ def test_parallel_alloc(): ib.scope_attr( tvm.tir.const(1, "int32"), "pragma_scope", tvm.tir.StringImm("parallel_launch_point") ) - with ib.for_range(0, n, name="i", for_type="parallel") as i: + with ib.for_range(0, n, name="i", kind="parallel") as i: with ib.for_range(0, 10, name="j") as j: A = ib.allocate("float32", n, name="A", scope="global") A[j] = A[j] + 2 diff --git a/tests/python/unittest/test_tir_transform_unroll_loop.py b/tests/python/unittest/test_tir_transform_unroll_loop.py index 57b7810198c0..b511118f8b52 100644 --- a/tests/python/unittest/test_tir_transform_unroll_loop.py +++ b/tests/python/unittest/test_tir_transform_unroll_loop.py @@ -27,7 +27,7 @@ def test_unroll_loop(): Aptr = ib.buffer_ptr(Ab) # for i in 0 to n-1: with ib.for_range(n, n + 2, name="i") as i: - with ib.for_range(0, 8, name="i", for_type="unroll") as j: + with ib.for_range(0, 8, name="i", kind="unroll") as j: Aptr[j + 1] = Aptr[i] + 1 stmt = ib.get() @@ -48,7 +48,7 @@ def test_unroll_loop(): ): ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body assert isinstance(ret, tvm.tir.For) - assert ret.for_type == tvm.tir.For.Unrolled + assert ret.kind == tvm.tir.ForKind.UNROLLED ib = tvm.tir.ir_builder.create() ib.scope_attr(tvm.tir.const(0, "int32"), "pragma_auto_unroll_max_step", 16) @@ -63,9 +63,9 @@ def test_unroll_loop(): ): ret = tvm.tir.transform.UnrollLoop()(mod)["main"].body assert isinstance(ret[0], tvm.tir.For) - assert ret[0].for_type == tvm.tir.For.Unrolled + assert ret[0].kind == tvm.tir.ForKind.UNROLLED assert isinstance(ret[1], tvm.tir.For) - assert ret[1].for_type != tvm.tir.For.Unrolled + assert ret[1].kind != tvm.tir.ForKind.UNROLLED def test_unroll_fake_loop(): diff --git a/tests/python/unittest/test_tir_transform_vectorize.py b/tests/python/unittest/test_tir_transform_vectorize.py index 204e26feb6a9..5ae47e01f681 100644 --- a/tests/python/unittest/test_tir_transform_vectorize.py +++ b/tests/python/unittest/test_tir_transform_vectorize.py @@ -24,7 +24,7 @@ def test_vectorize_loop(): ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") with ib.for_range(0, n) as i: - with ib.for_range(0, 4, for_type="vectorize") as j: + with ib.for_range(0, 4, kind="vectorize") as j: A[j] = tvm.tir.const(1, A.dtype) stmt = ib.get() @@ -45,7 +45,7 @@ def test_vectorize_vector(): ib = tvm.tir.ir_builder.create() A = ib.pointer("float32x4", name="A") with ib.for_range(0, n) as i: - with ib.for_range(0, 4, for_type="vectorize") as j: + with ib.for_range(0, 4, kind="vectorize") as j: A[j] = tvm.tir.const(1, A.dtype) stmt = ib.get() assert isinstance(stmt.body, tvm.tir.For) @@ -64,7 +64,7 @@ def test_vectorize_with_if(): x = te.var("x") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, for_type="vectorize") as i: + with ib.for_range(0, 4, kind="vectorize") as i: with ib.if_scope(x < n): A[i] = A[i] + 1 with ib.else_scope(): @@ -86,7 +86,7 @@ def test_vectorize_let(): v = tvm.tir.Var("v", "float32") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, for_type="vectorize") as i: + with ib.for_range(0, 4, kind="vectorize") as i: ib.emit(lambda body: tvm.tir.LetStmt(v, A[i] + 1, body)) A[i] = v + 2 @@ -100,7 +100,7 @@ def test_vectorize_with_le_cond(): n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, for_type="vectorize") as i: + with ib.for_range(0, 4, kind="vectorize") as i: with ib.if_scope(i <= n): A[i] = A[i] + 1 stmt = ib.get() @@ -115,7 +115,7 @@ def test_vectorize_with_ge_cond(): n = te.var("n") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, for_type="vectorize") as i: + with ib.for_range(0, 4, kind="vectorize") as i: with ib.if_scope(i >= n): A[i] = A[i] + 1 stmt = ib.get() @@ -131,7 +131,7 @@ def test_vectorize_if_then_else(): x = te.var("x") ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") - with ib.for_range(0, 4, for_type="vectorize") as i: + with ib.for_range(0, 4, kind="vectorize") as i: A[i] = tvm.tir.call_intrin("float32", "tir.if_then_else", i > 0, A[i] + 1, A[i]) stmt = ib.get() @@ -143,7 +143,7 @@ def test_vectorize_if_then_else(): ib = tvm.tir.ir_builder.create() A = ib.pointer("float32", name="A") with ib.for_range(0, n) as k: - with ib.for_range(0, 4, for_type="vectorize") as i: + with ib.for_range(0, 4, kind="vectorize") as i: A[k * 4 + i] = tvm.tir.call_intrin( "float32", "tir.if_then_else", k > 0, A[k * 4 + i], 0 ) diff --git a/tutorials/dev/low_level_custom_pass.py b/tutorials/dev/low_level_custom_pass.py index 44fe59f99201..0bd656dd81dd 100644 --- a/tutorials/dev/low_level_custom_pass.py +++ b/tutorials/dev/low_level_custom_pass.py @@ -116,8 +116,8 @@ def vectorize8(op): name = op.loop_var.name lo, li = te.var(name + ".outer"), te.var(name + ".inner") body = tvm.tir.stmt_functor.substitute(op.body, {op.loop_var: lo * 8 + li}) - body = tvm.tir.For(li, 0, 8, tvm.tir.For.Vectorized, 0, body) - body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.For.Serial, 0, body) + body = tvm.tir.For(li, 0, 8, tvm.tir.ForKind.VECTORIZED, body) + body = tvm.tir.For(lo, 0, extent // 8, tvm.tir.ForKind.SERIAL, body) return body return None diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index a485d2cfb7b8..9770857fb0b9 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -231,7 +231,13 @@ def _merge_block(slist, body): body = tvm.tir.AttrStmt(op.node, op.attr_key, op.value, body) elif isinstance(op, tvm.tir.For): body = tvm.tir.For( - op.loop_var, op.min, op.extent, op.for_type, op.device_api, body + op.loop_var, + op.min, + op.extent, + op.kind, + body, + op.thread_binding, + op.annotations, ) else: raise RuntimeError("unexpected op") @@ -314,7 +320,9 @@ def _do_fold(stmt): if _match_pragma(stmt, "trim_loop"): op = stmt.body assert isinstance(op, tvm.tir.For) - return tvm.tir.For(op.loop_var, op.min, 2, op.for_type, op.device_api, op.body) + return tvm.tir.For( + op.loop_var, op.min, 2, op.kind, op.body, op.thread_binding, op.annotations + ) return None return f.with_body(