Skip to content

Commit

Permalink
[checkpoint] Woops, forgot VM exec tests won't work yet, disabled.
Browse files Browse the repository at this point in the history
  • Loading branch information
mbs-octoml committed Sep 18, 2021
1 parent e988a34 commit 966f413
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 30 deletions.
52 changes: 43 additions & 9 deletions src/relay/transforms/device_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,9 @@
* * Might be simpler to just let every type have a device annotation rather than work in
* a separate domain?
* * Switch to expr.CopyWith(...) form once implemented to avoid unnecessary copies.
* * The original device_annotation.cc RewriteAnnotatedOps removed all "on_device" calls
* in tuples at the top level of function bodies or main expression, irrespective of the
* "on_device" body. What's up with that?
*/

#include "./device_planner.h"
Expand Down Expand Up @@ -273,6 +276,29 @@ namespace transform {

namespace {

/*!
* \brief As for GetDeviceCopyProps, but for the call to the lowered TIR primitives rather
* than the original "device_copy" operator.
*
* See te_compiler.cc for where this rewriting occurs.
*/
DeviceCopyProps GetPrimitiveDeviceCopyProps(const CallNode* call_node) {
auto tir_call_attrs = call_node->attrs.as<TIRCallAttrs>();
if (tir_call_attrs == nullptr) {
return {};
}
if (tir_call_attrs->metadata.count("source_device") != 1 ||
tir_call_attrs->metadata.count("dst_device") != 1) {
return {};
}
ICHECK_EQ(call_node->args.size(), 1) << "device_copy is of arity 1";
return {
call_node->args[0],
static_cast<DLDeviceType>(
Downcast<Integer>(tir_call_attrs->metadata["source_device"])->value),
static_cast<DLDeviceType>(Downcast<Integer>(tir_call_attrs->metadata["dst_device"])->value)};
}

class DeviceDomain;
using DeviceDomainPtr = std::shared_ptr<DeviceDomain>;

Expand Down Expand Up @@ -581,25 +607,33 @@ class DeviceDomains {
return Lookup(itr->second);
}
std::vector<DeviceDomainPtr> args_and_result;
if (call->op == OnDeviceOp()) {

auto on_device_props = GetOnDeviceProps(call.get());
auto device_copy_props = GetDeviceCopyProps(call.get());
if (!device_copy_props.body.defined()) {
device_copy_props = GetPrimitiveDeviceCopyProps(call.get());
}

if (on_device_props.body.defined()) {
// on_device(expr, device_type=<t>, is_fixed=false)
// on_device : fn(<t>):?x?
//
// on_device(expr, device_type=<t>, is_fixed=true)
// on_device: fn(<t>):<t>
auto props = GetOnDeviceProps(call.get());
args_and_result.emplace_back(ForDeviceType(props.body->checked_type(), props.device_type));
if (props.is_fixed) {
args_and_result.emplace_back(
ForDeviceType(on_device_props.body->checked_type(), on_device_props.device_type));
if (on_device_props.is_fixed) {
args_and_result.emplace_back(args_and_result.front());
} else {
args_and_result.emplace_back(Free(props.body->checked_type()));
args_and_result.emplace_back(Free(on_device_props.body->checked_type()));
}
} else if (call->op == DeviceCopyOp()) {
} else if (device_copy_props.body.defined()) {
// device_copy(expr, src_dev_type=<s>, dst_dev_type=<d>)
// device_copy: fn(<s>):<d>
auto props = GetDeviceCopyProps(call.get());
args_and_result.emplace_back(ForDeviceType(props.body->checked_type(), props.src_dev_type));
args_and_result.emplace_back(ForDeviceType(props.body->checked_type(), props.dst_dev_type));
args_and_result.emplace_back(
ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.src_dev_type));
args_and_result.emplace_back(
ForDeviceType(device_copy_props.body->checked_type(), device_copy_props.dst_dev_type));
} else if (call->op == alloc_storage_op) {
ICHECK_EQ(call->args.size(), 2U);
// alloc_storage(size, alignment, device_type=<t>)
Expand Down
24 changes: 3 additions & 21 deletions tests/python/relay/test_pass_plan_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@

"""Unit tests for the PlanDevices pass. We check:
- The pass alone given the expected AST, though we need to manually run InferTypes.
- The pass is idempotent.
- Execution (where the pass is implicitly invoked) agrees with a reference Numpy
implementation. However not all backends support all Relay features and/or
multi-device programs, so we do what we can."""
- The pass is idempotent."""

# TODO(mbs): All the input/expected programs should be directly quoted using @script
# TODO(mbs): Not testing Match and Constructor since not supported by Python bindings?

# TODO(mbs): Add back reference implementation tests once VM is ready.

import tvm
from tvm import relay
Expand Down Expand Up @@ -57,19 +54,6 @@ def rewrite_and_assert(in_mod, expected_mod):
tvm.ir.assert_structural_equal(actual_mod, expected_mod)


def eval_and_assert(in_mod: tvm.IRModule, reference_func, args):
"""Test the standard compilation flow gives us a function which agrees with the Numpy
reference implementation."""
if not tvm.runtime.enabled(GPU):
print("Not evaluating since device %s is not enabled" % GPU)
return
with tvm.transform.PassContext(opt_level=3):
compiled = relay.create_executor("vm", mod=in_mod, device=GPU, target="cuda").evaluate()
actual = compiled(*args).numpy()
expected = reference_func(*args)
tvm.testing.assert_allclose(actual, expected)


def rand(shape):
return np.random.rand(*shape).astype("float32")

Expand All @@ -84,9 +68,7 @@ def exercise(in_mod: tvm.IRModule, expected_mod: tvm.IRModule, reference_func, a
rewrite_and_assert(in_mod, expected_mod)
# Idempotence
rewrite_and_assert(expected_mod, expected_mod)
# The VM can compile and possibly even run the module
if not (reference_func is None) and not (args is None):
eval_and_assert(in_mod, reference_func, args)
# TODO(mbs): Add back compiling and comparing to reference implementation once VM is ready.


#
Expand Down

0 comments on commit 966f413

Please sign in to comment.