From f7aeaf1d389881e408d29585ea62c6bb5ea65843 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sun, 5 Feb 2023 08:45:34 -0800 Subject: [PATCH] [TVMScript] Connect `assert_structural_equal` with new TVMScript printer (#13859) This PR refactors the output of `assert_structural_equal`. Different from the directly printing mismatching nodes, in the old version, the improved one will print the whole scripts, with mismatching nodes underlined. And we print the `ObjectPath` to the mismatching nodes for further better debug. For example, we have following functions ```python @T.prim_func def func1(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) @T.prim_func def func2(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 256)) ``` the log of `assert_structural_equal(func1, func2)` will be like ```python ValueError: StructuralEqual check failed, caused by lhs at .buffer_map[b].shape[1].value: # from tvm.script import tir as T @T.prim_func def main(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 128)) ^^^ T.evaluate(0) and rhs at .buffer_map[b].shape[1].value: # from tvm.script import tir as T @T.prim_func def main(a: T.handle, b: T.handle): A = T.match_buffer(a, (128, 128)) B = T.match_buffer(b, (128, 256)) ^^^ T.evaluate(0) ``` instead of ```python ValueError: StructuralEqual check failed, caused by lhs: 128 and rhs: 256 ``` which is not readable sometimes. --- include/tvm/node/script_printer.h | 11 +- include/tvm/node/structural_equal.h | 11 +- python/tvm/runtime/script_printer.py | 11 ++ src/meta_schedule/module_equality.cc | 2 +- src/node/script_printer.cc | 6 + src/node/structural_equal.cc | 73 ++++++++-- src/node/structural_hash.cc | 20 +-- .../printer/doc_printer/python_doc_printer.cc | 4 +- src/script/printer/tir/block.cc | 2 +- src/script/printer/tir/for_loop.cc | 24 ++-- src/script/printer/tir/function.cc | 11 +- src/script/printer/tir/stmt.cc | 2 +- src/script/printer/tir/utils.h | 4 +- src/tir/analysis/deep_equal.cc | 1 + ...test_tvmscript_printer_structural_equal.py | 134 ++++++++++++++++++ 15 files changed, 268 insertions(+), 48 deletions(-) create mode 100644 tests/python/unittest/test_tvmscript_printer_structural_equal.py diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index af50aae71a43..eca302b395b3 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -53,14 +53,16 @@ class PrinterConfigNode : public Object { DataType float_dtype = DataType::Void(); /*! \brief Whether or not to verbose print expressions. */ bool verbose_expr = false; - /* \brief Number of spaces used for indentation*/ + /*! \brief Number of spaces used for indentation*/ int indent_spaces = 4; - /* \brief Whether to print line numbers */ + /*! \brief Whether to print line numbers */ bool print_line_numbers = false; - /* \brief Number of context lines to print around the underlined text */ + /*! \brief Number of context lines to print around the underlined text */ int num_context_lines = -1; - /* \brief Object path to be underlined */ + /*! \brief Object path to be underlined */ Optional path_to_underline = NullOpt; + /*! \brief Whether to output with syntax sugar, set false for complete printing. */ + bool syntax_sugar = true; void VisitAttrs(AttrVisitor* v) { v->Visit("ir_prefix", &ir_prefix); @@ -72,6 +74,7 @@ class PrinterConfigNode : public Object { v->Visit("print_line_numbers", &print_line_numbers); v->Visit("num_context_lines", &num_context_lines); v->Visit("path_to_underline", &path_to_underline); + v->Visit("syntax_sugar", &syntax_sugar); } static constexpr const char* _type_key = "node.PrinterConfig"; diff --git a/include/tvm/node/structural_equal.h b/include/tvm/node/structural_equal.h index 371b8f9c7bd9..5bd76404a998 100644 --- a/include/tvm/node/structural_equal.h +++ b/include/tvm/node/structural_equal.h @@ -153,6 +153,13 @@ class SEqualReducer { */ virtual void DeferFail(const ObjectPathPair& mismatch_paths) = 0; + /*! + * \brief Check if fail defferal is enabled. + * + * \return false if the fail deferral is not enabled, true otherwise. + */ + virtual bool IsFailDeferralEnabled() = 0; + /*! * \brief Lookup the graph node equal map for vars that are already mapped. * @@ -331,12 +338,14 @@ class SEqualReducer { */ class SEqualHandlerDefault : public SEqualReducer::Handler { public: - SEqualHandlerDefault(bool assert_mode, Optional* first_mismatch); + SEqualHandlerDefault(bool assert_mode, Optional* first_mismatch, + bool defer_fails); virtual ~SEqualHandlerDefault(); bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, const Optional& current_paths) override; void DeferFail(const ObjectPathPair& mismatch_paths) override; + bool IsFailDeferralEnabled() override; ObjectRef MapLhsToRhs(const ObjectRef& lhs) override; void MarkGraphNode() override; diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 23144c47f1ee..19d8e34ce85c 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -39,6 +39,7 @@ class PrinterConfig(Object): print_line_numbers: bool num_context_lines: int path_to_underline: Optional[ObjectPath] + syntax_sugar: bool def __init__( self, @@ -54,6 +55,7 @@ def __init__( print_line_numbers: bool = False, num_context_lines: Optional[int] = None, path_to_underline: Optional[ObjectPath] = None, + syntax_sugar: bool = True, ) -> None: if num_context_lines is None: num_context_lines = -1 @@ -71,6 +73,7 @@ def __init__( "print_line_numbers": print_line_numbers, "num_context_lines": num_context_lines, "path_to_underline": path_to_underline, + "syntax_sugar": syntax_sugar, }, ) @@ -96,6 +99,7 @@ def script( print_line_numbers: bool = False, num_context_lines: int = -1, path_to_underline: Optional[ObjectPath] = None, + syntax_sugar: bool = True, ) -> str: """Print TVM IR into TVMScript text format @@ -123,6 +127,8 @@ def script( The number of lines of context to print before and after the line to underline. path_to_underline : Optional[ObjectPath] = None Object path to be underlined + syntax_sugar: bool = True + Whether to output with syntax sugar, set false for complete printing. Returns ------- @@ -143,6 +149,7 @@ def script( print_line_numbers=print_line_numbers, num_context_lines=num_context_lines, path_to_underline=path_to_underline, + syntax_sugar=syntax_sugar, ), ) @@ -162,6 +169,7 @@ def show( print_line_numbers: bool = False, num_context_lines: int = -1, path_to_underline: Optional[ObjectPath] = None, + syntax_sugar: bool = True, ) -> None: """A sugar for print highlighted TVM script. @@ -194,6 +202,8 @@ def show( The number of lines of context to print before and after the line to underline. path_to_underline : Optional[ObjectPath] = None Object path to be underlined + syntax_sugar: bool = True + Whether to output with syntax sugar, set false for complete printing. """ from tvm.script.highlight import ( # pylint: disable=import-outside-toplevel cprint, @@ -212,6 +222,7 @@ def show( print_line_numbers=print_line_numbers, num_context_lines=num_context_lines, path_to_underline=path_to_underline, + syntax_sugar=syntax_sugar, ), style=style, black_format=black_format, diff --git a/src/meta_schedule/module_equality.cc b/src/meta_schedule/module_equality.cc index f5757adf08a8..0997aab9b6a6 100644 --- a/src/meta_schedule/module_equality.cc +++ b/src/meta_schedule/module_equality.cc @@ -38,7 +38,7 @@ class ModuleEqualityStructural : public ModuleEquality { class SEqualHandlerIgnoreNDArray : public SEqualHandlerDefault { public: - SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr) {} + SEqualHandlerIgnoreNDArray() : SEqualHandlerDefault(false, nullptr, false) {} protected: bool DispatchSEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 605d5208462f..d8787259b50e 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -29,6 +29,9 @@ TVMScriptPrinter::FType& TVMScriptPrinter::vtable() { } std::string TVMScriptPrinter::Script(const ObjectRef& node, const Optional& cfg) { + if (!TVMScriptPrinter::vtable().can_dispatch(node)) { + return AsLegacyRepr(node); + } return TVMScriptPrinter::vtable()(node, cfg.value_or(PrinterConfig())); } @@ -67,6 +70,9 @@ PrinterConfig::PrinterConfig(Map config_dict) { if (auto v = config_dict.Get("path_to_underline")) { n->path_to_underline = Downcast(v); } + if (auto v = config_dict.Get("syntax_sugar")) { + n->syntax_sugar = Downcast(v)->value; + } this->data_ = std::move(n); } diff --git a/src/node/structural_equal.cc b/src/node/structural_equal.cc index 80e390d9b0ad..788f3b7a1f3f 100644 --- a/src/node/structural_equal.cc +++ b/src/node/structural_equal.cc @@ -202,8 +202,12 @@ bool SEqualReducer::ObjectAttrsEqual(const ObjectRef& lhs, const ObjectRef& rhs, */ class SEqualHandlerDefault::Impl { public: - Impl(SEqualHandlerDefault* parent, bool assert_mode, Optional* first_mismatch) - : parent_(parent), assert_mode_(assert_mode), first_mismatch_(first_mismatch) {} + Impl(SEqualHandlerDefault* parent, bool assert_mode, Optional* first_mismatch, + bool defer_fails) + : parent_(parent), + assert_mode_(assert_mode), + first_mismatch_(first_mismatch), + defer_fails_(defer_fails) {} bool SEqualReduce(const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars, const Optional& current_paths) { @@ -245,6 +249,8 @@ class SEqualHandlerDefault::Impl { pending_tasks_.emplace_back(Task::ForceFailTag{}, mismatch_paths); } + bool IsFailDeferralEnabled() { return defer_fails_; } + void MarkGraphNode() { // need to push to pending tasks in this case ICHECK(!allow_push_to_stack_ && !task_stack_.empty()); @@ -264,6 +270,8 @@ class SEqualHandlerDefault::Impl { pending_tasks_.clear(); equal_map_lhs_.clear(); equal_map_rhs_.clear(); + root_lhs_ = lhs; + root_rhs_ = rhs; Optional current_paths; if (IsPathTracingEnabled()) { @@ -313,10 +321,38 @@ class SEqualHandlerDefault::Impl { *first_mismatch_ = current_paths; } if (assert_mode_ && !result) { - LOG(FATAL) << "ValueError: StructuralEqual check failed, caused by lhs:" << std::endl - << lhs << std::endl - << "and rhs:" << std::endl - << rhs; + std::ostringstream oss; + oss << "ValueError: StructuralEqual check failed, caused by lhs"; + if (first_mismatch_->defined()) { + oss << " at " << first_mismatch_->value()->lhs_path; + if (root_lhs_.defined()) { + Map dict = {{"path_to_underline", first_mismatch_->value()->lhs_path}, + {"syntax_sugar", Bool(false)}}; + PrinterConfig cfg(dict); + // The TVMScriptPrinter::Script will fallback to Repr printer, + // if the root node to print is not supported yet, + // e.g. Relay nodes, ArrayNode, MapNode, etc. + oss << ":" << std::endl << TVMScriptPrinter::Script(root_lhs_.value(), cfg); + } + } else { + oss << ":" << std::endl << lhs; + } + oss << std::endl << "and rhs"; + if (first_mismatch_->defined()) { + oss << " at " << first_mismatch_->value()->rhs_path; + if (root_rhs_.defined()) { + Map dict = {{"path_to_underline", first_mismatch_->value()->rhs_path}, + {"syntax_sugar", Bool(false)}}; + PrinterConfig cfg(dict); + // The TVMScriptPrinter::Script will fallback to Repr printer, + // if the root node to print is not supported yet, + // e.g. Relay nodes, ArrayNode, MapNode, etc. + oss << ":" << std::endl << TVMScriptPrinter::Script(root_rhs_.value(), cfg); + } + } else { + oss << ":" << std::endl << rhs; + } + LOG(FATAL) << oss.str(); } return result; } @@ -419,7 +455,8 @@ class SEqualHandlerDefault::Impl { bool allow_push_to_stack_{true}; // If in assert mode, must return true, and will throw error otherwise. bool assert_mode_{false}; - // Location to store the paths to the first detected mismatch, or nullptr to disable path tracing. + // Location to store the paths to the first detected mismatch, or nullptr to disable path + // tracing. Optional* first_mismatch_; // reflection vtable ReflectionVTable* vtable_ = ReflectionVTable::Global(); @@ -427,11 +464,18 @@ class SEqualHandlerDefault::Impl { std::unordered_map equal_map_lhs_; // map from rhs to lhs std::unordered_map equal_map_rhs_; + // root lhs for result printing + Optional root_lhs_; + // root rhs for result printing + Optional root_rhs_; + // whether to defer fails + bool defer_fails_; }; SEqualHandlerDefault::SEqualHandlerDefault(bool assert_mode, - Optional* first_mismatch) { - impl = new Impl(this, assert_mode, first_mismatch); + Optional* first_mismatch, + bool defer_fails) { + impl = new Impl(this, assert_mode, first_mismatch, defer_fails); } SEqualHandlerDefault::~SEqualHandlerDefault() { delete impl; } @@ -446,6 +490,8 @@ void SEqualHandlerDefault::DeferFail(const ObjectPathPair& mismatch_paths) { impl->DeferFail(mismatch_paths); } +bool SEqualHandlerDefault::IsFailDeferralEnabled() { return impl->IsFailDeferralEnabled(); } + ObjectRef SEqualHandlerDefault::MapLhsToRhs(const ObjectRef& lhs) { return impl->MapLhsToRhs(lhs); } void SEqualHandlerDefault::MarkGraphNode() { impl->MarkGraphNode(); } @@ -463,19 +509,22 @@ bool SEqualHandlerDefault::DispatchSEqualReduce(const ObjectRef& lhs, const Obje TVM_REGISTER_GLOBAL("node.StructuralEqual") .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool assert_mode, bool map_free_vars) { - return SEqualHandlerDefault(assert_mode, nullptr).Equal(lhs, rhs, map_free_vars); + Optional first_mismatch; + return SEqualHandlerDefault(assert_mode, &first_mismatch, false) + .Equal(lhs, rhs, map_free_vars); }); TVM_REGISTER_GLOBAL("node.GetFirstStructuralMismatch") .set_body_typed([](const ObjectRef& lhs, const ObjectRef& rhs, bool map_free_vars) { Optional first_mismatch; - bool equal = SEqualHandlerDefault(false, &first_mismatch).Equal(lhs, rhs, map_free_vars); + bool equal = + SEqualHandlerDefault(false, &first_mismatch, true).Equal(lhs, rhs, map_free_vars); ICHECK(equal == !first_mismatch.defined()); return first_mismatch; }); bool StructuralEqual::operator()(const ObjectRef& lhs, const ObjectRef& rhs) const { - return SEqualHandlerDefault(false, nullptr).Equal(lhs, rhs, false); + return SEqualHandlerDefault(false, nullptr, false).Equal(lhs, rhs, false); } bool NDArrayEqual(const runtime::NDArray::Container* lhs, const runtime::NDArray::Container* rhs, diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 0426b8454dce..fa77b47bd284 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -467,16 +467,18 @@ struct ArrayNodeTrait { // (2) a b c d e g h i j k l m // ^ // error here - if (lhs->size() > min_size) { - equal->DeferFail({array_paths->lhs_path->ArrayIndex(min_size), - array_paths->rhs_path->MissingArrayElement(min_size)}); - } else { - equal->DeferFail({array_paths->lhs_path->MissingArrayElement(min_size), - array_paths->rhs_path->ArrayIndex(min_size)}); + if (equal->IsFailDeferralEnabled()) { + if (lhs->size() > min_size) { + equal->DeferFail({array_paths->lhs_path->ArrayIndex(min_size), + array_paths->rhs_path->MissingArrayElement(min_size)}); + } else { + equal->DeferFail({array_paths->lhs_path->MissingArrayElement(min_size), + array_paths->rhs_path->ArrayIndex(min_size)}); + } + // Can return `true` pretending that everything is good since we have deferred the failure. + return true; } - - // Can return `true` pretending that everything is good since we have deferred the failure. - return true; + return false; } }; TVM_REGISTER_REFLECTION_VTABLE(ArrayNode, ArrayNodeTrait) diff --git a/src/script/printer/doc_printer/python_doc_printer.cc b/src/script/printer/doc_printer/python_doc_printer.cc index 334f76f72280..9d20afa148b5 100644 --- a/src/script/printer/doc_printer/python_doc_printer.cc +++ b/src/script/printer/doc_printer/python_doc_printer.cc @@ -510,7 +510,9 @@ void PythonDocPrinter::PrintTypedDoc(const SliceDoc& doc) { void PythonDocPrinter::PrintTypedDoc(const StmtBlockDoc& doc) { for (const StmtDoc& stmt : doc->stmts) { PrintDoc(stmt); - NewLine(); + if (stmt != doc->stmts.back()) { + NewLine(); + } } } diff --git a/src/script/printer/tir/block.cc b/src/script/printer/tir/block.cc index a5b8d6609622..979a27135cca 100644 --- a/src/script/printer/tir/block.cc +++ b/src/script/printer/tir/block.cc @@ -44,7 +44,7 @@ Doc PrintBlock(IRDocsifier d, tir::Block block, ObjectPath block_p, // std::vector remap_vars_indices; auto add_remapped_iter_var = [&](int i) -> bool { - if (realize) { + if (realize && d->cfg->syntax_sugar) { tir::ExprDeepEqual expr_equal; tir::IterVar iter_var = block->iter_vars[i]; PrimExpr value = realize->iter_values[i]; diff --git a/src/script/printer/tir/for_loop.cc b/src/script/printer/tir/for_loop.cc index 7d21de27a1a2..107521c94791 100644 --- a/src/script/printer/tir/for_loop.cc +++ b/src/script/printer/tir/for_loop.cc @@ -32,17 +32,19 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return grid_loop_vars.count(v); }); }; - for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as()) { - ICHECK(l->loop_var->dtype == l->min->dtype); - ICHECK(l->loop_var->dtype == l->extent->dtype); - if (l->kind != tir::ForKind::kSerial || // - !tir::is_zero(l->min) || // - !l->annotations.empty() || // - f_var_dep(l->extent)) { - break; + if (d->cfg->syntax_sugar) { + for (const tir::ForNode* l = loop.get(); l != nullptr; l = l->body.as()) { + ICHECK(l->loop_var->dtype == l->min->dtype); + ICHECK(l->loop_var->dtype == l->extent->dtype); + if (l->kind != tir::ForKind::kSerial || // + !tir::is_zero(l->min) || // + !l->annotations.empty() || // + f_var_dep(l->extent)) { + break; + } + grid.push_back(l); + grid_loop_vars.insert(l->loop_var.get()); } - grid.push_back(l); - grid_loop_vars.insert(l->loop_var.get()); } With f(d, loop); // Step 2. Construct `T.grid` @@ -114,7 +116,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) kwargs_values.push_back(annotations.value()); } ExprDoc rhs = prefix->Call(args, kwargs_keys, kwargs_values); - AsDocBody(loop->body, loop_p, (*f).get(), d); + AsDocBody(loop->body, loop_p->Attr("body"), (*f).get(), d); return ForDoc(lhs, rhs, (*f)->stmts); }); diff --git a/src/script/printer/tir/function.cc b/src/script/printer/tir/function.cc index 65f3db5b4fec..c3f9244962d6 100644 --- a/src/script/printer/tir/function.cc +++ b/src/script/printer/tir/function.cc @@ -95,7 +95,8 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) for (int i = 0; i < n_args; ++i) { tir::Var var = func->params[i]; ObjectPath var_p = p->Attr("params")->ArrayIndex(i); - if (CountVarOccurrence(func, var) == 2 && func->buffer_map.count(var)) { + if (d->cfg->syntax_sugar && CountVarOccurrence(func, var) == 2 && + func->buffer_map.count(var)) { tir::Buffer buffer = func->buffer_map[var]; if (IsSimpleBuffer(buffer) && buffer_data_counter.at(buffer->data.get()) == 1) { ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(var); @@ -122,11 +123,11 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) if (buffer_inlined.count(buffer.get())) { continue; } - ExprDoc param = args[i]->lhs; + ExprDoc param_doc = args[i]->lhs; ObjectPath buffer_p = p->Attr("buffer_map")->MapValue(param); ExprDoc lhs = DefineBuffer(buffer, *frame, d); // TODO(@junrushao): switch `lhs` and `rhs` - ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param}, buffer_p, *frame, d); + ExprDoc rhs = BufferDecl(buffer, "match_buffer", {param_doc}, buffer_p, *frame, d); (*frame)->stmts.push_back(AssignDoc(lhs, rhs, NullOpt)); } } @@ -150,9 +151,9 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) } return NullOpt; }(); - if (implicit_root_block) { + if (d->cfg->syntax_sugar && implicit_root_block) { tir::Block root_block = implicit_root_block.value(); - ObjectPath root_block_p = p->Attr("body")->Attr("body"); + ObjectPath root_block_p = p->Attr("body")->Attr("block"); (*frame)->stmts.push_back(CommentDoc("with T.block(\"root\"):")); // Handle root block `alloc_buffer` for (int i = 0, n = root_block->alloc_buffers.size(); i < n; ++i) { diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 7556f820df74..b730dd5606ba 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -179,7 +179,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::Allocate stmt, ObjectPath stmt_p, IRDocsifier d) -> Doc { bool concise = AllowConciseScoping(d); - if (IsAllocateDeclBufferPattern(stmt.get())) { + if (d->cfg->syntax_sugar && IsAllocateDeclBufferPattern(stmt.get())) { return d->AsDoc(stmt->body, stmt_p->Attr("body")); } Array args; diff --git a/src/script/printer/tir/utils.h b/src/script/printer/tir/utils.h index 0eead9a57713..18c64c5edcfe 100644 --- a/src/script/printer/tir/utils.h +++ b/src/script/printer/tir/utils.h @@ -115,10 +115,10 @@ inline IdDoc DefineBuffer(const tir::Buffer& buffer, const Frame& frame, const I inline void AsDocBody(const tir::Stmt& stmt, ObjectPath p, TIRFrameNode* f, const IRDocsifier& d) { if (const auto* seq_stmt = stmt.as()) { Array body = seq_stmt->seq; - p = p->Attr("seq"); for (int i = 0, n = body.size(); i < n; ++i) { f->allow_concise_scoping = (i == n - 1); - Doc doc = d->AsDoc(body[i], p->ArrayIndex(i)); + Doc doc = d->AsDoc(body[i], p->Attr("seq")->ArrayIndex(i)); + doc->source_paths.push_back(p); if (const auto* block = doc.as()) { f->stmts.insert(f->stmts.end(), block->stmts.begin(), block->stmts.end()); } else { diff --git a/src/tir/analysis/deep_equal.cc b/src/tir/analysis/deep_equal.cc index 451855c8f870..1ec9fc5522c8 100644 --- a/src/tir/analysis/deep_equal.cc +++ b/src/tir/analysis/deep_equal.cc @@ -44,6 +44,7 @@ class DeepCmpSEqualHandler : public SEqualReducer::Handler { } void DeferFail(const ObjectPathPair&) final { fail_ = true; } + bool IsFailDeferralEnabled() final { return false; } ObjectRef MapLhsToRhs(const ObjectRef& lhs) final { return ObjectRef(nullptr); } void MarkGraphNode() final {} diff --git a/tests/python/unittest/test_tvmscript_printer_structural_equal.py b/tests/python/unittest/test_tvmscript_printer_structural_equal.py new file mode 100644 index 000000000000..4bd578eab768 --- /dev/null +++ b/tests/python/unittest/test_tvmscript_printer_structural_equal.py @@ -0,0 +1,134 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +import tvm +from tvm.ir import assert_structural_equal +from tvm.relay.op.transform import split +from tvm.runtime import ObjectPath +from tvm.script import tir as T + + +def _error_message(exception): + splitter = "ValueError: StructuralEqual" + return splitter + str(exception).split(splitter)[1] + + +def _expected_result(func1, func2, objpath1, objpath2): + return f"""ValueError: StructuralEqual check failed, caused by lhs at {objpath1}: +{func1.script(path_to_underline=objpath1, syntax_sugar=False)} +and rhs at {objpath2}: +{func2.script(path_to_underline=objpath2, syntax_sugar=False)}""" + + +def test_prim_func_buffer_map(): + @T.prim_func + def func1(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 128)) + + @T.prim_func + def func2(a: T.handle, b: T.handle): + A = T.match_buffer(a, (128, 128)) + B = T.match_buffer(b, (128, 256)) + + with pytest.raises(ValueError) as ve: + assert_structural_equal(func1, func2) + assert _error_message(ve.value) == _expected_result( + func1, + func2, + ObjectPath.root() + .attr("buffer_map") + .map_value(func1.params[1]) + .attr("shape") + .array_index(1) + .attr("value"), + ObjectPath.root() + .attr("buffer_map") + .map_value(func2.params[1]) + .attr("shape") + .array_index(1) + .attr("value"), + ) + + +def test_evaluate(): + @T.prim_func + def func1(): + T.evaluate(0) + + @T.prim_func + def func2(): + T.evaluate(1) + + with pytest.raises(ValueError) as ve: + assert_structural_equal(func1, func2) + assert _error_message(ve.value) == _expected_result( + func1, + func2, + ObjectPath.root().attr("body").attr("value").attr("value"), + ObjectPath.root().attr("body").attr("value").attr("value"), + ) + + +def test_allocate(): + @T.prim_func + def func1(): + a_data = T.allocate((128, 128), dtype="float32") + a = T.decl_buffer((128, 128), dtype="float32", data=a_data) + + @T.prim_func + def func2(): + a_data = T.allocate((256, 128), dtype="float32") + a = T.decl_buffer((256, 128), dtype="float32", data=a_data) + + with pytest.raises(ValueError) as ve: + assert_structural_equal(func1, func2) + assert _error_message(ve.value) == _expected_result( + func1, + func2, + ObjectPath.root().attr("body").attr("extents").array_index(0).attr("value"), + ObjectPath.root().attr("body").attr("extents").array_index(0).attr("value"), + ) + + +def test_for(): + @T.prim_func + def func1(): + for i, j in T.grid(128, 128): + with T.block(): + pass + + @T.prim_func + def func2(): + for i, j, k in T.grid(128, 128, 128): + with T.block(): + pass + + with pytest.raises(ValueError) as ve: + assert_structural_equal(func1, func2) + assert _error_message(ve.value) == _expected_result( + func1, + func2, + ObjectPath.root().attr("body").attr("block").attr("body").attr("body").attr("body"), + ObjectPath.root().attr("body").attr("block").attr("body").attr("body").attr("body"), + ) + + +if __name__ == "__main__": + tvm.testing.main()