Skip to content

Commit

Permalink
[TVMScript] Connect assert_structural_equal with new TVMScript prin…
Browse files Browse the repository at this point in the history
…ter (#13859)

This PR refactors the output of `assert_structural_equal`. Different from the directly printing mismatching nodes, in the old version, the improved one will print the whole scripts, with mismatching nodes underlined. And we print the `ObjectPath` to the mismatching nodes for further better debug. For example, we have following functions

```python
@T.prim_func
def func1(a: T.handle, b: T.handle):
  A = T.match_buffer(a, (128, 128))
  B = T.match_buffer(b, (128, 128))

@T.prim_func
def func2(a: T.handle, b: T.handle):
  A = T.match_buffer(a, (128, 128))
  B = T.match_buffer(b, (128, 256))
```

the log of `assert_structural_equal(func1, func2)` will be like

```python
ValueError: StructuralEqual check failed, caused by lhs at <root>.buffer_map[b].shape[1].value:
# from tvm.script import tir as T

@T.prim_func
def main(a: T.handle, b: T.handle):
  A = T.match_buffer(a, (128, 128))
  B = T.match_buffer(b, (128, 128))
                              ^^^
  T.evaluate(0)
and rhs at <root>.buffer_map[b].shape[1].value:
# from tvm.script import tir as T

@T.prim_func
def main(a: T.handle, b: T.handle):
  A = T.match_buffer(a, (128, 128))
  B = T.match_buffer(b, (128, 256))
                              ^^^
  T.evaluate(0)
```

instead of

```python
ValueError: StructuralEqual check failed, caused by lhs:
128
and rhs:
256
```

which is not readable sometimes.
  • Loading branch information
cyx-6 authored Feb 5, 2023
1 parent 98008c2 commit f7aeaf1
Show file tree
Hide file tree
Showing 15 changed files with 268 additions and 48 deletions.
11 changes: 7 additions & 4 deletions include/tvm/node/script_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ class PrinterConfigNode : public Object {
DataType float_dtype = DataType::Void();
/*! \brief Whether or not to verbose print expressions. */
bool verbose_expr = false;
/* \brief Number of spaces used for indentation*/
/*! \brief Number of spaces used for indentation*/
int indent_spaces = 4;
/* \brief Whether to print line numbers */
/*! \brief Whether to print line numbers */
bool print_line_numbers = false;
/* \brief Number of context lines to print around the underlined text */
/*! \brief Number of context lines to print around the underlined text */
int num_context_lines = -1;
/* \brief Object path to be underlined */
/*! \brief Object path to be underlined */
Optional<ObjectPath> path_to_underline = NullOpt;
/*! \brief Whether to output with syntax sugar, set false for complete printing. */
bool syntax_sugar = true;

void VisitAttrs(AttrVisitor* v) {
v->Visit("ir_prefix", &ir_prefix);
Expand All @@ -72,6 +74,7 @@ class PrinterConfigNode : public Object {
v->Visit("print_line_numbers", &print_line_numbers);
v->Visit("num_context_lines", &num_context_lines);
v->Visit("path_to_underline", &path_to_underline);
v->Visit("syntax_sugar", &syntax_sugar);
}

static constexpr const char* _type_key = "node.PrinterConfig";
Expand Down
11 changes: 10 additions & 1 deletion include/tvm/node/structural_equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ class SEqualReducer {
*/
virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;

/*!
* \brief Check if fail defferal is enabled.
*
* \return false if the fail deferral is not enabled, true otherwise.
*/
virtual bool IsFailDeferralEnabled() = 0;

/*!
* \brief Lookup the graph node equal map for vars that are already mapped.
*
Expand Down Expand Up @@ -331,12 +338,14 @@ class SEqualReducer {
*/
class SEqualHandlerDefault : public SEqualReducer::Handler {
public:
SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch);
SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
bool defer_fails);
virtual ~SEqualHandlerDefault();

bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const Optional<ObjectPathPair>& current_paths) override;
void DeferFail(const ObjectPathPair& mismatch_paths) override;
bool IsFailDeferralEnabled() override;
ObjectRef MapLhsToRhs(const ObjectRef& lhs) override;
void MarkGraphNode() override;

Expand Down
11 changes: 11 additions & 0 deletions python/tvm/runtime/script_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class PrinterConfig(Object):
print_line_numbers: bool
num_context_lines: int
path_to_underline: Optional[ObjectPath]
syntax_sugar: bool

def __init__(
self,
Expand All @@ -54,6 +55,7 @@ def __init__(
print_line_numbers: bool = False,
num_context_lines: Optional[int] = None,
path_to_underline: Optional[ObjectPath] = None,
syntax_sugar: bool = True,
) -> None:
if num_context_lines is None:
num_context_lines = -1
Expand All @@ -71,6 +73,7 @@ def __init__(
"print_line_numbers": print_line_numbers,
"num_context_lines": num_context_lines,
"path_to_underline": path_to_underline,
"syntax_sugar": syntax_sugar,
},
)

Expand All @@ -96,6 +99,7 @@ def script(
print_line_numbers: bool = False,
num_context_lines: int = -1,
path_to_underline: Optional[ObjectPath] = None,
syntax_sugar: bool = True,
) -> str:
"""Print TVM IR into TVMScript text format
Expand Down Expand Up @@ -123,6 +127,8 @@ def script(
The number of lines of context to print before and after the line to underline.
path_to_underline : Optional[ObjectPath] = None
Object path to be underlined
syntax_sugar: bool = True
Whether to output with syntax sugar, set false for complete printing.
Returns
-------
Expand All @@ -143,6 +149,7 @@ def script(
print_line_numbers=print_line_numbers,
num_context_lines=num_context_lines,
path_to_underline=path_to_underline,
syntax_sugar=syntax_sugar,
),
)

Expand All @@ -162,6 +169,7 @@ def show(
print_line_numbers: bool = False,
num_context_lines: int = -1,
path_to_underline: Optional[ObjectPath] = None,
syntax_sugar: bool = True,
) -> None:
"""A sugar for print highlighted TVM script.
Expand Down Expand Up @@ -194,6 +202,8 @@ def show(
The number of lines of context to print before and after the line to underline.
path_to_underline : Optional[ObjectPath] = None
Object path to be underlined
syntax_sugar: bool = True
Whether to output with syntax sugar, set false for complete printing.
"""
from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel
cprint,
Expand All @@ -212,6 +222,7 @@ def show(
print_line_numbers=print_line_numbers,
num_context_lines=num_context_lines,
path_to_underline=path_to_underline,
syntax_sugar=syntax_sugar,
),
style=style,
black_format=black_format,
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/module_equality.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ModuleEqualityStructural : public ModuleEquality {

class SEqualHandlerIgnoreNDArray : public SEqualHandlerDefault {
public:
SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr) {}
SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr, false) {}

protected:
bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
Expand Down
6 changes: 6 additions & 0 deletions src/node/script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ TVMScriptPrinter::FType& TVMScriptPrinter::vtable() {
}

std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional<PrinterConfig>& cfg) {
if (!TVMScriptPrinter::vtable().can_dispatch(node)) {
return AsLegacyRepr(node);
}
return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig()));
}

Expand Down Expand Up @@ -67,6 +70,9 @@ PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) {
if (auto v = config_dict.Get("path_to_underline")) {
n->path_to_underline = Downcast<ObjectPath>(v);
}
if (auto v = config_dict.Get("syntax_sugar")) {
n->syntax_sugar = Downcast<IntImm>(v)->value;
}
this->data_ = std::move(n);
}

Expand Down
73 changes: 61 additions & 12 deletions src/node/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,12 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs,
*/
class SEqualHandlerDefault::Impl {
public:
Impl(SEqualHandlerDefault* parent, bool assert_mode, Optional<ObjectPathPair>* first_mismatch)
: parent_(parent), assert_mode_(assert_mode), first_mismatch_(first_mismatch) {}
Impl(SEqualHandlerDefault* parent, bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
bool defer_fails)
: parent_(parent),
assert_mode_(assert_mode),
first_mismatch_(first_mismatch),
defer_fails_(defer_fails) {}

bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const Optional<ObjectPathPair>& current_paths) {
Expand Down Expand Up @@ -245,6 +249,8 @@ class SEqualHandlerDefault::Impl {
pending_tasks_.emplace_back(Task::ForceFailTag{}, mismatch_paths);
}

bool IsFailDeferralEnabled() { return defer_fails_; }

void MarkGraphNode() {
// need to push to pending tasks in this case
ICHECK(!allow_push_to_stack_ && !task_stack_.empty());
Expand All @@ -264,6 +270,8 @@ class SEqualHandlerDefault::Impl {
pending_tasks_.clear();
equal_map_lhs_.clear();
equal_map_rhs_.clear();
root_lhs_ = lhs;
root_rhs_ = rhs;

Optional<ObjectPathPair> current_paths;
if (IsPathTracingEnabled()) {
Expand Down Expand Up @@ -313,10 +321,38 @@ class SEqualHandlerDefault::Impl {
*first_mismatch_ = current_paths;
}
if (assert_mode_ && !result) {
LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl
<< lhs << std::endl
<< "and rhs:" << std::endl
<< rhs;
std::ostringstream oss;
oss << "ValueError: StructuralEqual check failed, caused by lhs";
if (first_mismatch_->defined()) {
oss << " at " << first_mismatch_->value()->lhs_path;
if (root_lhs_.defined()) {
Map<String, ObjectRef> dict = {{"path_to_underline", first_mismatch_->value()->lhs_path},
{"syntax_sugar", Bool(false)}};
PrinterConfig cfg(dict);
// The TVMScriptPrinter::Script will fallback to Repr printer,
// if the root node to print is not supported yet,
// e.g. Relay nodes, ArrayNode, MapNode, etc.
oss << ":" << std::endl << TVMScriptPrinter::Script(root_lhs_.value(), cfg);
}
} else {
oss << ":" << std::endl << lhs;
}
oss << std::endl << "and rhs";
if (first_mismatch_->defined()) {
oss << " at " << first_mismatch_->value()->rhs_path;
if (root_rhs_.defined()) {
Map<String, ObjectRef> dict = {{"path_to_underline", first_mismatch_->value()->rhs_path},
{"syntax_sugar", Bool(false)}};
PrinterConfig cfg(dict);
// The TVMScriptPrinter::Script will fallback to Repr printer,
// if the root node to print is not supported yet,
// e.g. Relay nodes, ArrayNode, MapNode, etc.
oss << ":" << std::endl << TVMScriptPrinter::Script(root_rhs_.value(), cfg);
}
} else {
oss << ":" << std::endl << rhs;
}
LOG(FATAL) << oss.str();
}
return result;
}
Expand Down Expand Up @@ -419,19 +455,27 @@ class SEqualHandlerDefault::Impl {
bool allow_push_to_stack_{true};
// If in assert mode, must return true, and will throw error otherwise.
bool assert_mode_{false};
// Location to store the paths to the first detected mismatch, or nullptr to disable path tracing.
// Location to store the paths to the first detected mismatch, or nullptr to disable path
// tracing.
Optional<ObjectPathPair>* first_mismatch_;
// reflection vtable
ReflectionVTable* vtable_ = ReflectionVTable::Global();
// map from lhs to rhs
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_lhs_;
// map from rhs to lhs
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_rhs_;
// root lhs for result printing
Optional<ObjectRef> root_lhs_;
// root rhs for result printing
Optional<ObjectRef> root_rhs_;
// whether to defer fails
bool defer_fails_;
};

SEqualHandlerDefault::SEqualHandlerDefault(bool assert_mode,
Optional<ObjectPathPair>* first_mismatch) {
impl = new Impl(this, assert_mode, first_mismatch);
Optional<ObjectPathPair>* first_mismatch,
bool defer_fails) {
impl = new Impl(this, assert_mode, first_mismatch, defer_fails);
}

SEqualHandlerDefault::~SEqualHandlerDefault() { delete impl; }
Expand All @@ -446,6 +490,8 @@ void SEqualHandlerDefault::DeferFail(const ObjectPathPair& mismatch_paths) {
impl->DeferFail(mismatch_paths);
}

bool SEqualHandlerDefault::IsFailDeferralEnabled() { return impl->IsFailDeferralEnabled(); }

ObjectRef SEqualHandlerDefault::MapLhsToRhs(const ObjectRef& lhs) { return impl->MapLhsToRhs(lhs); }

void SEqualHandlerDefault::MarkGraphNode() { impl->MarkGraphNode(); }
Expand All @@ -463,19 +509,22 @@ bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const Obje
TVM_REGISTER_GLOBAL("node.StructuralEqual")
.set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode,
bool map_free_vars) {
return SEqualHandlerDefault(assert_mode, nullptr).Equal(lhs, rhs, map_free_vars);
Optional<ObjectPathPair> first_mismatch;
return SEqualHandlerDefault(assert_mode, &first_mismatch, false)
.Equal(lhs, rhs, map_free_vars);
});

TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch")
.set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
Optional<ObjectPathPair> first_mismatch;
bool equal = SEqualHandlerDefault(false, &first_mismatch).Equal(lhs, rhs, map_free_vars);
bool equal =
SEqualHandlerDefault(false, &first_mismatch, true).Equal(lhs, rhs, map_free_vars);
ICHECK(equal == !first_mismatch.defined());
return first_mismatch;
});

bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
return SEqualHandlerDefault(false, nullptr).Equal(lhs, rhs, false);
return SEqualHandlerDefault(false, nullptr, false).Equal(lhs, rhs, false);
}

bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs,
Expand Down
20 changes: 11 additions & 9 deletions src/node/structural_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,16 +467,18 @@ struct ArrayNodeTrait {
// (2) a b c d e g h i j k l m
// ^
// error here
if (lhs->size() > min_size) {
equal->DeferFail({array_paths->lhs_path->ArrayIndex(min_size),
array_paths->rhs_path->MissingArrayElement(min_size)});
} else {
equal->DeferFail({array_paths->lhs_path->MissingArrayElement(min_size),
array_paths->rhs_path->ArrayIndex(min_size)});
if (equal->IsFailDeferralEnabled()) {
if (lhs->size() > min_size) {
equal->DeferFail({array_paths->lhs_path->ArrayIndex(min_size),
array_paths->rhs_path->MissingArrayElement(min_size)});
} else {
equal->DeferFail({array_paths->lhs_path->MissingArrayElement(min_size),
array_paths->rhs_path->ArrayIndex(min_size)});
}
// Can return `true` pretending that everything is good since we have deferred the failure.
return true;
}

// Can return `true` pretending that everything is good since we have deferred the failure.
return true;
return false;
}
};
TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
Expand Down
4 changes: 3 additions & 1 deletion src/script/printer/doc_printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,9 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) {
void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) {
for (const StmtDoc& stmt : doc->stmts) {
PrintDoc(stmt);
NewLine();
if (stmt != doc->stmts.back()) {
NewLine();
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/script/printer/tir/block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, //

std::vector<int> remap_vars_indices;
auto add_remapped_iter_var = [&](int i) -> bool {
if (realize) {
if (realize && d->cfg->syntax_sugar) {
tir::ExprDeepEqual expr_equal;
tir::IterVar iter_var = block->iter_vars[i];
PrimExpr value = realize->iter_values[i];
Expand Down
24 changes: 13 additions & 11 deletions src/script/printer/tir/for_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return grid_loop_vars.count(v);
});
};
for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as<tir::ForNode>()) {
ICHECK(l->loop_var->dtype == l->min->dtype);
ICHECK(l->loop_var->dtype == l->extent->dtype);
if (l->kind != tir::ForKind::kSerial || //
!tir::is_zero(l->min) || //
!l->annotations.empty() || //
f_var_dep(l->extent)) {
break;
if (d->cfg->syntax_sugar) {
for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as<tir::ForNode>()) {
ICHECK(l->loop_var->dtype == l->min->dtype);
ICHECK(l->loop_var->dtype == l->extent->dtype);
if (l->kind != tir::ForKind::kSerial || //
!tir::is_zero(l->min) || //
!l->annotations.empty() || //
f_var_dep(l->extent)) {
break;
}
grid.push_back(l);
grid_loop_vars.insert(l->loop_var.get());
}
grid.push_back(l);
grid_loop_vars.insert(l->loop_var.get());
}
With<TIRFrame> f(d, loop);
// Step 2. Construct `T.grid`
Expand Down Expand Up @@ -114,7 +116,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
kwargs_values.push_back(annotations.value());
}
ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values);
AsDocBody(loop->body, loop_p, (*f).get(), d);
AsDocBody(loop->body, loop_p->Attr("body"), (*f).get(), d);
return ForDoc(lhs, rhs, (*f)->stmts);
});

Expand Down
Loading

0 comments on commit f7aeaf1

Please sign in to comment.