Skip to content

Commit

Permalink
[Relay] Refactor Interpreter to treat lowering as IRModule->IRModule …
Browse files Browse the repository at this point in the history
…rewrite. (apache#8597)

* This continues the work outlined in the RFC
  https://discuss.tvm.apache.org/t/rfc-relay-tecompiler-rewrite-existing-compile-engine-to-match-updated-compiler-flow/9233
This gets about halfway there for the Interpreter:

* Remove direct access to TECompiler from interpreter, and instead call
  tec::LowerTEExpr when 'preparing' a module and expression for evaluation.
* Make clear there's no phase distinction between create_interpreter and
  evaluate on the Python side -- both must be prepared together as a single IRModule.
* But in return make sure the result of evaluate on the Python side is a packed func
  ready to directly apply 'simple' arguments to an already interpreted closure.
* The interpreter builds and caches primitive TIR functions (and their corresponding
  dynamic shape functions) as packed funcs as they are encountered.
* Cleanup uses of interpreter for constant folding on the C++ side.

Future work:
* Fold LoweredModule into IRModule so tec::LowerTEExpr is just another pass.
* Get rid of the implicit caching of lowered functions in TECompiler.
* Make calling convention from Relay to TIR explicit, and remove all the function
  attribute hackery currently needed so the interpreter can correctly invoke lowered
  functions as it encounters them.
* Make TECompiler private. Though could do this now it will make migrating the VM and
  AOT uses of CompilerEngine harder.

Force a gc between sphinx-gallery items to reclaim GPU memory. (apache#8722)

GPU memory is only released once the PackedFunc for evaling the model is gced
by Python. In CI we're noticing intermittent 'CUDA: Out of memory' failures
while processing the tutorials, and tracing showed there was no gc happening
between items. Not confident this will solve the problem but worth a try.

* Get rid of logs spam.
  • Loading branch information
mbs-octoml authored and ylc committed Sep 29, 2021
1 parent c85c9d2 commit fe4b85c
Show file tree
Hide file tree
Showing 92 changed files with 2,033 additions and 1,330 deletions.
6 changes: 2 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,8 @@ def __call__(self, filename):
# collecting TVM packed function closures for any device memory to also be released. This
# is not a good setup for machines with lots of CPU ram but constrained GPU ram, so force
# a gc after each example.
def force_gc(gallery_cong, fname):
print("(Forcing Python gc after '{}' to avoid lag in reclaiming CUDA memory)".format(fname))
def force_gc(gallery_conf, fname):
gc.collect()
print("(Remaining garbage: {})".format(gc.garbage))


sphinx_gallery_conf = {
Expand All @@ -341,7 +339,7 @@ def force_gc(gallery_cong, fname):
"download_all_examples": False,
"min_reported_time": 60,
"expected_failing_examples": [],
"reset_modules": (force_gc, "matplotlib", "seaborn"),
"reset_modules": ("matplotlib", "seaborn", force_gc),
}

autodoc_default_options = {
Expand Down
2 changes: 1 addition & 1 deletion docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ Either match the first pattern or the second pattern.
Domination
**********

Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node betwen the child and the pattern matches the path pattern.
Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node between the child and the pattern matches the path pattern.

Function Pattern
****************
Expand Down
43 changes: 37 additions & 6 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

namespace tvm {
Expand Down Expand Up @@ -307,6 +308,14 @@ class IRModuleNode : public Object {
/*! \brief Helper function for registering a typedef's constructors */
void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);

/*!
* \brief Returns a version of \p name which is unique amongst all function definitions in module.
*
* \param name The original name.
* \return Updated name which is unique.
*/
String GetUniqueName(const String& name);

/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
Expand Down Expand Up @@ -361,16 +370,38 @@ class IRModule : public ObjectRef {
}

/*!
* \brief Construct a module from a standalone expression.
* \brief Constructs a module from a standalone expression \p expr.
*
* If \p expr is a function it will be bound directly. Otherwise a function over the free
* variables of \p expr (possibly none) with \p expr as body is created and bound.
*
* The function is bound to, in preference order:
* - The "global_symbol" attribute of \p expr, if it is a function with that attribute.
* - 'main'
* - A unique name derived from 'main' if 'main' is already bound in \p global_funcs.
*
* Allows one to optionally pass a global function map and
* map of type definitions as well.
* Additional global functions and type definitions may be included in the result module.
*
* See also \p FromExpr.
*
* \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map.
* \param type_definitions Map of global type definitions
* \param global_funcs The global function map. Default empty.
* \param type_definitions The global type definition map. Default empty.
* \param import_set Set of external modules already imported. Default empty.
*
* \returns A module with \p expr set as the main function, and the global var to which
* \p expr was bound (typcially 'main').
*
* \returns A module with expr set as the main function.
* TODO(mbs): Does import_set and the bound global var need to be exposed via ffi?
*/
static std::pair<IRModule, GlobalVar> FromExprInContext(
const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs = {},
const Map<GlobalTypeVar, TypeData>& type_definitions = {},
std::unordered_set<String> import_set = {});

/*!
* \brief As for \p FromExprInContext, but assuming \p expr is bound to 'main' and no
* imports.
*/
TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
const Map<GlobalVar, BaseFunc>& global_funcs = {},
Expand Down
70 changes: 48 additions & 22 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,11 @@
#include <tvm/runtime/object.h>
#include <tvm/target/target.h>

#include <unordered_set>

namespace tvm {
namespace relay {

/*!
*\brief Create a Interpreter function that can
* evaluate an expression and produce a value.
*
* The resulting value can be passed to Python, making it easy to use
* for testing and debugging.
*
* The interpreter interprets the program fragments not supported by the
* TVM runtime, although the interpreter is naively implemented it uses
* TVM operators for evaluating all operators.
*
* Our intent is that this will never be the most efficient implementation of
* Relay's semantics, but a readable and clear one.
*
* \param mod The function module.
* \param device The primary device that the interepreter runs on.
* \param target Compiler target flag to compile the functions on the context.
* \return A function that takes in an expression and returns a value.
*/
runtime::TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, Device device,
Target target);

/*! \brief The container type of Closures used by the interpreter. */
class InterpreterClosureObj : public runtime::ClosureObj {
public:
Expand Down Expand Up @@ -164,6 +144,52 @@ class ConstructorValue : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
};

/*!
* \brief Returns a packed function over Relay expressions which will evaluate \p expr
* applied to those arguments, where \p expr is w.r.t. the definitions in \p mod.
*
* This function is intended to support the Python 'debug' executor.
*
* The given \p expr should have function type. The given \p mod may be empty or
* undefined if \p expr is self-contained. Relay arguments passed to the result
* packed function must be constants, references, or constructors/tuples over such.
* As much work as possible is done while constructing the result packed function, and
* that function may be reasonably efficiently applied multiple times without redoing
* unnecessary work.
*
* Primitives are lowered and compiled to packed functions for execution on \p device
* with properties given by \p target. All other Relay constructs are interpreted.
*
* The interpreter is intended to be a 'reference' implementation of the Relay semantics
* for testing and interactive use. It is not intended to be particularly efficient.
*
* \param mod A module containing definitions which can be referenced from
* \p expr. May be empty or undefined.
* \param expr An expression of function type to evaluate. May reference definitions from \p mod.
* \param device The device on which all primitives will be executed.
* \param target The compiler target flag for compiling primitives.
* \return A packed function that takes an array of Relay expressions and returns the
* result of applying \p expr to those arguments.
*/
TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, Device device,
Target target);

/*!
* \brief Evaluates \p expr and returns its result.
*
* This function is intended to support TVM constant evaluation.
*
* \param expr An expression to evaluate.
* \param type_definitions Global type definitions which \p expr may references.
* \param import_set Already imported external modules.
* \param device The device on which all primitives will be executed.
* \param target The compiler target flag for compiling primitives.
* @return The object representing the result.
*/
ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, Device device, Target target);

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_INTERPRETER_H_
3 changes: 1 addition & 2 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,7 @@ def get_calibration_data(mod, data):
mod = _ffi_api.get_calibrate_module(mod)
mod = transform.Inline()(mod)

ref_ex = build_module.create_executor("graph", mod=mod, device=cpu(0))
ref_res = ref_ex.evaluate()(**data)
ref_res = build_module.create_executor("graph", mod=mod, device=cpu(0)).evaluate()(**data)

calib_data = {}
for gvar, indices in output_map.items():
Expand Down
81 changes: 35 additions & 46 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@

import tvm._ffi
from tvm.runtime import container, Object
from tvm.ir import IRModule

from . import _backend
from .. import _make, analysis, transform
from .. import _make, analysis
from ... import nd
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, const
from ..function import Function
Expand Down Expand Up @@ -178,6 +177,7 @@ def evaluate(self, expr=None, binds=None):
return self._make_executor(expr)

# normal expression evaluated by running a function.
# TODO(mbs): This should really be type rather than syntax driven.
func = Function([], expr)
return self._make_executor(func)()

Expand All @@ -196,65 +196,54 @@ class Interpreter(Executor):
target : tvm.Target
The target option to build the function.
CAUTION: Despite the API the module is prepared upon each call to evaluate
rather than once in create_executor.
That is:
.. code-block:: python
executor = relay.create_executor(kind="debug", mod=module)
a = executor.evaluate(expr)(args1)
b = executor.evaluate(expr)(args2)
will prepare all the bindings in module twice. For efficiency, try to hoist
calls to evaluate as high as possible, preferably immediately after create_executor:
.. code-block:: python
func = relay.create_executor(kind="debug", mod=module).evaluate(expr)
a = func(args1)
b = func(args2)
"""

def __init__(self, mod, device, target):
self.mod = mod
self.device = device
self.target = target

def optimize(self):
"""Optimize functions in a module.
Returns
-------
opt_mod : tvm.IRModule
The optimized module.
"""
seq = tvm.transform.Sequential(
[
# tvm.parser.AnnotateSpans(),
transform.SimplifyInference(),
transform.FuseOps(0),
transform.ToANormalForm(),
transform.InferType(),
]
)
mod = seq(self.mod)
return mod

def _make_executor(self, expr=None):
if expr is None or isinstance(expr, GlobalVar):
assert self.mod is not None

_intrp = _backend.CreateInterpreter(self.optimize(), self.device, self.target)
if expr is None:
# A missing expr denotes 'main' in the given module.
expr = self.mod.get_global_var("main")

def _interp_wrapper(*args, **kwargs):
if expr is None:
args = self._convert_args(self.mod["main"], args, kwargs)
# Evaluate expr to a packed function we can efficiently re-apply
# to Relay arguments.
func = _backend.EvalFunction(self.mod, expr, self.device, self.target)

def _apply_args(*args, **kwargs):
if isinstance(expr, GlobalVar):
# When expanding args, look inside the actual global definition so kwargs
# can be matched.
args = self._convert_args(self.mod[expr.name_hint], args, kwargs)
else:
args = self._convert_args(expr, args, kwargs)

# Reflect python arguments up into Relay.
relay_args = []
for arg in args:
relay_args.append(_arg_to_ast(self.mod, arg))
# Apply func to Relay args
return func(relay_args)

# Set the entry function for the module.
if expr is None:
pass
elif isinstance(expr, GlobalVar):
self.mod["main"] = self.mod[expr]
else:
assert isinstance(expr, Function)
func = Function([], Call(expr, relay_args))
relay_args = []
if self.mod:
self.mod["main"] = func
else:
self.mod = IRModule.from_expr(func)

mod = self.optimize()
opt_expr = Call(mod["main"], relay_args)
return _intrp(opt_expr)

return _interp_wrapper
return _apply_args
3 changes: 3 additions & 0 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,9 @@ def _graph_wrapper(*args, **kwargs):
return _graph_wrapper


# TODO(mbs): Collapse the create_executor/evaluate phases together since a) most callers don't
# reuse the executor for multiple expressions and b) any preparation necessary for the expression
# evaluation needs to (currently) be done along with preparation for the module.
def create_executor(kind="debug", mod=None, device=None, target="llvm", params=None):
"""Factory function to create an executor.
Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,12 @@ def infer_value(input_val, params, mod=None):
mod["main"] = _function.Function(analysis.free_vars(input_val), input_val)
else:
mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, device=tvm.cpu(), target="llvm")
inputs = []
for param in mod["main"].params:
inputs.append(params[param.name_hint])
result = exc.evaluate()(*inputs)
result = tvm.relay.create_executor(
"debug", mod=mod, device=tvm.cpu(), target="llvm"
).evaluate()(*inputs)
return result


Expand Down
8 changes: 5 additions & 3 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,13 @@ def check_grad(
test_inputs = inputs

for target, dev in enabled_targets():
intrp = relay.create_executor(device=dev, target=target)
# Eval the backward and forward functions
# TODO(mbs): Evaluate a pair of functions so can share preparation between them.
bwd_func_compiled = relay.create_executor(device=dev, target=target).evaluate(bwd_func)
fwd_func_compiled = relay.create_executor(device=dev, target=target).evaluate(fwd_func)

# Get analytic gradients.
_, grads = intrp.evaluate(bwd_func)(*inputs)
_, grads = bwd_func_compiled(*inputs)
grads = [grad.numpy().astype("float64") for grad in grads]

# Throw out gradients we aren't testing
Expand All @@ -154,7 +157,6 @@ def check_grad(
assert len(grads) > 0, "You must test at least one gradient."

# Get numeric gradients for each dimension of each param, using two-sided approximation.
fwd_func_compiled = intrp.evaluate(fwd_func)
approx_grads = []
for x in test_inputs:
approx_grad = np.zeros(x.shape)
Expand Down
Loading

0 comments on commit fe4b85c

Please sign in to comment.