Skip to content

Commit

Permalink
[TIR][REFACTOR] ForNode update
Browse files Browse the repository at this point in the history
- Remove deprecated device_api.
- Add ThreadBinding for_type.
- Add additional annotations.
  • Loading branch information
tqchen committed Jan 18, 2021
1 parent 052ad3d commit 838de68
Show file tree
Hide file tree
Showing 32 changed files with 167 additions and 85 deletions.
61 changes: 45 additions & 16 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -752,16 +752,32 @@ class Evaluate : public Stmt {
TVM_DEFINE_OBJECT_REF_METHODS(Evaluate, Stmt, EvaluateNode);
};

/*! \brief Additional annotation of for loop. */
/*!
* \brief The type of the loop.
*
* ForType can change the control flow semantics
* of the loop. So the for_type field needs to be considered
* in all TIR passes.
*/
enum class ForType : int {
/*! \brief serial execution. */
/*! \brief default semantics -- serial execution. */
Serial = 0,
/*! \brief parallel execution on CPU. */
/*! \brief Parallel execution on CPU. */
Parallel = 1,
/*! \brief Vector SIMD loop annotaion. */
/*!
* \brief Vector SIMD loop.
* The loop body will be vectorized.
*/
Vectorized = 2,
/*! \brief Unroll annotation. */
Unrolled = 3
/*! \brief The loop body must be unrolled. */
Unrolled = 3,
/*!
* \brief The loop variable is bound to a thread in
* an environment. In the final stage of lowering,
* the loop is simply removed and the loop variable is
* mapped to the corresponding context thread.
*/
ThreadBinding = 4
};

// Kevice api of for loop
Expand Down Expand Up @@ -789,37 +805,49 @@ class ForNode : public StmtNode {
PrimExpr extent;
/*! \brief The type of the for loop. */
ForType for_type;
/*!
* \brief Deprecated, reserved for backward compatibility.
* Consider refactor and remove later.
*/
DeviceAPI device_api;
/*! \brief The body of the for loop. */
Stmt body;
/*!
* \brief Only valid when for_type == ForType::ThreadBinding
* The context thread that this loop variable bounds to.
*/
Optional<IterVar> thread_binding;
/*!
* \brief Additional annotations about the loop.
*
* These annotations can be used as auxiliary hint
* to future transformations. An annotation should
* not change the control flow semantics of the loop
* and can be ignored in most passes.
*/
Map<String, ObjectRef> annotations;

void VisitAttrs(AttrVisitor* v) {
v->Visit("loop_var", &loop_var);
v->Visit("min", &min);
v->Visit("extent", &extent);
v->Visit("for_type", &for_type);
v->Visit("device_api", &device_api);
v->Visit("body", &body);
v->Visit("thread_binding", &thread_binding);
v->Visit("annotations", &annotations);
v->Visit("span", &span);
}

bool SEqualReduce(const ForNode* other, SEqualReducer equal) const {
return equal.DefEqual(loop_var, other->loop_var) && equal(min, other->min) &&
equal(extent, other->extent) && equal(for_type, other->for_type) &&
equal(device_api, other->device_api) && equal(body, other->body);
equal(body, other->body) && equal(thread_binding, other->thread_binding) &&
equal(annotations, other->annotations);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(loop_var);
hash_reduce(min);
hash_reduce(extent);
hash_reduce(for_type);
hash_reduce(device_api);
hash_reduce(body);
hash_reduce(thread_binding);
hash_reduce(annotations);
}

static constexpr const char* _type_key = "tir.For";
Expand All @@ -832,8 +860,9 @@ class ForNode : public StmtNode {
*/
class For : public Stmt {
public:
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api,
Stmt body, Span span = Span());
TVM_DLL For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, Stmt body,
Optional<IterVar> thread_binding = NullOpt,
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(), Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(For, Stmt, ForNode);
};
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/script/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def serial(begin, end, span):
self.context.report_error("Expect exact 1 loop var", span)
ana = tvm.arith.Analyzer()
extent = end if begin == 0 else ana.simplify(end - begin)
return tvm.tir.For(self.loop_vars[0], begin, extent, 0, 0, self.body, span=span)
return tvm.tir.For(self.loop_vars[0], begin, extent, 0, self.body, span=span)

super().__init__(serial)

Expand All @@ -241,7 +241,7 @@ def parallel(begin, end, span):
self.context.report_error("Expect exact 1 loop var")
ana = tvm.arith.Analyzer()
extent = end if begin == 0 else ana.simplify(end - begin)
return tvm.tir.For(self.loop_vars[0], begin, extent, 1, 0, self.body, span=span)
return tvm.tir.For(self.loop_vars[0], begin, extent, 1, self.body, span=span)

super().__init__(parallel)

Expand All @@ -256,7 +256,7 @@ def vectorized(begin, end, span):
self.context.report_error("Expect exact 1 loop var")
ana = tvm.arith.Analyzer()
extent = end if begin == 0 else ana.simplify(end - begin)
return tvm.tir.For(self.loop_vars[0], begin, extent, 2, 0, self.body, span=span)
return tvm.tir.For(self.loop_vars[0], begin, extent, 2, self.body, span=span)

super().__init__(vectorized)

Expand All @@ -271,6 +271,6 @@ def unroll(begin, end, span):
self.context.report_error("Expect exact 1 loop var")
ana = tvm.arith.Analyzer()
extent = end if begin == 0 else ana.simplify(end - begin)
return tvm.tir.For(self.loop_vars[0], begin, extent, 3, 0, self.body, span=span)
return tvm.tir.For(self.loop_vars[0], begin, extent, 3, self.body, span=span)

super().__init__(unroll)
2 changes: 1 addition & 1 deletion python/tvm/te/hybrid/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def visit_For(self, node):
_internal_assert(
not isinstance(for_type, tuple), "Micro expansion should be handled before!"
)
res = tvm.tir.For(iter_var, tvm.runtime.const(0, "int32"), ext, for_type, 0, _body)
res = tvm.tir.For(iter_var, tvm.runtime.const(0, "int32"), ext, for_type, _body)

self.symbols.pop(_name)
return res
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _exit_cb():
for_type_id = 3
else:
raise ValueError("Unknown for_type")
self.emit(_stmt.For(loop_var, begin, extent, for_type_id, 0, self._pop_seq()))
self.emit(_stmt.For(loop_var, begin, extent, for_type_id, self._pop_seq()))

return WithScope(loop_var, _exit_cb)

Expand Down
35 changes: 29 additions & 6 deletions python/tvm/tir/stmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,16 @@ class For(Stmt):
for_type : int
The for type.
device_api : int
The device api type.
body : Stmt
The body statement.
thread_binding: Optional[tir.IterVar]
The thread this loop binds to. Only valid
if for_type is ThreadBinding
annotations: tvm.ir.Map
Additional annotation hints.
span : Optional[Span]
The location of this itervar in the source code.
"""
Expand All @@ -114,10 +118,29 @@ class For(Stmt):
Parallel = 1
Vectorized = 2
Unrolled = 3

def __init__(self, loop_var, min_val, extent, for_type, device_api, body, span=None):
ThreadBiding = 4

def __init__(
self,
loop_var,
min_val,
extent,
for_type,
body,
thread_binding=None,
annotations=None,
span=None,
):
self.__init_handle_by_constructor__(
_ffi_api.For, loop_var, min_val, extent, for_type, device_api, body, span
_ffi_api.For,
loop_var,
min_val,
extent,
for_type,
body,
thread_binding,
annotations,
span,
)


Expand Down
2 changes: 1 addition & 1 deletion src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ class PerStoreFeatureExtractor : public StmtExprVisitor {
is_gpu_ = true;

// make a fake for node for blockIdx.x or threadIdx.x
Stmt fake_for_node = For(var, 0, extent, ForType::Parallel, DeviceAPI::None, node->body);
Stmt fake_for_node = For(var, 0, extent, ForType::Parallel, node->body);

outer_loop_prod_ *= extent;
for_loop_stack_.push_back(fake_for_node.as<ForNode>());
Expand Down
4 changes: 4 additions & 0 deletions src/autotvm/feature_visitor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) {
case ForType::Serial:
ann = kSerial;
break;
case ForType::ThreadBinding:
LOG(FATAL) << "Loop ThreadBinding is reserved for future used and "
<< "not yet supported in TIR";
break;
}

if (EnterItervar_(op->loop_var, loop_extent, ann)) {
Expand Down
3 changes: 3 additions & 0 deletions src/printer/tir_text_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,9 @@ inline const char* ForType2String(ForType t) {
return "vectorized";
case ForType::Unrolled:
return "unroll";
case ForType::ThreadBinding:
LOG(FATAL) << "Loop ThreadBinding is reserved for future used and "
<< "not yet supported in TIR";
}
LOG(FATAL) << "Unknown ForType";
return "Unknown";
Expand Down
4 changes: 4 additions & 0 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,10 @@ inline const char* ForType2String(ForType t) {
return "vectorized";
case ForType::Unrolled:
return "unroll";
case ForType::ThreadBinding:
LOG(FATAL) << "Loop ThreadBinding is reserved for future used and "
<< "not yet supported in TIR";
return "threadbinding";
}
LOG(FATAL) << "Unknown ForType";
return "Unknown";
Expand Down
5 changes: 3 additions & 2 deletions src/target/llvm/codegen_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,9 @@ void CodeGenCPU::VisitStmt_(const ForNode* op) {
CodeGenLLVM::VisitStmt_(op);
} else if (op->for_type == ForType::Parallel) {
if (parallel_env_.penv == nullptr) {
CreateParallelLaunch(
For(op->loop_var, op->min, op->extent, op->for_type, op->device_api, op->body), 0);
CreateParallelLaunch(For(op->loop_var, op->min, op->extent, op->for_type, op->body,
op->thread_binding, op->annotations),
0);
} else {
// already in parallel env.
ICHECK(parallel_env_.task_id.defined());
Expand Down
13 changes: 7 additions & 6 deletions src/te/operation/hybrid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,9 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map<IterVar, Range
PrimExpr cond = likely(outer * factor < (op->extent - inner));
ret = IfThenElse(cond, ret);
ret = For(inner->var, PrimExpr(0), inner->dom->extent,
IterVarTypeToForType(inner->iter_type), op->device_api, ret);
IterVarTypeToForType(inner->iter_type), ret);
ret = For(outer->var, PrimExpr(0), outer->dom->extent,
IterVarTypeToForType(outer->iter_type), op->device_api, ret);
IterVarTypeToForType(outer->iter_type), ret);
splitted = true;
return ret;
}
Expand Down Expand Up @@ -277,8 +277,8 @@ Stmt ApplyLoopShapes(const Stage& stage, const std::unordered_map<IterVar, Range
rmap[op->loop_var.get()] = indexdiv(parent, extent);
body = tir::Substitute(body, rmap);
under_outer = false;
return For(parent->var, PrimExpr(0), extent * op->extent, op->for_type, op->device_api,
body);
return For(parent->var, PrimExpr(0), extent * op->extent, op->for_type, body,
op->thread_binding, op->annotations);
} else if (under_outer) {
Stmt body = this->VisitStmt(op->body);
std::unordered_map<const VarNode*, PrimExpr> rmap;
Expand Down Expand Up @@ -332,7 +332,7 @@ Stmt ApplyLoopAnnotations(const Stage& stage, const std::unordered_map<IterVar,
return AttrStmt(iter_var, "thread_extent", op->extent, body);
} else {
return For(op->loop_var, op->min, op->extent, IterVarTypeToForType(attr->iter_type),
op->device_api, op->body);
op->body, op->thread_binding, op->annotations);
}
}
return StmtMutator::VisitStmt_(op);
Expand Down Expand Up @@ -414,7 +414,8 @@ Stmt ApplyLoopOrder(const Stage& stage, const std::unordered_map<IterVar, Range>
for_type = IterVarTypeToForType(stage->iter_var_attrs[target]->iter_type);
}
const Range& range = target->dom.defined() ? target->dom : dom_map.find(target)->second;
return For(target->var, range->min, range->extent, for_type, DeviceAPI::None, body);
return For(target->var, range->min, range->extent, for_type, body, op->thread_binding,
op->annotations);
}
};

Expand Down
4 changes: 2 additions & 2 deletions src/te/operation/op_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,11 +115,11 @@ std::vector<std::vector<Stmt> > MakeLoopNest(const Stage& stage,
nest[i + 1].emplace_back(LetStmt(var, cast(var.dtype(), dom->min), no_op));
value_map[iv] = cast(var.dtype(), dom->min);
} else if (is_zero(dom->min)) {
nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type, DeviceAPI::None, no_op));
nest[i + 1].emplace_back(For(var, 0, dom->extent, for_type, no_op));
value_map[iv] = var;
} else {
Var idx(bind_iv->var->name_hint + ".idx", bind_iv->var.dtype());
nest[i + 1].emplace_back(For(idx, 0, dom->extent, for_type, DeviceAPI::None, no_op));
nest[i + 1].emplace_back(For(idx, 0, dom->extent, for_type, no_op));
PrimExpr new_value = dom->min + idx;
value_map[iv] = new_value;
nest[i + 1].emplace_back(LetStmt(var, new_value, no_op));
Expand Down
3 changes: 2 additions & 1 deletion src/te/schedule/schedule_postproc_rewrite_for_tensor_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -968,7 +968,8 @@ class TensorCoreIRMutator : public StmtExprMutator {
scaled_extent_value = ori_extent_value / scale_factor;
}
PrimExpr scaled_extent = make_const(op->extent.dtype(), scaled_extent_value);
stmt = For(op->loop_var, op->min, scaled_extent, op->for_type, op->device_api, op->body);
stmt = For(op->loop_var, op->min, scaled_extent, op->for_type, op->body, op->thread_binding,
op->annotations);
}
}
return stmt;
Expand Down
22 changes: 13 additions & 9 deletions src/tir/ir/stmt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
});

// For
For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAPI device_api,
Stmt body, Span span) {
For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, Stmt body,
Optional<IterVar> thread_binding, Map<String, ObjectRef> annotations, Span span) {
ICHECK(min.defined());
ICHECK(extent.defined());
ICHECK(min.dtype().is_scalar());
Expand All @@ -142,18 +142,19 @@ For::For(Var loop_var, PrimExpr min, PrimExpr extent, ForType for_type, DeviceAP
node->min = std::move(min);
node->extent = std::move(extent);
node->for_type = for_type;
node->device_api = device_api;
node->body = std::move(body);
node->thread_binding = std::move(thread_binding);
node->annotations = std::move(annotations);
node->span = std::move(span);
data_ = std::move(node);
}

TVM_REGISTER_GLOBAL("tir.For").set_body_typed([](Var loop_var, PrimExpr min, PrimExpr extent,
int for_type, int device_api, Stmt body,
Span span) {
return For(loop_var, min, extent, static_cast<ForType>(for_type),
static_cast<DeviceAPI>(device_api), body, span);
});
TVM_REGISTER_GLOBAL("tir.For").set_body_typed(
[](Var loop_var, PrimExpr min, PrimExpr extent, int for_type, Stmt body,
Optional<IterVar> thread_binding, Optional<Map<String, ObjectRef>> annotations, Span span) {
return For(loop_var, min, extent, static_cast<ForType>(for_type), body, thread_binding,
annotations.value_or(Map<String, ObjectRef>()), span);
});

TVM_REGISTER_NODE_TYPE(ForNode);

Expand All @@ -171,6 +172,9 @@ std::ostream& operator<<(std::ostream& out, ForType type) { // NOLINT(*)
case ForType::Vectorized:
out << "vectorized";
break;
case ForType::ThreadBinding:
out << "launch_thread";
break;
}
return out;
}
Expand Down
3 changes: 1 addition & 2 deletions src/tir/transforms/inject_double_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ class DoubleBufferInjector : public StmtExprMutator {
vmap[old_loop->loop_var.get()] = outer_var * factor + make_const(factor.dtype(), i);
loop_seq.emplace_back(Substitute(old_loop->body, vmap));
}
Stmt loop = For(outer_var, zero, outer_ext, old_loop->for_type, old_loop->device_api,
SeqStmt::Flatten(loop_seq));
Stmt loop = For(outer_var, zero, outer_ext, old_loop->for_type, SeqStmt::Flatten(loop_seq));
// tail
std::vector<Stmt> tail_seq;
Stmt tail_body = StripDoubleBufferWrite()(old_loop->body);
Expand Down
Loading

0 comments on commit 838de68

Please sign in to comment.