Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] shape int32-int64 check error in trilu's te.compute #13029

Closed
ganler opened this issue Oct 11, 2022 · 0 comments · Fixed by #13123
Closed

[Bug] shape int32-int64 check error in trilu's te.compute #13029

ganler opened this issue Oct 11, 2022 · 0 comments · Fixed by #13123
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@ganler
Copy link
Contributor

ganler commented Oct 11, 2022

Expected behavior

TVM should successfully compile a model whose operators are supported.

Actual behavior

The compilation could fail when the model contains the recently supported trilu operator.

In the Steps to reproduce section, the minimal reproducible is derived from an ONNX model exported by PyTorch which uses int64 as shape arguments, mixing with int32 constants in TVM's frontend translator, causing the compilation to fail due to int32-int64 mismatch in check_op:

check_position = check_op(row_index, col_index - k)

A quick fix could just be aligning integer types of row_index and col_index - k before doing check_op.

Environment

fa17da22c73fb9e95c27e4c28130835b628caf6b on Ubuntu 20.04.

Steps to reproduce

Minimized reproducible.

import tvm
from tvm import relay

x1 = relay.var("x1", shape=[2, 1], dtype="float32")
x2 = relay.var("x2", shape=(1, 1, 1, 1), dtype="float32")
x3 = relay.var("x3", shape=(), dtype="int64")
v0 = relay.broadcast_to(x1, shape=relay.const([2, 1], dtype="int64"))
v2 = relay.divide(x2, v0)
v3 = relay.trilu(v0, x3)

f = relay.Function([x1, x2, x3], relay.Tuple([v2, v3]))
relay.create_executor("graph", device=tvm.cpu(), target="llvm").evaluate(f)
Log. Click to expand!
"""
Traceback (most recent call last):
  File "test.py", line 12, in <module>
    relay.create_executor("graph", device=tvm.cpu(), target="llvm").evaluate(f)
 ...
  25: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  24: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*)
  23: _ZN3tvm5relay9
  22: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  21: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  20: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  19: tvm::NodeFunctor<tvm::RelayExpr (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) const
  18: _ZZN3tvm5relay11ExprFunc
  17: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::TupleNode const*)
  16: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  15: tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  14: tvm::NodeFunctor<tvm::RelayExpr (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::RelayExpr (tvm::RelayExpr const&)>*) const
  13: _ZZN3tvm5relay11ExprFunc
  12: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*)
  11: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*)
  10: tvm::relay::tec::TECompilerImpl::Lower(tvm::relay::tec::CCacheKey const&)
  9: tvm::relay::tec::TECompilerImpl::LowerInternal(tvm::relay::tec::CCacheKey const&, tvm::GlobalVarSupply)
  8: tvm::relay::tec::PrimFuncFor(tvm::relay::Function const&, tvm::Target const&, tvm::GlobalVarSupply)
  7: tvm::relay::tec::ScheduleBuilder::Create(tvm::relay::Function const&, tvm::GlobalVarSupply)
  6: tvm::relay::tec::LowerToTECompute::Lower(tvm::relay::Function const&)
  5: tvm::relay::backend::MemoizedExprTranslator<tvm::runtime::Array<tvm::te::Tensor, void> >::VisitExpr(tvm::RelayExpr const&)
  4: tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>::VisitExpr(tvm::RelayExpr const&)
  3: tvm::NodeFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*)>::operator()(tvm::runtime::ObjectRef const&, tvm::relay::ExprFunctor<tvm::runtime::Array<tvm::te::Tensor, void> (tvm::RelayExpr const&)>*) const
  2: _ZZN3tvm5relay11ExprFunc
  1: tvm::relay::tec::LowerToTECompute::VisitExpr_(tvm::relay::CallNode const*)
  0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::$_2> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/relay/backend/te_compiler.py", line 317, in lower_call
    best_impl, outputs = select_implementation(op, call.attrs, inputs, ret_type, target)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/relay/backend/te_compiler.py", line 207, in select_implementation
    outs = impl.compute(attrs, inputs, out_type)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/relay/op/op.py", line 126, in compute
    return _OpImplementationCompute(self, attrs, inputs, out_type)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
  3: TVMFuncCall
  2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::$_3> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  1: tvm::relay::OpImplementation::Compute(tvm::Attrs const&, tvm::runtime::Array<tvm::te::Tensor, void> const&, tvm::Type const&)
  0: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<TVMFuncCreateFromCFunc::$_2> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", line 81, in cfun
    rv = local_pyfunc(*pyargs)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/relay/op/strategy/generic.py", line 1489, in _compute_trilu
    topi_compute(
  File "/home/jiawei/dev/tvm-official-release/python/tvm/topi/transform.py", line 1061, in trilu
    return te.compute(data.shape, _apply_trilu, name="trilu")
  File "/home/jiawei/dev/tvm-official-release/python/tvm/te/operation.py", line 132, in compute
    body = fcompute(*[v.var for v in dim_var])
  File "/home/jiawei/dev/tvm-official-release/python/tvm/topi/transform.py", line 1057, in _apply_trilu
    check_position = check_op(row_index, col_index - k)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/tir/expr.py", line 881, in __init__
    self.__init_handle_by_constructor__(_ffi_api.LE, a, b, span)  # type: ignore
  File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/object.py", line 145, in __init_handle_by_constructor__
    handle = __init_by_constructor__(fconstructor, args)
  File "/home/jiawei/dev/tvm-official-release/python/tvm/_ffi/_ctypes/packed_func.py", line 260, in __init_handle_by_constructor__
    raise get_last_ffi_error()
  2: TVMFuncCall
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::tir::LE (tvm::PrimExpr, tvm::PrimExpr, tvm::Span)>::AssignTypedLambda<tvm::tir::$_51>(tvm::tir::$_51, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: tvm::tir::LE::LE(tvm::PrimExpr, tvm::PrimExpr, tvm::Span)
  File "/home/jiawei/dev/tvm-official-release/src/tir/ir/expr.cc", line 459
TypeError: Check failed: (a.dtype() == b.dtype()) is false: mismatched types. int32 vs. int64
"""

Triage

Please refer to the list of label tags linked above to find the relevant tags and add them here in a bullet format (example below).

  • needs-triage

cc: @jwfromm

@ganler ganler added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels Oct 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant