Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug][Relay] cannot squeeze axis with dimension not equal to 1 at Keras frontend #14869

Open
jikechao opened this issue May 17, 2023 · 5 comments
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug

Comments

@jikechao
Copy link
Contributor

jikechao commented May 17, 2023

For the LSTM below, when batch_size != 1 (i.e., the size of first dimension input), compiling will lead to an unexpected crash and throw Check failed: *axis_ptr == 1 (2 vs. 1) : cannot squeeze axis with dimension not equal to 1
Notice that, when batch_size ==1, TVM can run well.

image

Question:

For the LSTM model, why batch_size !=1 leads to a crash?

Compile successfully

Actual behavior

Traceback (most recent call last):
  File "test.py", line 18, in <module>
    model = relay.build_module.create_executor("vm", mod, tvm.cpu(0), 'llvm', params).evaluate()
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/interpreter.py", line 171, in evaluate
    return self._make_executor()
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/vm.py", line 219, in _make_executor
    self.executable = compile(self.mod, self.target)
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/vm.py", line 67, in compile
    compiler.lower(mod, target, target_host)
  File "/workplace/software/tvm/tvm/python/tvm/relay/backend/vm.py", line 126, in lower
    self._lower(mod, raw_targets)
  File "/workplace/software/tvm/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  15: TVMFuncCall
  14: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::vm::VMCompiler::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  13: tvm::relay::vm::VMCompiler::Lower(tvm::IRModule, tvm::runtime::Array<tvm::Target, void> const&)
  12: tvm::relay::vm::VMCompiler::LowerImpl(tvm::IRModule)
  11: tvm::relay::vm::VMCompiler::OptimizeModuleImpl(tvm::IRModule)
  10: tvm::transform::Pass::operator()(tvm::IRModule) const
  9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  8: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  6: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_2>(tvm::relay::transform::InferType()::$_2)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  2: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  1: tvm::relay::TypeSolver::Solve()
  0: _ZN3tvm7runtime6detail
  19: TVMFuncCall
  18: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::vm::VMCompiler::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::$_0> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  17: tvm::relay::vm::VMCompiler::Lower(tvm::IRModule, tvm::runtime::Array<tvm::Target, void> const&)
  16: tvm::relay::vm::VMCompiler::LowerImpl(tvm::IRModule)
  15: tvm::relay::vm::VMCompiler::OptimizeModuleImpl(tvm::IRModule)
  14: tvm::transform::Pass::operator()(tvm::IRModule) const
  13: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  12: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  11: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  10: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  9: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  8: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  7: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::$_2>(tvm::relay::transform::InferType()::$_2)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  6: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  5: tvm::relay::TypeSolver::Solve()
  4: tvm::TypedEnvFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::operator()(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&) const
  3: _ZN3tvm7runtime13Pac
  2: tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*) const
  1: tvm::relay::SqueezeRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
  0: _ZN3tvm7runtime6detail
  File "/workplace/software/tvm/tvm/src/relay/analysis/type_solver.cc", line 643
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (false) is false: [07:40:29] /workplace/software/tvm/tvm/src/relay/op/tensor/transform.cc:2340: 
---------------------------------------------------------------

  Check failed: *axis_ptr == 1 (2 vs. 1) : cannot squeeze axis with dimension not equal to 1

Steps to reproduce

import tvm
import tvm.relay as relay
from tensorflow import keras
from tensorflow.keras import layers, models

input_shape = (2, 3, 4)
x = layers.Input(shape=input_shape[1:], dtype='float32')

layer = keras.layers.LSTM(units=2)
layer.set_weights(layer.get_weights())

y = layer(x)
model = models.Model(x, y)
model.summary()
mod, params = relay.frontend.from_keras(model, {'input_1': input_shape})
print(mod)
with tvm.transform.PassContext(opt_level=3):
    model = relay.build_module.create_executor("vm", mod, tvm.cpu(0), 'llvm', params).evaluate()

Triage

  • frontend:keras

cc @shingjan

@jikechao jikechao added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug labels May 17, 2023
@jikechao
Copy link
Contributor Author

This is a bug about the LSTM layer. I'm not very familiar with the RNN network. For the LSTM model, why does batch_size !=1 lead to a crash in TVM?

@echuraev Could you give me some suggestions? Thank you in advance.

@echuraev
Copy link
Contributor

@jikechao I suppose that the problem is the same as in #14868. In LSTM we have the same logic here as for Simple RNN.

@jikechao
Copy link
Contributor Author

@echuraev Thanks for your explanation, I'm leaning to fix this bug.

@echuraev
Copy link
Contributor

echuraev commented Feb 5, 2024

@jikechao I quickly took a look at this issue, and it looks like that one of the possible solution might be in using reshape instead of squeeze implementation in LSTM layer. The same trick I did in this PR: #16526

@jikechao
Copy link
Contributor Author

@jikechao I quickly took a look at this issue, and it looks like that one of the possible solution might be in using reshape instead of squeeze implementation in LSTM layer. The same trick I did in this PR: #16526

@echuraev, Thank you! I submitted a PR to fix it. Could you help me review it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type: bug
Projects
None yet
Development

No branches or pull requests

2 participants