Skip to content

Commit

Permalink
[checkpoint] nuke device_annotation.cc and python bindings
Browse files Browse the repository at this point in the history
Probably lots of dangling refs on the python side.
Original handling of default devices and the '0' device type target not
brought over and needs to be recovered.
  • Loading branch information
mbs-octoml committed Sep 17, 2021
1 parent a677fb7 commit ed31ec7
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 1,177 deletions.
10 changes: 0 additions & 10 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,16 +211,6 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod);
*/
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const IRModule& mod);

/*!
* \brief Collect the device annotation operators.
*
* \param expr The expression.
*
* \return The annotated expression to device type mapping for annotation ops.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
TVM_DLL Map<Expr, Integer> CollectAllDeviceAnnotationOps(const IRModule& mod);

/*!
* \brief Finds cases that the given match expression does not catch, if any.
*
Expand Down
17 changes: 0 additions & 17 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,23 +253,6 @@ def all_dtypes(expr):
return set(_ffi_api.all_dtypes(expr))


def collect_device_annotation_ops(expr):
"""Collect the device annotation ops for the given expression.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.Expr, int]
A dictionary mapping tvm.relay.Expr to device type where the keys are
annotation expressions.
"""
return _ffi_api.CollectDeviceAnnotationOps(expr)


def get_total_mac_number(expr):
"""
Count the number of MACs (multiply-accumulate) of a model
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@
register_broadcast_schedule("fast_erf")
# a fake on_device schedule.
# this will not be used in actual computation
# as on_device will be removed during DeviceAnnotation pass
register_injective_schedule("on_device")


Expand Down
21 changes: 0 additions & 21 deletions python/tvm/relay/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,27 +544,6 @@ def MergeCompilerRegions():
return _ffi_api.MergeCompilerRegions()


def RewriteAnnotatedOps(fallback_device):
"""Rewrite the annotated program where annotation operators, e.g.
`on_device`, mark which device an expression should be scheduled to.
This pass helps heterogeneous execution where different operators may need
to be allocated on various devices.
Parameters
----------
fallback_device : int
The fallback device type. It is also used as the default device for
operators with no annotated device.
Returns
-------
ret: tvm.transform.Pass
The registered pass that rewrites an expression with annotated
`on_device` operators.
"""
return _ffi_api.RewriteAnnotatedOps(fallback_device)


def ToANormalForm():
"""Turn Graph Normal Form expression into A Normal Form Expression.
The scope of the root expression is the global scope.
Expand Down
56 changes: 10 additions & 46 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -325,15 +325,16 @@ class RelayBuildModule : public runtime::ModuleNode {
transform::PassContext pass_ctx = PassContext::Current();
Optional<Integer> opt_fallback_dev =
pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast<int>(kDLCPU)));
DLDeviceType fallback_dev = static_cast<DLDeviceType>(opt_fallback_dev.value()->value);
ICHECK_GT(fallback_dev, 0U);
#if 0
// TODO(mbs): Remove
if (targets_.size() > 1) {
relay_module = RunDeviceAnnotationPass(relay_module, fallback_dev);
}
#endif
pass_seqs.push_back(transform::PlanDevices(fallback_dev));
DLDeviceType default_device_type = static_cast<DLDeviceType>(opt_fallback_dev.value()->value);
ICHECK_GT(default_device_type, 0U);
// What about the implied 'default' target with 'device type' 0?
UpdateHeterogeneousInputs(default_device_type);
// Make sure the 'default' target is always available keyed by 0.
// This is the current convention for conveying which of the >= 2 targets is the default.
targets_.Set(kNullDeviceType, CreateDefaultTarget(default_device_type));
// TODO(mbs): Used to be some obsure logic for choosing a different fallback_device
// from the existing "on_device" annotations. What is that for?
pass_seqs.push_back(transform::PlanDevices(default_device_type));

// Fuse the operations if it is needed.
pass_seqs.push_back(transform::FuseOps());
Expand Down Expand Up @@ -411,43 +412,6 @@ class RelayBuildModule : public runtime::ModuleNode {
}
}

/*!
* \brief Execute the device annotation passes to update the input program and
* target information.
*
* \param relay_module The input Relay module.
* \param fallback_device The fallback device for heterogeneous execution.
*
* \return updated_module The updated module after device annotation.
*/
IRModule RunDeviceAnnotationPass(const IRModule& relay_module, int fallback_device) {
UpdateHeterogeneousInputs(fallback_device);

// If there's a unique device type used by all "on_device" CallNodes then use that
// as the fallback_device.
// TODO(mbs): This defaulting only roughly matches the original behavior. We should
// cleanup all the logic around default host and device targets.
Map<Expr, Integer> annotations = CollectAllDeviceAnnotationOps(relay_module);
if (!annotations.empty()) {
std::unordered_set<int> device_types;
for (const auto& pair : annotations) {
device_types.insert(static_cast<int>((*annotations.begin()).second->value));
}
if (device_types.size() == 1UL) {
fallback_device = *device_types.begin();
}
}
// Make sure the 'default' target is always available keyed by 0.
// This is the current convention for conveying which of the >= 2 targets is the default.
targets_.Set(kNullDeviceType, CreateDefaultTarget(fallback_device));

// Insert "device_copy" CallNodes to account for any user-supplied "on_device" CallNodes.
auto updated_module = transform::RewriteAnnotatedOps(fallback_device)(relay_module);
ICHECK(updated_module.defined());

return updated_module;
}

/*!
* \brief Compile a Relay IR module to runtime module.
*
Expand Down
2 changes: 2 additions & 0 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator {

auto device_copy = IsDeviceCopy(func);
if (std::get<0>(device_copy)) {
// Record that device copy source and destination devices so the device planner can
// still follow along.
auto source_device = std::get<1>(device_copy);
auto dst_device = std::get<2>(device_copy);
tir_call_attrs->metadata.Set("source_device", tvm::Integer(source_device));
Expand Down
Loading

0 comments on commit ed31ec7

Please sign in to comment.