Skip to content

Commit

Permalink
[Bugfix] Handle function name properly in Relax TVMScript printer (ap…
Browse files Browse the repository at this point in the history
…ache#317)

* remove relax_func_name_ and change logic

* well_formed check for globalvar and gsymbol consistency

* revise the logic in well_formed and update test

* Remove `global_symbol` in test_function_attr.py

* Update docs

Co-authored-by: Ruihang Lai <ruihangl@cs.cmu.edu>
  • Loading branch information
2 people authored and junrushao committed Feb 5, 2023
1 parent 40fd2ef commit 0774958
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 21 deletions.
30 changes: 22 additions & 8 deletions src/relax/analysis/well_formed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@
* If it's malformed, messages will be logged as Warning.
* This pass will check:
* 1. GlobalVars are defined before use.
* 2. Vars are defined before use.
* 3. Vars are defined exactly once.
* 4. Symbolic Vars are defined before use.
* 5. DataflowVars cannot be defined inside BindingBlock.
* 6. Vars defined in IfNode, except the return Var, are invisible
* 2. When a Function has a corresponding GlobalVar and a `global_symbol`
* attribute, the name of the GlobalVar must equal the value of the
* `global_symbol` attribute value.
* 3. Vars are defined before use.
* 4. Vars are defined exactly once.
* 5. Symbolic Vars are defined before use.
* 6. DataflowVars cannot be defined inside BindingBlock.
* 7. Vars defined in IfNode, except the return Var, are invisible
* out of the If body.(May change for new AST designs)
* 6. SeqExpr only serves as function body, or in the true and
* 8. SeqExpr only serves as function body, or in the true and
* false branches in IfNode.
* 7. The IR is in ANF:
* 9. The IR is in ANF:
* (a) Expressions cannot contain nested complex expressions.
* Here are the expressions that may be nested inside other expressions:
* Var, DataflowVar, GlobalVar, Constant, ShapeExpr, RuntimeDepShape,
Expand All @@ -48,7 +51,7 @@
* * The cond field of If nodes
* * The op or args fields of Call nodes
* * Inside the fields of Tuple nodes
* 8. Expr always has checked_type_ (with the exception of Op).
* 10. Expr always has checked_type_ (with the exception of Op).
*/
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr.h>
Expand Down Expand Up @@ -86,6 +89,16 @@ class WellFormedChecker : public relax::ExprVisitor,

void RegisterGlobalVar(GlobalVar var) { global_var_set_.insert(var); }

void CheckGlobalVarAndGsymbolConsistency(GlobalVar var, Function func) {
// check name in global var and gsymbol
Optional<String> gsymbol = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (gsymbol.defined() && gsymbol != var->name_hint) {
Malformed(Diagnostic::Error(func->span)
<< "Name in GlobalVar is not equal to name in gsymbol: " << var->name_hint
<< " != " << gsymbol.value());
}
}

private:
// Possible mode of visitor
enum class VisitMode {
Expand Down Expand Up @@ -413,6 +426,7 @@ bool WellFormed(const IRModule& m, Optional<DiagnosticContext> diag_ctx) {
// visit relax.Function
if (auto* n = it.second.as<FunctionNode>()) {
Function func = GetRef<Function>(n);
well_formed_checker.CheckGlobalVarAndGsymbolConsistency(it.first, func);
well_formed_checker.VisitExpr(func);
}
}
Expand Down
14 changes: 9 additions & 5 deletions src/relay/printer/relax_script_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,11 +356,10 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::SeqExprNode* op) {
Doc RelaxScriptPrinter::VisitNode_(const relax::FunctionNode* op) {
Optional<String> gsymbol = op->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (gsymbol) {
ICHECK_EQ(gsymbol.value(), relax_func_name_);
return PrintFunctionDef(Doc::Text(relax_func_name_), GetRef<relax::Function>(op),
return PrintFunctionDef(Doc::Text(gsymbol.value()), GetRef<relax::Function>(op),
/*is_global=*/true);
} else {
return PrintFunctionDef(Doc::Text(relax_func_name_), GetRef<relax::Function>(op),
return PrintFunctionDef(Doc::Text(relax_default_func_name_), GetRef<relax::Function>(op),
/*is_global=*/true);
}
}
Expand Down Expand Up @@ -536,8 +535,13 @@ Doc RelaxScriptPrinter::PrintIRModule(const IRModule& mod) {
if (pr.second.as<tir::PrimFuncNode>()) {
func = PrintPrimFunc(pr.first->name_hint, Downcast<tir::PrimFunc>(pr.second));
} else {
relax_func_name_ = pr.first->name_hint;
func = Print(pr.second);
Doc func_name;
Optional<String> gsymbol = pr.second->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (gsymbol.defined()) {
ICHECK_EQ(gsymbol.value(), pr.first->name_hint);
}
func_name << pr.first->name_hint;
func = PrintFunctionDef(func_name, Downcast<Function>(pr.second), true);
}
doc << Doc::Indent(4, Doc::NewLine() << func);
}
Expand Down
4 changes: 2 additions & 2 deletions src/relay/printer/text_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ class RelaxScriptPrinter : public relax::IRFunctor<Doc(const ObjectRef&)>,
size_t local_func_counter_ = 0;
/*! \brief meta data context. */
TextMetaDataContext* meta_;
/*! \brief the current relax function name. */
String relax_func_name_ = "foo";
/*! \brief default relax function name in printer. */
constexpr const static char* relax_default_func_name_ = "main";
/*!
* \brief A bool flag to indicate if we print symbolic shape as str, usually for global
* function.
Expand Down
14 changes: 14 additions & 0 deletions tests/python/relax/test_analysis_well_formed.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,5 +360,19 @@ def test_ANF():
assert not rx.analysis.well_formed(mod)


def test_global_var_vs_gsymbol():
# Error: gsymbol "main1" not equals to the name in global var "main"
gv0 = rx.Var("gv0", R.Tensor([m, n], "float32"))
bindings = [rx.VarBinding(gv0, x)]
blocks = [rx.DataflowBlock(bindings)]
func = rx.Function(
[x],
rx.SeqExpr(blocks, gv0),
R.Tensor(ndim=2, dtype="float32"),
).with_attr("global_symbol", "main1")
mod = tvm.IRModule({rx.GlobalVar("main"): func})
assert not rx.analysis.well_formed(mod)


if __name__ == "__main__":
pytest.main([__file__])
11 changes: 5 additions & 6 deletions tests/python/relax/test_function_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,20 @@ def test_func_attr_setter():
mod = InputModule
assert isinstance(mod, tvm.IRModule)

mod = annotate(mod, "relax_add", {"Codegen": "test-codegen", "global_symbol": "test-symbol"})
mod = annotate(mod, "relax_add", {"Codegen": "test-codegen"})
_check_save_roundtrip(mod)
annot_func = mod["relax_add"]

# Test annotation
assert annot_func.attrs
assert annot_func.attrs["Codegen"] == "test-codegen"
assert annot_func.attrs["global_symbol"] == "test-symbol"


def test_func_attr_roundtrip_and_equality():
mod = InputModule
assert isinstance(mod, tvm.IRModule)
mod1 = annotate(mod, "relax_add", {"Codegen": "test-codegen", "global_symbol": "test-symbol"})
mod2 = annotate(mod, "relax_add", {"Codegen": "test-codegen", "global_symbol": "test-symbol"})
mod1 = annotate(mod, "relax_add", {"Codegen": "test-codegen"})
mod2 = annotate(mod, "relax_add", {"Codegen": "test-codegen"})
_check_save_roundtrip(mod1)
_check_save_roundtrip(mod2)
_check_equal(mod1, mod2)
Expand All @@ -87,7 +86,7 @@ def test_func_attr_setter_with_passes():
mod = InputModule
assert isinstance(mod, tvm.IRModule)
# Annotate
mod = annotate(mod, "relax_add", {"Codegen": "test-codegen", "global_symbol": "test-symbol"})
mod = annotate(mod, "relax_add", {"Codegen": "test-codegen"})

# Test with passes
# Annotation should stay the same unless the pass needs to modify it
Expand All @@ -101,13 +100,13 @@ def test_func_attr_setter_with_passes():

# Apply passes
new_mod = seq(mod)
print(mod.script())
_check_save_roundtrip(new_mod)

# Test annotation
func = new_mod["relax_add"]
assert func.attrs
assert func.attrs["Codegen"] == "test-codegen"
assert func.attrs["global_symbol"] == "test-symbol"


def test_irmodule_attr_setter_with_passes():
Expand Down

0 comments on commit 0774958

Please sign in to comment.