Skip to content

Commit

Permalink
[TECompiler] Decouple TE compute and schedule lowering in ScheduleBui…
Browse files Browse the repository at this point in the history
…lder (#10561)

* Decouple TE compute and schedule lowering in ScheduleBuilder

* fixed merge conflict

* removed create_schedule stuff

* add public, fix include path convention

* Forgot visiting arg in ScheduleBuilder CallNode vsit

* fixed anchor impl selection
  • Loading branch information
masahi authored Mar 11, 2022
1 parent 51ae845 commit 076fa33
Showing 1 changed file with 146 additions and 114 deletions.
260 changes: 146 additions & 114 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,13 @@
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/op_strategy.h>
#include <tvm/runtime/device_api.h>
#include <tvm/runtime/registry.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/function.h>
#include <tvm/topi/tags.h>

#include <functional>
Expand Down Expand Up @@ -114,100 +116,40 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
return res;
}

// Construct a schedule for a given Relay primitive function and target.
class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
// Lowers Relay primitive Function to TE Compute
class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
explicit ScheduleBuilder(Target target, bool create_schedule = true)
: target_(target),
device_copy_op_(Op::Get("device_copy")),
create_schedule_(create_schedule) {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
use_meta_schedule_ = backend::IsMetaScheduleEnabled();
}
explicit LowerToTECompute(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {}

CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
Array<tvm::te::Tensor> fn_inputs;
Array<te::Tensor> Lower(const Function& relay_func,
std::function<std::string(std::string)> renamer) {
for (Var param : relay_func->params) {
Array<tvm::te::Tensor> inputs;
for (const auto& ttype : FlattenTupleType(param->checked_type())) {
tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
fn_inputs.push_back(tensor);
inputs.push_back(tensor);
fn_inputs_.push_back(tensor);
}
memo_[param] = inputs;
}
readable_name_stream_ << "fused";
auto outputs = this->VisitExpr(relay_func->body);
auto candidate_name = readable_name_stream_.str();

Array<te::Tensor> outputs = this->VisitExpr(relay_func->body);

candidate_name_ = readable_name_stream_.str();

constexpr static size_t kMaxFuncNameLength = 80;
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
// whenever the value of kMaxFuncNameLength changes
if (candidate_name.size() > kMaxFuncNameLength) {
if (candidate_name_.size() > kMaxFuncNameLength) {
std::stringstream truncated_name;
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_";
candidate_name = truncated_name.str();
}

// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
// no other GlobalVar ctors should appear inside the lowering machinery.
auto prim_fn_var = GlobalVar(renamer(candidate_name));
prim_fn_var->checked_type_ = relay_func->checked_type();

// Fusion over tupled results may leave identity relationships
// between inputs and outputs, and those should not be scheduled.
// Hence schedule only non PlaceholderOp outputs.
tvm::Array<te::Tensor> tensor_outs;
for (const auto& tensor : outputs) {
if (!tensor->op.as<te::PlaceholderOpNode>()) {
tensor_outs.push_back(tensor);
}
}

te::Schedule schedule{nullptr};
tir::PrimFunc prim_func{nullptr};
// No need to register schedule for device copy op.
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr && create_schedule_) {
if (use_auto_scheduler_) {
const auto* fauto_schedule =
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
ICHECK(fauto_schedule != nullptr)
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs);
if (obj.defined()) {
schedule = Downcast<te::Schedule>(obj);
}
}
if (use_meta_schedule_) {
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs));
Optional<ObjectRef> opt_mod_or_base_func =
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
prim_func = GetRef<tir::PrimFunc>(result);
} else {
prim_func = tir::PrimFunc(nullptr);
}
}

// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
if (!schedule.defined() && !prim_func.defined()) {
ICHECK(anchor_implementation_.defined());
schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
}
if (schedule.defined()) {
for (const auto& scalar : scalars_) {
if (schedule->Contain(scalar)) {
schedule[scalar].compute_inline();
}
}
}
truncated_name << candidate_name_.substr(0, kMaxFuncNameLength);
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name_) << "_";
candidate_name_ = truncated_name.str();
}

return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
IRModule(Map<GlobalVar, BaseFunc>({})), constant_tensors_);
return outputs;
}

Array<te::Tensor> VisitExpr_(const VarNode* op) final {
Expand Down Expand Up @@ -254,7 +196,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
}

Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
ICHECK(flower_call) << "relay.backend.lower_call is not registered.";

Expand All @@ -278,28 +219,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);

Array<te::Tensor> outputs;
OpImplementation impl;
// TODO(mbs): device_copy cleanup
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered";

LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
outputs = lowered_out->outputs;
impl = lowered_out->implementation;

if (create_schedule_) {
int op_pattern = fpattern[op];
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
<< " anchor=" << anchor_op_ << " current=" << op;
}
if (op_pattern >= anchor_op_pattern_) {
anchor_op_ = op;
anchor_attrs_ = call_node->attrs;
anchor_op_pattern_ = op_pattern;
anchor_implementation_ = impl;
}
}
Array<te::Tensor> outputs = lowered_out->outputs;
op_implementations_[op.operator->()] = lowered_out->implementation;

if (outputs.size() != 1) {
const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
ICHECK(tuple_type) << "Expected output to be a tuple type "
Expand All @@ -308,8 +234,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
ICHECK_EQ(tuple_type->fields.size(), outputs.size());
}

// TODO(mbs): device_copy cleanup
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered";
readable_name_stream_ << '_' << op->name;
return outputs;
}
Expand Down Expand Up @@ -347,26 +271,131 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
return {tuple[op->index]};
}

public:
// Additional outputs
Array<tvm::te::Tensor> fn_inputs_;
Array<te::Operation> scalars_;
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
std::unordered_map<const OpNode*, OpImplementation> op_implementations_;
std::string candidate_name_;

private:
tvm::Target target_;
Op anchor_op_;
Attrs anchor_attrs_;
int anchor_op_pattern_{0};
OpImplementation anchor_implementation_;
std::ostringstream readable_name_stream_;
Array<te::Operation> scalars_;
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
bool use_auto_scheduler_;
bool use_meta_schedule_;
// Index of the global constants
static int const_index;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const Op& device_copy_op_;
bool create_schedule_;
// Index of the global constants
static int const_index;
};

int ScheduleBuilder::const_index = 0;
int LowerToTECompute::const_index = 0;

// Construct a schedule for a given Relay primitive function and target.
class ScheduleBuilder : public ExprVisitor {
public:
explicit ScheduleBuilder(Target target) : target_(target) {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
}

CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
LowerToTECompute lower_te_compute(target_);
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer);
Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
VisitExpr(relay_func->body);

// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
// no other GlobalVar ctors should appear inside the lowering machinery.
auto prim_fn_var = GlobalVar(renamer(lower_te_compute.candidate_name_));
prim_fn_var->checked_type_ = relay_func->checked_type();

// Fusion over tupled results may leave identity relationships
// between inputs and outputs, and those should not be scheduled.
// Hence schedule only non PlaceholderOp outputs.
tvm::Array<te::Tensor> tensor_outs;
for (const auto& tensor : outputs) {
if (!tensor->op.as<te::PlaceholderOpNode>()) {
tensor_outs.push_back(tensor);
}
}

te::Schedule schedule{nullptr};
tir::PrimFunc prim_func{nullptr};
// No need to register schedule for device copy op.
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
if (use_auto_scheduler_) {
const auto* fauto_schedule =
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
ICHECK(fauto_schedule != nullptr)
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs);
if (obj.defined()) {
schedule = Downcast<te::Schedule>(obj);
}
}
if (backend::IsMetaScheduleEnabled()) {
prim_func = tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs));
Optional<ObjectRef> opt_mod_or_base_func =
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
prim_func = GetRef<tir::PrimFunc>(result);
} else {
prim_func = tir::PrimFunc(nullptr);
}
}

// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
if (!schedule.defined() && !prim_func.defined()) {
auto anchor_impl = lower_te_compute.op_implementations_.find(anchor_op_.operator->());
ICHECK(anchor_impl != lower_te_compute.op_implementations_.end());
schedule = anchor_impl->second.Schedule(anchor_attrs_, tensor_outs, target_);
}
if (schedule.defined()) {
for (const auto& scalar : lower_te_compute.scalars_) {
if (schedule->Contain(scalar)) {
schedule[scalar].compute_inline();
}
}
}
}

return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_);
}

void VisitExpr_(const CallNode* call_node) final {
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");

ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);

for (Expr arg : call_node->args) {
VisitExpr(arg);
}

int op_pattern = fpattern[op];
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
<< " anchor=" << anchor_op_ << " current=" << op;
}
if (op_pattern >= anchor_op_pattern_) {
anchor_op_ = op;
anchor_attrs_ = call_node->attrs;
anchor_op_pattern_ = op_pattern;
}
}

private:
tvm::Target target_;
Op anchor_op_;
Attrs anchor_attrs_;
int anchor_op_pattern_{0};
bool use_auto_scheduler_;
};

/*!
* \brief Create schedule for target.
Expand Down Expand Up @@ -750,9 +779,12 @@ std::string GetUniqueName(std::string name, std::unordered_map<std::string, int>
}

TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) {
return ScheduleBuilder(tvm::Target("ext_dev"), false).Create(prim_func, [&](std::string name) {
return name;
});
auto tgt = tvm::Target("ext_dev");
LowerToTECompute lower_te_compute(tgt);
auto outputs = lower_te_compute.Lower(prim_func, [&](std::string name) { return name; });
return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_,
outputs, te::Schedule(), tir::PrimFunc(), {},
IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_);
});

} // namespace tec
Expand Down

0 comments on commit 076fa33

Please sign in to comment.