From 0ac5dd7f645298e0ebf1e3be643842814efdc514 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Sat, 13 Aug 2022 20:07:45 -0700 Subject: [PATCH] [Fix] Fix errors in error checking and reporting (#12423) --- python/tvm/te/operation.py | 14 ++++++++------ python/tvm/tir/function.py | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/python/tvm/te/operation.py b/python/tvm/te/operation.py index ada5c369ad3b..b8e43c06386c 100644 --- a/python/tvm/te/operation.py +++ b/python/tvm/te/operation.py @@ -22,9 +22,9 @@ from typing import List import tvm._ffi +import tvm.arith._ffi_api import tvm.tir import tvm.tir._ffi_api -import tvm.arith._ffi_api from tvm._ffi.base import string_types from tvm.ir import Array from tvm.runtime import convert @@ -420,11 +420,13 @@ def before_split(a: T.handle, b: T.handle) -> None: ) for tensor, buffer in zip(input_tensors, input_buffers): # TODO(csullivan): Can a stronger comparison between Tensor<>Buffer be made? - assert tensor.shape == buffer.shape, ( - "The input input_tensors provided do not match the input buffers in the ", - "primfunc. Please check that the order of input te.Input_Tensors and the ", - "order of the primfunc variables in the params list agree.", - ) + assert len(tensor.shape) == len(buffer.shape) + for d1, d2 in zip(tensor.shape, buffer.shape): + assert d1 == d2, ( + "The input input_tensors provided do not match the input buffers in the ", + "primfunc. Please check that the order of input te.Input_Tensors and the ", + "order of the primfunc variables in the params list agree.", + ) output = extern( [buf.shape for buf in outputs], input_tensors, diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py index f06376147b9a..6c57e27b82e3 100644 --- a/python/tvm/tir/function.py +++ b/python/tvm/tir/function.py @@ -394,7 +394,7 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] = raise TypeError( "Expected mapping function to return list of " "either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. " - "Instead received {val} of type {type(val)}." + f"Instead received {val} of type {type(val)}." ) return IndexMap(initial_indices, final_indices), axis_separators