Skip to content

Commit

Permalink
[checkpoints] Woops, can't memoize in VMFunctionCompiler!
Browse files Browse the repository at this point in the history
  • Loading branch information
mbs-octoml committed Sep 15, 2021
1 parent cb22588 commit f5e2448
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 32 deletions.
12 changes: 6 additions & 6 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ int GetFallbackDevice() {
return fallback_dev->value;
}

class VMFunctionCompiler : DeviceAwareExprVisitor {
class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
public:
VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host)
: last_register_(0), registers_num_(0), context_(context), target_host_(target_host) {
Expand Down Expand Up @@ -313,7 +313,7 @@ class VMFunctionCompiler : DeviceAwareExprVisitor {
size_t NewRegister() { return registers_num_++; }

inline void Emit(const Instruction& instr) {
DLOG(INFO) << "VMCompiler::Emit: instr=" << instr;
VLOG(1) << "VMCompiler::Emit: instr=" << instr;
ICHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
switch (instr.op) {
case Opcode::AllocADT:
Expand Down Expand Up @@ -343,7 +343,7 @@ class VMFunctionCompiler : DeviceAwareExprVisitor {
instructions_.push_back(instr);
}

using DeviceAwareExprVisitor::VisitExpr_;
using DeviceAwareExprFunctor<void(const Expr&)>::VisitExpr_;

void VisitExpr_(const ConstantNode* const_node) final {
// Check the shape is valid
Expand Down Expand Up @@ -698,8 +698,8 @@ class VMFunctionCompiler : DeviceAwareExprVisitor {
auto global = GetRef<GlobalVar>(global_node);
auto it = context_->global_map.find(global);
ICHECK(it != context_->global_map.end());
DLOG(INFO) << "VisitExpr_: generating invoke for " << global->name_hint
<< " with func_index=" << it->second;
VLOG(1) << "VisitExpr_: generating invoke for " << global->name_hint
<< " with func_index=" << it->second;

// TODO(tvm-team):
// Think about mixed call into global that is not a relay::Function
Expand Down Expand Up @@ -937,7 +937,7 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe

#if USE_RELAY_DEBUG
for (auto vm_func : exec_->functions) {
DLOG(INFO) << vm_func << "-------------";
VLOG(1) << vm_func << "-------------";
}
#endif // USE_RELAY_DEBUG

Expand Down
122 changes: 122 additions & 0 deletions src/relay/transforms/device_planner.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

#include <vector>

#include "../attrs/annotation.h"

namespace tvm {
namespace relay {
namespace transform {
Expand Down Expand Up @@ -97,6 +99,126 @@ class LexicalOnDeviceMixin {
var_device_types_;
};

template <typename FType>
class DeviceAwareExprFunctor;

/*!
* \brief ExprFunctor which tracks devices. We only support 'visitor' style implementation
* with no additional arguments, thus this is equivalent to \p DeviceAwareExprVisitor without
* any memoization.
*/
template <>
class DeviceAwareExprFunctor<void(const Expr& n)> : public ExprFunctor<void(const Expr& n)>,
public LexicalOnDeviceMixin {
private:
using TSuper = ExprFunctor<void(const Expr& n)>;

public:
void VisitExpr_(const FunctionNode* function_node) {
if (function_node->HasNonzeroAttr(attr::kPrimitive)) {
// No tracking inside primitive functions.
return DeviceAwareVisitExpr_(function_node);
} else {
// Function parameters come into scope.
for (size_t i = 0; i < function_node->params.size(); ++i) {
PushBoundVar(function_node->params[i], GetFunctionParamDeviceType(function_node, i));
}
// Entering scope of function body.
PushDeviceType(GetFunctionResultDeviceType(function_node));
EnterFunctionBody();

DeviceAwareVisitExpr_(function_node);

// Leaving scope of function body.
ExitFunctionBody();
PopDeviceType();
// Function parameters go out of scope.
for (size_t i = 0; i < function_node->params.size(); ++i) {
PopBoundVar(function_node->params[i]);
}
}
}

void VisitExpr_(const LetNode* let_node) {
PreVisitLetBlock_(let_node);
std::vector<const LetNode*> bindings;
Expr expr = GetRef<Expr>(let_node);
while (const auto* inner_let_node = expr.as<LetNode>()) {
// Let-bound var (in pre visited version) goes into scope.
// (We'll just assume this is a letrec.)
PushBoundVar(inner_let_node->var, GetInScopeDeviceType(inner_let_node->value));
PreVisitLetBinding_(inner_let_node->var, inner_let_node->value);
bindings.emplace_back(inner_let_node);
expr = inner_let_node->body;
}

VisitExpr(expr);

for (auto itr = bindings.rbegin(); itr != bindings.rend(); ++itr) {
// Let-bound var goes out of scope.
const LetNode* visited_let_node = *itr;
PopBoundVar(visited_let_node->var);
PostVisitLet_(visited_let_node);
}
PostVisitLetBlock_(let_node);
}

void VisitExpr_(const CallNode* call_node) {
auto props = GetOnDeviceProps(call_node);
if (props.body.defined() && props.is_fixed) {
// Entering lexical scope of fixed "on_device" call.
PushDeviceType(props.device_type);
VisitExpr(props.body);
// Leaving lexical scope of "on_device" call.
PopDeviceType();
} else {
DeviceAwareVisitExpr_(call_node);
}
}

/*!
* \brief These are as for VisitExpr_. Devices for expressions and function parameters will be
* tracked automatically. Default implementation defers to ExprMutator::VisitExpr_. For
* functions the function_nesting count will already include that of \p function_node.
*/

virtual void DeviceAwareVisitExpr_(const FunctionNode* function_node) {
return TSuper::VisitExpr_(function_node);
}

virtual void DeviceAwareVisitExpr_(const CallNode* call_node) { return TSuper::VisitExpr_(call_node); }

/*!
* \brief Visit the first let in a chain of let expressions before any let bindings or final
* body has been visited. Default implementation is a no-op.
*/
virtual void PreVisitLetBlock_(const LetNode* let_node) { /* no-op */
}

/*!
* \brief Visit a let-bound expression before the let body has been visited. Devices for the
* let-bound variable will be tracked automatically. Default implementation just visits var and
* value.
*/
virtual void PreVisitLetBinding_(const Var& var, const Expr& value) {
VisitExpr(var);
VisitExpr(value);
}

/*!
* \brief Visit a let expression after the let-bound value and body have been visited.
* Default implementation is a no-op.
*/
virtual void PostVisitLet_(const LetNode* let_node) { /* no-op */
}

/*!
* \brief Visit the first let in a chain of let expressions after it has been visited.
* Default implementation is a no-op.
*/
virtual void PostVisitLetBlock_(const LetNode* let_node) {}
};

/*! \brief ExprVisitor which tracks devices. */
class DeviceAwareExprVisitor : public ExprVisitor, public LexicalOnDeviceMixin {
public:
Expand Down
61 changes: 35 additions & 26 deletions src/relay/transforms/memory_alloc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@
#include "./pattern_utils.h"

using namespace tvm::runtime;
using namespace tvm::relay::tec;

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -255,16 +254,21 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
CHECK(imm) << "expect static int shape";
int_shape.push_back(imm->value);
}
Expr shape = MakeConstant(int_shape);
Expr size = ComputeStorage(type);
Expr alignment = ComputeAlignment(type->dtype);
Expr shape = OnDevice(MakeConstant(int_shape), cpu_device_.device_type, /*is_fixed=*/true);
Expr size = OnDevice(ComputeStorage(type), cpu_device_.device_type, /*is_fixed=*/true);
// Alignment is directly captured in the instruction rather than calculated, so we
// don't want to wrop it with an "on_device".
Expr alignment =
ComputeAlignment(type->dtype);
// Run type inference later to get the correct type.
Var var("storage_" + name_hint, Type(nullptr));
Expr value = AllocStorage(size, alignment, dev, type->dtype);
Expr value = OnDevice(AllocStorage(size, alignment, dev, type->dtype), dev.device_type,
/*is_fixed=*/true);
auto sto = scope->Push(var, value);

// TODO(@jroesch): There is a bug with typing based on the constant shape.
auto tensor = AllocTensor(sto, shape, type->dtype, type->shape);
auto tensor = OnDevice(AllocTensor(sto, shape, type->dtype, /*assert_shape=*/type->shape),
dev.device_type, /*is_fixed=*/true);
Var tensor_var("tensor_" + name_hint, Type(nullptr));
return scope->Push(tensor_var, tensor);
}
Expand All @@ -274,15 +278,14 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
const std::vector<Expr>& new_args) {
Array<Expr> shape_func_ins;

TECompiler compiler;
tec::TECompiler compiler;

CCacheKey key(func, target_host_);
tec::CCacheKey key(func, target_host_);
auto cfunc = compiler->LowerShapeFunc(key);
auto input_states = cfunc->shape_func_param_states;

Array<Integer> is_inputs;
int input_pos = 0;
Device cpu_dev = default_device_;
CHECK_EQ(new_args.size(), input_states.size());
for (size_t i = 0; i < new_args.size(); ++i) {
Expr arg = new_args[i];
Expand All @@ -294,27 +297,28 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
}
int state = input_states[i]->value;
// Pass Shapes
if (state == 2) {
if (state == tec::kNeedInputShape) {
std::vector<Expr> exprs = FromTupleType(ty, arg);
for (size_t j = 0; j < exprs.size(); ++j) {
Expr sh_of = Mutate(ShapeOf(exprs[j]));
Expr sh_of = Mutate(ShapeOf(exprs[j])); // already accounts for device
Var in_shape_var("in_shape_" + std::to_string(input_pos + j), Type(nullptr));
shape_func_ins.push_back(scope->Push(in_shape_var, sh_of));
input_pos++;
}
is_inputs.push_back(0);
} else if (state == 1) {
auto new_arg = Mutate(arg);
} else if (state == tec::kNeedInputData) {
auto new_arg = Mutate(arg); // already accounts for device
DLDeviceType device_type = GetInScopeDeviceType(arg);
if (device_type != cpu_dev.device_type) {
new_arg = DeviceCopy(new_arg, device_type, cpu_dev.device_type);
if (device_type != cpu_device_.device_type) {
new_arg = OnDevice(DeviceCopy(new_arg, device_type, cpu_device_.device_type),
cpu_device_.device_type, /*is_fixed=*/true);
}
Var in_shape_var("in_shape_" + std::to_string(input_pos), Type(nullptr));
shape_func_ins.push_back(scope->Push(in_shape_var, new_arg));
input_pos++;
is_inputs.push_back(1);
} else {
// TODO(@jroesch): handle 3rd case
// TODO(@jroesch): handle kNeedBoth
LOG(FATAL) << "unsupported shape function input state";
}
}
Expand All @@ -325,12 +329,14 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
auto tt = TensorType(out->shape, out->dtype);
// Put shape func on CPU. This also ensures that everything between
// shape_of and shape_func are on CPU.
auto alloc = MakeStaticAllocation(scope, tt, cpu_dev, std::to_string(i));
auto alloc = OnDevice(MakeStaticAllocation(scope, tt, cpu_device_, std::to_string(i)),
cpu_device_.device_type, /*is_fixed=*/true);
Var shape_func_out_var("shape_func_out_" + std::to_string(i), Type(nullptr));
alloc = scope->Push(shape_func_out_var, alloc);
out_shapes.push_back(alloc);
}
auto shape_call = ShapeFunc(func, Tuple(shape_func_ins), Tuple(out_shapes), is_inputs);
auto shape_call = OnDevice(ShapeFunc(func, Tuple(shape_func_ins), Tuple(out_shapes), is_inputs),
cpu_device_.device_type, /*is_fixed=*/true);
Var shape_func_var("shape_func", Type(nullptr));
scope->Push(shape_func_var, shape_call);
return out_shapes;
Expand All @@ -346,10 +352,13 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
for (size_t i = 0; i < out_shapes.size(); ++i) {
auto out_shape = out_shapes[i];
auto out_type = out_types[i];
auto size = ComputeStorageInRelay(out_shape, out_type);
auto alignment = ComputeAlignment(out_type->dtype);
auto size = OnDevice(ComputeStorageInRelay(out_shape, out_type),
cpu_device_.device_type, /*is_fixed=*/true);
auto alignment = OnDevice(ComputeAlignment(out_type->dtype),
cpu_device_.device_type, /*is_fixed=*/true);
Var sto_var("storage_" + std::to_string(i), Type(nullptr));
auto val = AllocStorage(size, alignment, dev, out_type->dtype);
auto val = OnDevice(AllocStorage(size, alignment, dev, out_type->dtype),
dev.device_type, /*is_fixed=*/true);
storages.push_back(scope->Push(sto_var, val));
}

Expand All @@ -358,15 +367,15 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
auto out_shape = out_shapes[i];
auto out_type = out_types[i];
auto storage = storages[i];
auto alloc = AllocTensor(storage, out_shape, out_type->dtype, out_type->shape);
auto alloc = OnDevice(AllocTensor(storage, out_shape, out_type->dtype, out_type->shape),
dev.device_type, /*is_fixed=*/true);
Var out_var("out_" + std::to_string(i), Type(nullptr));
outs.push_back(scope->Push(out_var, alloc));
}

Tuple tuple_outs(outs);
// TODO(mbs): Capure device in invoke attributes.
auto invoke = InvokeTVMOp(func, ins, tuple_outs);
scope->Push(OnDevice(invoke, dev.device_type, /*is_fixed=*/true));
auto invoke = OnDevice(InvokeTVMOp(func, ins, tuple_outs), dev.device_type, /*is_fixed=*/true);
scope->Push(invoke);
return ToTupleType(ret_type,
std::vector<Expr>(tuple_outs->fields.begin(), tuple_outs->fields.end()));
}
Expand Down Expand Up @@ -397,7 +406,7 @@ class DialectRewriter : public transform::DeviceAwareExprMutator {
std::vector<LetList> scopes_;

runtime::DataType compute_dtype_ = runtime::DataType::Int(64);
Device default_device_{kDLCPU, 0};
Device cpu_device_{kDLCPU, 0};
};

namespace transform {
Expand Down

0 comments on commit f5e2448

Please sign in to comment.