Skip to content

Commit

Permalink
[TIR] Use IRModuleNode::Remove to remove None in PrimFuncPass (#14494)
Browse files Browse the repository at this point in the history
Prior to this commit, `PrimFuncPass` directly removed empty `PrimFunc`
objects from the module's `Map<GlobalVar, BaseFunc> functions`.
Because it didn't update the `global_var_map_` as well, these two maps
could become out of sync.  Since the `global_var_map_` is checked as
part of `StructuralEqual()`, but isn't displayed when printing to
TVMScript, this can result in identical printouts being flagged as
non-identical.

This commit updates `PrimFuncPass` to call the `IRModuleNode::Remove`
method, which updates both the `functions` and `global_var_map_`
variables.
  • Loading branch information
Lunderberg authored Apr 5, 2023
1 parent 4b6e635 commit deb11d3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 5 deletions.
11 changes: 7 additions & 4 deletions src/tir/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ PrimFuncPass::PrimFuncPass(
// Perform Module -> Module optimizations at the PrimFunc level.
IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx) const {
ICHECK(mod.defined());
std::vector<ObjectRef> deleted_list;
std::vector<GlobalVar> deleted_list;

IRModuleNode* mod_ptr = mod.CopyOnWrite();
auto* func_dict = mod_ptr->functions.CopyOnWrite();
// directly loop over the underlying dict
Expand All @@ -101,14 +102,16 @@ IRModule PrimFuncPassNode::operator()(IRModule mod, const PassContext& pass_ctx)
kv.second = std::move(func);

if (!kv.second.defined()) {
deleted_list.push_back(kv.first);
deleted_list.push_back(Downcast<GlobalVar>(kv.first));
}
}
}

// automatic removal of None
// Automatic removal of None. This uses IRModuleNode::Remove
// instead of manipulating func_dict directly, to ensure that both
// the function map and the global_var_map_ are correctly updated.
for (const auto& gv : deleted_list) {
func_dict->erase(gv);
mod_ptr->Remove(gv);
}
return mod;
}
Expand Down
30 changes: 29 additions & 1 deletion tests/python/unittest/test_tir_transform_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest

import tvm
from tvm.script import tir as T
from tvm.script import tir as T, ir as I
import tvm.testing


Expand Down Expand Up @@ -119,5 +119,33 @@ def checker_filter_out_both(func: tvm.tir.PrimFunc):
assert len(after.functions) == 0


class TestFilterRemovesGlobalVarMap(tvm.testing.CompareBeforeAfter):
"""Filtering out a function should be identical to never adding it
This test is to guard against hidden state in the IRModule that
remains after filtering. Previously, this was observed in the
`IRModuleNode::global_var_map_`, which retained entries of
filtered-out functions.
"""

transform = tvm.tir.transform.Filter(lambda prim_func: False)

def before(self):
@I.ir_module
class module:
@T.prim_func
def func():
T.evaluate(0)

return module

def expected(self):
@I.ir_module
class module:
pass

return module


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit deb11d3

Please sign in to comment.