diff --git a/python/tvm/script/builder/tir/var.py b/python/tvm/script/builder/tir/var.py index f5c4e68dc2ce..7c89b173c7fe 100644 --- a/python/tvm/script/builder/tir/var.py +++ b/python/tvm/script/builder/tir/var.py @@ -15,10 +15,11 @@ # specific language governing permissions and limitations # under the License. """TVM Script TIR Buffer""" -from tvm.ir import PrimExpr, Array, Range -from tvm.tir import Var, IntImm, BufferLoad, BufferRegion from tvm._ffi import register_object as _register_object -from tvm.runtime import Object, DataType +from tvm.ir import Array, PrimExpr, Range +from tvm.runtime import DataType, Object +from tvm.tir import BufferLoad, BufferRegion, IntImm, Var + from . import _ffi_api @@ -94,6 +95,8 @@ def buffer_type(self) -> int: return self.buffer.buffer_type def __getitem__(self, indices): + if not isinstance(indices, (tuple, list)): + indices = [indices] if any(isinstance(index, slice) for index in indices): region = [] for index in indices: