Skip to content

Commit

Permalink
[BYOC][Optimization] Run accelerator specific optimizations (#6068)
Browse files Browse the repository at this point in the history
* register and invoke optimization pipeline for external codegen

* add unit test
  • Loading branch information
zhiics authored Jul 16, 2020
1 parent ae4480a commit f9e2b95
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 31 deletions.
83 changes: 52 additions & 31 deletions src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,37 +261,26 @@ class Partitioner : public MixedModeMutator {
}

/*!
* \brief Create a function and its function call for the given region. If the function has
* multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes
* will be created to serve output consumers.
* \brief Check if an expr is a constant or a tuple that only contain constants.
*/
void CreateFunction(AnnotatedRegion region, const CallNode* end_node) {
// Create fields which is a unique list of outputs.
Array<Expr> fields;
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> out_expr_to_idx;
int out_idx = 0;
for (auto region_end_node : region->GetOutputs()) {
auto ret_node = Downcast<Call>(region_end_node)->args[0];
// Don't duplicate outputs.
if (!out_expr_to_idx.count(ret_node)) {
auto ret_expr = MixedModeMutator::VisitExpr(ret_node);
fields.push_back(ret_expr);
out_expr_to_idx[ret_node] = out_idx++;
}
}
bool IsConstant(const Expr& expr) const {
if (expr->IsInstance<ConstantNode>()) return true;
if (!expr->IsInstance<TupleNode>()) return false;
const auto* tn = expr.as<TupleNode>();
return std::all_of(tn->fields.begin(), tn->fields.end(),
[](const Expr& e) { return e->IsInstance<ConstantNode>(); });
}

/*!
* \brief Create a call to the function that represents a region.
* \note The customized optimization pipeline will be invoked as well to
* optimize each function that is handled by external codegen.
*/
Call CreateRegionCall(AnnotatedRegion region, const Array<Expr>& fields,
const CallNode* end_node) {
Array<Var> params;
Array<Expr> param_expr;
Map<Var, Expr> params_bind;

auto IsConstant = [](const Expr& expr) {
if (expr->IsInstance<ConstantNode>()) return true;
if (!expr->IsInstance<TupleNode>()) return false;
const auto* tn = expr.as<TupleNode>();
return std::all_of(tn->fields.begin(), tn->fields.end(),
[](const Expr& e) { return e->IsInstance<ConstantNode>(); });
};

for (auto pair : region_func_meta_[region].args) {
params.push_back(pair.first);
if (IsConstant(pair.second)) {
Expand All @@ -314,18 +303,25 @@ class Partitioner : public MixedModeMutator {
std::string target = end_node->attrs.as<CompilerAttrs>()->compiler;
std::string name = target + "_" + std::to_string(region->GetID());

// Constant propagation
if (!params_bind.empty()) {
global_region_func = Downcast<Function>(relay::Bind(global_region_func, params_bind));
}
std::string ext_opt = "relay.ext." + target + ".optimize";
auto pf = tvm::runtime::Registry::Get(ext_opt);
if (pf != nullptr) {
auto mod = IRModule::FromExpr(global_region_func);
mod = (*pf)(mod);
global_region_func = Downcast<Function>(mod->Lookup("main"));
}

global_region_func =
WithAttr(std::move(global_region_func), tvm::attr::kGlobalSymbol, runtime::String(name));
global_region_func = WithAttr(std::move(global_region_func), attr::kPrimitive, tvm::Integer(1));
global_region_func =
WithAttr(std::move(global_region_func), attr::kCompiler, tvm::runtime::String(target));
global_region_func = WithAttr(std::move(global_region_func), attr::kInline, tvm::Integer(1));

// Constant propagation
if (!params_bind.empty()) {
global_region_func = Downcast<Function>(relay::Bind(global_region_func, params_bind));
}

std::string fname = name;
CHECK(!module_->ContainGlobalVar(fname)) << "Global function " << fname << " already exists";
// Create a global function and add it to the IRModule for the region.
Expand All @@ -340,6 +336,31 @@ class Partitioner : public MixedModeMutator {
auto call = Call(glob_func, param_expr);
region_func_meta_[region].func_call = call;

return call;
}

/*!
* \brief Create a function and its function call for the given region. If the function has
* multiple outputs, a Tuple will be formed to aggregate all outputs, and TupleGetItem nodes
* will be created to serve output consumers.
*/
void CreateFunction(AnnotatedRegion region, const CallNode* end_node) {
// Create fields which is a unique list of outputs.
Array<Expr> fields;
std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> out_expr_to_idx;
int out_idx = 0;
for (auto region_end_node : region->GetOutputs()) {
auto ret_node = Downcast<Call>(region_end_node)->args[0];
// Don't duplicate outputs.
if (!out_expr_to_idx.count(ret_node)) {
auto ret_expr = MixedModeMutator::VisitExpr(ret_node);
fields.push_back(ret_expr);
out_expr_to_idx[ret_node] = out_idx++;
}
}

Call call = CreateRegionCall(region, fields, end_node);

// Create output expr(s) for the function call.
if (out_expr_to_idx.size() == 1) {
// Single output direcly uses the call node as the output expr.
Expand Down
32 changes: 32 additions & 0 deletions tests/python/relay/test_pass_partition_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1286,6 +1286,37 @@ def test_tuple_output_exec():
[(10, 10), (10, 10)],
[(a_data + b_data), (a_data - b_data)])

def test_extern_opt():
def Optimize(mod):
return relay.transform.FoldConstant()(mod)

tvm.register_func("relay.ext.test_target.optimize", Optimize)

x = relay.var('x', shape=(2, 2))
y0 = relay.var('y0', shape=(2, 2))
y1 = relay.var('y1', shape=(2, 2))
yy0 = relay.annotation.compiler_begin(y0, 'test_target')
yy1 = relay.annotation.compiler_begin(y1, 'test_target')
z = yy0 + yy1
end = relay.annotation.compiler_end(z, 'test_target')
f = relay.Function([x, y0, y1], end * x)
c = np.ones(shape=(2, 2), dtype="float32")
f = bind_params_by_name(f, {"y0": tvm.nd.array(c), "y1": tvm.nd.array(c)})
mod = tvm.IRModule()
mod["main"] = f
mod = transform.PartitionGraph()(mod)

try:
t0 = mod["test_target_0"]
except:
raise KeyError("test_target_0 not found")

assert isinstance(t0.body, relay.Constant)
expected = np.empty([2, 2])
expected.fill(2)
tvm.testing.assert_allclose(t0.body.data.asnumpy(), expected, rtol=1e-5,
atol=1e-5)

if __name__ == "__main__":
test_multi_node_compiler()
test_extern_ccompiler_single_op()
Expand All @@ -1305,3 +1336,4 @@ def test_tuple_output_exec():
test_constant_tuples()
test_flatten_tuple_output()
test_tuple_output_exec()
test_extern_opt()

0 comments on commit f9e2b95

Please sign in to comment.