Skip to content

Commit

Permalink
More style consistency refactor to make the ForNode
Browse files Browse the repository at this point in the history
to be consistent with rest of the codebase.

- ForType => ForKind
- Add constant prefix k to enum consts per Google C style
- Introduce ForKind to the python side.
  • Loading branch information
tqchen committed Jan 18, 2021
1 parent 838de68 commit 307386f
Show file tree
Hide file tree
Showing 56 changed files with 252 additions and 249 deletions.
42 changes: 18 additions & 24 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -753,38 +753,33 @@ class Evaluate : public Stmt {
};

/*!
* \brief The type of the loop.
* \brief The kind of the loop.
*
* ForType can change the control flow semantics
* of the loop. So the for_type field needs to be considered
* ForKind can change the control flow semantics
* of the loop. So the kind field needs to be considered
* in all TIR passes.
*/
enum class ForType : int {
enum class ForKind : int {
/*! \brief default semantics -- serial execution. */
Serial = 0,
kSerial = 0,
/*! \brief Parallel execution on CPU. */
Parallel = 1,
kParallel = 1,
/*!
* \brief Vector SIMD loop.
* The loop body will be vectorized.
*/
Vectorized = 2,
kVectorized = 2,
/*! \brief The loop body must be unrolled. */
Unrolled = 3,
kUnrolled = 3,
/*!
* \brief The loop variable is bound to a thread in
* an environment. In the final stage of lowering,
* the loop is simply removed and the loop variable is
* mapped to the corresponding context thread.
*/
ThreadBinding = 4
kThreadBinding = 4
};

// Kevice api of for loop
// kept for backward compatibility
// consider refactor and remove later.
enum class DeviceAPI : int { None = 0 };

/*!
* \brief A for loop, with poissible type annotations.
*
Expand All @@ -803,12 +798,12 @@ class ForNode : public StmtNode {
PrimExpr min;
/*! \brief The extent of the iteration. */
PrimExpr extent;
/*! \brief The type of the for loop. */
ForType for_type;
/*! \brief The kind of the for loop. */
ForKind kind;
/*! \brief The body of the for loop. */
Stmt body;
/*!
* \brief Only valid when for_type == ForType::ThreadBinding
* \brief Only valid when kind == ForKind::kThreadBinding
* The context thread that this loop variable bounds to.
*/
Optional<IterVar> thread_binding;
Expand All @@ -826,7 +821,7 @@ class ForNode : public StmtNode {
v->Visit("loop_var", &loop_var);
v->Visit("min", &min);
v->Visit("extent", &extent);
v->Visit("for_type", &for_type);
v->Visit("kind", &kind);
v->Visit("body", &body);
v->Visit("thread_binding", &thread_binding);
v->Visit("annotations", &annotations);
Expand All @@ -835,16 +830,15 @@ class ForNode : public StmtNode {

bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) &&
equal(extent, other->extent) && equal(for_type, other->for_type) &&
equal(body, other->body) && equal(thread_binding, other->thread_binding) &&
equal(annotations, other->annotations);
equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) &&
equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(loop_var);
hash_reduce(min);
hash_reduce(extent);
hash_reduce(for_type);
hash_reduce(kind);
hash_reduce(body);
hash_reduce(thread_binding);
hash_reduce(annotations);
Expand All @@ -860,7 +854,7 @@ class ForNode : public StmtNode {
*/
class For : public Stmt {
public:
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, Stmt body,
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
Optional<IterVar> thread_binding = NullOpt,
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());

Expand Down Expand Up @@ -1044,7 +1038,7 @@ inline bool IsPragmaKey(const std::string& attr_key) {
TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());

// overload printing of for type.
TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type);
TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);

} // namespace tir
} // namespace tvm
Expand Down
20 changes: 10 additions & 10 deletions python/tvm/te/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@
from tvm.target import Target
from tvm.tir import expr as _expr
from tvm.tir import call_intrin
from tvm.tir.stmt import For
from tvm.tir.stmt import ForKind

from .utils import _internal_assert

# pylint: disable=redefined-builtin,invalid-name

LOOP_INTRIN = {
"range": For.Serial,
"unroll": For.Unrolled,
"parallel": For.Parallel,
"vectorize": For.Vectorized,
"const_range": (For.Unrolled,),
"range": ForKind.SERIAL,
"unroll": ForKind.UNROLLED,
"parallel": ForKind.PARALLEL,
"vectorize": ForKind.VECTORIZED,
"const_range": (ForKind.UNROLLED,),
}


Expand All @@ -48,9 +48,9 @@ def _range(annotation, args):
low, ext = args[0], args[1]
if not tvm.tir.analysis.expr_deep_equal(low, const(0, dtype="int32")):
ext = ext - low
for_type = LOOP_INTRIN[annotation]
kind = LOOP_INTRIN[annotation]
iter_var = None
return iter_var, low, ext, for_type
return iter_var, low, ext, kind


range = unroll = vectorize = parallel = const_range = _range # pylint: disable=invalid-name
Expand All @@ -63,8 +63,8 @@ def bind(func_id, args):
_internal_assert(isinstance(args[0], str), "A loop bind's first argument should be a string!")
low, ext = const(0, "int32"), args[1]
iter_var = tvm.te.thread_axis((low, ext), args[0])
for_type = None
return iter_var, low, ext, for_type
kind = None
return iter_var, low, ext, kind


def _math_intrin(func_id, args):
Expand Down
14 changes: 7 additions & 7 deletions python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,14 +480,14 @@ def visit_Call(self, node):
return op

def visit_For(self, node):
iter_var, low, ext, for_type = self.visit(node.iter)
iter_var, low, ext, kind = self.visit(node.iter)
_internal_assert(
isinstance(node.target, ast.Name), "The loop iterator should be a variable!"
)

_name = node.target.id

if isinstance(for_type, tuple):
if isinstance(kind, tuple):
low = self.analyzer.simplify(low)
ext = self.analyzer.simplify(ext)
_internal_assert(
Expand All @@ -511,28 +511,28 @@ def visit_For(self, node):
return concat_list_to_block(bodies)

if iter_var is None:
_internal_assert(for_type is not None, "The loop iterating function parse error!")
_internal_assert(kind is not None, "The loop iterating function parse error!")
offset = iter_var = tvm.te.var(_name)
if not tvm.tir.analysis.expr_deep_equal(low, tvm.runtime.const(0, "int32")):
offset = iter_var + low
self.add_symbol(_name, Symbol.LoopVar, offset)
_body = visit_list_to_block(self.visit, node.body)
else:
_internal_assert(for_type is None, "The loop bind function parse error!")
_internal_assert(kind is None, "The loop bind function parse error!")
self.add_symbol(_name, Symbol.ThreadBind, iter_var)
self.device += 1
_body = visit_list_to_block(self.visit, node.body)
self.device -= 1

_body = self.wrap_up_realize(node, _body)

if for_type is None:
if kind is None:
res = _body
else:
_internal_assert(
not isinstance(for_type, tuple), "Micro expansion should be handled before!"
not isinstance(kind, tuple), "Micro expansion should be handled before!"
)
res = tvm.tir.For(iter_var, tvm.runtime.const(0, "int32"), ext, for_type, _body)
res = tvm.tir.For(iter_var, tvm.runtime.const(0, "int32"), ext, kind, _body)

self.symbols.pop(_name)
return res
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
from .expr import Call, CallEffectKind, Let, IterVar, Any

from .stmt import Stmt, LetStmt, AssertStmt, For
from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
from .stmt import ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
Expand Down
24 changes: 12 additions & 12 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def scope_attr(self, node, attr_key, value):
value = op.max(1, value)
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))

def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
def for_range(self, begin, end, name="i", dtype="int32", kind="serial"):
"""Create a for iteration scope.
Parameters
Expand All @@ -224,7 +224,7 @@ def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
dtype : str, optional
The data type of iteration variable.
for_type : str, optional
kind : str, optional
The special tag on the for loop.
Returns
Expand All @@ -249,17 +249,17 @@ def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
extent = end if begin == 0 else (end - begin)

def _exit_cb():
if for_type == "serial":
for_type_id = 0
elif for_type == "parallel":
for_type_id = 1
elif for_type == "vectorize":
for_type_id = 2
elif for_type == "unroll":
for_type_id = 3
if kind == "serial":
kind_id = 0
elif kind == "parallel":
kind_id = 1
elif kind == "vectorize":
kind_id = 2
elif kind == "unroll":
kind_id = 3
else:
raise ValueError("Unknown for_type")
self.emit(_stmt.For(loop_var, begin, extent, for_type_id, self._pop_seq()))
raise ValueError("Unknown kind")
self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq()))

return WithScope(loop_var, _exit_cb)

Expand Down
33 changes: 22 additions & 11 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
assert isinstance(st, tvm.tir.stmt.Store)
assert(st.buffer_var == a)
"""
from enum import IntEnum
import tvm._ffi

from tvm.runtime import Object
Expand Down Expand Up @@ -82,6 +83,22 @@ def __init__(self, condition, message, body, span=None):
self.__init_handle_by_constructor__(_ffi_api.AssertStmt, condition, message, body, span)


class ForKind(IntEnum):
"""The kind of the for loop.
note
----
ForKind can change the control flow semantics
of the loop and need to be considered in all TIR passes.
"""

SERIAL = 0
PARALLEL = 1
VECTORIZED = 2
UNROLLED = 3
THREAD_BINDING = 4


@tvm._ffi.register_object("tir.For")
class For(Stmt):
"""For node.
Expand All @@ -97,15 +114,15 @@ class For(Stmt):
extent : PrimExpr
The length of the loop.
for_type : int
The for type.
kind : ForKind
The type of the for.
body : Stmt
The body statement.
thread_binding: Optional[tir.IterVar]
The thread this loop binds to. Only valid
if for_type is ThreadBinding
if kind is ThreadBinding
annotations: tvm.ir.Map
Additional annotation hints.
Expand All @@ -114,18 +131,12 @@ class For(Stmt):
The location of this itervar in the source code.
"""

Serial = 0
Parallel = 1
Vectorized = 2
Unrolled = 3
ThreadBiding = 4

def __init__(
self,
loop_var,
min_val,
extent,
for_type,
kind,
body,
thread_binding=None,
annotations=None,
Expand All @@ -136,7 +147,7 @@ def __init__(
loop_var,
min_val,
extent,
for_type,
kind,
body,
thread_binding,
annotations,
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/topi/cuda/nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
j = bx * max_threads + tx
with ib.if_scope(j < nkeep):
src_idx = base_src_idx + sorted_index[i * num_anchors + j] * box_data_length
with ib.for_range(0, 4, for_type="unroll") as k:
with ib.for_range(0, 4, kind="unroll") as k:
out_bboxes[(base_bbox_idx + j * 4 + k)] = data[src_idx + coord_start + k]

out_scores[i * num_anchors + j] = data[src_idx + score_index]
Expand All @@ -593,7 +593,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
# Only needed for return_indices = False case
if return_indices is False:
with ib.if_scope(j < num_anchors):
with ib.for_range(0, 4, for_type="unroll") as k:
with ib.for_range(0, 4, kind="unroll") as k:
out_bboxes[(base_bbox_idx + j * 4 + k)] = -1.0

out_scores[i, j] = -1.0
Expand All @@ -609,7 +609,7 @@ def calculate_overlap(out_tensor, box_a_idx, box_b_idx):
with ib.if_scope(j < valid_count[i]):
src_offset = base_src_idx + j * box_data_length

with ib.for_range(0, 4, for_type="unroll") as k:
with ib.for_range(0, 4, kind="unroll") as k:
out_bboxes[base_bbox_idx + j * 4 + k] = data[src_offset + coord_start + k]
out_scores[i * num_anchors + j] = data[src_offset + score_index]

Expand Down Expand Up @@ -855,7 +855,7 @@ def ir(out_bboxes, out_scores, out_class_ids, out):
i = by

with ib.if_scope(tid < num_anchors):
with ib.for_range(0, 4, for_type="unroll") as j:
with ib.for_range(0, 4, kind="unroll") as j:
out[i, tid, coord_start + j] = out_bboxes[i, tid, j]
out[i, tid, score_index] = out_scores[i, tid]
if id_index >= 0:
Expand Down
Loading

0 comments on commit 307386f

Please sign in to comment.