Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR][REFACTOR] ForNode introduce thread binding and remove legacy field #7306

Merged
merged 4 commits into from
Jan 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 54 additions & 31 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -752,23 +752,34 @@ class Evaluate : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
};

/*! \brief Additional annotation of for loop. */
enum class ForType : int {
/*! \brief serial execution. */
Serial = 0,
/*! \brief parallel execution on CPU. */
Parallel = 1,
/*! \brief Vector SIMD loop annotaion. */
Vectorized = 2,
/*! \brief Unroll annotation. */
Unrolled = 3
/*!
* \brief The kind of the loop.
*
* ForKind can change the control flow semantics
* of the loop. So the kind field needs to be considered
* in all TIR passes.
*/
enum class ForKind : int {
/*! \brief default semantics -- serial execution. */
kSerial = 0,
/*! \brief Parallel execution on CPU. */
kParallel = 1,
/*!
* \brief Vector SIMD loop.
* The loop body will be vectorized.
*/
kVectorized = 2,
/*! \brief The loop body must be unrolled. */
kUnrolled = 3,
/*!
* \brief The loop variable is bound to a thread in
* an environment. In the final stage of lowering,
* the loop is simply removed and the loop variable is
* mapped to the corresponding context thread.
*/
kThreadBinding = 4
};

// Kevice api of for loop
// kept for backward compatibility
// consider refactor and remove later.
enum class DeviceAPI : int { None = 0 };

/*!
* \brief A for loop, with poissible type annotations.
*
Expand All @@ -787,39 +798,50 @@ class ForNode : public StmtNode {
PrimExpr min;
/*! \brief The extent of the iteration. */
PrimExpr extent;
/*! \brief The type of the for loop. */
ForType for_type;
/*!
* \brief Deprecated, reserved for backward compatibility.
* Consider refactor and remove later.
*/
DeviceAPI device_api;
/*! \brief The kind of the for loop. */
ForKind kind;
/*! \brief The body of the for loop. */
Stmt body;
/*!
* \brief Only valid when kind == ForKind::kThreadBinding
* The context thread that this loop variable bounds to.
*/
Optional<IterVar> thread_binding;
/*!
* \brief Additional annotations about the loop.
*
* These annotations can be used as auxiliary hint
* to future transformations. An annotation should
* not change the control flow semantics of the loop
* and can be ignored in most passes.
*/
Map<String, ObjectRef> annotations;

void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_var", &loop_var);
v->Visit("min", &min);
v->Visit("extent", &extent);
v->Visit("for_type", &for_type);
v->Visit("device_api", &device_api);
v->Visit("kind", &kind);
v->Visit("body", &body);
v->Visit("thread_binding", &thread_binding);
v->Visit("annotations", &annotations);
v->Visit("span", &span);
}

bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) &&
equal(extent, other->extent) && equal(for_type, other->for_type) &&
equal(device_api, other->device_api) && equal(body, other->body);
equal(extent, other->extent) && equal(kind, other->kind) && equal(body, other->body) &&
equal(thread_binding, other->thread_binding) && equal(annotations, other->annotations);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(loop_var);
hash_reduce(min);
hash_reduce(extent);
hash_reduce(for_type);
hash_reduce(device_api);
hash_reduce(kind);
hash_reduce(body);
hash_reduce(thread_binding);
hash_reduce(annotations);
}

static constexpr const char* _type_key = "tir.For";
Expand All @@ -832,8 +854,9 @@ class ForNode : public StmtNode {
*/
class For : public Stmt {
public:
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api,
Stmt body, Span span = Span());
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForKind kind, Stmt body,
Optional<IterVar> thread_binding = NullOpt,
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
};
Expand Down Expand Up @@ -1015,7 +1038,7 @@ inline bool IsPragmaKey(const std::string& attr_key) {
TVM_DLL PrimExpr TypeAnnotation(DataType dtype, Span span = Span());

// overload printing of for type.
TVM_DLL std::ostream& operator<<(std::ostream& os, ForType for_type);
TVM_DLL std::ostream& operator<<(std::ostream& os, ForKind kind);

} // namespace tir
} // namespace tvm
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/script/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def serial(begin, end, span):
self.context.report_error("Expect exact 1 loop var", span)
ana = tvm.arith.Analyzer()
extent = end if begin == 0 else ana.simplify(end - begin)
return tvm.tir.For(self.loop_vars[0], begin, extent, 0, 0, self.body, span=span)
return tvm.tir.For(self.loop_vars[0], begin, extent, 0, self.body, span=span)

super().__init__(serial)

Expand All @@ -241,7 +241,7 @@ def parallel(begin, end, span):
self.context.report_error("Expect exact 1 loop var")
ana = tvm.arith.Analyzer()
extent = end if begin == 0 else ana.simplify(end - begin)
return tvm.tir.For(self.loop_vars[0], begin, extent, 1, 0, self.body, span=span)
return tvm.tir.For(self.loop_vars[0], begin, extent, 1, self.body, span=span)

super().__init__(parallel)

Expand All @@ -256,7 +256,7 @@ def vectorized(begin, end, span):
self.context.report_error("Expect exact 1 loop var")
ana = tvm.arith.Analyzer()
extent = end if begin == 0 else ana.simplify(end - begin)
return tvm.tir.For(self.loop_vars[0], begin, extent, 2, 0, self.body, span=span)
return tvm.tir.For(self.loop_vars[0], begin, extent, 2, self.body, span=span)

super().__init__(vectorized)

Expand All @@ -271,6 +271,6 @@ def unroll(begin, end, span):
self.context.report_error("Expect exact 1 loop var")
ana = tvm.arith.Analyzer()
extent = end if begin == 0 else ana.simplify(end - begin)
return tvm.tir.For(self.loop_vars[0], begin, extent, 3, 0, self.body, span=span)
return tvm.tir.For(self.loop_vars[0], begin, extent, 3, self.body, span=span)

super().__init__(unroll)
20 changes: 10 additions & 10 deletions python/tvm/te/hybrid/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,18 @@
from tvm.target import Target
from tvm.tir import expr as _expr
from tvm.tir import call_intrin
from tvm.tir.stmt import For
from tvm.tir.stmt import ForKind

from .utils import _internal_assert

# pylint: disable=redefined-builtin,invalid-name

LOOP_INTRIN = {
"range": For.Serial,
"unroll": For.Unrolled,
"parallel": For.Parallel,
"vectorize": For.Vectorized,
"const_range": (For.Unrolled,),
"range": ForKind.SERIAL,
"unroll": ForKind.UNROLLED,
"parallel": ForKind.PARALLEL,
"vectorize": ForKind.VECTORIZED,
"const_range": (ForKind.UNROLLED,),
}


Expand All @@ -48,9 +48,9 @@ def _range(annotation, args):
low, ext = args[0], args[1]
if not tvm.tir.analysis.expr_deep_equal(low, const(0, dtype="int32")):
ext = ext - low
for_type = LOOP_INTRIN[annotation]
kind = LOOP_INTRIN[annotation]
iter_var = None
return iter_var, low, ext, for_type
return iter_var, low, ext, kind


range = unroll = vectorize = parallel = const_range = _range # pylint: disable=invalid-name
Expand All @@ -63,8 +63,8 @@ def bind(func_id, args):
_internal_assert(isinstance(args[0], str), "A loop bind's first argument should be a string!")
low, ext = const(0, "int32"), args[1]
iter_var = tvm.te.thread_axis((low, ext), args[0])
for_type = None
return iter_var, low, ext, for_type
kind = None
return iter_var, low, ext, kind


def _math_intrin(func_id, args):
Expand Down
14 changes: 7 additions & 7 deletions python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,14 +480,14 @@ def visit_Call(self, node):
return op

def visit_For(self, node):
iter_var, low, ext, for_type = self.visit(node.iter)
iter_var, low, ext, kind = self.visit(node.iter)
_internal_assert(
isinstance(node.target, ast.Name), "The loop iterator should be a variable!"
)

_name = node.target.id

if isinstance(for_type, tuple):
if isinstance(kind, tuple):
low = self.analyzer.simplify(low)
ext = self.analyzer.simplify(ext)
_internal_assert(
Expand All @@ -511,28 +511,28 @@ def visit_For(self, node):
return concat_list_to_block(bodies)

if iter_var is None:
_internal_assert(for_type is not None, "The loop iterating function parse error!")
_internal_assert(kind is not None, "The loop iterating function parse error!")
offset = iter_var = tvm.te.var(_name)
if not tvm.tir.analysis.expr_deep_equal(low, tvm.runtime.const(0, "int32")):
offset = iter_var + low
self.add_symbol(_name, Symbol.LoopVar, offset)
_body = visit_list_to_block(self.visit, node.body)
else:
_internal_assert(for_type is None, "The loop bind function parse error!")
_internal_assert(kind is None, "The loop bind function parse error!")
self.add_symbol(_name, Symbol.ThreadBind, iter_var)
self.device += 1
_body = visit_list_to_block(self.visit, node.body)
self.device -= 1

_body = self.wrap_up_realize(node, _body)

if for_type is None:
if kind is None:
res = _body
else:
_internal_assert(
not isinstance(for_type, tuple), "Micro expansion should be handled before!"
not isinstance(kind, tuple), "Micro expansion should be handled before!"
)
res = tvm.tir.For(iter_var, tvm.runtime.const(0, "int32"), ext, for_type, 0, _body)
res = tvm.tir.For(iter_var, tvm.runtime.const(0, "int32"), ext, kind, _body)

self.symbols.pop(_name)
return res
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .expr import Select, BufferLoad, ProducerLoad, Load, Ramp, Broadcast, Shuffle
from .expr import Call, CallEffectKind, Let, IterVar, Any

from .stmt import Stmt, LetStmt, AssertStmt, For
from .stmt import Stmt, LetStmt, AssertStmt, ForKind, For
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
from .stmt import ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
Expand Down
24 changes: 12 additions & 12 deletions python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def scope_attr(self, node, attr_key, value):
value = op.max(1, value)
self.emit(lambda x: _stmt.AttrStmt(node, attr_key, value, x))

def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
def for_range(self, begin, end, name="i", dtype="int32", kind="serial"):
"""Create a for iteration scope.

Parameters
Expand All @@ -224,7 +224,7 @@ def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
dtype : str, optional
The data type of iteration variable.

for_type : str, optional
kind : str, optional
The special tag on the for loop.

Returns
Expand All @@ -249,17 +249,17 @@ def for_range(self, begin, end, name="i", dtype="int32", for_type="serial"):
extent = end if begin == 0 else (end - begin)

def _exit_cb():
if for_type == "serial":
for_type_id = 0
elif for_type == "parallel":
for_type_id = 1
elif for_type == "vectorize":
for_type_id = 2
elif for_type == "unroll":
for_type_id = 3
if kind == "serial":
kind_id = _stmt.ForKind.SERIAL
elif kind == "parallel":
kind_id = _stmt.ForKind.PARALLEL
elif kind == "vectorize":
kind_id = _stmt.ForKind.VECTORIZED
elif kind == "unroll":
kind_id = _stmt.ForKind.UNROLLED
else:
raise ValueError("Unknown for_type")
self.emit(_stmt.For(loop_var, begin, extent, for_type_id, 0, self._pop_seq()))
raise ValueError("Unknown kind")
self.emit(_stmt.For(loop_var, begin, extent, kind_id, self._pop_seq()))

return WithScope(loop_var, _exit_cb)

Expand Down
Loading