Skip to content

Commit

Permalink
[Fix][TVMScript] Fix index of metadata in printed script (#14130)
Browse files Browse the repository at this point in the history
Currently, if the same metadata object (e.g. a multi-line `tir.StringImm`) is referenced for more than one times in an IRModule, each reference will have different indices of the metadata array. For example, this code

```
str_imm = T.StringImm("aaa\nbbb\n")
@I.ir_module
class Module:
    @T.prim_func
    def foo() -> None:
        A = str_imm
        B = str_imm

    @T.prim_func
    def foo1() -> None:
        A = str_imm
Module.show()
```

where `str_imm` is referenced three times, will generate such output:

```
@I.ir_module
class Module:
    @T.prim_func
    def foo():
        A: T.handle = metadata["tir.StringImm"][0]
        B: T.handle = metadata["tir.StringImm"][1]
        T.evaluate(0)

    @T.prim_func
    def foo1():
        A: T.handle = metadata["tir.StringImm"][2]
        T.evaluate(0)
```

Each time has a different metadata index. 

This PR fixes this problem by detecting duplicate item in `IRDocsifierNode::AddMetadata`.
  • Loading branch information
Ubospica authored Feb 25, 2023
1 parent 9fab56c commit 1ad1994
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/script/printer/ir_docsifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) {
ICHECK(obj.defined()) << "TypeError: Cannot add nullptr to metadata";
String key = obj->GetTypeKey();
Array<ObjectRef>& array = metadata[key];
int index = array.size();
array.push_back(obj);
return IdDoc("metadata") //
[{LiteralDoc::Str(key, NullOpt)}] //
[{LiteralDoc::Int(index, NullOpt)}];
int index = std::find(array.begin(), array.end(), obj) - array.begin();
if (index == static_cast<int>(array.size())) {
array.push_back(obj);
}
return IdDoc("metadata")[{LiteralDoc::Str(key, NullOpt)}][{LiteralDoc::Int(index, NullOpt)}];
}

bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return obj2info.count(obj); }
Expand Down
47 changes: 47 additions & 0 deletions tests/python/unittest/test_tvmscript_printer_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
import tvm.testing
from tvm.script.parser import ir as I
from tvm.script.parser import tir as T


def test_str_metadata():
# This test is to check we reuse the existing metadata element for the same tir.StringImm
# So metadata["tir.StringImm"][0] will occur in the printed script for three times
str_imm = T.StringImm("aaa\nbbb\n")

@I.ir_module
class Module:
@T.prim_func
def foo() -> None:
A = str_imm
B = str_imm

@T.prim_func
def foo1() -> None:
A = str_imm

printed_str = Module.script(verbose_expr=True)
assert (
printed_str.count('metadata["tir.StringImm"][0]') == 3
and printed_str.count('metadata["tir.StringImm"][1]') == 0
)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 1ad1994

Please sign in to comment.