Skip to content

Commit

Permalink
[checkpoint] cross-device example working on vm!
Browse files Browse the repository at this point in the history
[checkpoint] Can't wrap on_device around big-lambdas.

[checkpoint] ToANF working I think

 - Cleanup pairs-of-pairs for OnDeviceProps etc.
 - Don't wrap OnDevice around expressions that don't need it.

[checkpoint] ANF working again

[checkpoint] Visitor helpers, lambda lifting tracks devices

1/2 way through ANF tracking devices but currently very broken

[checkpoint] Rollback ANF scope changes, need to revisit

[checkpoint] Standalone pass unit tests all pass :-)

[checkpoint] TupleGetItem is working

[checkpoint] Giving up on FindFixedAndFreeExpressions, will introduce rewrite 'phase 0' instead

[checkpoint] Unit tests starting to work

 - Add 'is_fixed' field and the 'implicit is_fixed=true' rule
 - Start porting original context planning and annotation tests

[checkpoint] comment polish, get rid of no-op overrides.

[checkpoint] handle pattern-bound vars

[checkpoint] Rework intro comment. Introduce UnifyCollapsed.

[checkpoint] fix 'stored on' vs 'executes on' confusion

[checkpoint] more tests and bug fixes

[checkpoint] improve test

[checkpoint] basic tests passing

[checkpoint] get basic test going again

[checkpoint] Fix merge snafu

[checkpoint] Rename to device_planner / PlanDevices, improve comments, kind checking

[checkpoint] Switch to higher-order domains, add defaulting visitor.

[checkpoint] builds again

[checkpoint] (Does not build) Rework handling of let- and param-bound functions.

[checkpoint] (won't build) device_copy can capture scope, cleanup var->function tracking

[checkpoint] Python uses Devices not types.

[checkpoint] Renames, restore lost make_devices_explicit.cc

[checkpoint] rename test_pass_context_analysis.py to test_pass_make_devices_explicit.py

[checkpoint] few more rollbacks

[checkpoint] rollback bogus rename

[checkpoint] Cleanup default device handling.

[checkpoint] bug fixes, working on trivial example again

The default device stuff is messed up.

[checkpoint] Cleanup on_device handling. Fix param device lookups.

[checkpoint] Merged LowerTE Pass

[checkpoint] Get going with interpreter.

- ToANormalForm considers the arg to on_devivce an inner scope.
- FuseOps does not consider on_device a primitive
- Interpreter knows on_device is id

[checkpoint] undo accidental rename

[checkpoint] starting unit test

[checkpoint] Get rid of device_map from LowerTE

 - Inserted the transform in I think the right place for VM, AOT, Interpreter
   and GraphExecutor.
 - LowerTE still needs the memory plan, so still a lot of re-doing of
   memory planning going on. But at least the device map does not need to
   be rebuilt.
 - Add logging context help -- preparing for the long climb to get all the
   tests going.

Still need to figure out all the default device stuff, I don't think that's being
handled correctly.

[checkpoint] Mixin helper, capture OnDeviceAttrs for params.

TODO:
 - Make sure device pass actually runs.
 - Handle default device when targets_.size() == 1.
 - Device vs int vs DLDeviceType confusion everywhere

[checkpoint] Make device assignment a pass.

VM compiler still needs explicit map.
All very rough.

Lots of mismatches between Device and DLDeviceType as unit of annotation.

[checkpoint] better messages

[checkpoint] Merge in VLOG so can try it out with larger cl.

Will need to split it out again.

[checkpoint] rollback WithAttr node since seems using CallNode is the pattern

[checkpoint] Got rid of CollectDeviceInfo

[checkpoint] fiddling with WithAttr

[checkpoint] trivial

[checkpoint] rename context_analysis.cc to make_devices_explicit.cc and move to transforms/
  • Loading branch information
mbs-octoml committed Sep 14, 2021
1 parent e1ae821 commit 651b848
Show file tree
Hide file tree
Showing 67 changed files with 5,076 additions and 2,245 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ file(GLOB_RECURSE RELAY_PASS_SRCS
src/relay/analysis/*.cc
src/relay/transforms/*.cc
src/relay/quantize/*.cc
src/relay/attrs/*.cc
)
file(GLOB RELAY_BACKEND_SRCS
src/relay/backend/*.cc
Expand Down
23 changes: 2 additions & 21 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,22 +212,14 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod);
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const IRModule& mod);

/*!
* \brief Collect the device mapping information of each expression.
*
* \param expr The expression.
*
* \return The device mapping.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);

/*!
* \brief Collect the device anntation operators.
* \brief Collect the device annotation operators.
*
* \param expr The expression.
*
* \return The annotated expression to device type mapping for annotation ops.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
TVM_DLL Map<Expr, Integer> CollectAllDeviceAnnotationOps(const IRModule& mod);

/*!
* \brief Finds cases that the given match expression does not catch, if any.
Expand Down Expand Up @@ -268,17 +260,6 @@ TVM_DLL IRModule GetCalibrateModule(IRModule mod);
*/
TVM_DLL Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& mod);

/*!
* \brief Analyze the device context of each IR node in a given relay module.
*
* \param mod The module for analysis.
* \param default_device The default device used by unassigned IR nodes.
*
* \return The mapping between an IR node and its associated device.
*/
TVM_DLL std::unordered_map<Expr, Device, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
ContextAnalysis(const IRModule& mod, const Device& default_device);

} // namespace relay
} // namespace tvm

Expand Down
45 changes: 43 additions & 2 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,55 @@ namespace tvm {
namespace relay {

/*!
* \brief Options for the device annotation operators.
* \brief Attributes for the "on_device" operator.
*
* The relay call
* \code
* on_device(expr, device_type=2)
* \endcode
* denotes that the result of \p expr should be stored on the device with \p DLDeviceType 2
* (i.e. \p kDLCuda). Semantically the operator is the identity function.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
// TODO(mbs): Replace device types with TargetDevice.
/*! \brief Device type on which argument expression should be evaluated. */
int device_type;
/*!
* \brief If true, the result device must also be \p device_type and device planning should
* not insert any "device_copy" calls to respect this annotation.
*
* This is used by the device planning pass itself when annotating the planned program.
*/
bool is_fixed;

TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(device_type)
.describe("The virutal device/context type that an expression is annotated with.")
.describe("The type of the virtual device which should hold the expression result.")
.set_default(0);
TVM_ATTR_FIELD(is_fixed)
.describe("If true, do not insert a \"device_copy\" call to respect this annotation.")
.set_default(false);
}
};

/*!
* \brief Attributes for Relay function definitions which capture the devices for the
* function parameters and result.
*/
struct FunctionOnDeviceAttrs : public tvm::AttrsNode<FunctionOnDeviceAttrs> {
constexpr static const char* kFunctionAttrsKey = "on_device";

/*! \brief Device type on which each of the function's arguments already resides. */
Array<Integer> param_device_types;
// TODO(mbs): Replace device types with TargetDevice.
/*! \brief Device type on which function body should be evaluated. */
int result_device_type;

TVM_DECLARE_ATTRS(FunctionOnDeviceAttrs, "relay.attrs.FunctionOnDeviceAttrs") {
TVM_ATTR_FIELD(param_device_types)
.describe("The type of the virtual device which holds each function parameters.");
TVM_ATTR_FIELD(result_device_type)
.describe("The type of the virtual device which will hold the function's result.")
.set_default(0);
}
};
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/attrs/device_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace relay {
* \brief Options for the device copy operators.
*/
struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
// TODO(mbs): Should be TargetDevice.
int dst_dev_type;
int src_dev_type;

Expand Down
13 changes: 13 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,20 @@ class Call : public Expr {
TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
Array<Type> type_args = Array<Type>(), Span span = Span());

/*!
* \brief Returns a copy of this with given properties. A null property denotes 'no change'. Returns
* this if all properties are unchanged. Returns a modified this if this is the only reference
* to the underlying node.
*/
// TODO(mbs): Extend to all node types.
Call CopyWith(Optional<Expr> opt_op = Optional<Expr>(),
Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(nullptr),
Optional<Attrs> opt_attrs = Optional<Attrs>(nullptr),
Optional<Array<Type>> opt_type_args = Optional<Array<Type>>(nullptr),
Optional<Span> opt_span = Optional<Span>(nullptr));

TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
};

/*!
Expand Down
3 changes: 2 additions & 1 deletion include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <unordered_map>
#include <utility>
#include <vector>

namespace tvm {
namespace relay {

Expand Down Expand Up @@ -227,7 +228,7 @@ class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
*
* MixedModeVisitor provides the same recursive API as ExprVisitor, and uses
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratatively to prevent stack overflows
* of the graph and processes them iteratively to prevent stack overflows
*/
class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
public:
Expand Down
15 changes: 13 additions & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,13 +430,24 @@ TVM_DLL Pass SimplifyExpr();
* \brief A pass for manifesting explicit memory allocations and rewriting
* specific dialects.
*
* \param target_host The target used by the host for compliation.
* \param targets The device type and target pairs for compliation.
* \param target_host The target used by the host for compilation.
* \param targets The device type and target pairs for compilation.
*
* \return The pass.
*/
TVM_DLL Pass ManifestAlloc(Target target_host, Map<tvm::Integer, tvm::Target> targets);

/*!
* \brief Uses existing "on_device" and "device_copy" CallNodes to infer the device on which
* every Relay sub-expression should run (and the result stored). Captures the result of that
* analysis using new "on_device" and "device_copy" CallNodes. See
* tvm::relay::transform::{LexicalOnDeviceMixin,DeviceAwareExprVisitor,DeviceAwareExprMutator}
* for help recovering the device for an arbitrary sub-expression in downstream transformations.
*
* \param default_device_type DLDeviceType for default device.
*/
TVM_DLL Pass PlanDevices(DLDeviceType default_device_type);

} // namespace transform

/*!
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/container/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> {
};

/*!
* \brief Array, container representing a contigious sequence of ObjectRefs.
* \brief Array, container representing a contiguous sequence of ObjectRefs.
*
* Array implements in-place copy-on-write semantics.
*
Expand Down
104 changes: 92 additions & 12 deletions include/tvm/runtime/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#define TVM_RUNTIME_LOGGING_H_

#include <dmlc/common.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/c_runtime_api.h>

#include <ctime>
Expand Down Expand Up @@ -129,8 +130,9 @@
* a = ...
* b = ...
* // if quit_on_assertion is true, if a==b, continue, otherwise quit.
* // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' (default
* behaviour) COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting"
* // if quit_on_assertion is false, if a==b, continue, otherwise 'return false'
* // (default behaviour)
* COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting"
* ...
* for (int i = 0; i < N; i++) {
* a = ...
Expand Down Expand Up @@ -395,8 +397,8 @@ class LogMessageVoidify {
inline bool DebugLoggingEnabled() {
static int state = 0;
if (state == 0) {
if (auto var = std::getenv("TVM_LOG_DEBUG")) {
if (std::string(var) == "1") {
if (const char* var = std::getenv("TVM_LOG_DEBUG")) {
if (var[0] == '1') {
state = 1;
} else {
state = -1;
Expand All @@ -409,6 +411,63 @@ inline bool DebugLoggingEnabled() {
return state == 1;
}

/*!
* \brief Returns true if a VLOG statement in \p filename is enabled by the environment
* variable \p TVM_LOG_DEBUG for logging at verbosity \p level.
*
* Filenames are canonicalized to be w.r.t. the src/ dir of the TVM tree. (VLOG's should not
* appear under include/).
*
* To enable file \p relay/foo.cc up to level 2 and \p ir/bar.cc for level 0 only set:
* \code
* TVM_LOG_DEBUG="1;relay/foo.cc=2;ir/bar.cc=0;"
* \endcode
*
* To enable all files up to level 3 but disable \p ir/bar.cc set:
* \code
* TVM_LOG_DEBUG="1;*=2;ir/bar.cc=-1;"
* \endcode
*/
bool VerboseLoggingEnabled(const char* filename, int level);

/*!
* A stack of VLOG context messages.
*
* For use by VLOG_CONTEXT only.
*/
class VLogContext {
public:
void Push(std::stringstream* stream) { context_stack.push_back(stream); }
void Pop() {
if (!context_stack.empty()) {
context_stack.pop_back();
}
}

std::string str() const;

private:
std::vector<std::stringstream*> context_stack;
};

/*! Thread local VLogContext for tracking a stack of VLOG context messages. */
using ThreadLocalVLogContext = dmlc::ThreadLocalStore<VLogContext>;

/*!
* \brief A RAII class to push/pos a VLOG context message onto the thread-local stack.
*
* For use by VLOG_CONTEXT only.
*/
class VLogContextEntry {
public:
VLogContextEntry() { ThreadLocalVLogContext::Get()->Push(&sstream_); }
~VLogContextEntry() { ThreadLocalVLogContext::Get()->Pop(); }
std::ostream& stream() { return sstream_; }

private:
std::stringstream sstream_;
};

constexpr const char* kTVM_INTERNAL_ERROR_MESSAGE =
"\n"
"---------------------------------------------------------------\n"
Expand Down Expand Up @@ -447,6 +506,7 @@ TVM_CHECK_FUNC(_GE, >=)
TVM_CHECK_FUNC(_EQ, ==)
TVM_CHECK_FUNC(_NE, !=)
#pragma GCC diagnostic pop

} // namespace detail

#define LOG(level) LOG_##level
Expand Down Expand Up @@ -487,17 +547,43 @@ TVM_CHECK_FUNC(_NE, !=)
#define DLOG_IF(severity, condition) \
LOG_IF(severity, ::tvm::runtime::detail::DebugLoggingEnabled() && (condition))

/*!
* \brief Push a context message onto an internal stack. All VLOG messages will include
* this stack as their prefix to help with debugging.
*/
#define VLOG_CONTEXT \
::tvm::runtime::detail::VLogContextEntry vlog_entry_; \
vlog_entry_.stream()

#else

#define LOG_DFATAL LOG_ERROR
#define DFATAL ERROR
#define DLOG(severity) true ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity)
#define DLOG_IF(severity, condition) \
(true || !(condition)) ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity)
#define VLOG_CONTEXT true ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(INFO)

#endif

/*!
* \brief If the containing file has been enabled at level or greater via
* TVM_LOG_DEBUG (see VerboseLoggingEnabled above for the format) then log a
* message. Otherwise no-op.
*/
#define VLOG(level) \
DLOG_IF(INFO, ::tvm::runtime::detail::VerboseLoggingEnabled(__FILE__, (level))) \
<< ::tvm::runtime::detail::ThreadLocalVLogContext::Get()->str()

#if TVM_LOG_DEBUG
#define DCHECK(x) CHECK(x)
#define DCHECK_LT(x, y) CHECK((x) < (y))
#define DCHECK_GT(x, y) CHECK((x) > (y))
#define DCHECK_LE(x, y) CHECK((x) <= (y))
#define DCHECK_GE(x, y) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) CHECK((x) == (y))
#define DCHECK_NE(x, y) CHECK((x) != (y))
#else
#define DCHECK(x) \
while (false) CHECK(x)
#define DCHECK_LT(x, y) \
Expand All @@ -512,14 +598,6 @@ TVM_CHECK_FUNC(_NE, !=)
while (false) CHECK((x) == (y))
#define DCHECK_NE(x, y) \
while (false) CHECK((x) != (y))
#else
#define DCHECK(x) CHECK(x)
#define DCHECK_LT(x, y) CHECK((x) < (y))
#define DCHECK_GT(x, y) CHECK((x) > (y))
#define DCHECK_LE(x, y) CHECK((x) <= (y))
#define DCHECK_GE(x, y) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) CHECK((x) == (y))
#define DCHECK_NE(x, y) CHECK((x) != (y))
#endif

#define TVM_ICHECK_INDENT " "
Expand Down Expand Up @@ -549,8 +627,10 @@ TVM_CHECK_FUNC(_NE, !=)
(x) : (x)) // NOLINT(*)

} // namespace runtime

// Re-export error types
using runtime::Error;
using runtime::InternalError;

} // namespace tvm
#endif // TVM_RUNTIME_LOGGING_H_
Loading

0 comments on commit 651b848

Please sign in to comment.