Skip to content

Commit

Permalink
[commit] few more rollbacks
Browse files Browse the repository at this point in the history
[commit] rollback bogus rename

[checkpoint] Cleanup default device handling.

[checkpoint] bug fixes, working on trivial example again

The default device stuff is messed up.

[checkpoint] Cleanup on_device handling. Fix param device lookups.

[checkpoint] Merged LowerTE Pass

[checkpoint] Get going with interpreter.

- ToANormalForm considers the arg to on_devivce an inner scope.
- FuseOps does not consider on_device a primitive
- Interpreter knows on_device is id

[checkpoint] undo accidental rename

[checkpoint] starting unit test

[checkpoint] Get rid of device_map from LowerTE

 - Inserted the transform in I think the right place for VM, AOT, Interpreter
   and GraphExecutor.
 - LowerTE still needs the memory plan, so still a lot of re-doing of
   memory planning going on. But at least the device map does not need to
   be rebuilt.
 - Add logging context help -- preparing for the long climb to get all the
   tests going.

Still need to figure out all the default device stuff, I don't think that's being
handled correctly.

[checkpoint] Mixin helper, capture OnDeviceAttrs for params.

TODO:
 - Make sure device pass actually runs.
 - Handle default device when targets_.size() == 1.
 - Device vs int vs DLDeviceType confusion everywhere

[checkpoint] Make device assignment a pass.

VM compiler still needs explicit map.
All very rough.

Lots of mismatches between Device and DLDeviceType as unit of annotation.

[commit] better messages

[checkpoint] Merge in VLOG so can try it out with larger cl.

Will need to split it out again.

[checkpoint] rollback WithAttr node since seems using CallNode is the pattern

[checkpoint] Got rid of CollectDeviceInfo

[checkpoint] fiddling with WithAttr

[checkpoint] trivial
  • Loading branch information
mbs-octoml committed Aug 27, 2021
1 parent 6f45b86 commit d0c7a31
Show file tree
Hide file tree
Showing 48 changed files with 1,174 additions and 1,571 deletions.
23 changes: 2 additions & 21 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,22 +212,14 @@ TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Expr& expr, const IRModule& mod);
TVM_DLL tvm::Array<TypeVar> AllTypeVars(const Type& t, const IRModule& mod);

/*!
* \brief Collect the device mapping information of each expression.
*
* \param expr The expression.
*
* \return The device mapping.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceInfo(const Expr& expr);

/*!
* \brief Collect the device anntation operators.
* \brief Collect the device annotation operators.
*
* \param expr The expression.
*
* \return The annotated expression to device type mapping for annotation ops.
*/
TVM_DLL Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr);
TVM_DLL Map<Expr, Integer> CollectAllDeviceAnnotationOps(const IRModule& mod);

/*!
* \brief Finds cases that the given match expression does not catch, if any.
Expand Down Expand Up @@ -268,17 +260,6 @@ TVM_DLL IRModule GetCalibrateModule(IRModule mod);
*/
TVM_DLL Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& mod);

/*!
* \brief Analyze the device context of each IR node in a given relay module.
*
* \param mod The module for analysis.
* \param default_device The default device used by unassigned IR nodes.
*
* \return The mapping between an IR node and its associated device.
*/
TVM_DLL std::unordered_map<Expr, Device, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
ContextAnalysis(const IRModule& mod, const Device& default_device);

} // namespace relay
} // namespace tvm

Expand Down
12 changes: 10 additions & 2 deletions include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,22 @@ namespace tvm {
namespace relay {

/*!
* \brief Options for the device annotation operators.
* \brief Attributes for the "on_device" operator.
*
* The relay call
* \code
* on_device(expr, device_type=3)
* \endcode
* denotes that \p expr should be evaluated on the device with DLDeviceType 3. Semantically the
* operator is the identity.
*/
struct OnDeviceAttrs : public tvm::AttrsNode<OnDeviceAttrs> {
// TODO(mbs): Should be TargetDevice
int device_type;

TVM_DECLARE_ATTRS(OnDeviceAttrs, "relay.attrs.OnDeviceAttrs") {
TVM_ATTR_FIELD(device_type)
.describe("The virutal device/context type that an expression is annotated with.")
.describe("The virtual device type that an expression is annotated to execute on.")
.set_default(0);
}
};
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/attrs/device_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace relay {
* \brief Options for the device copy operators.
*/
struct DeviceCopyAttrs : public tvm::AttrsNode<DeviceCopyAttrs> {
// TODO(mbs): Should be TargetDevice.
int dst_dev_type;
int src_dev_type;

Expand Down
13 changes: 13 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,20 @@ class Call : public Expr {
TVM_DLL Call(Expr op, Array<Expr> args, Attrs attrs = Attrs(),
Array<Type> type_args = Array<Type>(), Span span = Span());

/*!
* \brief Returns a copy of this with given properties. A null property denotes 'no change'. Returns
* this if all properties are unchanged. Returns a modified this if this is the only reference
* to the underlying node.
*/
// TODO(mbs): Extend to all node types.
Call CopyWith(Optional<Expr> opt_op = Optional<Expr>(),
Optional<Array<Expr>> opt_args = Optional<Array<Expr>>(nullptr),
Optional<Attrs> opt_attrs = Optional<Attrs>(nullptr),
Optional<Array<Type>> opt_type_args = Optional<Array<Type>>(nullptr),
Optional<Span> opt_span = Optional<Span>(nullptr));

TVM_DEFINE_OBJECT_REF_METHODS(Call, RelayExpr, CallNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(CallNode);
};

/*!
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/expr_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ class ExprMutator : public ::tvm::relay::ExprFunctor<Expr(const Expr&)> {
*
* MixedModeVisitor provides the same recursive API as ExprVisitor, and uses
* recursion to traverse most forms of the IR, but under the hood it expands nested dataflow regions
* of the graph and processes them iteratatively to prevent stack overflows
* of the graph and processes them iteratively to prevent stack overflows
*/
class MixedModeVisitor : public ::tvm::relay::ExprVisitor {
public:
Expand Down
12 changes: 10 additions & 2 deletions include/tvm/relay/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,13 +430,21 @@ TVM_DLL Pass SimplifyExpr();
* \brief A pass for manifesting explicit memory allocations and rewriting
* specific dialects.
*
* \param target_host The target used by the host for compliation.
* \param targets The device type and target pairs for compliation.
* \param target_host The target used by the host for compilation.
* \param targets The device type and target pairs for compilation.
*
* \return The pass.
*/
TVM_DLL Pass ManifestAlloc(Target target_host, Map<tvm::Integer, tvm::Target> targets);

/*!
* \brief Inserts "on_device" Nodes to make explicit which device is executing each Relay
* expression.
*
* \param default_device_type DLDeviceType for default device.
*/
TVM_DLL Pass MakeDevicesExplicit(DLDeviceType default_device_type);

} // namespace transform

/*!
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/container/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ class ArrayNode : public Object, public InplaceArrayBase<ArrayNode, ObjectRef> {
};

/*!
* \brief Array, container representing a contigious sequence of ObjectRefs.
* \brief Array, container representing a contiguous sequence of ObjectRefs.
*
* Array implements in-place copy-on-write semantics.
*
Expand Down
99 changes: 87 additions & 12 deletions include/tvm/runtime/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#define TVM_RUNTIME_LOGGING_H_

#include <dmlc/common.h>
#include <dmlc/thread_local.h>
#include <tvm/runtime/c_runtime_api.h>

#include <ctime>
Expand Down Expand Up @@ -129,8 +130,9 @@
* a = ...
* b = ...
* // if quit_on_assertion is true, if a==b, continue, otherwise quit.
* // if quit_on_assertion is false, if a==b, continue, otherwise 'return false' (default
* behaviour) COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting"
* // if quit_on_assertion is false, if a==b, continue, otherwise 'return false'
* // (default behaviour)
* COND_CHECK_EQ(quit_on_assertion, a, b) << "some error message when quiting"
* ...
* for (int i = 0; i < N; i++) {
* a = ...
Expand Down Expand Up @@ -395,8 +397,8 @@ class LogMessageVoidify {
inline bool DebugLoggingEnabled() {
static int state = 0;
if (state == 0) {
if (auto var = std::getenv("TVM_LOG_DEBUG")) {
if (std::string(var) == "1") {
if (const char* var = std::getenv("TVM_LOG_DEBUG")) {
if (var[0] == '1') {
state = 1;
} else {
state = -1;
Expand All @@ -409,6 +411,58 @@ inline bool DebugLoggingEnabled() {
return state == 1;
}

/*!
* \brief Returns true if \p filename is mentioned in the environment
* variable \p TVM_LOG_DEBUG as enabled for 'verbose' logging at /p level
* or greater. Filenames are canonicalized to be w.r.t. the src/ dir of
* the TVM tree.
*
* To enable file \p relay/foo.cc for level <= 2 and \p ir/bar.cc for level <= 0 set:
*
* \code
* TVM_LOG_DEBUG="1;relay/foo.cc=2;ir/bar.cc=0;"
* \endcode
*/
bool VerboseLoggingEnabled(const char* filename, int level);

/*!
* A stack of VLOG context messages.
*
* For use by VLOG_CONTEXT only.
*/
class VLogContext {
public:
void Push(std::stringstream* stream) { context_stack.push_back(stream); }
void Pop() {
if (!context_stack.empty()) {
context_stack.pop_back();
}
}

std::string str() const;

private:
std::vector<std::stringstream*> context_stack;
};

/*! Thread local VLogContext for tracking a stack of VLOG context messages. */
using ThreadLocalVLogContext = dmlc::ThreadLocalStore<VLogContext>;

/*!
* \brief A RAII class to push/pos a VLOG context message onto the thread-local stack.
*
* For use by VLOG_CONTEXT only.
*/
class VLogContextEntry {
public:
VLogContextEntry() { ThreadLocalVLogContext::Get()->Push(&sstream_); }
~VLogContextEntry() { ThreadLocalVLogContext::Get()->Pop(); }
std::ostream& stream() { return sstream_; }

private:
std::stringstream sstream_;
};

constexpr const char* kTVM_INTERNAL_ERROR_MESSAGE =
"\n"
"---------------------------------------------------------------\n"
Expand Down Expand Up @@ -447,6 +501,7 @@ TVM_CHECK_FUNC(_GE, >=)
TVM_CHECK_FUNC(_EQ, ==)
TVM_CHECK_FUNC(_NE, !=)
#pragma GCC diagnostic pop

} // namespace detail

#define LOG(level) LOG_##level
Expand Down Expand Up @@ -487,17 +542,43 @@ TVM_CHECK_FUNC(_NE, !=)
#define DLOG_IF(severity, condition) \
LOG_IF(severity, ::tvm::runtime::detail::DebugLoggingEnabled() && (condition))

/*!
* \brief Push a context message onto an internal stack. All VLOG messages will include
* this stack as their prefix to help with debugging.
*/
#define VLOG_CONTEXT \
::tvm::runtime::detail::VLogContextEntry vlog_entry_; \
vlog_entry_.stream()

#else

#define LOG_DFATAL LOG_ERROR
#define DFATAL ERROR
#define DLOG(severity) true ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity)
#define DLOG_IF(severity, condition) \
(true || !(condition)) ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(severity)
#define VLOG_CONTEXT true ? (void)0 : ::tvm::runtime::detail::LogMessageVoidify() & LOG(INFO)

#endif

/*!
* \brief If the containing file has been enabled at level or greater via
* TVM_LOG_DEBUG (see VerboseLoggingEnabled above for the format) then log a
* message. Otherwise no-op.
*/
#define VLOG(level) \
DLOG_IF(INFO, ::tvm::runtime::detail::VerboseLoggingEnabled(__FILE__, (level))) \
<< ::tvm::runtime::detail::ThreadLocalVLogContext::Get()->str()

#if TVM_LOG_DEBUG
#define DCHECK(x) CHECK(x)
#define DCHECK_LT(x, y) CHECK((x) < (y))
#define DCHECK_GT(x, y) CHECK((x) > (y))
#define DCHECK_LE(x, y) CHECK((x) <= (y))
#define DCHECK_GE(x, y) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) CHECK((x) == (y))
#define DCHECK_NE(x, y) CHECK((x) != (y))
#else
#define DCHECK(x) \
while (false) CHECK(x)
#define DCHECK_LT(x, y) \
Expand All @@ -512,14 +593,6 @@ TVM_CHECK_FUNC(_NE, !=)
while (false) CHECK((x) == (y))
#define DCHECK_NE(x, y) \
while (false) CHECK((x) != (y))
#else
#define DCHECK(x) CHECK(x)
#define DCHECK_LT(x, y) CHECK((x) < (y))
#define DCHECK_GT(x, y) CHECK((x) > (y))
#define DCHECK_LE(x, y) CHECK((x) <= (y))
#define DCHECK_GE(x, y) CHECK((x) >= (y))
#define DCHECK_EQ(x, y) CHECK((x) == (y))
#define DCHECK_NE(x, y) CHECK((x) != (y))
#endif

#define TVM_ICHECK_INDENT " "
Expand Down Expand Up @@ -549,8 +622,10 @@ TVM_CHECK_FUNC(_NE, !=)
(x) : (x)) // NOLINT(*)

} // namespace runtime

// Re-export error types
using runtime::Error;
using runtime::InternalError;

} // namespace tvm
#endif // TVM_RUNTIME_LOGGING_H_
26 changes: 16 additions & 10 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,19 @@
#include <vector>

namespace tvm {
namespace runtime {

typedef DLDevice Device;
// alias DLDevice
using Device = DLDevice;

// A 'null' device type, does not correspond to any DLDeviceType enum.
// TODO(mbs): This is to help us as we transition away from representing the 'homogenous' case
// as a singleton target map indexed by the invalid DLDeviceType '0'.
constexpr DLDeviceType kNullDeviceType = static_cast<DLDeviceType>(0);

// An 'invalid' device type, does not correspond to any DLDeviceType enum.
constexpr DLDeviceType kInvalidDeviceType = static_cast<DLDeviceType>(-1);

namespace runtime {

/*!
* \brief Managed NDArray.
Expand Down Expand Up @@ -481,23 +491,19 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
}

} // namespace runtime

// alias Device
using tvm::runtime::Device;

} // namespace tvm

namespace std {
template <>
struct hash<tvm::runtime::Device> {
std::size_t operator()(const tvm::runtime::Device& dev) const {
struct hash<tvm::Device> {
std::size_t operator()(const tvm::Device& dev) const {
return ((dev.device_id << 8) | dev.device_type);
}
};

template <>
struct equal_to<tvm::runtime::Device> {
bool operator()(const tvm::runtime::Device& lhs, const tvm::runtime::Device& rhs) const {
struct equal_to<tvm::Device> {
bool operator()(const tvm::Device& lhs, const tvm::Device& rhs) const {
return (lhs.device_type == rhs.device_type && lhs.device_id == rhs.device_id);
}
};
Expand Down
Loading

0 comments on commit d0c7a31

Please sign in to comment.