Skip to content

Commit

Permalink
[checkpoint] cross-device example working on vm!
Browse files Browse the repository at this point in the history
[checkpoint] Can't wrap on_device around big-lambdas.

[checkpoint] ToANF working I think

 - Cleanup pairs-of-pairs for OnDeviceProps etc.
 - Don't wrap OnDevice around expressions that don't need it.

[checkpoint] ANF working again

[checkpoint] Visitor helpers, lambda lifting tracks devices

1/2 way through ANF tracking devices but currently very broken

[checkpoint] Rollback ANF scope changes, need to revisit

[checkpoint] Standalone pass unit tests all pass :-)

[checkpoint] TupleGetItem is working

[checkpoint] Giving up on FindFixedAndFreeExpressions, will introduce rewrite 'phase 0' instead

[checkpoint] Unit tests starting to work

 - Add 'is_fixed' field and the 'implicit is_fixed=true' rule
 - Start porting original context planning and annotation tests

[checkpoint] comment polish, get rid of no-op overrides.

[checkpoint] handle pattern-bound vars

[checkpoint] Rework intro comment. Introduce UnifyCollapsed.

[checkpoint] fix 'stored on' vs 'executes on' confusion

[checkpoint] more tests and bug fixes

[checkpoint] improve test

[checkpoint] basic tests passing

[checkpoint] get basic test going again

[checkpoint] Fix merge snafu

[checkpoint] Rename to device_planner / PlanDevices, improve comments, kind checking

[checkpoint] Switch to higher-order domains, add defaulting visitor.

[checkpoint] builds again

[checkpoint] (Does not build) Rework handling of let- and param-bound functions.

[checkpoint] (won't build) device_copy can capture scope, cleanup var->function tracking

[checkpoint] Python uses Devices not types.

[checkpoint] Renames, restore lost make_devices_explicit.cc

[checkpoint] rename test_pass_context_analysis.py to test_pass_make_devices_explicit.py

[checkpoint] few more rollbacks

[checkpoint] rollback bogus rename

[checkpoint] Cleanup default device handling.

[checkpoint] bug fixes, working on trivial example again

The default device stuff is messed up.

[checkpoint] Cleanup on_device handling. Fix param device lookups.

[checkpoint] Merged LowerTE Pass

[checkpoint] Get going with interpreter.

- ToANormalForm considers the arg to on_devivce an inner scope.
- FuseOps does not consider on_device a primitive
- Interpreter knows on_device is id

[checkpoint] undo accidental rename

[checkpoint] starting unit test

[checkpoint] Get rid of device_map from LowerTE

 - Inserted the transform in I think the right place for VM, AOT, Interpreter
   and GraphExecutor.
 - LowerTE still needs the memory plan, so still a lot of re-doing of
   memory planning going on. But at least the device map does not need to
   be rebuilt.
 - Add logging context help -- preparing for the long climb to get all the
   tests going.

Still need to figure out all the default device stuff, I don't think that's being
handled correctly.

[checkpoint] Mixin helper, capture OnDeviceAttrs for params.

TODO:
 - Make sure device pass actually runs.
 - Handle default device when targets_.size() == 1.
 - Device vs int vs DLDeviceType confusion everywhere

[checkpoint] Make device assignment a pass.

VM compiler still needs explicit map.
All very rough.

Lots of mismatches between Device and DLDeviceType as unit of annotation.

[checkpoint] better messages

[checkpoint] Merge in VLOG so can try it out with larger cl.

Will need to split it out again.

[checkpoint] rollback WithAttr node since seems using CallNode is the pattern

[checkpoint] Got rid of CollectDeviceInfo

[checkpoint] fiddling with WithAttr

[checkpoint] trivial

[checkpoint] rename context_analysis.cc to make_devices_explicit.cc and move to transforms/
  • Loading branch information
mbs-octoml committed Sep 15, 2021
1 parent aedf15e commit ebf87be
Show file tree
Hide file tree
Showing 58 changed files with 4,818 additions and 2,197 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ file(GLOB_RECURSE RELAY_PASS_SRCS
src/relay/analysis/*.cc
src/relay/transforms/*.cc
src/relay/quantize/*.cc
src/relay/attrs/*.cc
)
file(GLOB RELAY_BACKEND_SRCS
src/relay/backend/*.cc
Expand Down
23 changes: 2 additions & 21 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,22 +212,14 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod);
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const IRModule& mod);

/*!
* \brief Collect the device mapping information of each expression.
*
* \param expr The expression.
*
* \return The device mapping.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);

/*!
* \brief Collect the device anntation operators.
* \brief Collect the device annotation operators.
*
* \param expr The expression.
*
* \return The annotated expression to device type mapping for annotation ops.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
TVM_DLL Map<Expr, Integer> CollectAllDeviceAnnotationOps(const IRModule& mod);

/*!
* \brief Finds cases that the given match expression does not catch, if any.
Expand Down Expand Up @@ -268,17 +260,6 @@ TVM_DLL IRModule GetCalibrateModule(IRModule mod);
*/
TVM_DLL Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& mod);

/*!
* \brief Analyze the device context of each IR node in a given relay module.
*
* \param mod The module for analysis.
* \param default_device The default device used by unassigned IR nodes.
*
* \return The mapping between an IR node and its associated device.
*/
TVM_DLL std::unordered_map<Expr, Device, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
ContextAnalysis(const IRModule& mod, const Device& default_device);

} // namespace relay
} // namespace tvm

Expand Down
45 changes: 43 additions & 2 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,55 @@ namespace tvm {
namespace relay {

/*!
* \brief Options for the device annotation operators.
* \brief Attributes for the "on_device" operator.
*
* The relay call
* \code
* on_device(expr, device_type=2)
* \endcode
* denotes that the result of \p expr should be stored on the device with \p DLDeviceType 2
* (i.e. \p kDLCuda). Semantically the operator is the identity function.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
// TODO(mbs): Replace device types with TargetDevice.
/*! \brief Device type on which argument expression should be evaluated. */
int device_type;
/*!
* \brief If true, the result device must also be \p device_type and device planning should
* not insert any "device_copy" calls to respect this annotation.
*
* This is used by the device planning pass itself when annotating the planned program.
*/
bool is_fixed;

TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(device_type)
.describe("The virutal device/context type that an expression is annotated with.")
.describe("The type of the virtual device which should hold the expression result.")
.set_default(0);
TVM_ATTR_FIELD(is_fixed)
.describe("If true, do not insert a \"device_copy\" call to respect this annotation.")
.set_default(false);
}
};

/*!
* \brief Attributes for Relay function definitions which capture the devices for the
* function parameters and result.
*/
struct FunctionOnDeviceAttrs : public tvm::AttrsNode<FunctionOnDeviceAttrs> {
constexpr static const char* kFunctionAttrsKey = "on_device";

/*! \brief Device type on which each of the function's arguments already resides. */
Array<Integer> param_device_types;
// TODO(mbs): Replace device types with TargetDevice.
/*! \brief Device type on which function body should be evaluated. */
int result_device_type;

TVM_DECLARE_ATTRS(FunctionOnDeviceAttrs, "relay.attrs.FunctionOnDeviceAttrs") {
TVM_ATTR_FIELD(param_device_types)
.describe("The type of the virtual device which holds each function parameters.");
TVM_ATTR_FIELD(result_device_type)
.describe("The type of the virtual device which will hold the function's result.")
.set_default(0);
}
};
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/attrs/device_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace relay {
* \brief Options for the device copy operators.
*/
struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
// TODO(mbs): Should be TargetDevice.
int dst_dev_type;
int src_dev_type;

Expand Down
13 changes: 13 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,20 @@ class Call : public Expr {
TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
Array<Type> type_args = Array<Type>(), Span span = Span());

/*!
* \brief Returns a copy of this with given properties. A null property denotes 'no change'. Returns
* this if all properties are unchanged. Returns a modified this if this is the only reference
* to the underlying node.
*/
// TODO(mbs): Extend to all node types.
Call CopyWith(Optional<Expr> opt_op = Optional<Expr>(),
Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(nullptr),
Optional<Attrs> opt_attrs = Optional<Attrs>(nullptr),
Optional<Array<Type>> opt_type_args = Optional<Array<Type>>(nullptr),
Optional<Span> opt_span = Optional<Span>(nullptr));

TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
};

/*!
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <unordered_map>
#include <utility>
#include <vector>

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -227,7 +228,7 @@ class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
*
* MixedModeVisitor provides the same recursive API as ExprVisitor, and uses
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratatively to prevent stack overflows
* of the graph and processes them iteratively to prevent stack overflows
*/
class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
public:
Expand Down
15 changes: 13 additions & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,13 +430,24 @@ TVM_DLL Pass SimplifyExpr();
* \brief A pass for manifesting explicit memory allocations and rewriting
* specific dialects.
*
* \param target_host The target used by the host for compliation.
* \param targets The device type and target pairs for compliation.
* \param target_host The target used by the host for compilation.
* \param targets The device type and target pairs for compilation.
*
* \return The pass.
*/
TVM_DLL Pass ManifestAlloc(Target target_host, Map<tvm::Integer, tvm::Target> targets);

/*!
* \brief Uses existing "on_device" and "device_copy" CallNodes to infer the device on which
* every Relay sub-expression should run (and the result stored). Captures the result of that
* analysis using new "on_device" and "device_copy" CallNodes. See
* tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator}
* for help recovering the device for an arbitrary sub-expression in downstream transformations.
*
* \param default_device_type DLDeviceType for default device.
*/
TVM_DLL Pass PlanDevices(DLDeviceType default_device_type);

} // namespace transform

/*!
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/container/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> {
};

/*!
* \brief Array, container representing a contigious sequence of ObjectRefs.
* \brief Array, container representing a contiguous sequence of ObjectRefs.
*
* Array implements in-place copy-on-write semantics.
*
Expand Down
26 changes: 16 additions & 10 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,19 @@
#include <vector>

namespace tvm {
namespace runtime {

typedef DLDevice Device;
// alias DLDevice
using Device = DLDevice;

// A 'null' device type, does not correspond to any DLDeviceType enum.
// TODO(mbs): This is to help us as we transition away from representing the 'homogenous' case
// as a singleton target map indexed by the invalid DLDeviceType '0'.
constexpr DLDeviceType kNullDeviceType = static_cast<DLDeviceType>(0);

// An 'invalid' device type, does not correspond to any DLDeviceType enum.
constexpr DLDeviceType kInvalidDeviceType = static_cast<DLDeviceType>(-1);

namespace runtime {

/*!
* \brief Managed NDArray.
Expand Down Expand Up @@ -481,23 +491,19 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
}

} // namespace runtime

// alias Device
using tvm::runtime::Device;

} // namespace tvm

namespace std {
template <>
struct hash<tvm::runtime::Device> {
std::size_t operator()(const tvm::runtime::Device& dev) const {
struct hash<tvm::Device> {
std::size_t operator()(const tvm::Device& dev) const {
return ((dev.device_id << 8) | dev.device_type);
}
};

template <>
struct equal_to<tvm::runtime::Device> {
bool operator()(const tvm::runtime::Device& lhs, const tvm::runtime::Device& rhs) const {
struct equal_to<tvm::Device> {
bool operator()(const tvm::Device& lhs, const tvm::Device& rhs) const {
return (lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id);
}
};
Expand Down
4 changes: 2 additions & 2 deletions include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ struct VMFunction {
/*! \brief The size of the frame for this function */
Index register_file_size;
/*! \brief The device type of each parameter for this function. */
std::vector<Index> params_device_type;
std::vector<DLDeviceType> params_device_type;

VMFunction(const std::string& name, std::vector<std::string> params,
const std::vector<Instruction>& instructions, Index register_file_size,
const std::vector<Index> params_device_type = {})
const std::vector<DLDeviceType> params_device_type = {})
: name(name),
params(params),
instructions(instructions),
Expand Down
32 changes: 0 additions & 32 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,6 @@
from .feature import Feature


def context_analysis(mod, default_device):
"""Analyze the device context information of each IR node in a Relay
program.
Parameters
----------
mod : tvm.IRModule
The input module.
default_device : tvm.runtime.Device
The default context allocated to an IR node.
"""
return _ffi_api.ContextAnalysis(mod, default_device)


def post_order_visit(expr, fvisit):
"""Recursively visit the ir in post DFS order node,
apply fvisit. Each node is guaranteed to be visited
Expand Down Expand Up @@ -268,23 +253,6 @@ def all_dtypes(expr):
return set(_ffi_api.all_dtypes(expr))


def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
Parameters
----------
expr : tvm.relay.Expr
The input expression.
Returns
-------
ret : Dict[tvm.relay.ir.expr, int]
A dictionary mapping tvm.relay.Expr to device type.
"""
return _ffi_api.CollectDeviceInfo(expr)


def collect_device_annotation_ops(expr):
"""Collect the device annotation ops for the given expression.
Expand Down
39 changes: 28 additions & 11 deletions python/tvm/relay/op/annotation/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,19 @@
from .. import op as reg


def on_device(data, device):
def _device_to_int(device):
if isinstance(device, _Device):
return device.device_type
elif isinstance(device, str):
return _nd.device(device).device_type
else:
raise ValueError(
"device is expected to be the type of Device or "
"str, but received %s" % (type(device))
)


def on_device(data, device, is_fixed=False):
"""Annotate an expression with a certain device type.
Parameters
Expand All @@ -33,21 +45,26 @@ def on_device(data, device):
device : Union[:py:class:`Device`, str]
The device type to annotate.
is_fixed : bool
If true, annotation does not imply a device_copy may be inserted.
(This parameter is used internally by the compiler and unit tests and
should not need to be set in user programs.)
Returns
-------
result : tvm.relay.Expr
The annotated expression.
"""
if isinstance(device, _Device):
device = device.device_type
elif isinstance(device, str):
device = _nd.device(device).device_type
else:
raise ValueError(
"device is expected to be the type of Device or "
"str, but received %s" % (type(device))
)
return _make.on_device(data, device)
return _make.on_device(data, _device_to_int(device), is_fixed)


# for testing only
def function_on_device(function, param_devices, result_device):
"""Attaches attribute to function indicating the devices for its parameters and result.
"""

return _make.function_on_device(function,
[_device_to_int(d) for d in param_devices], _device_to_int(result_device))


def stop_fusion(data):
Expand Down
Loading

0 comments on commit ebf87be

Please sign in to comment.