Skip to content

Commit

Permalink
Decoupling AOT from graph memory planner
Browse files Browse the repository at this point in the history
In this PR we are decoupling AOT from the Graph Memory Planner. Since
AOT has the runner expressed in TIR we can get rid of the GMP in relay
and use the Storage Rewrite Pass to do memory planning on the runner
function. This also sorts out the issue mentioned in apache#8062

Change-Id: I6e33fadbf0462edf0366ee37e84ffde26123d3cb
  • Loading branch information
Giuseppe Rossini committed May 21, 2021
1 parent dbd076a commit 443e7a8
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 56 deletions.
252 changes: 196 additions & 56 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/transform.h>

#include <algorithm>
#include <list>
Expand All @@ -44,52 +45,179 @@ namespace tvm {
namespace relay {
namespace backend {

/**
* Struct to contain information about intermediate variables in the
* runner function
*/
struct StorageInfo {
/*! \brief unique integer identifier of the particular intermediate variable */
std::vector<int> ids;
/*! \brief exact size of the temporary */
std::vector<int> sizes_bytes;
/*! \brief device type of the temporary variable */
std::vector<int> dev_types;
};

using IntegerArray = Array<Integer>;
using TargetsMap = std::unordered_map<int, Target>;
using StorageMap =
std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;

class AotReturnSidVisitor : public ExprVisitor {
/**
* This is an on demand allocator for AOT. A new temporary
* (storage allocator identifier) is allocated for each operation.
*/
class AOTOnDemandAllocator : public ExprVisitor {
public:
explicit AotReturnSidVisitor(Map<Expr, Array<IntegerArray>> storage_device_map)
: storage_device_map_{storage_device_map}, return_sid_{-1} {}
// run the visitor on a function.
void Run(const Function& func) {
node_device_map_ = CollectDeviceInfo(func);

IntegerArray FindReturnSid(Function func) {
VisitExpr(func->body);
return return_sid_;
for (Expr param : func->params) {
CreateSid(param.operator->());
}

GetSid(func->body);
}

protected:
void AssignReturnSid(Expr e) {
auto iter = storage_device_map_.find(e);
if (iter != storage_device_map_.end()) {
return_sid_ = (*iter).second[0];
std::vector<int> GetReturnIds() const { return return_ids_; }

StorageMap GetStorageMap() const { return storage_device_map_; }

void VisitExpr_(const ConstantNode* op) final {
CreateSid(op);
AssignReturnSid(GetRef<Expr>(op));
}

void VisitExpr_(const CallNode* op) final {
// create token for the call node.
CreateSid(op);
for (Expr arg : op->args) {
GetSid(arg);
}
AssignReturnSid(GetRef<Expr>(op));
}

void VisitExpr_(const ConstantNode* cn) override {
ExprVisitor::VisitExpr_(cn);
AssignReturnSid(GetRef<Expr>(cn));
void VisitExpr_(const VarNode* op) final {
ExprVisitor::VisitExpr_(op);
AssignReturnSid(GetRef<Expr>(op));
}

void VisitExpr_(const VarNode* vn) override {
ExprVisitor::VisitExpr_(vn);
AssignReturnSid(GetRef<Expr>(vn));
void VisitExpr_(const FunctionNode* op) final {
// do not recurse into sub function.
}

void VisitExpr_(const CallNode* cn) override {
ExprVisitor::VisitExpr_(cn);
AssignReturnSid(GetRef<Expr>(cn));
void VisitExpr_(const GlobalVarNode* op) final {
// Do nothing.
}

void VisitExpr_(const LetNode* op) override { VisitExpr(op->body); }
void VisitExpr_(const OpNode* op) final {
// Do nothing.
}

void VisitExpr_(const TupleNode* op) final {
StorageInfo field_sid;
Expr expr = GetRef<Expr>(op);
for (Expr field : op->fields) {
auto sid = GetSid(field);
field_sid.ids.insert(field_sid.ids.end(), sid.ids.begin(), sid.ids.end());
field_sid.dev_types.insert(field_sid.dev_types.end(), sid.dev_types.begin(),
sid.dev_types.end());
field_sid.sizes_bytes.insert(field_sid.sizes_bytes.end(), sid.sizes_bytes.begin(),
sid.sizes_bytes.end());
}

storage_device_map_[expr] = field_sid;
AssignReturnSid(expr);
}

void VisitExpr_(const TupleNode* tn) override {
ExprVisitor::VisitExpr_(tn);
AssignReturnSid(GetRef<Expr>(tn));
void VisitExpr_(const TupleGetItemNode* op) final {
Expr expr = GetRef<Expr>(op);
const auto& sid = GetSid(op->tuple);
ICHECK_LT(static_cast<size_t>(op->index), sid.ids.size());
storage_device_map_[expr].ids = {sid.ids[op->index]};
storage_device_map_[expr].sizes_bytes = {sid.sizes_bytes[op->index]};
storage_device_map_[expr].dev_types = {sid.dev_types[op->index]};
AssignReturnSid(expr);
}

void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; }

void VisitExpr_(const LetNode* op) final { LOG(FATAL) << "if is not supported."; }

private:
Map<Expr, Array<IntegerArray>> storage_device_map_;
IntegerArray return_sid_;
void AssignReturnSid(Expr e) {
auto iter = storage_device_map_.find(e);
if (iter != storage_device_map_.end()) {
return_ids_ = (*iter).second.ids;
}
}
/*!
* \brief ceil(size/word_size) to get number of words.
* \param size The original size.
* \param word_size The element size.
*/
static size_t DivRoundUp(size_t size, size_t word_size) {
return (size + word_size - 1) / word_size;
}
/*!
* \brief Get the memory requirement.
* \param prototype The prototype token.
* \return The required memory size.
*/
size_t GetMemorySize(const TensorTypeNode* ttype) {
ICHECK(ttype != nullptr);
size_t size = 1;
for (IndexExpr dim : ttype->shape) {
const int64_t* pval = tir::as_const_int(dim);
ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape;
ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval;
size *= static_cast<size_t>(pval[0]);
}
size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8);
return size;
}
/*!
* \brief Get the necessary token.
* \param expr The expression.
* \return The corresponding token.
*/
StorageInfo GetSid(const Expr& expr) {
this->VisitExpr(expr);
auto it = storage_device_map_.find(expr);
ICHECK(it != storage_device_map_.end());
return it->second;
}

void CreateSid(const ExprNode* op) {
StorageInfo sid;
Expr expr = GetRef<Expr>(op);
int device_type = node_device_map_.count(GetRef<Expr>(op)) ? node_device_map_[expr]->value : 0;
if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
for (Type t : tuple_type->fields) {
const auto* ttype = t.as<TensorTypeNode>();
ICHECK(ttype);
sid.ids.push_back(sid_++);
sid.dev_types.push_back(device_type);
sid.sizes_bytes.push_back(GetMemorySize(ttype));
}
} else {
const auto* ttype = op->checked_type().as<TensorTypeNode>();
ICHECK(ttype);
sid.ids.push_back(sid_++);
sid.dev_types.push_back(device_type);
sid.sizes_bytes.push_back(GetMemorySize(ttype));
}
storage_device_map_[expr] = sid;
}
/*! \brief mapping of expression -> storageInfo*/
StorageMap storage_device_map_;
/*! \brief mapping of expression -> device type*/
Map<Expr, Integer> node_device_map_;
/*! \brief current id of the temporary allocated*/
int sid_{0};
/*! \brief the set of identifiers that are return variables */
std::vector<int> return_ids_;
};

/*! \brief Code generator for AOT executor */
Expand Down Expand Up @@ -120,14 +248,14 @@ class AOTExecutorCodegen : public ExprVisitor {
* \brief Return a vector of variables that represents the sids for the given Relay Expr
*/
std::vector<tir::Var> PackSid(Expr expr) {
Array<IntegerArray> sids = storage_device_map_[expr];
auto sids = storage_device_map_[expr];
std::vector<tir::Var> sid_vars;

// Note that an expression can have multiple sids associated with it
// e.g., returning multiple values from a function
for (const auto& sid : sids[0]) {
for (const auto& sid : sids.ids) {
// Determine if an sid is an output buffer
int sid_int = static_cast<int>((sid.as<IntImmNode>())->value);
int sid_int = sid;
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid_int);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
Expand Down Expand Up @@ -390,8 +518,8 @@ class AOTExecutorCodegen : public ExprVisitor {
}

ICHECK_GE(storage_device_map_.count(expr), 0);
auto& device_type = storage_device_map_[expr][1];
auto call_dev_type = device_type[0]->value;
auto& device_type = storage_device_map_[expr].dev_types;
auto call_dev_type = device_type[0];
// Normal Relay Function
if (targets_.size() == 1) {
// homogeneous execution.
Expand Down Expand Up @@ -428,14 +556,14 @@ class AOTExecutorCodegen : public ExprVisitor {

// If the Var node is an output node we need to copy the content of the variable to the output
// It's safe to check the SID here because Var StorageToken are never reallocated
Array<IntegerArray> sids = storage_device_map_[expr];
auto sids = storage_device_map_[expr];

auto output_iter = std::find(return_sid_.begin(), return_sid_.end(),
static_cast<int>((sids[0][0].as<IntImmNode>())->value));
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sids.ids[0]);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
auto var_expr = FindExpr(expr);
CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0], sids[2][0]);
CopyToOutput(main_signature_[input_vars_.size() + output_index], var_expr[0],
sids.sizes_bytes[0]);
}
}

Expand All @@ -444,18 +572,18 @@ class AOTExecutorCodegen : public ExprVisitor {
size_t index = params_.size();
std::string name = "p" + std::to_string(index);

param_storage_ids_[name] = storage_device_map_[expr][0][0]->value;
param_storage_ids_[name] = storage_device_map_[expr].ids[0];
params_[name] = op->data;
params_by_expr_.Set(expr, name);

// If the Constant node is an output node we need to copy the content of the parameter to the
// output A Var node can only produce a single output
Array<IntegerArray> sids = storage_device_map_[expr];
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(),
static_cast<int>((sids[0][0].as<IntImmNode>())->value));
auto sids = storage_device_map_[expr];
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sids.ids[0]);
if (output_iter != return_sid_.end()) {
int output_index = std::distance(return_sid_.begin(), output_iter);
CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr), sids[2][0]);
CopyToOutput(main_signature_[input_vars_.size() + output_index], PackParam(expr),
sids.sizes_bytes[0]);
}
}

Expand Down Expand Up @@ -511,9 +639,9 @@ class AOTExecutorCodegen : public ExprVisitor {
continue;
}

for (unsigned int i = 0; i < kv.second[0].size(); i++) {
int size = kv.second[2][i];
int sid = static_cast<int>((kv.second[0][i].as<IntImmNode>())->value);
for (unsigned int i = 0; i < kv.second.ids.size(); i++) {
int size = kv.second.sizes_bytes[i];
int sid = kv.second.ids[i];

if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) {
continue;
Expand All @@ -523,6 +651,8 @@ class AOTExecutorCodegen : public ExprVisitor {
// so we don't pay the price of allocation for every inference
if (!allocated[sid]) {
body = tir::Allocate(sids_table_[sid], DataType::Int(8), {size}, tir::const_true(), body);
body = tir::AttrStmt(sids_table_[sid], tir::attr::storage_scope, tir::StringImm("global"),
body);
}
allocated[sid] = true;
}
Expand Down Expand Up @@ -566,7 +696,8 @@ class AOTExecutorCodegen : public ExprVisitor {
std::unordered_map<std::string, int64_t> param_storage_ids_;

/*! \brief plan memory of device result */
Map<Expr, Array<IntegerArray>> storage_device_map_;
StorageMap storage_device_map_;
/*! \brief mapping sid -> tir::Var */
std::unordered_map<int, te::Var> sids_table_;
/*! \brief lowered funcs */
std::unordered_map<std::string, IRModule> lowered_funcs_;
Expand All @@ -577,7 +708,7 @@ class AOTExecutorCodegen : public ExprVisitor {
/*! \brief the set of statements that make the program */
std::vector<tir::Stmt> stmts_;
/*! \brief the list of return sids (note that the function might return more then one output */
IntegerArray return_sid_;
std::vector<int> return_sid_;

public:
AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host)
Expand All @@ -588,9 +719,11 @@ class AOTExecutorCodegen : public ExprVisitor {
}

LoweredOutput Codegen(relay::Function func) {
// Get the module, storage map and token sizes
auto pf = GetPackedFunc("relay.backend.GraphPlanMemory");
storage_device_map_ = (*pf)(func);
auto aot_allocator = AOTOnDemandAllocator();
aot_allocator.Run(func);

// Retrieve the storage map
storage_device_map_ = aot_allocator.GetStorageMap();

int input_index = 0;
for (auto input : func->params) {
Expand All @@ -600,14 +733,14 @@ class AOTExecutorCodegen : public ExprVisitor {

// Define the storage allocator ids
for (auto kv : storage_device_map_) {
for (const auto& sid : kv.second[0]) {
for (const auto& sid : kv.second.ids) {
te::Var sid_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8))));
sids_table_[sid] = sid_var;
}
}

// Find the return sid
return_sid_ = AotReturnSidVisitor(storage_device_map_).FindReturnSid(func);
// Retrieve the return sids
return_sid_ = aot_allocator.GetReturnIds();
for (unsigned int output_index = 0; output_index < return_sid_.size(); output_index++) {
main_signature_.push_back(tir::Var(MakeString("output_", output_index), DataType::Handle()));
}
Expand Down Expand Up @@ -635,14 +768,21 @@ class AOTExecutorCodegen : public ExprVisitor {
}
ret.external_mods = compile_engine_->LowerExternalFunctions();

// Build the TIR IRModule
Map<GlobalVar, BaseFunc> symbol_map;
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
IRModule mod_run(symbol_map);

// Apply storage rewrite pass to the runner function to do memory planning
auto storage_rewrite = tir::transform::StorageRewrite();
mod_run = storage_rewrite(mod_run);

// Update the lowered functions
auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_str]->Add(
GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
ret.lowered_funcs[target_host_str]->Update(mod_run);
} else {
Map<GlobalVar, BaseFunc> symbol_map;
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_prefix), prim_func);
ret.lowered_funcs.Set(target_host_str, IRModule(symbol_map));
ret.lowered_funcs.Set(target_host_str, mod_run);
}
ret.function_metadata = std::move(function_metadata_);
ret.metadata =
Expand Down
Loading

0 comments on commit 443e7a8

Please sign in to comment.