Skip to content

Commit

Permalink
Decouple TE compute and schedule lowering in ScheduleBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 10, 2022
1 parent 48793f3 commit f21f8f4
Showing 1 changed file with 141 additions and 113 deletions.
254 changes: 141 additions & 113 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
#include "../../te/operation/create_primfunc.h"
#include "../op/memory/memory.h"
#include "../transforms/pass_utils.h"
#include "tvm/relay/op_strategy.h"
#include "utils.h"

namespace tvm {
Expand Down Expand Up @@ -115,99 +116,25 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
}

// Construct a schedule for a given Relay primitive function and target.
class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
public:
explicit ScheduleBuilder(Target target, bool create_schedule = true)
: target_(target),
device_copy_op_(Op::Get("device_copy")),
create_schedule_(create_schedule) {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
use_meta_schedule_ = backend::IsMetaScheduleEnabled();
}
explicit LowerToTECompute(Target target)
: target_(target), device_copy_op_(Op::Get("device_copy")) {}

CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
Array<tvm::te::Tensor> fn_inputs;
Array<te::Tensor> Lower(const Function& relay_func,
std::function<std::string(std::string)> renamer) {
for (Var param : relay_func->params) {
Array<tvm::te::Tensor> inputs;
for (const auto& ttype : FlattenTupleType(param->checked_type())) {
tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype);
fn_inputs.push_back(tensor);
inputs.push_back(tensor);
fn_inputs_.push_back(tensor);
}
memo_[param] = inputs;
}
readable_name_stream_ << "fused";
auto outputs = this->VisitExpr(relay_func->body);
auto candidate_name = readable_name_stream_.str();
constexpr static size_t kMaxFuncNameLength = 80;
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
// whenever the value of kMaxFuncNameLength changes
if (candidate_name.size() > kMaxFuncNameLength) {
std::stringstream truncated_name;
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_";
candidate_name = truncated_name.str();
}

// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
// no other GlobalVar ctors should appear inside the lowering machinery.
auto prim_fn_var = GlobalVar(renamer(candidate_name));
prim_fn_var->checked_type_ = relay_func->checked_type();

// Fusion over tupled results may leave identity relationships
// between inputs and outputs, and those should not be scheduled.
// Hence schedule only non PlaceholderOp outputs.
tvm::Array<te::Tensor> tensor_outs;
for (const auto& tensor : outputs) {
if (!tensor->op.as<te::PlaceholderOpNode>()) {
tensor_outs.push_back(tensor);
}
}

te::Schedule schedule{nullptr};
tir::PrimFunc prim_func{nullptr};
// No need to register schedule for device copy op.
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr && create_schedule_) {
if (use_auto_scheduler_) {
const auto* fauto_schedule =
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
ICHECK(fauto_schedule != nullptr)
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs);
if (obj.defined()) {
schedule = Downcast<te::Schedule>(obj);
}
}
if (use_meta_schedule_) {
prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs);
Optional<ObjectRef> opt_mod_or_base_func =
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
prim_func = GetRef<tir::PrimFunc>(result);
} else {
prim_func = tir::PrimFunc(nullptr);
}
}

// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
if (!schedule.defined() && !prim_func.defined()) {
ICHECK(anchor_implementation_.defined());
schedule = anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
}
if (schedule.defined()) {
for (const auto& scalar : scalars_) {
if (schedule->Contain(scalar)) {
schedule[scalar].compute_inline();
}
}
}
}

return CachedFunc(target_, prim_fn_var, fn_inputs, outputs, schedule, prim_func, {},
IRModule(Map<GlobalVar, BaseFunc>({})), constant_tensors_);
return outputs;
}

Array<te::Tensor> VisitExpr_(const VarNode* op) final {
Expand Down Expand Up @@ -254,7 +181,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
}

Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
ICHECK(flower_call) << "relay.backend.lower_call is not registered.";

Expand All @@ -278,28 +204,13 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);

Array<te::Tensor> outputs;
OpImplementation impl;
// TODO(mbs): device_copy cleanup
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered";

LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
outputs = lowered_out->outputs;
impl = lowered_out->implementation;
Array<te::Tensor> outputs = lowered_out->outputs;
anchor_implementation_ = lowered_out->implementation;

if (create_schedule_) {
int op_pattern = fpattern[op];
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
<< " anchor=" << anchor_op_ << " current=" << op;
}
if (op_pattern >= anchor_op_pattern_) {
anchor_op_ = op;
anchor_attrs_ = call_node->attrs;
anchor_op_pattern_ = op_pattern;
anchor_implementation_ = impl;
}
}
if (outputs.size() != 1) {
const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
ICHECK(tuple_type) << "Expected output to be a tuple type "
Expand All @@ -308,8 +219,6 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
ICHECK_EQ(tuple_type->fields.size(), outputs.size());
}

// TODO(mbs): device_copy cleanup
ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered";
readable_name_stream_ << '_' << op->name;
return outputs;
}
Expand Down Expand Up @@ -347,27 +256,146 @@ class ScheduleBuilder : public backend::MemoizedExprTranslator<Array<te::Tensor>
return {tuple[op->index]};
}

public:
// outputs
Array<tvm::te::Tensor> fn_inputs_;
Array<te::Operation> scalars_;
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
std::ostringstream readable_name_stream_;
OpImplementation anchor_implementation_;

private:
tvm::Target target_;
// Index of the global constants
static int const_index;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const Op& device_copy_op_;
};

int LowerToTECompute::const_index = 0;

// Construct a schedule for a given Relay primitive function and target.
class ScheduleBuilder : ExprVisitor {
public:
explicit ScheduleBuilder(Target target, bool create_schedule = true)
: target_(target),

create_schedule_(create_schedule) {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
use_meta_schedule_ = backend::IsMetaScheduleEnabled();
}

CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
LowerToTECompute lower_te_compute(target_);
Array<te::Tensor> outputs = lower_te_compute.Lower(relay_func, renamer);
std::string candidate_name = lower_te_compute.readable_name_stream_.str();
VisitExpr(relay_func->body);

constexpr static size_t kMaxFuncNameLength = 80;
// WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
// whenever the value of kMaxFuncNameLength changes
if (candidate_name.size() > kMaxFuncNameLength) {
std::stringstream truncated_name;
truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_";
candidate_name = truncated_name.str();
}

// TODO(mbs): This should be the definitive global by which the PrimFunc is known and
// no other GlobalVar ctors should appear inside the lowering machinery.
auto prim_fn_var = GlobalVar(renamer(candidate_name));
prim_fn_var->checked_type_ = relay_func->checked_type();

// Fusion over tupled results may leave identity relationships
// between inputs and outputs, and those should not be scheduled.
// Hence schedule only non PlaceholderOp outputs.
tvm::Array<te::Tensor> tensor_outs;
for (const auto& tensor : outputs) {
if (!tensor->op.as<te::PlaceholderOpNode>()) {
tensor_outs.push_back(tensor);
}
}

te::Schedule schedule{nullptr};
tir::PrimFunc prim_func{nullptr};
// No need to register schedule for device copy op.
if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr && create_schedule_) {
if (use_auto_scheduler_) {
const auto* fauto_schedule =
runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
ICHECK(fauto_schedule != nullptr)
<< "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs);
if (obj.defined()) {
schedule = Downcast<te::Schedule>(obj);
}
}
if (use_meta_schedule_) {
prim_func = tir::CreatePrimFuncFromOutputs(tensor_outs);
Optional<ObjectRef> opt_mod_or_base_func =
meta_schedule::MetaScheduleContext::QueryInsideWithScope(
prim_fn_var->name_hint, IRModule({{prim_fn_var, relay_func}}), target_,
Array<IRModule>{IRModule({{prim_fn_var, prim_func}})});
if (const auto* result = opt_mod_or_base_func.as<tir::PrimFuncNode>()) {
prim_func = GetRef<tir::PrimFunc>(result);
} else {
prim_func = tir::PrimFunc(nullptr);
}
}

// Use TOPI schedule if user specificed, or the function has no auto_scheduler schedule.
if (!schedule.defined() && !prim_func.defined()) {
ICHECK(lower_te_compute.anchor_implementation_.defined());
schedule =
lower_te_compute.anchor_implementation_.Schedule(anchor_attrs_, tensor_outs, target_);
}
if (schedule.defined()) {
for (const auto& scalar : lower_te_compute.scalars_) {
if (schedule->Contain(scalar)) {
schedule[scalar].compute_inline();
}
}
}
}

return CachedFunc(target_, prim_fn_var, lower_te_compute.fn_inputs_, outputs, schedule,
prim_func, {}, IRModule(Map<GlobalVar, BaseFunc>({})),
lower_te_compute.constant_tensors_);
}

void VisitExpr_(const CallNode* call_node) final {
static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");

ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
Op op = Downcast<Op>(call_node->op);

if (create_schedule_) {
int op_pattern = fpattern[op];
if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
<< " anchor=" << anchor_op_ << " current=" << op;
}
if (op_pattern >= anchor_op_pattern_) {
anchor_op_ = op;
anchor_attrs_ = call_node->attrs;
anchor_op_pattern_ = op_pattern;
}
}
}

private:
tvm::Target target_;
Op anchor_op_;
Attrs anchor_attrs_;
int anchor_op_pattern_{0};
OpImplementation anchor_implementation_;
std::ostringstream readable_name_stream_;
Array<te::Operation> scalars_;
std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
bool use_auto_scheduler_;
bool use_meta_schedule_;
// Cache device copy op for equivalence checking to reduce registry lookup
// overhead for each invocation of call node when retrieving schedules.
const Op& device_copy_op_;
bool create_schedule_;
// Index of the global constants
static int const_index;
};

int ScheduleBuilder::const_index = 0;

/*!
* \brief Create schedule for target.
* \param source_func The primitive function to be lowered.
Expand Down

0 comments on commit f21f8f4

Please sign in to comment.