Skip to content

Commit

Permalink
Add a pass to legalize packed calls
Browse files Browse the repository at this point in the history
Change-Id: I8aa43d3a1b837b03a5cf3c6b32fc760bd78d3436
  • Loading branch information
Giuseppe Rossini committed Jun 22, 2021
1 parent 369745f commit fc7eb40
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 101 deletions.
5 changes: 5 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,11 @@ TVM_DLL Pass ConvertBlocksToOpaque();
*/
TVM_DLL Pass CompactBufferAllocation();

/*!
* This pass legalizes packed calls by wrapping their arguments into TVMValues
*/
TVM_DLL Pass LegalizePackedCalls();

/*!
* \brief Flatten the multi-dimensional BufferLoad and BufferStore
* to single dimensional Load/Store. Also remove Block to
Expand Down
94 changes: 42 additions & 52 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,50 +269,11 @@ class AOTExecutorCodegen : public ExprVisitor {
}

auto sid_value = sids_table_[sid];
if (!use_unpacked_api_) {
// Pack the sid inside the TVMValue
auto sid_array = te::Var(MakeString("sid_", sid, "_value"), DataType::Handle());
tvm::PrimExpr set_tensor =
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{sid_array, 0, tir::builtin::kArrData, sid_value});
stmts_.push_back(
tir::LetStmt(sid_array, StackAlloca("array", 1), tir::Evaluate(set_tensor)));
buffer_vars.push_back(sid_array);
} else {
buffer_vars.push_back(sid_value);
}
buffer_vars.push_back(sid_value);
}
return buffer_vars;
}

/*!
* \brief Utility function to return a parameter associated with an expression
* \param expr Relay Expression associated with the parameter
* \return Variable that represents the DLTensor associated with the parameters
*/
tir::Var PackParam(Expr expr) {
int param_sid = param_storage_ids_[params_by_expr_[expr]];
auto param_array = te::Var(MakeString("param_", param_sid, "_array"), DataType::Handle());

// Compose the lookup_call using a local stack
Array<tir::Stmt> lookup_call;
// Set the param to the value returned by lookup_call
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[expr])});

if (!use_unpacked_api_) {
tvm::PrimExpr set_param_array =
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{param_array, 0, tir::builtin::kArrData, param_handle});
stmts_.push_back(
tir::LetStmt(param_array, StackAlloca("arg_value", 1), tir::Evaluate(set_param_array)));
} else {
stmts_.push_back(tir::LetStmt(param_array, param_handle, tir::Evaluate(0)));
}

return param_array;
}

/*!
* brief Given an expression return the variable(s) associated with that expression
*/
Expand All @@ -322,9 +283,6 @@ class AOTExecutorCodegen : public ExprVisitor {
// Input variable
int main_index = std::distance(input_vars_.begin(), input_iter);
return {main_signature_[main_index]};
} else if (params_by_expr_.find(arg) != params_by_expr_.end()) {
// Parameter of the network
return {PackParam(arg)};
} else {
// Storage identifier (i.e., intermediate memory)
return PackSid(arg);
Expand All @@ -340,8 +298,14 @@ class AOTExecutorCodegen : public ExprVisitor {

// Pack the inputs
for (Expr arg : call->args) {
auto var_arg = FindExpr(arg);
args.push_back(var_arg[0]);
if (params_by_expr_.find(arg) != params_by_expr_.end()) {
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[arg])});
args.push_back(param_handle);
} else {
auto var_arg = FindExpr(arg);
args.push_back(var_arg[0]);
}
}

auto ret_expr = Downcast<Expr>(call);
Expand Down Expand Up @@ -369,7 +333,7 @@ class AOTExecutorCodegen : public ExprVisitor {
* TODO(giuseros): we should try to avoid unnecessary copy to the output, e.g., in a
* copy-on-write fashion.
*/
void CopyToOutput(te::Var out, te::Var in, size_t size) {
void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) {
// Define intermediate DLTensor to load/store the data
auto tmp0 = te::Var("tmp0", DataType::Handle());
auto tmp1 = te::Var("tmp1", DataType::Handle());
Expand All @@ -381,10 +345,15 @@ class AOTExecutorCodegen : public ExprVisitor {
PrimExpr tostore = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_get(),
{out, 0, tir::builtin::kArrData});
if (use_unpacked_api_) {
retval_get = in;
tostore = out;
}

// Do not pack the input if the flag is set or the caller
// explicitly asked to do so (e.g., copying a param to the output)
if (use_unpacked_api_ || !pack_input) {
retval_get = in;
}

// Copy the variable from the input to the output
tir::Stmt copy = tir::For(
loop_idx, 0, ConstInt32(size), tir::ForKind::kSerial,
Expand Down Expand Up @@ -563,9 +532,16 @@ class AOTExecutorCodegen : public ExprVisitor {
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), buffers[0].sid);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
auto var_expr = FindExpr(expr);
CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0],
buffers[0].size_bytes);
if (params_by_expr_.find(expr) != params_by_expr_.end()) {
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[expr])});
CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle,
/*pack_input*/ true, buffers[0].size_bytes);
} else {
auto var_expr = FindExpr(expr);
CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0],
/*pack_input*/ true, buffers[0].size_bytes);
}
}
}

Expand All @@ -584,7 +560,9 @@ class AOTExecutorCodegen : public ExprVisitor {
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), buffers[0].sid);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr),
auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
{tir::StringImm(params_by_expr_[expr])});
CopyToOutput(main_signature_[input_vars_.size() + output_index], param_handle, false,
buffers[0].size_bytes);
}
}
Expand Down Expand Up @@ -625,7 +603,9 @@ class AOTExecutorCodegen : public ExprVisitor {
throw std::invalid_argument("match case not yet implemented");
}

// Create the main PrimFunc to execute the graph
// Create the main PrimFunc to execute the graph. Please note that
// the packed function calls don't pack their arguments. The AOT
// runner function needs to be legalized by the LegalizePackedCalls pass.
tir::PrimFunc CreateMainFunc(unsigned int relay_params) {
tir::Stmt body = tir::SeqStmt(stmts_);

Expand Down Expand Up @@ -757,6 +737,9 @@ class AOTExecutorCodegen : public ExprVisitor {

VisitExpr(func->body);

// Create the runner function. Please note that the function is not legal yet
// because the packed calls arguments are not wrapped in TVMValues. To make this happen we need
// to run the LegalizePackedCalls pass.
auto prim_func = CreateMainFunc(func->params.size());
UpdateMainWorkspaceSize(prim_func, func);
LoweredOutput ret;
Expand Down Expand Up @@ -787,6 +770,13 @@ class AOTExecutorCodegen : public ExprVisitor {
auto storage_rewrite = tir::transform::StorageRewrite();
mod_run = storage_rewrite(mod_run);

// Legalize AOT if needed. This means that all the packed calls
// need to be wrapped in TVMValues (unless use_unpacked_api is set)
if (!use_unpacked_api_) {
auto pack_calls = tir::transform::LegalizePackedCalls();
mod_run = pack_calls(mod_run);
}

// Update the lowered functions
auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
Expand Down
21 changes: 21 additions & 0 deletions src/tir/transforms/ir_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,27 @@ inline int GetTempAllocaAlignment(DataType type, int32_t const_size) {
return align;
}

/*!
* \brief Create an int32 constant
* \param index the value of the constant
* \return the PrimExpr that represents the constant
*/
inline PrimExpr ConstInt32(size_t index) {
ICHECK_LE(index, std::numeric_limits<int>::max());
return make_const(DataType::Int(32), static_cast<int>(index));
}

/*!
* \brief Allocate TVMValues on the stack
* \param type type of allocation
* \param num number of TVMValues to allocate
* \return PrimExpr representing the TVMValue
*/
inline PrimExpr StackAlloca(std::string type, size_t num) {
Array<PrimExpr> args = {StringImm(type), ConstInt32(num)};
return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args);
}

/*!
* \brief Convert a IR node to be SSA form.
* \param stmt The source statement to be converted.
Expand Down
108 changes: 108 additions & 0 deletions src/tir/transforms/legalize_packed_calls.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file make_packed_call.cc
* \brief Rewrite packed calls in AOT so that the arguments are packed
*/
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include <unordered_map>

#include "ir_utils.h"

namespace tvm {
namespace tir {

using InputMap =
std::unordered_map<PrimExpr, bool, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
/**
* This is a legalization pass only used in AOT. Traverse the TIR graph to legalize
* packed calls by making its argument wrapped in TVMValues (by using tvm_set_struct built-in)
*/
class PackedCallLegalizer : public StmtExprMutator {
public:
Stmt Legalize(const InputMap& params, tir::Stmt body) {
inputs_ = params;
return StmtExprMutator::VisitStmt(body);
}

Stmt VisitStmt_(const EvaluateNode* op) final {
if (tir::is_const_int(op->value)) return StmtExprMutator::VisitStmt_(op);
const CallNode* call = op->value.as<CallNode>();
// Given a packed call f(A,B,C), we need a set of new statements
// let A_packed = set_struct(tvm_value1, A)
// let B_packed = set_struct(tvm_value2, B)
// let C_packed = set_struct(tvm_value3, C)
// call_packed(f, A_packed, B_packed, C_packed)
std::vector<Stmt> new_stmts;
if (call) {
if (call->op.same_as(builtin::tvm_call_cpacked())) {
Array<PrimExpr> packed_args{call->args[0]};
for (unsigned i = 1; i < call->args.size(); i++) {
// No need to pack inputs of the prim_func
if (inputs_[call->args[i]] == true) {
packed_args.push_back(call->args[i]);
} else {
// Pack the argument inside a TVMValue
auto sid_array = tir::Var("tvm_value", DataType::Handle());
tir::Stmt set_struct_stmt = tir::Evaluate(
tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::tvm_struct_set(),
{sid_array, 0, tir::builtin::kArrData, call->args[i]}));
new_stmts.push_back(LetStmt(sid_array, StackAlloca("array", 1), set_struct_stmt));
packed_args.push_back(sid_array);
}
}
// Finally, evaluate the packed call and return a sequential statement
new_stmts.push_back(tir::Evaluate(tir::Call(call->dtype, call->op, packed_args)));
return tir::SeqStmt(new_stmts);
}
}
return StmtExprMutator::VisitStmt_(op);
}

private:
InputMap inputs_; // Store the inputs to the primfunc that don't need to be packed.
};

namespace transform {

Pass LegalizePackedCalls() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();

// Create the
InputMap inputs;
for (auto i : f->params) {
inputs[i] = true;
}
n->body = PackedCallLegalizer().Legalize(inputs, std::move(n->body));
return std::move(f);
};
return CreatePrimFuncPass(pass_func, 0, "tir.LegalizePackedCalls", {});
}
} // namespace transform

} // namespace tir
} // namespace tvm
10 changes: 0 additions & 10 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,6 @@
namespace tvm {
namespace tir {

inline PrimExpr ConstInt32(size_t index) {
ICHECK_LE(index, std::numeric_limits<int>::max());
return make_const(DataType::Int(32), static_cast<int>(index));
}

inline PrimExpr StackAlloca(std::string type, size_t num) {
Array<PrimExpr> args = {StringImm(type), ConstInt32(num)};
return Call(DataType::Handle(), builtin::tvm_stack_alloca(), args);
}

// Calculate the statistics of packed function.
// These information are needed during codegen.
class BuiltinLower : public StmtExprMutator {
Expand Down
Loading

0 comments on commit fc7eb40

Please sign in to comment.