Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Refactor Interpreter to treat lowering as IRModule->IRModule rewrite. #8597

Merged
merged 2 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,10 +318,8 @@ def __call__(self, filename):
# collecting TVM packed function closures for any device memory to also be released. This
# is not a good setup for machines with lots of CPU ram but constrained GPU ram, so force
# a gc after each example.
def force_gc(gallery_cong, fname):
print("(Forcing Python gc after '{}' to avoid lag in reclaiming CUDA memory)".format(fname))
def force_gc(gallery_conf, fname):
gc.collect()
print("(Remaining garbage: {})".format(gc.garbage))


sphinx_gallery_conf = {
Expand All @@ -341,7 +339,7 @@ def force_gc(gallery_cong, fname):
"download_all_examples": False,
"min_reported_time": 60,
"expected_failing_examples": [],
"reset_modules": (force_gc, "matplotlib", "seaborn"),
"reset_modules": ("matplotlib", "seaborn", force_gc),
}

autodoc_default_options = {
Expand Down
2 changes: 1 addition & 1 deletion docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ Either match the first pattern or the second pattern.
Domination
**********

Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node betwen the child and the pattern matches the path pattern.
Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node between the child and the pattern matches the path pattern.

Function Pattern
****************
Expand Down
43 changes: 37 additions & 6 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>

namespace tvm {
Expand Down Expand Up @@ -307,6 +308,14 @@ class IRModuleNode : public Object {
/*! \brief Helper function for registering a typedef's constructors */
void RegisterConstructors(const GlobalTypeVar& var, const TypeData& type);

/*!
* \brief Returns a version of \p name which is unique amongst all function definitions in module.
*
* \param name The original name.
* \return Updated name which is unique.
*/
String GetUniqueName(const String& name);

/*! \brief A map from string names to global variables that
* ensures global uniqueness.
*/
Expand Down Expand Up @@ -361,16 +370,38 @@ class IRModule : public ObjectRef {
}

/*!
* \brief Construct a module from a standalone expression.
* \brief Constructs a module from a standalone expression \p expr.
*
* If \p expr is a function it will be bound directly. Otherwise a function over the free
* variables of \p expr (possibly none) with \p expr as body is created and bound.
*
* The function is bound to, in preference order:
* - The "global_symbol" attribute of \p expr, if it is a function with that attribute.
* - 'main'
* - A unique name derived from 'main' if 'main' is already bound in \p global_funcs.
*
* Allows one to optionally pass a global function map and
* map of type definitions as well.
* Additional global functions and type definitions may be included in the result module.
*
* See also \p FromExpr.
*
* \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map.
* \param type_definitions Map of global type definitions
* \param global_funcs The global function map. Default empty.
* \param type_definitions The global type definition map. Default empty.
* \param import_set Set of external modules already imported. Default empty.
*
* \returns A module with \p expr set as the main function, and the global var to which
* \p expr was bound (typcially 'main').
*
* \returns A module with expr set as the main function.
* TODO(mbs): Does import_set and the bound global var need to be exposed via ffi?
*/
static std::pair<IRModule, GlobalVar> FromExprInContext(
const RelayExpr& expr, const Map<GlobalVar, BaseFunc>& global_funcs = {},
const Map<GlobalTypeVar, TypeData>& type_definitions = {},
std::unordered_set<String> import_set = {});
mbs-octoml marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief As for \p FromExprInContext, but assuming \p expr is bound to 'main' and no
* imports.
*/
TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
const Map<GlobalVar, BaseFunc>& global_funcs = {},
Expand Down
70 changes: 48 additions & 22 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,31 +40,11 @@
#include <tvm/runtime/object.h>
#include <tvm/target/target.h>

#include <unordered_set>

namespace tvm {
namespace relay {

/*!
*\brief Create a Interpreter function that can
* evaluate an expression and produce a value.
*
* The resulting value can be passed to Python, making it easy to use
* for testing and debugging.
*
* The interpreter interprets the program fragments not supported by the
* TVM runtime, although the interpreter is naively implemented it uses
* TVM operators for evaluating all operators.
*
* Our intent is that this will never be the most efficient implementation of
* Relay's semantics, but a readable and clear one.
*
* \param mod The function module.
* \param device The primary device that the interepreter runs on.
* \param target Compiler target flag to compile the functions on the context.
* \return A function that takes in an expression and returns a value.
*/
runtime::TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, Device device,
Target target);

/*! \brief The container type of Closures used by the interpreter. */
class InterpreterClosureObj : public runtime::ClosureObj {
public:
Expand Down Expand Up @@ -164,6 +144,52 @@ class ConstructorValue : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
};

/*!
* \brief Returns a packed function over Relay expressions which will evaluate \p expr
* applied to those arguments, where \p expr is w.r.t. the definitions in \p mod.
*
* This function is intended to support the Python 'debug' executor.
*
* The given \p expr should have function type. The given \p mod may be empty or
* undefined if \p expr is self-contained. Relay arguments passed to the result
* packed function must be constants, references, or constructors/tuples over such.
* As much work as possible is done while constructing the result packed function, and
* that function may be reasonably efficiently applied multiple times without redoing
* unnecessary work.
*
* Primitives are lowered and compiled to packed functions for execution on \p device
* with properties given by \p target. All other Relay constructs are interpreted.
*
* The interpreter is intended to be a 'reference' implementation of the Relay semantics
* for testing and interactive use. It is not intended to be particularly efficient.
*
* \param mod A module containing definitions which can be referenced from
* \p expr. May be empty or undefined.
* \param expr An expression of function type to evaluate. May reference definitions from \p mod.
* \param device The device on which all primitives will be executed.
* \param target The compiler target flag for compiling primitives.
* \return A packed function that takes an array of Relay expressions and returns the
* result of applying \p expr to those arguments.
*/
TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, Device device,
Target target);

/*!
* \brief Evaluates \p expr and returns its result.
*
* This function is intended to support TVM constant evaluation.
*
* \param expr An expression to evaluate.
* \param type_definitions Global type definitions which \p expr may references.
* \param import_set Already imported external modules.
* \param device The device on which all primitives will be executed.
* \param target The compiler target flag for compiling primitives.
* @return The object representing the result.
*/
ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, Device device, Target target);

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_INTERPRETER_H_
3 changes: 1 addition & 2 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,7 @@ def get_calibration_data(mod, data):
mod = _ffi_api.get_calibrate_module(mod)
mod = transform.Inline()(mod)

ref_ex = build_module.create_executor("graph", mod=mod, device=cpu(0))
ref_res = ref_ex.evaluate()(**data)
ref_res = build_module.create_executor("graph", mod=mod, device=cpu(0)).evaluate()(**data)
mbs-octoml marked this conversation as resolved.
Show resolved Hide resolved

calib_data = {}
for gvar, indices in output_map.items():
Expand Down
81 changes: 35 additions & 46 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@

import tvm._ffi
from tvm.runtime import container, Object
from tvm.ir import IRModule

from . import _backend
from .. import _make, analysis, transform
from .. import _make, analysis
from ... import nd
from ..expr import Tuple, RefCreate, Call, Constant, GlobalVar, const
from ..function import Function
Expand Down Expand Up @@ -178,6 +177,7 @@ def evaluate(self, expr=None, binds=None):
return self._make_executor(expr)

# normal expression evaluated by running a function.
# TODO(mbs): This should really be type rather than syntax driven.
func = Function([], expr)
return self._make_executor(func)()

Expand All @@ -196,65 +196,54 @@ class Interpreter(Executor):

target : tvm.Target
The target option to build the function.

CAUTION: Despite the API the module is prepared upon each call to evaluate
rather than once in create_executor.
That is:
.. code-block:: python

executor = relay.create_executor(kind="debug", mod=module)
a = executor.evaluate(expr)(args1)
b = executor.evaluate(expr)(args2)

will prepare all the bindings in module twice. For efficiency, try to hoist
calls to evaluate as high as possible, preferably immediately after create_executor:
.. code-block:: python

func = relay.create_executor(kind="debug", mod=module).evaluate(expr)
a = func(args1)
b = func(args2)
"""

def __init__(self, mod, device, target):
self.mod = mod
self.device = device
self.target = target

def optimize(self):
"""Optimize functions in a module.

Returns
-------
opt_mod : tvm.IRModule
The optimized module.
"""
seq = tvm.transform.Sequential(
[
# tvm.parser.AnnotateSpans(),
transform.SimplifyInference(),
transform.FuseOps(0),
transform.ToANormalForm(),
transform.InferType(),
]
)
mod = seq(self.mod)
return mod

def _make_executor(self, expr=None):
if expr is None or isinstance(expr, GlobalVar):
assert self.mod is not None

_intrp = _backend.CreateInterpreter(self.optimize(), self.device, self.target)
if expr is None:
# A missing expr denotes 'main' in the given module.
expr = self.mod.get_global_var("main")

def _interp_wrapper(*args, **kwargs):
if expr is None:
args = self._convert_args(self.mod["main"], args, kwargs)
# Evaluate expr to a packed function we can efficiently re-apply
# to Relay arguments.
func = _backend.EvalFunction(self.mod, expr, self.device, self.target)

def _apply_args(*args, **kwargs):
if isinstance(expr, GlobalVar):
# When expanding args, look inside the actual global definition so kwargs
# can be matched.
args = self._convert_args(self.mod[expr.name_hint], args, kwargs)
else:
args = self._convert_args(expr, args, kwargs)

# Reflect python arguments up into Relay.
relay_args = []
for arg in args:
relay_args.append(_arg_to_ast(self.mod, arg))
# Apply func to Relay args
return func(relay_args)

# Set the entry function for the module.
if expr is None:
pass
elif isinstance(expr, GlobalVar):
self.mod["main"] = self.mod[expr]
else:
assert isinstance(expr, Function)
func = Function([], Call(expr, relay_args))
relay_args = []
if self.mod:
self.mod["main"] = func
else:
self.mod = IRModule.from_expr(func)

mod = self.optimize()
opt_expr = Call(mod["main"], relay_args)
return _intrp(opt_expr)

return _interp_wrapper
return _apply_args
3 changes: 3 additions & 0 deletions python/tvm/relay/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,9 @@ def _graph_wrapper(*args, **kwargs):
return _graph_wrapper


# TODO(mbs): Collapse the create_executor/evaluate phases together since a) most callers don't
# reuse the executor for multiple expressions and b) any preparation necessary for the expression
# evaluation needs to (currently) be done along with preparation for the module.
def create_executor(kind="debug", mod=None, device=None, target="llvm", params=None):
"""Factory function to create an executor.

Expand Down
5 changes: 3 additions & 2 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,12 @@ def infer_value(input_val, params, mod=None):
mod["main"] = _function.Function(analysis.free_vars(input_val), input_val)
else:
mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, device=tvm.cpu(), target="llvm")
inputs = []
for param in mod["main"].params:
inputs.append(params[param.name_hint])
result = exc.evaluate()(*inputs)
result = tvm.relay.create_executor(
"debug", mod=mod, device=tvm.cpu(), target="llvm"
).evaluate()(*inputs)
return result


Expand Down
8 changes: 5 additions & 3 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,13 @@ def check_grad(
test_inputs = inputs

for target, dev in enabled_targets():
intrp = relay.create_executor(device=dev, target=target)
# Eval the backward and forward functions
# TODO(mbs): Evaluate a pair of functions so can share preparation between them.
bwd_func_compiled = relay.create_executor(device=dev, target=target).evaluate(bwd_func)
fwd_func_compiled = relay.create_executor(device=dev, target=target).evaluate(fwd_func)

# Get analytic gradients.
_, grads = intrp.evaluate(bwd_func)(*inputs)
_, grads = bwd_func_compiled(*inputs)
grads = [grad.numpy().astype("float64") for grad in grads]

# Throw out gradients we aren't testing
Expand All @@ -154,7 +157,6 @@ def check_grad(
assert len(grads) > 0, "You must test at least one gradient."

# Get numeric gradients for each dimension of each param, using two-sided approximation.
fwd_func_compiled = intrp.evaluate(fwd_func)
approx_grads = []
for x in test_inputs:
approx_grad = np.zeros(x.shape)
Expand Down
Loading