Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

myelin::ir::tensor_t*& myelin::ir::operand_t::tensor() #1541

Closed
quintetoy opened this issue Oct 8, 2021 · 20 comments
Closed

myelin::ir::tensor_t*& myelin::ir::operand_t::tensor() #1541

quintetoy opened this issue Oct 8, 2021 · 20 comments
Labels
triaged Issue has been triaged by maintainers

Comments

@quintetoy
Copy link

I have got a problem like this,I was running pointpillars model-- pillarscatter, and when i transform an onnx to tensorrt ,this occured.
python: /root/gpgpu/MachineLearning/myelin/src/compiler/./ir/operand.h:166: myelin::ir::tensor_t*& myelin::ir::operand_t::tensor(): Assertion `is_tensor()' failed.

who can tell me how to solve this problem?

my environment :
cuda:10.2
tensorrt:8.2

@quintetoy
Copy link
Author

pointpillars scatter ops caused this problem,does anyone tried this?

@yasserkhalil93
Copy link

Facing the same error. I have used index_select and index_put in my pytorch code which successfully translated to onnx. However, gave me the following when creating tensorrt:

python: /root/gpgpu/MachineLearning/myelin/src/compiler/./ir/operand.h:166: myelin::ir::tensor_t*& myelin::ir::operand_t::tensor(): Assertion is_tensor()' failed.
Aborted (core dumped)`

Any help is appreciated.

my environment :
cuda:10.2
tensorrt:8.2
onnx:1.9.0
pytorch: 1.5.1

@ttyio
Copy link
Collaborator

ttyio commented Nov 8, 2021

Hello @quintetoy @yasserkhalil93 , sounds bug in TensorRT, could you share us the onnx the model to debug? thanks!

@ttyio ttyio added Release: 8.x triaged Issue has been triaged by maintainers labels Nov 8, 2021
@Keysmis
Copy link

Keysmis commented Dec 5, 2021

@ttyio @yasserkhalil93 hello~, any update now? is it succeed?

@ttyio
Copy link
Collaborator

ttyio commented Dec 6, 2021

Hello @Keysmis , we did not get the repro to debug the issue. could you try 8.2 release? send us repro if it still failed. thanks!

@KeyKy
Copy link

KeyKy commented Dec 10, 2021

I also get this error! when using 8.2 tensorrt.

python: /root/gpgpu/MachineLearning/myelin/src/compiler/./ir/operand.h:166: myelin::ir::tensor_t*&myelin::ir::operand_t::tensor(): Assertion `is_tensor()' failed.

For my model, it works well in 8.0 tensorrt

@aeoleader
Copy link

I am having the same issue with 8.2.2.1 version.

@ttyio
Copy link
Collaborator

ttyio commented Jan 7, 2022

@KeyKy @aeoleader , could you provide us repro to debug the issue? thanks!

@aeoleader
Copy link

aeoleader commented Jan 11, 2022

@ttyio
I have done a little bit of testing. It seems like a problem in the slice assignment.

Repro steps:

  1. Convert model to onnx
def test_func(x):
    t = torch.zeros(x.shape)

    # is_tensor error
    t[0] = x[0]
    return t

class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
    
    def forward(self, x):
        return test_func(x)

if __name__ == '__main__':
    model = MyModule()
    input_shape = (1, 16, 4, 224, 224)
    model.cpu().eval()
    model.to('cuda:0')
    input_tensor = torch.randn(input_shape).to('cuda:0')
    output_file = 'debug.onnx'
    torch.onnx.export(
    model,
    input_tensor,
    output_file,
    export_params=True,
    keep_initializers_as_inputs=True,
    verbose=False,
    opset_version=11)
  1. Use the trtexec to compile with the following command:
trtexec --onnx=debug.onnx --saveEngine=debug.trt --best --buildOnly --workspace=4096 --verbose

@ttyio
Copy link
Collaborator

ttyio commented Jan 13, 2022

Thank you @aeoleader for the repro, I am able to repro the failure and create internal bug to track this.

@aeoleader
Copy link

@ttyio
Can you also try this as well? Constant padding seems broken as well.

@ttyio
Copy link
Collaborator

ttyio commented Jan 25, 2022

Hi @aeoleader , this Assertion is_tensor() failure will be fixed in next major release, thanks!

@pycoco
Copy link

pycoco commented Mar 3, 2022

is it fix now? which trt version that i will use?

@pycoco
Copy link

pycoco commented Mar 3, 2022

i met the error:
Timing Runner: {ForeignNode[178...DequantizeLinear_40_dequantize_scale_node]} (Myelin)
trtexec: /root/gpgpu/MachineLearning/myelin/src/compiler/./ir/operand.h:166: myelin::ir::tensor_t*& myelin::ir::operand_t::tensor(): Assertion `is_tensor()' failed.

@ttyio
Copy link
Collaborator

ttyio commented Mar 31, 2022

This will be fixed in 8.4GA, thanks all

@padeler
Copy link

padeler commented May 11, 2022

Just my two cents: If you have to use tensorrt8.2, you can just replace your "slice assignment" operations (which is implemented by a scatter) with a concat. This is a workaround until 8.4 is out.

@wenqibiao
Copy link

In TensorRT8.2, I meet the same error. But in TensorRT8.4, it work well.

@wenqibiao
Copy link

This will be fixed in 8.4GA, thanks all

Why NOT fix it in 8.2???

@Yebi1837
Copy link

I am facing the same error when trying to convert torch.index_add_() operator. Is there a work around for this with v8.2?

@nvpohanh
Copy link
Collaborator

Why NOT fix it in 8.2???

We don't have bandwidth to back-integrate all bug fixes to older releases because we are busy with getting other new features (like ND shape tensor support, better dynamic shape support, more ONNX op support, etc.) and new bug fixes into our latest TRT release as soon as possible.

TRT 8.4.1 (8.4 GA) has just been released. closing this issue for now. please feel free to reopen if the issue still exists. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests