diff --git a/python/tvm/script/parser/core/parser.py b/python/tvm/script/parser/core/parser.py index c6d43f11cbf5..7c699c42aecb 100644 --- a/python/tvm/script/parser/core/parser.py +++ b/python/tvm/script/parser/core/parser.py @@ -19,6 +19,7 @@ from collections import defaultdict from contextlib import contextmanager from typing import Any, Callable, Dict, List, Optional, Set, Union +import numpy as np from tvm._ffi.base import TVMError from tvm.error import DiagnosticError @@ -150,8 +151,11 @@ def add(self, var: str, value: Any, allow_shadowing: bool = False): The options of whether variable shadowing allwed for this variable. """ # Skip if the key and value are equal to those in the var_table - if self.name2value[var] and self.name2value[var][-1] == value: - return + if self.name2value[var] and isinstance(self.name2value[var][-1], type(value)): + if isinstance(value, np.ndarray) and (self.name2value[var][-1] == value).all(): + return + elif self.name2value[var][-1] == value: + return if allow_shadowing and var in self.frames[-1].vars: # Shadowing self.name2value[var][-1] = value diff --git a/tests/python/unittest/test_tvmscript_regression.py b/tests/python/unittest/test_tvmscript_regression.py index 3ad8090893eb..05c1665ea2a1 100644 --- a/tests/python/unittest/test_tvmscript_regression.py +++ b/tests/python/unittest/test_tvmscript_regression.py @@ -45,5 +45,20 @@ def test_multi_element_array_in_outmost_namespace(): tvm.ir.assert_structural_equal(func, rt_func) +def test_different_dtype_assignment_to_var(): + @T.prim_func + def test_case(): + a = T.alloc_buffer((10, 10), dtype="int8") + + @T.prim_func + def func_ref(): + a = T.alloc_buffer([10, 10], dtype="int8") + T.evaluate(0) + + tvm.ir.assert_structural_equal(test_case, func_ref) + + if __name__ == "__main__": + a = numpy.zeros((10, 10), dtype="int8") test_multi_element_array_in_outmost_namespace() + test_different_dtype_assignment_to_var()