Skip to content

Commit

Permalink
[checkpoint] Bug fixes. AOT is broken.
Browse files Browse the repository at this point in the history
[checkpoint] Separate eval-to-clousure and apply-closure phases at last

[checkpoint] Fix GetType recursion, get debug going.

[checkpoint] Audit python to collapse create_executor and evaluate phases

Just a few places where this doesn't work, seem harmless.

[checkpoint] Get interpreter working using tec::LowerTE, but no dynamic shapes.

 - Hide TECompiler impl inside te_compiler.cc. However I think it is already exposed
   into Python land so this probably won't be possible now.
 - Move 'optimize' pre-transforms from interpreter.py to interpreter.cc so can be
   applied uniformly to both mod and expr.
 - Don't push the expr into the mod in interpreter.py since it's done again in
   interpreter.cc. Instead just build the Call node with the reflected args.
 - Both the mod and the expr are prepared identically (same transforms, of which
   LowerTensorExpr should be one).
 - LowerTensorExpr can look through let-bound and global vars, eg
     let f = fn (..., Primitive=1) { ... } ... f(...)
     ==> @lowered_f = ... @lowered_f(...)
 - Lots of DLOGs that need to be removed or reorganized.

[checkpoint] Support shape functions.

TODO:
 - Unit tests.
 - Cleanup logging (VLOG?)
 - Don't build all prims on each apply.

[checkpoint] typo

[checkpoint] Don't allow evaling expr independently of preparing module.

TODO:
 - Make eval(mod, expr) the interface.
 - GlobalVar's don't line up.
 - Rework use of interpreter in fold_constant.cc to make clear
   it is evaling prim calls which have already been prepared.
 - Find a dynamic shape example that works at HEAD.
 - Unit tests.

[checkpoint] Interpreting expression with refs to module defs working

Commit to interpreter evaling expr w.r.t. mod in single phase. Thankfully
turns out no existing uses broke that assumption so we dodged a bullet.

Binding of expr into mod to be evaled is a mess, needs to be fixed.

Still can't confirm dynamic shapes working since don't have an
example working at HEAD.

Change to partial_eval needs to be tested, but smells ok.

[checkpoint] Dynamic shapes working

The use of TIRCallAttrs is pretty hacky but the shape function
calls are at least working.

Next is to tackle the 'build everything just to project one prim fun' problem.

[checkpoint] Cache built prims, make sure build with minimal deps.

[checkpoint] Cleanup expr-to-module hackery.
  • Loading branch information
mbs-octoml committed Aug 3, 2021
1 parent b3e832a commit 01c73c5
Show file tree
Hide file tree
Showing 51 changed files with 1,120 additions and 533 deletions.
2 changes: 1 addition & 1 deletion docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ Either match the first pattern or the second pattern.
Domination
**********

Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node betwen the child and the pattern matches the path pattern.
Match child pattern, find a match for the parent pattern, insuring that the child ultimately dominates the parent (i.e., no nodes outside the pattern use outputs of the parent), and that ever node between the child and the pattern matches the path pattern.

Function Pattern
****************
Expand Down
24 changes: 17 additions & 7 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -307,20 +307,30 @@ class IRModule : public ObjectRef {
}

/*!
* \brief Construct a module from a standalone expression.
* \brief Constructs a module from a standalone expression \p expr.
*
* Allows one to optionally pass a global function map and
* map of type definitions as well.
* If \p expr is a function it will be bound directly. Otherwise a function over the free
* variables of \p expr (possibly none) with \p expr as body is created and bound.
*
* The function is bound to, in preference order:
* - The "global_symbol" attribute of \p expr, if it is a function with that attribute.
* - \p name_hint, if non-empty.
* - "main"
*
* Additional global functions and type definitions may be included in the result module.
*
* \param expr The expression to set as the main function to the module.
* \param global_funcs The global function map.
* \param type_definitions Map of global type definitions
* \param global_funcs The global function map. Default empty.
* \param type_definitions Map of global type definitions. Default empty.
* \param name_hint Name hint for global var. Default empty.
*
* \returns A module with expr set as the main function.
* \returns A module with \p expr set as the main function.
*/
TVM_DLL static IRModule FromExpr(const RelayExpr& expr,
const Map<GlobalVar, BaseFunc>& global_funcs = {},
const Map<GlobalTypeVar, TypeData>& type_definitions = {});
const Map<GlobalTypeVar, TypeData>& type_definitions = {},
std::unordered_set<String> import_set = {},
const std::string& name_hint = std::string());

/*!
* \brief Parse text format source file into an IRModule.
Expand Down
66 changes: 44 additions & 22 deletions include/tvm/relay/interpreter.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,28 +43,6 @@
namespace tvm {
namespace relay {

/*!
*\brief Create a Interpreter function that can
* evaluate an expression and produce a value.
*
* The resulting value can be passed to Python, making it easy to use
* for testing and debugging.
*
* The interpreter interprets the program fragments not supported by the
* TVM runtime, although the interpreter is naively implemented it uses
* TVM operators for evaluating all operators.
*
* Our intent is that this will never be the most efficient implementation of
* Relay's semantics, but a readable and clear one.
*
* \param mod The function module.
* \param device The primary device that the interepreter runs on.
* \param target Compiler target flag to compile the functions on the context.
* \return A function that takes in an expression and returns a value.
*/
runtime::TypedPackedFunc<ObjectRef(Expr)> CreateInterpreter(IRModule mod, Device device,
Target target);

/*! \brief The container type of Closures used by the interpreter. */
class InterpreterClosureObj : public runtime::ClosureObj {
public:
Expand Down Expand Up @@ -164,6 +142,50 @@ class ConstructorValue : public ObjectRef {
TVM_DEFINE_OBJECT_REF_METHODS(ConstructorValue, ObjectRef, ConstructorValueObj);
};

/*!
* \brief Returns a packed function over Relay expressions which will evaluate \p expr
* applied to those arguments, where \p expr is w.r.t. the definitions in \p mod.
*
* This function is intended to support the Python 'debug' executor.
*
* The given \p expr should have function type. The given \p mod may be empty or
* undefined if \p expr is self-contained. Relay arguments passed to the result
* packed function must be constants, references, or constructors/tuples over such.
* As much work as possible is done while constructing the result packed function, and
* that function may be reasonably efficiently applied multiple times without redoing
* unnecessary work.
*
* Primitives are lowered and compiled to packed functions for execution on \p device
* with properties given by \p target. All other Relay constructs are interpreted.
*
* The interpreter is intended to be a 'reference' implementation of the Relay semantics
* for testing and interactive use. It is not intended to be particularly efficient.
*
* \param mod A module containing definitions which can be referenced from
* \p expr. May be empty or undefined.
* \param expr An expression of function type to evaluate. May reference definitions from \p mod.
* \param device The device on which all primitives will be executed.
* \param target The compiler target flag for compiling primitives.
* \return A packed function that takes an array of Relay expressions and returns the
* result of applying \p expr to those arguments.
*/
TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, Device device,
Target target);

/*!
* \brief Evaluates \p expr and returns its result.
*
* This function is intended to support TVM constant evaluation.
*
* @param expr An expression to evaluate.
* \param device The device on which all primitives will be executed.
* \param target The compiler target flag for compiling primitives.
* @return The object representing the result.
*/
ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, Device device, Target target);

} // namespace relay
} // namespace tvm

#endif // TVM_RELAY_INTERPRETER_H_
43 changes: 29 additions & 14 deletions include/tvm/runtime/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -487,41 +487,53 @@ TVM_CHECK_FUNC(_NE, !=)
#define DLOG_IF(severity, condition) \
LOG_IF(severity, ::tvm::runtime::detail::DebugLoggingEnabled() && (condition))

#ifdef VLOG_LEVEL
#define VLOG(level) DLOG_IF(INFO, (level) <= VLOG_LEVEL)
#else
#define VLOG(level) true ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(INFO)
#endif

#else

#define LOG_DFATAL LOG_ERROR
#define DFATAL ERROR
#define DLOG(severity) true ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity)
#define DLOG_IF(severity, condition) \
(true || !(condition)) ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity)
#define VLOG(level) true ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(INFO)

#endif

#if TVM_LOG_DEBUG
#define DCHECK(x) CHECK(x)
#define DCHECK_LT(x, y) CHECK((x) < (y))
#define DCHECK_GT(x, y) CHECK((x) > (y))
#define DCHECK_LE(x, y) CHECK((x) <= (y))
#define DCHECK_GE(x, y) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) CHECK((x) == (y))
#define DCHECK_NE(x, y) CHECK((x) != (y))
#else
#define DCHECK(x) \
while (false) CHECK(x)
while (false) CHECK(x)
#define DCHECK_LT(x, y) \
while (false) CHECK((x) < (y))
while (false) CHECK((x) < (y))
#define DCHECK_GT(x, y) \
while (false) CHECK((x) > (y))
while (false) CHECK((x) > (y))
#define DCHECK_LE(x, y) \
while (false) CHECK((x) <= (y))
while (false) CHECK((x) <= (y))
#define DCHECK_GE(x, y) \
while (false) CHECK((x) >= (y))
while (false) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) \
while (false) CHECK((x) == (y))
while (false) CHECK((x) == (y))
#define DCHECK_NE(x, y) \
while (false) CHECK((x) != (y))
while (false) CHECK((x) != (y))
#endif

#if TVM_LOG_DEBUG
#else
#define DCHECK(x) CHECK(x)
#define DCHECK_LT(x, y) CHECK((x) < (y))
#define DCHECK_GT(x, y) CHECK((x) > (y))
#define DCHECK_LE(x, y) CHECK((x) <= (y))
#define DCHECK_GE(x, y) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) CHECK((x) == (y))
#define DCHECK_NE(x, y) CHECK((x) != (y))
#endif


#define TVM_ICHECK_INDENT " "

#define ICHECK_BINARY_OP(name, op, x, y) \
Expand Down Expand Up @@ -552,5 +564,8 @@ TVM_CHECK_FUNC(_NE, !=)
// Re-export error types
using runtime::Error;
using runtime::InternalError;

void InitLogging();

} // namespace tvm
#endif // TVM_RUNTIME_LOGGING_H_
80 changes: 35 additions & 45 deletions python/tvm/relay/backend/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,65 +196,55 @@ class Interpreter(Executor):
target : tvm.Target
The target option to build the function.
CAUTION: Despite the API the module is prepared upon each call to evaluate
rather than once in create_executor.
That is:
.. code-block:: python
executor = relay.create_executor(kind="debug", mod=module)
a = executor.evaluate(expr)(args1)
b = executor.evaluate(expr)(args2)
will prepare all the bindings in module twice. For efficiency, try to hoist
calls to evaluate as high as possible, preferably immediately after create_executor:
.. code-block:: python
func = relay.create_executor(kind="debug", mod=module).evaluate(expr)
a = func(args1)
b = func(args2)
"""

def __init__(self, mod, device, target):
self.mod = mod
self.device = device
self.target = target

def optimize(self):
"""Optimize functions in a module.
Returns
-------
opt_mod : tvm.IRModule
The optimized module.
"""
seq = tvm.transform.Sequential(
[
# tvm.parser.AnnotateSpans(),
transform.SimplifyInference(),
transform.FuseOps(0),
transform.ToANormalForm(),
transform.InferType(),
]
)
mod = seq(self.mod)
return mod

def _make_executor(self, expr=None):
if expr is None or isinstance(expr, GlobalVar):
assert self.mod is not None

_intrp = _backend.CreateInterpreter(self.optimize(), self.device, self.target)
if expr is None:
# A missing expr denotes 'main' in the given module.
expr = self.mod.get_global_var("main")

def _interp_wrapper(*args, **kwargs):
if expr is None:
args = self._convert_args(self.mod["main"], args, kwargs)
else:
args = self._convert_args(expr, args, kwargs)
# Evaluate expr to a packed function we can efficiently re-apply
# to Relay arguments.
print("before EvalFunction\n")
func = _backend.EvalFunction(self.mod, expr, self.device, self.target)
print("after EvalFunction\n")

def _apply_args(*args, **kwargs):
if expr is GlobalVar:
# When expanding args, look inside the actual global definition so kwargs can be matched.
args = self._convert_args(self.mod[expr.name_hint], args, kwargs)
else:
args = self._convert_args(expr, args, kwargs)
# Reflect python arguments up into Relay.
relay_args = []
for arg in args:
relay_args.append(_arg_to_ast(self.mod, arg))
# Apply func to Relay args
return func(relay_args)

# Set the entry function for the module.
if expr is None:
pass
elif isinstance(expr, GlobalVar):
self.mod["main"] = self.mod[expr]
else:
assert isinstance(expr, Function)
func = Function([], Call(expr, relay_args))
relay_args = []
if self.mod:
self.mod["main"] = func
else:
self.mod = IRModule.from_expr(func)

mod = self.optimize()
opt_expr = Call(mod["main"], relay_args)
return _intrp(opt_expr)

return _interp_wrapper
return _apply_args
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,10 +545,10 @@ def infer_value(input_val, params, mod=None):
mod["main"] = _function.Function(analysis.free_vars(input_val), input_val)
else:
mod = IRModule.from_expr(input_val)
exc = tvm.relay.create_executor("debug", mod=mod, device=tvm.cpu(), target="llvm")
inputs = []
for param in mod["main"].params:
inputs.append(params[param.name_hint])
exc = tvm.relay.create_executor("debug", mod=mod, device=tvm.cpu(), target="llvm")
result = exc.evaluate()(*inputs)
return result

Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relay/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,15 @@ def check_grad(
test_inputs = inputs

for target, dev in enabled_targets():
# Eval the backward and forward functions
intrp = relay.create_executor(device=dev, target=target)
bwd_func_compiled, fwd_func_compiled = intrp.evaluate(relay.Tuple([bwd_func, fwd_func]))

# Get analytic gradients.
_, grads = intrp.evaluate(bwd_func)(*inputs)
_, grads = bwd_func_compiled(*inputs)
grads = [grad.numpy().astype("float64") for grad in grads]


# Throw out gradients we aren't testing
if inputs != test_inputs:
tmp = []
Expand All @@ -154,7 +157,6 @@ def check_grad(
assert len(grads) > 0, "You must test at least one gradient."

# Get numeric gradients for each dimension of each param, using two-sided approximation.
fwd_func_compiled = intrp.evaluate(fwd_func)
approx_grads = []
for x in test_inputs:
approx_grad = np.zeros(x.shape)
Expand Down
12 changes: 8 additions & 4 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,20 +349,24 @@ void IRModuleNode::Update(const IRModule& mod) {

IRModule IRModule::FromExpr(const RelayExpr& expr,
const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions) {
auto mod = IRModule(global_funcs, type_definitions);
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions,
std::unordered_set<String> import_set,
const std::string& name_hint) {
auto mod = IRModule(global_funcs, type_definitions, std::move(import_set));
BaseFunc func;
std::string gv_name = "main";
std::string gv_name = name_hint;

if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
gv_name = opt.value();
}

} else {
func = relay::Function(relay::FreeVars(expr), expr, Type(), relay::FreeTypeVars(expr, mod), {});
}
if (gv_name.empty()) {
gv_name = "main";
}
auto main_gv = GlobalVar(gv_name);
mod->Add(main_gv, func);
return mod;
Expand Down
2 changes: 1 addition & 1 deletion src/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,7 @@ Pass GetPass(const String& pass_name) {
return (*f)();
}

// TODO(zhiics): we currenlty only sequentially execute each pass in
// TODO(zhiics): we currently only sequentially execute each pass in
// a Sequential without the consideration of their orders. The phase
// ordering problem needs to be handled in the future.
IRModule SequentialNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
Expand Down
Loading

0 comments on commit 01c73c5

Please sign in to comment.