Skip to content

Commit

Permalink
[Relay] Prepare DeadCodeElimination for running post LowerTEPass/Mani…
Browse files Browse the repository at this point in the history
…festAlloc. (apache#9542)

* Prepare DeadCodeElimination for running post LowerTEPass/ManifestAlloc.

As part of apache#9483 we need to prepare some critical Relay passes for running after
lowering and conversion to DPS. For DCE we need to make sure we never remove
side-effecting let-bound expressions, such as for allocation or evaluation of
an external function with unknown effectfulness.

Introduce a new purity pre-pass. It makes a half-hearted attempt at accounting
for functions by tracking both 'eval' and 'call' purity, but must fallback to
assuming call-impurity in more difficult cases (eg calling a function passed as
a parameter, calling a function projected from a tuple, etc). However it seems
plenty good enough.

Purity must also be accounted for when determining the usage count of let-bound
variables, so reworked that. Collapsed the let-bound value accumulation pass into
the usage counting pass to make up for inserting the new purity analysis pass.

A few tests assume DCE eliminates dead reference writes. The previous
implementation certainly did that, but by eliminating *all* writes.

Filed CORE-118 to extend DCE to soundly elim dead writes (a very simple-minded
analysis would probably do just fine and we don't need to get hung up on alias
analysis). In the meantime, added an 'ignore_impurity' flag (default False)
and set to true just in the few unit tests which rely on the unsound impl.

* [checkpoint] Merge Lily's suggestions.
  • Loading branch information
mbs-octoml authored and mehrdadh committed Dec 1, 2021
1 parent e11b5c5 commit d17f040
Show file tree
Hide file tree
Showing 11 changed files with 773 additions and 159 deletions.
24 changes: 17 additions & 7 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,21 +62,31 @@ TVM_DLL Pass CreateFunctionPass(
const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>& pass_func,
int opt_level, String name, tvm::Array<String> required);

/*! \brief Remove expressions which does not effect the program result.
/*! \brief Remove let-bound expressions which do not effect the program result.
*
* It will remove let bindings which are not referenced,
* and inline let bindings that are only used once.
* This pass will remove let bindings which are not referenced. If inline_once is True,
* let bindings which are only referenced once will also be inlined.
*
* For example, this pass should turn `let a = 1 in 2` into `2`,
* For example, this pass should turn `let a = 1; 2` into `2`,
* as the value of the expression does not depend on a.
*
* As another example, `let a = 1 in a` will be optimized into 1.
* As another example, `let a = 1; a` will be optimized into 1 if inline_once is True.
*
* \param inline_once whether or not to inline binding used one.
* If ignore_purity is False, possibly side-effecting expressions (such as memory allocation,
* random number generation, reading/writing references, or calls to primitive or external
* functions) are never elided or inlined. This is sound, but ignore_purity can be set to True
* to suppress this check.
*
* The analysis is fairly conservative, for example it assumes all local functions
* may be called more than once, any functions passed as arguments have side effects,
* and so on.
*
* \param inline_once whether or not to inline bindings used exactly once.
* \param ignore_purity whether to ignore whether expressions have side-effects
*
* \return the pass.
*/
TVM_DLL Pass DeadCodeElimination(bool inline_once = false);
TVM_DLL Pass DeadCodeElimination(bool inline_once = false, bool ignore_purity = false);

/*!
* \brief Convert all expressions of TensorType into GradCell,
Expand Down
8 changes: 5 additions & 3 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,20 +209,22 @@ def CanonicalizeOps():
return _ffi_api.CanonicalizeOps()


def DeadCodeElimination(inline_once=False):
def DeadCodeElimination(inline_once=False, ignore_impurity=False):
"""Remove expressions that do not have any users (dead code).
Parameters
----------
inline_once: Optional[Bool]
Whether to inline binding that occurs only once.
Whether to inline a binding that is referenced exactly once.
ignore_impurity: Optional[Bool]
Whether to ignore possible side-effects in let-bound expressions.
Returns
-------
ret: tvm.transform.Pass
The registered pass that eliminates the dead code in a Relay program.
"""
return _ffi_api.DeadCodeElimination(inline_once)
return _ffi_api.DeadCodeElimination(inline_once, ignore_impurity)


def LazyGradientInit():
Expand Down
12 changes: 8 additions & 4 deletions src/relay/op/memory/memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,15 @@ RELAY_REGISTER_OP("memory.alloc_storage")
.set_attrs_type_key("relay.attrs.AllocStorageAttrs")
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpIsStateful>("TOpIsStateful", true)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

const Op& MemoryAllocTensorOp() {
static const Op& op = Op::Get("memory.alloc_tensor");
return op;
}

Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype,
Array<IndexExpr> assert_shape) {
auto attrs = make_object<AllocTensorAttrs>();
Expand All @@ -106,8 +111,7 @@ Expr AllocTensor(Expr storage, Expr offset, Expr shape, DataType dtype,
ICHECK(constant_node);
attrs->const_shape = GetRef<Constant>(constant_node);
}
static const Op& op = Op::Get("memory.alloc_tensor");
return Call(op, {storage, offset, shape}, Attrs(attrs), {});
return Call(MemoryAllocTensorOp(), {storage, offset, shape}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relay.op.memory._make.alloc_tensor").set_body_typed(AllocTensor);
Expand Down Expand Up @@ -196,7 +200,7 @@ RELAY_REGISTER_OP("memory.alloc_tensor")
.set_attrs_type_key("relay.attrs.AllocTensorAttrs")
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpIsStateful>("TOpIsStateful", true)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

Expand Down
2 changes: 2 additions & 0 deletions src/relay/op/memory/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ namespace tvm {
namespace relay {

Expr AllocStorage(Expr size, Expr alignment, SEScope se_scope, DataType dtype_hint);
/*! \brief Returns the "memory.alloc_tensor" operator. */
const Op& MemoryAllocTensorOp();
Expr AllocTensor(Expr storage, Expr offset, tvm::relay::Expr shape, DataType dtype,
Array<IndexExpr> assert_shape);
Expr ToTupleType(const Type& ty, const std::vector<Expr>& exprs);
Expand Down
19 changes: 13 additions & 6 deletions src/relay/op/vm/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Expr ShapeOf(Expr expr) {
auto attrs = make_object<ShapeOfAttrs>();
attrs->dtype = DataType::Int(64);
static const Op& op = Op::Get("vm.shape_of");
return Call(op, {expr}, Attrs(attrs), {});
return Call(op, {std::move(expr)}, Attrs(std::move(attrs)), {});
}

TVM_REGISTER_GLOBAL("relay.op.vm.shape_of").set_body_typed(ShapeOf);
Expand Down Expand Up @@ -156,7 +156,9 @@ bool InvokeTVMOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
if (func_type->ret_type.as<TensorTypeNode>()) {
ex_output = TupleType({func_type->ret_type});
} else {
ICHECK(func_type->ret_type.as<TupleTypeNode>()) << "should be tuple type";
ICHECK(func_type->ret_type.as<TupleTypeNode>())
<< "expecting function result to be tuple type. Types:" << std::endl
<< PrettyPrint(types);
ex_output = func_type->ret_type;
}
auto ex_input = TupleType(func_type->arg_types);
Expand All @@ -167,10 +169,14 @@ bool InvokeTVMOpRel(const Array<Type>& types, int num_inputs, const Attrs& attrs
}

Expr InvokeTVMOp(Expr func, Expr inputs, Expr outputs) {
return Call(Op::Get("vm.invoke_tvm_op"), {func, inputs, outputs}, Attrs());
static const Op& op = Op::Get("vm.invoke_tvm_op");
return Call(op, {std::move(func), std::move(inputs), std::move(outputs)}, {});
}

TVM_REGISTER_GLOBAL("relay.op.vm.invoke_tvm_op").set_body_typed(InvokeTVMOp);
TVM_REGISTER_GLOBAL("relay.op.vm.invoke_tvm_op")
.set_body_typed([](Expr func, Expr inputs, Expr outputs) {
return InvokeTVMOp(std::move(func), std::move(inputs), std::move(outputs));
});

RELAY_REGISTER_OP("vm.invoke_tvm_op")
.describe(R"code(Invoke an operation compiled by TVM.)code" TVM_ADD_FILELINE)
Expand All @@ -179,9 +185,10 @@ RELAY_REGISTER_OP("vm.invoke_tvm_op")
.add_argument("ins", "Tuple", "The input tensors.")
.add_argument("outs", "Tuple", "The output tensors.")
.add_type_rel("InvokeTVMOp", InvokeTVMOpRel)
.set_attrs_type_key("DictAttrs")
.set_support_level(10)
.set_attr<TOpPattern>("TOpPattern", kOpaque)
.set_attr<TOpIsStateful>("TOpIsStateful", false)
.set_attr<TOpIsStateful>("TOpIsStateful", true)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FInferCorrectLayout>("FInferCorrectLayout", ElemwiseArbitraryLayout);

Expand Down Expand Up @@ -217,7 +224,7 @@ Expr ReshapeTensor(Expr data, Expr shape, Array<PrimExpr> newshape) {
static const Op& op = Op::Get("vm.reshape_tensor");
auto attrs = make_object<ReshapeTensorAttrs>();
attrs->newshape = std::move(newshape);
return Call(op, {data, shape}, Attrs(attrs), {});
return Call(op, {std::move(data), std::move(shape)}, Attrs(std::move(attrs)), {});
}

TVM_REGISTER_GLOBAL("relay.op.vm.reshape_tensor").set_body_typed(ReshapeTensor);
Expand Down
Loading

0 comments on commit d17f040

Please sign in to comment.