Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Converting Multi-input model #251

Closed
vgundecha opened this issue Feb 8, 2020 · 8 comments
Closed

Converting Multi-input model #251

vgundecha opened this issue Feb 8, 2020 · 8 comments

Comments

@vgundecha
Copy link

I'm trying to convert a multi-input model to TensorRT. I can convert the model successfully but I get the following error while inferencing

[TensorRT] ERROR: INVALID_ARGUMENT: Can not find binding of given name
Traceback (most recent call last):
  File "model2trt.py", line 97, in <module>
    y_trt = model_trt(x, masks[i])
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 545, in __call__
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch2trt/torch2trt.py", line 332, in forward
    for i, input_name in enumerate(self.input_names):
TypeError: execute_async(): incompatible function arguments. The following argument types are supported:
    1. (self: tensorrt.tensorrt.IExecutionContext, batch_size: int = 1, bindings: List[int], stream_handle: int, input_consumed: capsule = None) -> bool

Invoked with: <tensorrt.tensorrt.IExecutionContext object at 0x7f9d6a369960>, 1, [140313058344960, None, 140313104482304], 0

I believe the error is at

idx = self.engine.get_binding_index(input_name)

This line returns -1 for both inputs. I checked it with pdb.

I convert the model using:
model_trt = torch2trt.torch2trt(model, [x, masks[0]])

And I run inference as follows:
y_trt = model_trt(x, masks[i])

This is my forward pass:

def forward(self, input, current_mask):
        # do stuff with input and current_mask
        return output

Thank you!

@leixiaoning
Copy link

i got same problem.
i run convert successfully with
'model_trt = torch2trt(model,inputs=[dummy_input[0],dummy_input[1]])',
but meet "Cannot find binding of given name: input_1 " with
'y_trt=model_trt(dummy_input[0],dummy_input[1])'

@ma-siddiqui
Copy link

Any solution?

@luhang-HPU
Copy link

I met the same issue.

@GabbySuwichaya
Copy link

GabbySuwichaya commented Jun 21, 2020

I have the same problem.

However, notice that I actually have only one input.

The solution that I used was to actually set the input_names in torch2trt.torch2trt to the same name specified in self.engine (input)....

model_trt = torch2trt.torch2trt(model, [x],   input_names="[input_names]" )

The following solution works for my case....

Before first (but I think all of you know)... move your python scripts that call the function torch2trt.torch2trt into /torch2trt folder.

First, you check what is the name of self.engine in Line364 ....

To do this, import pdb at the top of the torch2trt.py for debugging....

import pdb; 

Then, check the name by using self.engine.get_binding_name().
For example, you can do the following:

Line 364:  idx = self.engine.get_binding_index(input_name)  

Line 365:  pdb.set_trace() 
Line 366:  self.engine.get_binding_name()    << Then, read what is the name...
Line 367:  pdb.set_trace() 

Line 368: bindings[idx] = inputs[i].data_ptr()

In my case, the name of self.engine is "__img"
Then, you have to reverse everything back. Remove the previously added commands pdb.set_trace() and self.engine.get_binding_name()... (because you only want to check the name)....

After you know the name ("__img"), go back to your torch2trt.torch2trt command. Change "input_names" to "__img"

model_trt = torch2trt.torch2trt(model, [x],  input_names=["__img"])

Then you should be able to run it...

@ahangchen
Copy link

@GabbySuwichaya
when I call get_binding_name, I got(running with tensorrt 7.1.3 on xavier nx):

TypeError: get_binding_name(): incompatible function arguments. The following argument types are supported:
1. (self: tensorrt.tensorrt.ICudaEngine, index: int) -> unicode

Invoked with: <tensorrt.tensorrt.ICudaEngine object at 0x7eed91a6f0>

@GabbySuwichaya
Copy link

Hi @ahangchen, I am not sure about Jetson xavier nx.
But if you can debug that line, please try dir(self) and see what are the attributes/operations you can use.
Mainly what you want to do is to check the name....
For me, I found the command from get_binding_name

@ahangchen
Copy link

ahangchen commented Sep 5, 2020

@GabbySuwichaya I found the api change in tensor7. We should use engine.get_binding_name(index), where index means the ith binding. But I find the number of my binding names of input turn from 4 to 2. In network the number is 4 and in engine the number is 2, thus I can not convert multi input network to trt model. Currently I can only merge the inputs to one single input as a workaround. I think torch2trt didn't test any examples about multi inputs and results in such an unconvenient situation.

@GabbySuwichaya
Copy link

@ahangchen, I am sorry to hear that.
I have never tried the multiple inputs before.
I posted the answer here because my problem was the given input name, which is somehow similar to the top of the post.

Have you tried something else like Onnx to TRT? Check it out here>> https://github.com/onnx/onnx-tensorrt
This solution seems to work for many people, but unfortunately, I have not yet tried it because of some mismatched installation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants