Skip to content

Commit

Permalink
[BugFix][TVMScript] Parser crash (apache#13630)
Browse files Browse the repository at this point in the history
This PR tries to fix the crash of parser when the old value of a var is an array but the new value is not. For example:

```python
from tvm.script import tir as T
def func_wrapper(shape, dtype):
    @T.prim_func
    def test_case():
        a = T.alloc_buffer(shape, dtype=dtype)
    
    return test_case


if __name__ == "__main__":
    a = np.zeros((10, 10), dtype="int8")
    print(func_wrapper((256, 256), dtype="int8").script())
```

In the above code, there are two assignment to var 'a'. In the global scope, its value is a numpy array. But it is a Buffer in the prim function. There is a table named 'name2value' to track the value of vars like 'a' here.
When the parser wants to update its value, it will compare the value between the new and the old assignment. Here the problem comes. When we use '==' to compare an array with a value, the result is an array too, which can not be used as a condition of a if stmt directly. So, the code above will emit an error:

```shell
error: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
 --> /workspace/code_newest/tvm/private_test/test_meta_programming.py:16:9
    |  
 16 |          a = T.alloc_buffer(shape, dtype=dtype)
    |          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```

This PR fixes this by change "==" to "is".

Co-authored-by: lightzhan-intellif <zhan.liang@intellif.com>
  • Loading branch information
2 people authored and Mikael Sevenier committed Dec 29, 2022
1 parent 999690b commit 195f16b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
8 changes: 6 additions & 2 deletions python/tvm/script/parser/core/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from collections import defaultdict
from contextlib import contextmanager
from typing import Any, Callable, Dict, List, Optional, Set, Union
import numpy as np
from tvm._ffi.base import TVMError

from tvm.error import DiagnosticError
Expand Down Expand Up @@ -150,8 +151,11 @@ def add(self, var: str, value: Any, allow_shadowing: bool = False):
The options of whether variable shadowing allwed for this variable.
"""
# Skip if the key and value are equal to those in the var_table
if self.name2value[var] and self.name2value[var][-1] == value:
return
if self.name2value[var] and isinstance(self.name2value[var][-1], type(value)):
if isinstance(value, np.ndarray) and (self.name2value[var][-1] == value).all():
return
elif self.name2value[var][-1] == value:
return
if allow_shadowing and var in self.frames[-1].vars:
# Shadowing
self.name2value[var][-1] = value
Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_tvmscript_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,20 @@ def test_multi_element_array_in_outmost_namespace():
tvm.ir.assert_structural_equal(func, rt_func)


def test_different_dtype_assignment_to_var():
@T.prim_func
def test_case():
a = T.alloc_buffer((10, 10), dtype="int8")

@T.prim_func
def func_ref():
a = T.alloc_buffer([10, 10], dtype="int8")
T.evaluate(0)

tvm.ir.assert_structural_equal(test_case, func_ref)


if __name__ == "__main__":
a = numpy.zeros((10, 10), dtype="int8")
test_multi_element_array_in_outmost_namespace()
test_different_dtype_assignment_to_var()

0 comments on commit 195f16b

Please sign in to comment.