forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TIR][Transform] Implement InlinePrivateFunctions
The functionality to express a call from one `PrimFunc` to another was introduced in apache#14889. While this was initially planned to be supported at codegen for all targets (see apache#15835), this resulted in breakage on some backends (see apache#16033). After discussion, the plan was changed to support TIR inlining, which would enable the same high-level functionality in TIR without requiring immediate low-level support across all codegens. This commit implements and tests a new IRModule transform `InlinePrivateFunctions`, which can be used as part of lowering in a follow-up commit. Because this is initially implemented for use quite late in the lowering flow, many constructs are not currently supported. The current implementation has the following restrictions. * `tir::Block` nodes may not occur in the inlined function. Because a subroutine may be called multiple times, inlining of a subroutine that contains `tir::Block` would result in non-unique names. Support of subroutines with `tir::Block` instances will require de-duplication of block names. * The subroutine's callsite must occur within a `tir::Evaluate` block. Because inlining a subroutine inserts the `tir::Stmt` body at the point of use, replacement must occur in a context where a `tir::Stmt` can be returned. Support of subroutines that are called within an expression (e.g. Replacing `func` in `Buf[0] = func(1) + func(2)`) would require hoisting preprocessing done in the subroutine to the parent `tir::Stmt`. * The subroutine may only accept primitive arguments, and must have an empty `buffer_map`. Support of subroutines that are called with `tir::Buffer` or `tir::BufferRegion` arguments would require a way to represent these arguments at the callsite, and substitution of the buffer into the callee. If these unsupported constructs are used, then the inlining of those functions is skipped. This commit includes unit tests for these unsupported constructs, to validate that `InlinePrivateFunctions` produces well-formed output even when they are present.
- Loading branch information
1 parent
85e911a
commit 5dda59d
Showing
4 changed files
with
543 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,272 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file inline_private_functions.cc | ||
* \brief Inline private functions to their callsite | ||
*/ | ||
#include <tvm/runtime/registry.h> | ||
#include <tvm/tir/analysis.h> | ||
#include <tvm/tir/builtin.h> | ||
#include <tvm/tir/op.h> | ||
#include <tvm/tir/stmt.h> | ||
#include <tvm/tir/stmt_functor.h> | ||
#include <tvm/tir/transform.h> | ||
|
||
namespace tvm { | ||
namespace tir { | ||
namespace transform { | ||
|
||
namespace { | ||
|
||
template <typename T> | ||
using PSet = std::unordered_set<T, ObjectPtrHash, ObjectPtrEqual>; | ||
|
||
template <typename T, typename U> | ||
using PMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>; | ||
|
||
PMap<GlobalVar, PSet<GlobalVar>> CollectCallMap(const IRModule& mod) { | ||
struct Visitor : StmtExprVisitor { | ||
GlobalVar current; | ||
PMap<GlobalVar, PSet<GlobalVar>> caller_lookup; | ||
|
||
void VisitExpr_(const CallNode* op) { | ||
if (auto gvar = op->op.as<GlobalVar>()) { | ||
caller_lookup[gvar.value()].insert(current); | ||
} | ||
StmtExprVisitor::VisitExpr_(op); | ||
} | ||
} visitor; | ||
|
||
for (const auto& [gvar, base_func] : mod->functions) { | ||
if (auto prim_func = base_func.as<PrimFuncNode>()) { | ||
visitor.current = gvar; | ||
visitor(prim_func->body); | ||
} | ||
} | ||
|
||
return visitor.caller_lookup; | ||
} | ||
|
||
PSet<GlobalVar> CollectRecursiveFunctions(const IRModule& mod) { | ||
// Collect all direct callers | ||
auto call_map = CollectCallMap(mod); | ||
|
||
// Propagate to find all indirect callers | ||
while (true) { | ||
bool made_change = false; | ||
for (const auto& [callee, callers] : call_map) { | ||
for (const auto& caller : callers) { | ||
if (auto it = call_map.find(caller); it != call_map.end()) { | ||
PSet<GlobalVar>& indirect_callers = it->second; | ||
|
||
auto res = indirect_callers.insert(callee); | ||
made_change = made_change || res.second; | ||
} | ||
} | ||
} | ||
if (!made_change) { | ||
break; | ||
} | ||
} | ||
|
||
// Filter all GlobalVars that can be called by themselves, either | ||
// directly or indirectly. | ||
PSet<GlobalVar> recursive_funcs; | ||
for (const auto& [caller, callees] : call_map) { | ||
if (callees.count(caller)) { | ||
recursive_funcs.insert(caller); | ||
} | ||
} | ||
return recursive_funcs; | ||
} | ||
|
||
Map<GlobalVar, PrimFunc> CollectInlinablePrimFuncs(const IRModule& mod) { | ||
auto recursive_functions = CollectRecursiveFunctions(mod); | ||
|
||
Map<GlobalVar, PrimFunc> output; | ||
for (const auto& [gvar, base_func] : mod->functions) { | ||
if (auto opt = base_func.as<PrimFunc>()) { | ||
auto prim_func = opt.value(); | ||
|
||
// Only inline private functions. Externally-exposed functions | ||
// must be preserved so to avoid breaking callsites outside of | ||
// the IRModule. | ||
bool is_exposed = prim_func->GetAttr<String>(tvm::attr::kGlobalSymbol).defined(); | ||
|
||
// We do not currently implement any analysis for termination of | ||
// a function. If a recursive function requires runtime checks | ||
// in order to terminate, we would keep inlining until the | ||
// recursive visits segfault. | ||
bool is_recursive = recursive_functions.count(gvar); | ||
|
||
// We do not currently support inlining of functions that accept | ||
// buffer arguments. | ||
bool has_buffer_arguments = prim_func->buffer_map.size(); | ||
|
||
// We do not currently support inlining of schedulable TIR | ||
// functions. To support this use case, repeated names in | ||
// `tir::Block` nodes resulting from multiple calls to the same | ||
// inlined function will need to be de-duplicated. | ||
bool has_block_node = prim_func->body.as<BlockRealizeNode>(); | ||
|
||
if (!is_exposed && !is_recursive && !has_buffer_arguments && !has_block_node) { | ||
output.Set(gvar, prim_func); | ||
} | ||
} | ||
} | ||
|
||
return output; | ||
} | ||
|
||
class PrimFuncInliner : StmtExprMutator { | ||
public: | ||
PrimFuncInliner(Map<GlobalVar, PrimFunc> inlinable_funcs) : inlinable_funcs_(inlinable_funcs) { | ||
for (const auto& [gvar, callee] : inlinable_funcs_) { | ||
removable_funcs_.insert(gvar); | ||
} | ||
} | ||
|
||
PrimFunc VisitFunc(PrimFunc func) { | ||
current_target_ = func->GetAttr<Target>(tvm::attr::kTarget); | ||
auto new_body = VisitStmt(func->body); | ||
current_target_ = NullOpt; | ||
|
||
if (!new_body.same_as(func->body)) { | ||
func.CopyOnWrite()->body = new_body; | ||
} | ||
|
||
return func; | ||
} | ||
|
||
PSet<GlobalVar> GetRemovableFunctions() const { return removable_funcs_; } | ||
|
||
private: | ||
Stmt VisitStmt_(const EvaluateNode* eval) override { | ||
if (auto call = eval->value.as<CallNode>()) { | ||
if (auto gvar = call->op.as<GlobalVar>()) { | ||
if (auto opt_callee = inlinable_funcs_.Get(gvar.value())) { | ||
auto callee = opt_callee.value(); | ||
|
||
bool is_same_target = [&]() -> bool { | ||
auto callee_target = callee->GetAttr<Target>(tvm::attr::kTarget); | ||
if (current_target_ && callee_target) { | ||
return callee_target.value()->str() == current_target_.value()->str(); | ||
} else { | ||
return true; | ||
} | ||
}(); | ||
|
||
if (is_same_target) { | ||
Stmt inlined = InlineArguments(gvar.value(), callee, call->args); | ||
return VisitStmt(inlined); | ||
} | ||
} | ||
} | ||
} | ||
|
||
return StmtExprMutator::VisitStmt_(eval); | ||
} | ||
|
||
PrimExpr VisitExpr_(const CallNode* call) override { | ||
// Any callee that hasn't been inlined at this point must be kept | ||
// in the output IRModule. | ||
if (auto gvar = call->op.as<GlobalVar>()) { | ||
removable_funcs_.erase(gvar.value()); | ||
} | ||
return StmtExprMutator::VisitExpr_(call); | ||
} | ||
|
||
Stmt InlineArguments(const GlobalVar& gvar, PrimFunc callee, const Array<PrimExpr>& args) const { | ||
CHECK_EQ(callee->params.size(), args.size()) | ||
<< "Callee " << gvar << " accepts " << callee->params.size() << " parameters (" | ||
<< callee->params << "), but is called with " << args.size() << " arguments (" << args | ||
<< ")"; | ||
|
||
ICHECK(callee->buffer_map.empty()) | ||
<< "Inlining of PrimFuncs with buffer arguments is not yet supported, " | ||
<< "but callee " << gvar << " has non-empty buffer map " << callee->buffer_map; | ||
|
||
Map<Var, ObjectRef> param_map; | ||
for (size_t i = 0; i < callee->params.size(); i++) { | ||
param_map.Set(callee->params[i], args[i]); | ||
} | ||
|
||
callee = Specialize(callee, param_map); | ||
|
||
return callee->body; | ||
} | ||
|
||
// Map from GlobalVar to PrimFuncs which may be inlined. | ||
Map<GlobalVar, PrimFunc> inlinable_funcs_; | ||
|
||
/* \brief Set of callees that may be removed | ||
* | ||
* Some constructs may not be inlined (e.g. if the call site occurs | ||
* outside of an Evaluate node). For these cases, the output | ||
* IRModule must still contain the callee. | ||
*/ | ||
PSet<GlobalVar> removable_funcs_; | ||
|
||
Optional<Target> current_target_ = NullOpt; | ||
}; | ||
|
||
} // namespace | ||
|
||
Pass InlinePrivateFunctions() { | ||
auto pass_func = [](IRModule mod, PassContext ctx) { | ||
auto inlinable_prim_funcs = CollectInlinablePrimFuncs(mod); | ||
|
||
if (inlinable_prim_funcs.empty()) { | ||
// Early bail-out if the module has no inlinable PrimFuncs. | ||
return mod; | ||
} | ||
|
||
PrimFuncInliner mutator(std::move(inlinable_prim_funcs)); | ||
IRModule updates; | ||
|
||
for (const auto& [gvar, base_func] : mod->functions) { | ||
if (auto opt = base_func.as<PrimFunc>()) { | ||
auto updated = mutator.VisitFunc(opt.value()); | ||
if (!updated.same_as(base_func)) { | ||
updates->Add(gvar, updated); | ||
} | ||
} | ||
} | ||
|
||
if (updates->functions.size()) { | ||
auto write_ptr = mod.CopyOnWrite(); | ||
write_ptr->Update(updates); | ||
for (const auto& gvar : mutator.GetRemovableFunctions()) { | ||
write_ptr->Remove(gvar); | ||
} | ||
mod = ConvertSSA()(mod); | ||
} | ||
|
||
return mod; | ||
}; | ||
return tvm::transform::CreateModulePass(pass_func, 0, "tir.InlinePrivateFunctions", {}); | ||
} | ||
|
||
TVM_REGISTER_GLOBAL("tir.transform.InlinePrivateFunctions").set_body_typed(InlinePrivateFunctions); | ||
|
||
} // namespace transform | ||
|
||
} // namespace tir | ||
} // namespace tvm |
Oops, something went wrong.