Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[Fix] Fix errors in error checking and reporting (apache#12423)
Browse files Browse the repository at this point in the history
  • Loading branch information
cyx-6 authored and xinetzone committed Nov 25, 2022
1 parent c95c13a commit 68ec456
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
14 changes: 8 additions & 6 deletions python/tvm/te/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from typing import List

import tvm._ffi
import tvm.arith._ffi_api
import tvm.tir
import tvm.tir._ffi_api
import tvm.arith._ffi_api
from tvm._ffi.base import string_types
from tvm.ir import Array
from tvm.runtime import convert
Expand Down Expand Up @@ -420,11 +420,13 @@ def before_split(a: T.handle, b: T.handle) -> None:
)
for tensor, buffer in zip(input_tensors, input_buffers):
# TODO(csullivan): Can a stronger comparison between Tensor<>Buffer be made?
assert tensor.shape == buffer.shape, (
"The input input_tensors provided do not match the input buffers in the ",
"primfunc. Please check that the order of input te.Input_Tensors and the ",
"order of the primfunc variables in the params list agree.",
)
assert len(tensor.shape) == len(buffer.shape)
for d1, d2 in zip(tensor.shape, buffer.shape):
assert d1 == d2, (
"The input input_tensors provided do not match the input buffers in the ",
"primfunc. Please check that the order of input te.Input_Tensors and the ",
"order of the primfunc variables in the params list agree.",
)
output = extern(
[buf.shape for buf in outputs],
input_tensors,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def from_func_with_separators(mapping_function: Callable, ndim: Optional[int] =
raise TypeError(
"Expected mapping function to return list of "
"either tvm.ir.PrimExpr or IndexMap.AXIS_SEPARATOR. "
"Instead received {val} of type {type(val)}."
f"Instead received {val} of type {type(val)}."
)

return IndexMap(initial_indices, final_indices), axis_separators
Expand Down

0 comments on commit 68ec456

Please sign in to comment.