Skip to content

Commit

Permalink
[Relay] Merge analysis/context_analysis.cc and transforms/device_anno…
Browse files Browse the repository at this point in the history
…tation.cc

Currently LowerTEPass (backend/te_compiler.cc) is a 'special' pass because it
depends on a side-input DeviceMap. We'd like to remove that side-input, and
instead recover the Device (and, ultimately, Target) for each (fused) primitive
call from the AST alone.

By doing so we also avoid needing to perform device planning twice:
 - It needs to be done before lowering so we know which primitives need
   to be compiled for which devices.
 - It then needs to be re-done after lowering and optimization as a prelude
   to memory planning.
By baking the device plan into the AST we can simply do device planning before
lowering, and run memory planning later, both as ordinary passes.

While working on that issue we realized we currently have 3 'device planners':
 - transforms/device_annotation.cc, which supports only a small subset of Relay
   and uses a simple top-down algorithm to assign a device to every
   sub-expression.
 - analysis/context_analysis.cc, which makes a galant effort to support most of
   Relay, is based on unification rather than a top-down algorithm, but handles
   higher order functions by ad-hoc and fragile inlining.
 - transforms/annotate_target.cc, which works on Targets instead of Devices, but
   is otherwise like 'device planning'.
We'd like to bring these together.

In this PR we introduce a new transforms/device_planner.cc intended to replace
transforms/device_annotation.cc and analysis/context_analysis.cc. We don't
delete those two just yet since we need to switch all users off of them in the
next PR. We also leave transforms/annotate_target.cc alone pending a proper RFC
to bring devices and targets together sensibly, but have it firmly in our
sights.

transforms/device_planner.cc is based on analysis/context_analysis.cc, but
is heavily reworked to:
 1. Handle starting from existing "on_device" annotations as well as existing
    "device_copy" calls.
 2. Be idempotent, with the idea we'll probably need to re-run it to 'refine'
    device planning to account for storge scopes.
 3. Robustly handle all of Relay, particularly higher-order functions. For that
    we replace the inlining approach in analysis/context_analysis.cc with a
    higher-order unification domain.
 4. Be a little more systematic with defaulting.
 5. Capture the result of the analysis within the AST as new "device_copy" calls
    at device boundaries, and new/replaced "on_device" calls wherever the device
    for a sub-expression is not already 'obvious' from the sub-expression's
    lexical scope.
 6. Provide helper visitors for passes which need to ask for the device for
    any sub-expression they are processing and/or preserve device information
    on rewrites. Those passes include:
     - backend/aot_executor_codegen.cc (AOTOnDemandAllocator)
     - backend/graph_plan_memory.cc (StorageAllocaBaseVisitor etc)
     - backend/te_compiler.cc (LowerTensorExprMutator)
     - backend/vm/lambda_lift.cc (LambdaLifter)
     - transforms/memory_alloc.cc (DialectRewriter)
     - transforms/to_a_normal_form.cc (Fill)
     - backend/vm/compiler.cc (VMFunctionCompiler)
    However we won't change any of those in this PR.

See the draft apache#8788 for the end game, I'll
be peeling PRs out of that.
  • Loading branch information
mbs-octoml committed Sep 18, 2021
1 parent c401079 commit 7c6ed9a
Show file tree
Hide file tree
Showing 26 changed files with 4,101 additions and 93 deletions.
45 changes: 43 additions & 2 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,55 @@ namespace tvm {
namespace relay {

/*!
* \brief Options for the device annotation operators.
* \brief Attributes for the "on_device" operator.
*
* The relay call
* \code
* on_device(expr, device_type=2)
* \endcode
* denotes that the result of \p expr should be stored on the device with \p DLDeviceType 2
* (i.e. \p kDLCuda). Semantically the operator is the identity function.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
// TODO(mbs): Replace device types with TargetDevice.
/*! \brief Device type on which argument expression should be evaluated. */
int device_type;
/*!
* \brief If true, the result device must also be \p device_type and device planning should
* not insert any "device_copy" calls to respect this annotation.
*
* This is used by the device planning pass itself when annotating the planned program.
*/
bool is_fixed;

TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(device_type)
.describe("The virutal device/context type that an expression is annotated with.")
.describe("The type of the virtual device which should hold the expression result.")
.set_default(0);
TVM_ATTR_FIELD(is_fixed)
.describe("If true, do not insert a \"device_copy\" call to respect this annotation.")
.set_default(false);
}
};

/*!
* \brief Attributes for Relay function definitions which capture the devices for the
* function parameters and result.
*/
struct FunctionOnDeviceAttrs : public tvm::AttrsNode<FunctionOnDeviceAttrs> {
constexpr static const char* kFunctionAttrsKey = "on_device";

/*! \brief Device type on which each of the function's arguments already resides. */
Array<Integer> param_device_types;
// TODO(mbs): Replace device types with TargetDevice.
/*! \brief Device type on which function body should be evaluated. */
int result_device_type;

TVM_DECLARE_ATTRS(FunctionOnDeviceAttrs, "relay.attrs.FunctionOnDeviceAttrs") {
TVM_ATTR_FIELD(param_device_types)
.describe("The type of the virtual device which holds each function parameters.");
TVM_ATTR_FIELD(result_device_type)
.describe("The type of the virtual device which will hold the function's result.")
.set_default(0);
}
};
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/attrs/device_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace relay {
* \brief Options for the device copy operators.
*/
struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
// TODO(mbs): Should be TargetDevice.
int dst_dev_type;
int src_dev_type;

Expand Down
3 changes: 2 additions & 1 deletion include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <unordered_map>
#include <utility>
#include <vector>

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -227,7 +228,7 @@ class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
*
* MixedModeVisitor provides the same recursive API as ExprVisitor, and uses
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratatively to prevent stack overflows
* of the graph and processes them iteratively to prevent stack overflows
*/
class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
public:
Expand Down
15 changes: 13 additions & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,24 @@ TVM_DLL Pass RelayToTIRTargetHook();
* \brief A pass for manifesting explicit memory allocations and rewriting
* specific dialects.
*
* \param target_host The target used by the host for compliation.
* \param targets The device type and target pairs for compliation.
* \param target_host The target used by the host for compilation.
* \param targets The device type and target pairs for compilation.
*
* \return The pass.
*/
TVM_DLL Pass ManifestAlloc(Target target_host, Map<tvm::Integer, tvm::Target> targets);

/*!
* \brief Uses existing "on_device" and "device_copy" CallNodes to infer the device on which
* every Relay sub-expression should run (and the result stored). Captures the result of that
* analysis using new "on_device" and "device_copy" CallNodes. See
* tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator}
* for help recovering the device for an arbitrary sub-expression in downstream transformations.
*
* \param default_device_type DLDeviceType for default device.
*/
TVM_DLL Pass PlanDevices(DLDeviceType default_device_type);

} // namespace transform

/*!
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/container/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> {
};

/*!
* \brief Array, container representing a contigious sequence of ObjectRefs.
* \brief Array, container representing a contiguous sequence of ObjectRefs.
*
* Array implements in-place copy-on-write semantics.
*
Expand Down
26 changes: 16 additions & 10 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,19 @@
#include <vector>

namespace tvm {
namespace runtime {

typedef DLDevice Device;
// alias DLDevice
using Device = DLDevice;

// A 'null' device type, does not correspond to any DLDeviceType enum.
// TODO(mbs): This is to help us as we transition away from representing the 'homogenous' case
// as a singleton target map indexed by the invalid DLDeviceType '0'.
constexpr DLDeviceType kNullDeviceType = static_cast<DLDeviceType>(0);

// An 'invalid' device type, does not correspond to any DLDeviceType enum.
constexpr DLDeviceType kInvalidDeviceType = static_cast<DLDeviceType>(-1);

namespace runtime {

/*!
* \brief Managed NDArray.
Expand Down Expand Up @@ -481,23 +491,19 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
}

} // namespace runtime

// alias Device
using tvm::runtime::Device;

} // namespace tvm

namespace std {
template <>
struct hash<tvm::runtime::Device> {
std::size_t operator()(const tvm::runtime::Device& dev) const {
struct hash<tvm::Device> {
std::size_t operator()(const tvm::Device& dev) const {
return ((dev.device_id << 8) | dev.device_type);
}
};

template <>
struct equal_to<tvm::runtime::Device> {
bool operator()(const tvm::runtime::Device& lhs, const tvm::runtime::Device& rhs) const {
struct equal_to<tvm::Device> {
bool operator()(const tvm::Device& lhs, const tvm::Device& rhs) const {
return (lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id);
}
};
Expand Down
39 changes: 28 additions & 11 deletions python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,19 @@
from .. import op as reg


def on_device(data, device):
def _device_to_int(device):
if isinstance(device, _Device):
return device.device_type
elif isinstance(device, str):
return _nd.device(device).device_type
else:
raise ValueError(
"device is expected to be the type of Device or "
"str, but received %s" % (type(device))
)


def on_device(data, device, is_fixed=False):
"""Annotate an expression with a certain device type.
Parameters
Expand All @@ -33,21 +45,26 @@ def on_device(data, device):
device : Union[:py:class:`Device`, str]
The device type to annotate.
is_fixed : bool
If true, annotation does not imply a device_copy may be inserted.
(This parameter is used internally by the compiler and unit tests and
should not need to be set in user programs.)
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
if isinstance(device, _Device):
device = device.device_type
elif isinstance(device, str):
device = _nd.device(device).device_type
else:
raise ValueError(
"device is expected to be the type of Device or "
"str, but received %s" % (type(device))
)
return _make.on_device(data, device)
return _make.on_device(data, _device_to_int(device), is_fixed)


# for testing only
def function_on_device(function, param_devices, result_device):
"""Attaches attribute to function indicating the devices for its parameters and result.
"""

return _make.function_on_device(function,
[_device_to_int(d) for d in param_devices], _device_to_int(result_device))


def stop_fusion(data):
Expand Down
12 changes: 11 additions & 1 deletion python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ def MergeCompilerRegions():

def RewriteAnnotatedOps(fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_deivce`, mark which device an expression should be scheduled to.
`on_device`, mark which device an expression should be scheduled to.
This pass helps heterogeneous execution where different operators may need
to be allocated on various devices.
Expand Down Expand Up @@ -1167,6 +1167,16 @@ def SimplifyExpr():
return _ffi_api.SimplifyExpr()


def PlanDevices(default_device):
"""
Uses existing "on_device" and "device_copy" CallNodes to infer the device on which
every Relay sub-expression should run (and the result stored). Captures the result of that
analysis using new "on_device" and "device_copy" CallNodes. Note that the device_id of
the default_device is ignored.
"""
return _ffi_api.PlanDevices(default_device)


def FoldExplicitPadding():
"""
FoldExplicitPadding finds explict padding before an op that can support
Expand Down
7 changes: 5 additions & 2 deletions src/node/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
/*!
* \file src/node/structural_equal.cc
*/
#include <tvm/ir/module.h>
#include <tvm/node/functor.h>
#include <tvm/node/node.h>
#include <tvm/node/reflection.h>
Expand Down Expand Up @@ -119,8 +120,10 @@ class RemapVarSEqualHandler : public SEqualReducer::Handler {
// Check the result.
bool CheckResult(bool result, const ObjectRef& lhs, const ObjectRef& rhs) {
if (assert_mode_ && !result) {
LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by\n"
<< "lhs = " << lhs << "\nrhs = " << rhs;
LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl
<< PrettyPrint(lhs) << std::endl
<< "and rhs:" << std::endl
<< PrettyPrint(rhs);
}
return result;
}
Expand Down
8 changes: 3 additions & 5 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -439,11 +439,10 @@ class LowerTensorExprMutator : public ExprMutator {
}

// Non-External Relay Function
DLOG(INFO) << "lowering to target '" << target->str() << "' for primitive:\n"
<< PrettyPrint(func);
VLOG(1) << "lowering to target '" << target->str() << "' for primitive:\n" << PrettyPrint(func);
CCacheKey key = CCacheKey(func, target);
CachedFunc lowered_func = compiler_->Lower(key, module_name_);
DLOG(INFO) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'";
VLOG(1) << "lowered primitive bound to '" << PrettyPrint(lowered_func->prim_fn_var) << "'";

// Collect all the lowered functions produced for this primitive function.
Map<GlobalVar, tir::PrimFunc> prim_fns;
Expand All @@ -452,8 +451,7 @@ class LowerTensorExprMutator : public ExprMutator {
CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
all_prim_fn_vars.push_back(prim_fn.first);
DLOG(INFO) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first)
<< "'";
VLOG(1) << "lowered primitive includes bindings for '" << PrettyPrint(prim_fn.first) << "'";
}

// TODO(@areusch, @jroesch): this metadata is for AOT, this should be our interface for AOT
Expand Down
4 changes: 2 additions & 2 deletions src/relay/backend/vm/inline_primitives.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,13 @@ struct PrimitiveInliner : ExprMutator {
if (n->GetAttr<String>(attr::kCompiler).defined()) continue;
auto func = GetRef<Function>(n);

DLOG(INFO) << "Before inlining primitives: " << global << std::endl << AsText(func, false);
VLOG(1) << "Before inlining primitives: " << global << std::endl << PrettyPrint(func);

func = Function(func->params, VisitExpr(func->body), func->ret_type, func->type_params,
func->attrs);
module_->Add(global, func, true);

DLOG(INFO) << "After inlining primitives: " << global << std::endl << AsText(func, false);
VLOG(1) << "After inlining primitives: " << global << std::endl << PrettyPrint(func);
}
}
return module_;
Expand Down
Loading

0 comments on commit 7c6ed9a

Please sign in to comment.