Skip to content

Commit

Permalink
New relay backend for meta schedule task extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 11, 2022
1 parent 6f01901 commit 109187f
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 1 deletion.
19 changes: 18 additions & 1 deletion python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional, Union

from tvm._ffi import register_object
import numpy as np
import tvm.runtime.ndarray as nd

from tvm._ffi import register_object, get_global_func
from tvm.ir import IRModule, transform
from tvm.relay import Any
from tvm.relay import Function as RelayFunc
Expand Down Expand Up @@ -230,6 +233,20 @@ def extract_task_from_relay(
The tasks extracted from this network
"""

extract_task_func = get_global_func("relay.backend.MetaScheduleExtractTask")
assert extract_task_func

target = Target(target) if isinstance(target, str) else target

for name, param in params.items():
if isinstance(param, np.ndarray):
params[name] = nd.array(param)

with transform.PassContext(opt_level=opt_level):
with target:
tasks = extract_task_func(mod, target, params)
return tasks

@contextmanager
def _autotvm_silencer():
from tvm import autotvm # pylint: disable=import-outside-toplevel
Expand Down
86 changes: 86 additions & 0 deletions src/relay/backend/metaschedule_task_extraction.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <tvm/meta_schedule/integration.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/function.h>
#include <tvm/target/target.h>

#include "../../te/operation/create_primfunc.h"
#include "te_compiler_cache.h"
#include "utils.h"

namespace tvm {
namespace relay {
namespace backend {
namespace metaschedule {

using meta_schedule::ExtractedTask;

Array<ExtractedTask> ExtractTask(IRModule mod, Target target, Map<String, Constant> params) {
if (params.size()) {
std::unordered_map<std::string, runtime::NDArray> params_;
BaseFunc base_func = mod->Lookup("main");
ICHECK(base_func->IsInstance<FunctionNode>());
auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params_);
auto gvar = mod->GetGlobalVar("main");
mod->Add(gvar, f);
}

Array<Pass> pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true);
pass_seqs.push_back(transform::FuseOps());

transform::Sequential seq(pass_seqs);
auto opt_mod = seq(std::move(mod));

Array<ExtractedTask> tasks;
LOG(INFO) << opt_mod;
LOG(INFO) << opt_mod->Lookup("main");
PostOrderVisit(opt_mod->Lookup("main"), [target, &tasks](const Expr& exp) {
if (exp->IsInstance<FunctionNode>()) {
Function relay_func = Downcast<Function>(exp);
if (relay_func->HasNonzeroAttr(attr::kPrimitive)) {
LOG(INFO) << relay_func;
Array<te::Tensor> outputs;
std::string fused_name;
std::tie(outputs, fused_name) = tec::LowerTECompute(target, relay_func);
LOG(INFO) << fused_name;
LOG(INFO) << outputs;
auto prim_func = tir::CreatePrimFunc(outputs);
auto prim_fn_var = GlobalVar(fused_name);
auto relay_mod = IRModule({{prim_fn_var, relay_func}});
auto tir_mod = IRModule({{prim_fn_var, prim_func}});
tasks.push_back(ExtractedTask(prim_fn_var->name_hint, relay_mod, target, {tir_mod}));
}
}
});

return tasks;
}

TVM_REGISTER_GLOBAL("relay.backend.MetaScheduleExtractTask")
.set_body_typed([](IRModule mod, Target target, Map<String, Constant> params) {
return ExtractTask(mod, target, params);
});

} // namespace metaschedule
} // namespace backend
} // namespace relay
} // namespace tvm
18 changes: 18 additions & 0 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,24 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
return MakeShapeFunc().Create(prim_func, target, renamer);
}

std::pair<Array<te::Tensor>, std::string> LowerTECompute(Target target, const Function& relay_func,
bool return_inputs) {
LowerToTECompute lower_te_compute(target);
auto outputs = lower_te_compute.Lower(relay_func, [&](std::string name) { return name; });
// Following ScheduleBuilder, remove placeholder ops from outputs.
tvm::Array<te::Tensor> tensor_outs;
for (const auto& tensor : outputs) {
if (!tensor->op.as<te::PlaceholderOpNode>()) {
tensor_outs.push_back(tensor);
}
}
if (return_inputs) {
return std::make_pair(Concat(lower_te_compute.fn_inputs_, tensor_outs),
lower_te_compute.candidate_name_);
}
return std::make_pair(tensor_outs, lower_te_compute.candidate_name_);
}

/*!
* \brief Get unique name from name.
* \param name The orginal name.
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/te_compiler_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@ class CCacheValue : public ObjectRef {

Array<IndexExpr> GetShape(const Array<IndexExpr>& shape);

std::pair<Array<te::Tensor>, std::string> LowerTECompute(Target target, const Function& relay_func, bool return_inputs=true);

/*!
* \brief Create schedule for target.
* \param source_func The primitive function to be lowered.
Expand Down

0 comments on commit 109187f

Please sign in to comment.