Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TVMScript] Connect assert_structural_equal with new TVMScript printer #13859

Merged
merged 1 commit into from
Feb 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions include/tvm/node/script_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@ class PrinterConfigNode : public Object {
DataType float_dtype = DataType::Void();
/*! \brief Whether or not to verbose print expressions. */
bool verbose_expr = false;
/* \brief Number of spaces used for indentation*/
/*! \brief Number of spaces used for indentation*/
int indent_spaces = 4;
/* \brief Whether to print line numbers */
/*! \brief Whether to print line numbers */
bool print_line_numbers = false;
/* \brief Number of context lines to print around the underlined text */
/*! \brief Number of context lines to print around the underlined text */
int num_context_lines = -1;
/* \brief Object path to be underlined */
/*! \brief Object path to be underlined */
Optional<ObjectPath> path_to_underline = NullOpt;
/*! \brief Whether to output with syntax sugar, set false for complete printing. */
bool syntax_sugar = true;

void VisitAttrs(AttrVisitor* v) {
v->Visit("ir_prefix", &ir_prefix);
Expand All @@ -72,6 +74,7 @@ class PrinterConfigNode : public Object {
v->Visit("print_line_numbers", &print_line_numbers);
v->Visit("num_context_lines", &num_context_lines);
v->Visit("path_to_underline", &path_to_underline);
v->Visit("syntax_sugar", &syntax_sugar);
}

static constexpr const char* _type_key = "node.PrinterConfig";
Expand Down
11 changes: 10 additions & 1 deletion include/tvm/node/structural_equal.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ class SEqualReducer {
*/
virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0;

/*!
* \brief Check if fail defferal is enabled.
*
* \return false if the fail deferral is not enabled, true otherwise.
*/
virtual bool IsFailDeferralEnabled() = 0;

/*!
* \brief Lookup the graph node equal map for vars that are already mapped.
*
Expand Down Expand Up @@ -331,12 +338,14 @@ class SEqualReducer {
*/
class SEqualHandlerDefault : public SEqualReducer::Handler {
public:
SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch);
SEqualHandlerDefault(bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
bool defer_fails);
virtual ~SEqualHandlerDefault();

bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const Optional<ObjectPathPair>& current_paths) override;
void DeferFail(const ObjectPathPair& mismatch_paths) override;
bool IsFailDeferralEnabled() override;
ObjectRef MapLhsToRhs(const ObjectRef& lhs) override;
void MarkGraphNode() override;

Expand Down
11 changes: 11 additions & 0 deletions python/tvm/runtime/script_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class PrinterConfig(Object):
print_line_numbers: bool
num_context_lines: int
path_to_underline: Optional[ObjectPath]
syntax_sugar: bool

def __init__(
self,
Expand All @@ -54,6 +55,7 @@ def __init__(
print_line_numbers: bool = False,
num_context_lines: Optional[int] = None,
path_to_underline: Optional[ObjectPath] = None,
syntax_sugar: bool = True,
) -> None:
if num_context_lines is None:
num_context_lines = -1
Expand All @@ -71,6 +73,7 @@ def __init__(
"print_line_numbers": print_line_numbers,
"num_context_lines": num_context_lines,
"path_to_underline": path_to_underline,
"syntax_sugar": syntax_sugar,
},
)

Expand All @@ -96,6 +99,7 @@ def script(
print_line_numbers: bool = False,
num_context_lines: int = -1,
path_to_underline: Optional[ObjectPath] = None,
syntax_sugar: bool = True,
) -> str:
"""Print TVM IR into TVMScript text format

Expand Down Expand Up @@ -123,6 +127,8 @@ def script(
The number of lines of context to print before and after the line to underline.
path_to_underline : Optional[ObjectPath] = None
Object path to be underlined
syntax_sugar: bool = True
Whether to output with syntax sugar, set false for complete printing.

Returns
-------
Expand All @@ -143,6 +149,7 @@ def script(
print_line_numbers=print_line_numbers,
num_context_lines=num_context_lines,
path_to_underline=path_to_underline,
syntax_sugar=syntax_sugar,
),
)

Expand All @@ -162,6 +169,7 @@ def show(
print_line_numbers: bool = False,
num_context_lines: int = -1,
path_to_underline: Optional[ObjectPath] = None,
syntax_sugar: bool = True,
) -> None:
"""A sugar for print highlighted TVM script.

Expand Down Expand Up @@ -194,6 +202,8 @@ def show(
The number of lines of context to print before and after the line to underline.
path_to_underline : Optional[ObjectPath] = None
Object path to be underlined
syntax_sugar: bool = True
Whether to output with syntax sugar, set false for complete printing.
"""
from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel
cprint,
Expand All @@ -212,6 +222,7 @@ def show(
print_line_numbers=print_line_numbers,
num_context_lines=num_context_lines,
path_to_underline=path_to_underline,
syntax_sugar=syntax_sugar,
),
style=style,
black_format=black_format,
Expand Down
2 changes: 1 addition & 1 deletion src/meta_schedule/module_equality.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ModuleEqualityStructural : public ModuleEquality {

class SEqualHandlerIgnoreNDArray : public SEqualHandlerDefault {
public:
SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr) {}
SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr, false) {}

protected:
bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
Expand Down
6 changes: 6 additions & 0 deletions src/node/script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ TVMScriptPrinter::FType& TVMScriptPrinter::vtable() {
}

std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional<PrinterConfig>& cfg) {
if (!TVMScriptPrinter::vtable().can_dispatch(node)) {
return AsLegacyRepr(node);
}
return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig()));
}

Expand Down Expand Up @@ -67,6 +70,9 @@ PrinterConfig::PrinterConfig(Map<String, ObjectRef> config_dict) {
if (auto v = config_dict.Get("path_to_underline")) {
n->path_to_underline = Downcast<ObjectPath>(v);
}
if (auto v = config_dict.Get("syntax_sugar")) {
n->syntax_sugar = Downcast<IntImm>(v)->value;
}
this->data_ = std::move(n);
}

Expand Down
73 changes: 61 additions & 12 deletions src/node/structural_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,12 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs,
*/
class SEqualHandlerDefault::Impl {
public:
Impl(SEqualHandlerDefault* parent, bool assert_mode, Optional<ObjectPathPair>* first_mismatch)
: parent_(parent), assert_mode_(assert_mode), first_mismatch_(first_mismatch) {}
Impl(SEqualHandlerDefault* parent, bool assert_mode, Optional<ObjectPathPair>* first_mismatch,
bool defer_fails)
: parent_(parent),
assert_mode_(assert_mode),
first_mismatch_(first_mismatch),
defer_fails_(defer_fails) {}

bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars,
const Optional<ObjectPathPair>& current_paths) {
Expand Down Expand Up @@ -245,6 +249,8 @@ class SEqualHandlerDefault::Impl {
pending_tasks_.emplace_back(Task::ForceFailTag{}, mismatch_paths);
}

bool IsFailDeferralEnabled() { return defer_fails_; }

void MarkGraphNode() {
// need to push to pending tasks in this case
ICHECK(!allow_push_to_stack_ && !task_stack_.empty());
Expand All @@ -264,6 +270,8 @@ class SEqualHandlerDefault::Impl {
pending_tasks_.clear();
equal_map_lhs_.clear();
equal_map_rhs_.clear();
root_lhs_ = lhs;
root_rhs_ = rhs;

Optional<ObjectPathPair> current_paths;
if (IsPathTracingEnabled()) {
Expand Down Expand Up @@ -313,10 +321,38 @@ class SEqualHandlerDefault::Impl {
*first_mismatch_ = current_paths;
}
if (assert_mode_ && !result) {
LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl
<< lhs << std::endl
<< "and rhs:" << std::endl
<< rhs;
std::ostringstream oss;
oss << "ValueError: StructuralEqual check failed, caused by lhs";
Hzfengsy marked this conversation as resolved.
Show resolved Hide resolved
if (first_mismatch_->defined()) {
oss << " at " << first_mismatch_->value()->lhs_path;
if (root_lhs_.defined()) {
Map<String, ObjectRef> dict = {{"path_to_underline", first_mismatch_->value()->lhs_path},
{"syntax_sugar", Bool(false)}};
PrinterConfig cfg(dict);
// The TVMScriptPrinter::Script will fallback to Repr printer,
// if the root node to print is not supported yet,
// e.g. Relay nodes, ArrayNode, MapNode, etc.
oss << ":" << std::endl << TVMScriptPrinter::Script(root_lhs_.value(), cfg);
}
} else {
oss << ":" << std::endl << lhs;
}
oss << std::endl << "and rhs";
if (first_mismatch_->defined()) {
oss << " at " << first_mismatch_->value()->rhs_path;
if (root_rhs_.defined()) {
Map<String, ObjectRef> dict = {{"path_to_underline", first_mismatch_->value()->rhs_path},
{"syntax_sugar", Bool(false)}};
PrinterConfig cfg(dict);
// The TVMScriptPrinter::Script will fallback to Repr printer,
// if the root node to print is not supported yet,
// e.g. Relay nodes, ArrayNode, MapNode, etc.
oss << ":" << std::endl << TVMScriptPrinter::Script(root_rhs_.value(), cfg);
}
} else {
oss << ":" << std::endl << rhs;
}
LOG(FATAL) << oss.str();
}
return result;
}
Expand Down Expand Up @@ -419,19 +455,27 @@ class SEqualHandlerDefault::Impl {
bool allow_push_to_stack_{true};
// If in assert mode, must return true, and will throw error otherwise.
bool assert_mode_{false};
// Location to store the paths to the first detected mismatch, or nullptr to disable path tracing.
// Location to store the paths to the first detected mismatch, or nullptr to disable path
// tracing.
Optional<ObjectPathPair>* first_mismatch_;
// reflection vtable
ReflectionVTable* vtable_ = ReflectionVTable::Global();
// map from lhs to rhs
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_lhs_;
// map from rhs to lhs
std::unordered_map<ObjectRef, ObjectRef, ObjectPtrHash, ObjectPtrEqual> equal_map_rhs_;
// root lhs for result printing
Optional<ObjectRef> root_lhs_;
// root rhs for result printing
Optional<ObjectRef> root_rhs_;
// whether to defer fails
bool defer_fails_;
};

SEqualHandlerDefault::SEqualHandlerDefault(bool assert_mode,
Optional<ObjectPathPair>* first_mismatch) {
impl = new Impl(this, assert_mode, first_mismatch);
Optional<ObjectPathPair>* first_mismatch,
bool defer_fails) {
impl = new Impl(this, assert_mode, first_mismatch, defer_fails);
}

SEqualHandlerDefault::~SEqualHandlerDefault() { delete impl; }
Expand All @@ -446,6 +490,8 @@ void SEqualHandlerDefault::DeferFail(const ObjectPathPair& mismatch_paths) {
impl->DeferFail(mismatch_paths);
}

bool SEqualHandlerDefault::IsFailDeferralEnabled() { return impl->IsFailDeferralEnabled(); }

ObjectRef SEqualHandlerDefault::MapLhsToRhs(const ObjectRef& lhs) { return impl->MapLhsToRhs(lhs); }

void SEqualHandlerDefault::MarkGraphNode() { impl->MarkGraphNode(); }
Expand All @@ -463,19 +509,22 @@ bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const Obje
TVM_REGISTER_GLOBAL("node.StructuralEqual")
.set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode,
bool map_free_vars) {
return SEqualHandlerDefault(assert_mode, nullptr).Equal(lhs, rhs, map_free_vars);
Optional<ObjectPathPair> first_mismatch;
return SEqualHandlerDefault(assert_mode, &first_mismatch, false)
.Equal(lhs, rhs, map_free_vars);
});

TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch")
.set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) {
Optional<ObjectPathPair> first_mismatch;
bool equal = SEqualHandlerDefault(false, &first_mismatch).Equal(lhs, rhs, map_free_vars);
bool equal =
SEqualHandlerDefault(false, &first_mismatch, true).Equal(lhs, rhs, map_free_vars);
ICHECK(equal == !first_mismatch.defined());
return first_mismatch;
});

bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const {
return SEqualHandlerDefault(false, nullptr).Equal(lhs, rhs, false);
return SEqualHandlerDefault(false, nullptr, false).Equal(lhs, rhs, false);
}

bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs,
Expand Down
20 changes: 11 additions & 9 deletions src/node/structural_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,16 +467,18 @@ struct ArrayNodeTrait {
// (2) a b c d e g h i j k l m
// ^
// error here
if (lhs->size() > min_size) {
equal->DeferFail({array_paths->lhs_path->ArrayIndex(min_size),
array_paths->rhs_path->MissingArrayElement(min_size)});
} else {
equal->DeferFail({array_paths->lhs_path->MissingArrayElement(min_size),
array_paths->rhs_path->ArrayIndex(min_size)});
if (equal->IsFailDeferralEnabled()) {
if (lhs->size() > min_size) {
equal->DeferFail({array_paths->lhs_path->ArrayIndex(min_size),
array_paths->rhs_path->MissingArrayElement(min_size)});
} else {
equal->DeferFail({array_paths->lhs_path->MissingArrayElement(min_size),
array_paths->rhs_path->ArrayIndex(min_size)});
}
// Can return `true` pretending that everything is good since we have deferred the failure.
return true;
}

// Can return `true` pretending that everything is good since we have deferred the failure.
return true;
return false;
}
};
TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait)
Expand Down
4 changes: 3 additions & 1 deletion src/script/printer/doc_printer/python_doc_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,9 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) {
void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) {
for (const StmtDoc& stmt : doc->stmts) {
PrintDoc(stmt);
NewLine();
if (stmt != doc->stmts.back()) {
NewLine();
}
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/script/printer/tir/block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, //

std::vector<int> remap_vars_indices;
auto add_remapped_iter_var = [&](int i) -> bool {
if (realize) {
if (realize && d->cfg->syntax_sugar) {
tir::ExprDeepEqual expr_equal;
tir::IterVar iter_var = block->iter_vars[i];
PrimExpr value = realize->iter_values[i];
Expand Down
24 changes: 13 additions & 11 deletions src/script/printer/tir/for_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return grid_loop_vars.count(v);
});
};
for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as<tir::ForNode>()) {
ICHECK(l->loop_var->dtype == l->min->dtype);
ICHECK(l->loop_var->dtype == l->extent->dtype);
if (l->kind != tir::ForKind::kSerial || //
!tir::is_zero(l->min) || //
!l->annotations.empty() || //
f_var_dep(l->extent)) {
break;
if (d->cfg->syntax_sugar) {
for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as<tir::ForNode>()) {
ICHECK(l->loop_var->dtype == l->min->dtype);
ICHECK(l->loop_var->dtype == l->extent->dtype);
if (l->kind != tir::ForKind::kSerial || //
!tir::is_zero(l->min) || //
!l->annotations.empty() || //
f_var_dep(l->extent)) {
break;
}
grid.push_back(l);
grid_loop_vars.insert(l->loop_var.get());
}
grid.push_back(l);
grid_loop_vars.insert(l->loop_var.get());
}
With<TIRFrame> f(d, loop);
// Step 2. Construct `T.grid`
Expand Down Expand Up @@ -114,7 +116,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
kwargs_values.push_back(annotations.value());
}
ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values);
AsDocBody(loop->body, loop_p, (*f).get(), d);
AsDocBody(loop->body, loop_p->Attr("body"), (*f).get(), d);
return ForDoc(lhs, rhs, (*f)->stmts);
});

Expand Down
Loading