Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Add object path tracing to StructuralEqual #12101

Merged
merged 5 commits into from
Aug 3, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/tvm/node/reflection.h
Original file line number Diff line number Diff line change
Expand Up @@ -404,5 +404,11 @@ inline bool ReflectionVTable::GetReprBytes(const Object* self, std::string* repr
}
}

/*!
* \brief Given an object and an address of its attribute, return the key of the attribute.
* \return nullptr if no attribute with the given address exists.
*/
Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address);

} // namespace tvm
#endif // TVM_NODE_REFLECTION_H_
153 changes: 136 additions & 17 deletions include/tvm/node/structural_equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#define TVM_NODE_STRUCTURAL_EQUAL_H_

#include <tvm/node/functor.h>
#include <tvm/node/object_path.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/data_type.h>

Expand Down Expand Up @@ -56,6 +57,27 @@ class BaseValueEqual {
}
};

/*!
* \brief Pair of `ObjectPath`s, one for each object being tested for structural equality.
*/
class ObjectPathPairNode : public Object {
public:
ObjectPath lhs_path;
ObjectPath rhs_path;

ObjectPathPairNode(ObjectPath lhs_path, ObjectPath rhs_path);

static constexpr const char* _type_key = "ObjectPathPair";
TVM_DECLARE_FINAL_OBJECT_INFO(ObjectPathPairNode, Object);
};

class ObjectPathPair : public ObjectRef {
public:
ObjectPathPair(ObjectPath lhs_path, ObjectPath rhs_path);

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPathPair, ObjectRef, ObjectPathPairNode);
};

/*!
* \brief Content-aware structural equality comparator for objects.
*
Expand Down Expand Up @@ -99,7 +121,10 @@ class StructuralEqual : public BaseValueEqual {
* equality checking. Instead, it can store the necessary equality conditions
* and check later via an internally managed stack.
*/
class SEqualReducer : public BaseValueEqual {
class SEqualReducer {
gbonik marked this conversation as resolved.
Show resolved Hide resolved
private:
struct PathTracingData;

public:
/*! \brief Internal handler that defines custom behaviors.. */
class Handler {
Expand All @@ -110,12 +135,24 @@ class SEqualReducer : public BaseValueEqual {
* \param lhs The left operand.
* \param rhs The right operand.
* \param map_free_vars Whether do we allow remap variables if possible.
* \param current_paths Optional paths to `lhs` and `rhs` objects, for error traceability.
*
* \return false if there is an immediate failure, true otherwise.
* \note This function may save the equality condition of (lhs == rhs) in an internal
* stack and try to resolve later.
*/
virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) = 0;
virtual bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const Optional<ObjectPathPair>& current_paths) = 0;

/*!
* \brief Mark the comparison as failed, but don't fail immediately.
*
* This is useful for producing better error messages when comparing containers.
* For example, if two array sizes mismatch, it's better to mark the comparison as failed
* but compare array elements anyway, so that we could find the true first mismatch.
*/
virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;

/*!
* \brief Lookup the graph node equal map for vars that are already mapped.
*
Expand All @@ -129,28 +166,72 @@ class SEqualReducer : public BaseValueEqual {
* \brief Mark current comparison as graph node equal comparison.
*/
virtual void MarkGraphNode() = 0;
};

using BaseValueEqual::operator();
protected:
using PathTracingData = SEqualReducer::PathTracingData;
};

/*! \brief default constructor */
SEqualReducer() = default;
/*!
* \brief Constructor with a specific handler.
* \param handler The equal handler for objects.
* \param tracing_data Optional pointer to the path tracing data.
* \param map_free_vars Whether or not to map free variables.
*/
explicit SEqualReducer(Handler* handler, bool map_free_vars)
: handler_(handler), map_free_vars_(map_free_vars) {}
explicit SEqualReducer(Handler* handler, const PathTracingData* tracing_data, bool map_free_vars)
: handler_(handler), tracing_data_(tracing_data), map_free_vars_(map_free_vars) {}

/*!
* \brief Reduce condition to comparison of two attribute values.
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
bool operator()(const double& lhs, const double& rhs) const;
bool operator()(const int64_t& lhs, const int64_t& rhs) const;
bool operator()(const uint64_t& lhs, const uint64_t& rhs) const;
bool operator()(const int& lhs, const int& rhs) const;
bool operator()(const bool& lhs, const bool& rhs) const;
bool operator()(const std::string& lhs, const std::string& rhs) const;
bool operator()(const DataType& lhs, const DataType& rhs) const;

template <typename ENum, typename = typename std::enable_if<std::is_enum<ENum>::value>::type>
bool operator()(const ENum& lhs, const ENum& rhs) const {
using Underlying = typename std::underlying_type<ENum>::type;
static_assert(std::is_same<Underlying, int>::value,
"Enum must have `int` as the underlying type");
return EnumAttrsEqual(static_cast<int>(lhs), static_cast<int>(rhs), &lhs, &rhs);
}

/*!
* \brief Reduce condition to comparison of two objects.
* \param lhs The left operand.
* \param rhs The right operand.
* \return the immediate check result.
*/
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const;

/*!
* \brief Reduce condition to comparison of two objects.
*
* Like `operator()`, but with an additional `paths` parameter that specifies explicit object
* paths for `lhs` and `rhs`. This is useful for implementing SEqualReduce() methods for container
* objects like Array and Map, or other custom objects that store nested objects that are not
* simply attributes.
*
* Can only be called when `IsPathTracingEnabled()` is `true`.
*
* \param lhs The left operand.
* \param rhs The right operand.
* \param paths Object paths for `lhs` and `rhs`.
* \return the immediate check result.
*/
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
return handler_->SEqualReduce(lhs, rhs, map_free_vars_);
bool operator()(const ObjectRef& lhs, const ObjectRef& rhs, const ObjectPathPair& paths) const {
ICHECK(IsPathTracingEnabled()) << "Path tracing must be enabled when calling this function";
return ObjectAttrsEqual(lhs, rhs, map_free_vars_, &paths);
}

/*!
* \brief Reduce condition to comparison of two definitions,
* where free vars can be mapped.
Expand All @@ -162,9 +243,8 @@ class SEqualReducer : public BaseValueEqual {
* \param rhs The right operand.
* \return the immediate check result.
*/
bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs) {
return handler_->SEqualReduce(lhs, rhs, true);
}
bool DefEqual(const ObjectRef& lhs, const ObjectRef& rhs);

/*!
* \brief Reduce condition to comparison of two arrays.
* \param lhs The left operand.
Expand All @@ -173,13 +253,20 @@ class SEqualReducer : public BaseValueEqual {
*/
template <typename T>
bool operator()(const Array<T>& lhs, const Array<T>& rhs) const {
// quick specialization for Array to reduce amount of recursion
// depth as array comparison is pretty common.
if (lhs.size() != rhs.size()) return false;
for (size_t i = 0; i < lhs.size(); ++i) {
if (!(operator()(lhs[i], rhs[i]))) return false;
if (tracing_data_ == nullptr) {
// quick specialization for Array to reduce amount of recursion
// depth as array comparison is pretty common.
if (lhs.size() != rhs.size()) return false;
for (size_t i = 0; i < lhs.size(); ++i) {
if (!(operator()(lhs[i], rhs[i]))) return false;
}
return true;
}
return true;

// If tracing is enabled, fall back to the regular path
const ObjectRef& lhs_obj = lhs;
const ObjectRef& rhs_obj = rhs;
return (*this)(lhs_obj, rhs_obj);
}
/*!
* \brief Implementation for equality rule of var type objects(e.g. TypeVar, tir::Var).
Expand All @@ -198,9 +285,41 @@ class SEqualReducer : public BaseValueEqual {
/*! \return Get the internal handler. */
Handler* operator->() const { return handler_; }

/*! \brief Check if this reducer is tracing paths to the first mismatch. */
bool IsPathTracingEnabled() const { return tracing_data_ != nullptr; }

/*!
* \brief Get the paths of the currently compared objects.
*
* Can only be called when `IsPathTracingEnabled()` is true.
*/
const ObjectPathPair& GetCurrentObjectPaths() const;

/*!
* \brief Specify the object paths of a detected mismatch.
*
* Can only be called when `IsPathTracingEnabled()` is true.
*/
void RecordMismatchPaths(const ObjectPathPair& paths) const;

private:
bool EnumAttrsEqual(int lhs, int rhs, const void* lhs_address, const void* rhs_address) const;

bool ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const ObjectPathPair* paths) const;

static void GetPathsFromAttrAddressesAndStoreMismatch(const void* lhs_address,
const void* rhs_address,
const PathTracingData* tracing_data);

template <typename T>
static bool CompareAttributeValues(const T& lhs, const T& rhs,
const PathTracingData* tracing_data);
Comment on lines +311 to +317
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious - do you think it would make more sense to move those two methods to the cc file? The primary concern I'm having is that CompareAttributeValues is a templated method whose instantiation is all inside a cc file. If we care about visibility, we could introduce a friend class SEqualReducerHelper which has those two methods as static members

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the actual concern about having a function template declared in the header?

I don't have a strong opinion, can move these to a helper friend class if you prefer it that way.

Copy link
Member

@junrushao junrushao Jul 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no concrete problem in this particular case AFAICT - because template instantiation is only defined and used in a single cc file.

On the other hand, in more generic usecases, we would prefer template instantiation being defined in header files so that it's discoverable by the compiler when multiple cc files refer to this method.

Therefore, it's somehow a personal preference (so it's subjective, not any general requirement) that I either define both instantiation and declaration in header file, or both in cc files

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I mostly understand your point, but I'm missing one thing: why is this specific to function templates? For example, the non-template helper function GetPathsFromAttrAddressesAndStoreMismatch just above is also private but we have to put its declaration in the header file, because it is a static function in our class (which we need because of C++ visibility rules).

Even if we go with the SEqualReducerHelper approach, we still need to leak some details in the header file because we need to either declare it as a static class or as a friend class.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah my personal (again it's just subjective) preference is that we hide anything that's not intended to be publicly used, except for non-static methods when it requires some boilerplate code (adding helper friend classes / methods). If a method is in a header file, I would prefer to document it more or less to make it easier for others to catch up


/*! \brief Internal class pointer. */
Handler* handler_;
/*! \brief Pointer to the current path tracing context, or nullptr if path tracing is disabled. */
const PathTracingData* tracing_data_;
/*! \brief Whether or not to map free vars. */
bool map_free_vars_;
};
Expand Down
34 changes: 32 additions & 2 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,8 @@ def structural_equal(lhs, rhs, map_free_vars=False):
The left operand.

map_free_vars : bool
Whether or not shall we map free vars that does
not bound to any definitions as equal to each other.
Whether free variables (i.e. variables without a definition site) should be mapped
as equal to each other.

Return
------
Expand All @@ -209,6 +209,36 @@ def structural_equal(lhs, rhs, map_free_vars=False):
return bool(tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, False, map_free_vars))


def get_first_structural_mismatch(lhs, rhs, map_free_vars=False):
"""Like structural_equal(), but returns the ObjectPaths of the first detected mismatch.

Parameters
----------
lhs : Object
The left operand.

rhs : Object
The left operand.

map_free_vars : bool
Whether free variables (i.e. variables without a definition site) should be mapped
as equal to each other.

Returns
-------
mismatch: Optional[Tuple[ObjectPath, ObjectPath]]
`None` if `lhs` and `rhs` are structurally equal.
Otherwise, a tuple of two ObjectPath objects that point to the first detected mismtach.
"""
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
mismatch = tvm.runtime._ffi_node_api.GetFirstStructuralMismatch(lhs, rhs, map_free_vars)
if mismatch is None:
return None
else:
return mismatch.lhs_path, mismatch.rhs_path


def assert_structural_equal(lhs, rhs, map_free_vars=False):
"""Assert lhs and rhs are structurally equal to each other.

Expand Down
1 change: 1 addition & 0 deletions python/tvm/runtime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# class exposures
from .packed_func import PackedFunc
from .object import Object
from .object_path import ObjectPath, ObjectPathPair
from .object_generic import ObjectGeneric, ObjectTypes
from .ndarray import NDArray, DataType, DataTypeCode, Device
from .module import Module, num_threads
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/runtime/object_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"MissingArrayElementPath",
"MapValuePath",
"MissingMapEntryPath",
"ObjectPathPair",
)


Expand Down Expand Up @@ -122,3 +123,14 @@ class MapValuePath(ObjectPath):
@tvm._ffi.register_object("MissingMapEntryPath")
class MissingMapEntryPath(ObjectPath):
pass


@tvm._ffi.register_object("ObjectPathPair")
class ObjectPathPair(Object):
@property
def lhs_path(self) -> ObjectPath:
return _ffi_node_api.ObjectPathPairLhsPath(self)

@property
def rhs_path(self) -> ObjectPath:
return _ffi_node_api.ObjectPathPairRhsPath(self)
junrushao marked this conversation as resolved.
Show resolved Hide resolved
44 changes: 44 additions & 0 deletions src/node/reflection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,4 +281,48 @@ TVM_REGISTER_GLOBAL("node.NodeGetAttr").set_body(NodeGetAttr);
TVM_REGISTER_GLOBAL("node.NodeListAttrNames").set_body(NodeListAttrNames);

TVM_REGISTER_GLOBAL("node.MakeNode").set_body(MakeNode);

namespace {
// Attribute visitor class for finding the attribute key by its address
class GetAttrKeyByAddressVisitor : public AttrVisitor {
public:
explicit GetAttrKeyByAddressVisitor(const void* attr_address)
: attr_address_(attr_address), key_(nullptr) {}

void Visit(const char* key, double* value) final { DoVisit(key, value); }
void Visit(const char* key, int64_t* value) final { DoVisit(key, value); }
void Visit(const char* key, uint64_t* value) final { DoVisit(key, value); }
void Visit(const char* key, int* value) final { DoVisit(key, value); }
void Visit(const char* key, bool* value) final { DoVisit(key, value); }
void Visit(const char* key, std::string* value) final { DoVisit(key, value); }
void Visit(const char* key, void** value) final { DoVisit(key, value); }
void Visit(const char* key, DataType* value) final { DoVisit(key, value); }
void Visit(const char* key, runtime::NDArray* value) final { DoVisit(key, value); }
void Visit(const char* key, runtime::ObjectRef* value) final { DoVisit(key, value); }

const char* GetKey() const { return key_; }

private:
const void* attr_address_;
const char* key_;

void DoVisit(const char* key, const void* candidate) {
if (attr_address_ == candidate) {
key_ = key;
}
}
};
} // anonymous namespace

Optional<String> GetAttrKeyByAddress(const Object* object, const void* attr_address) {
GetAttrKeyByAddressVisitor visitor(attr_address);
ReflectionVTable::Global()->VisitAttrs(const_cast<Object*>(object), &visitor);
const char* key = visitor.GetKey();
if (key == nullptr) {
return NullOpt;
} else {
return String(key);
}
}

} // namespace tvm
Loading