Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch size not changing #1904

Closed
montmejat opened this issue Apr 5, 2022 · 5 comments
Closed

Batch size not changing #1904

montmejat opened this issue Apr 5, 2022 · 5 comments
Labels
ONNX Issues relating to ONNX usage and import Performance General performance issues triaged Issue has been triaged by maintainers

Comments

@montmejat
Copy link

Description

I'm trying to convert a RetinaNet model taken from torchvision, but I'm unable to use it with a batch size higher than 1. For a set batch size of 2, here is what my output looks like (batch_size is 2):

example = torch.randn((batch_size, 3, 1024, 1024), dtype=torch.float32, device='cuda')
output = retinanet(example)

Output:

{
    'split_head_outputs': {
         'cls_logits': [torch.Size([2, 147456, 2]), torch.Size([2, 36864, 2]), torch.Size([2, 9216, 2]), torch.Size([2, 2304, 2]), torch.Size([2, 576, 2])],
         'bbox_regression': [torch.Size([2, 147456, 4]), torch.Size([2, 36864, 4]), torch.Size([2, 9216, 4]), torch.Size([2, 2304, 4]), torch.Size([2, 576, 4])]
    },
    'split_anchors': [
        [torch.Size([147456, 4]), torch.Size([36864, 4]), torch.Size([9216, 4]), torch.Size([2304, 4]), torch.Size([576, 4])],
        [torch.Size([147456, 4]), torch.Size([36864, 4]), torch.Size([9216, 4]), torch.Size([2304, 4]), torch.Size([576, 4])]
    ]
}

With TensoRT engine:

# ... 

builder(TRT_LOGGER)
builder.max_batch_size = batch_size

network = builder.create_network(
    1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH) | 1 << int(trt.BuilderFlag.FP16))
parser = trt.OnnxParser(network, TRT_LOGGER)

success = parser.parse_from_file(onnx_model_path)
for idx in range(parser.num_errors):
    print(parser.get_error(idx))

if not success:
    raise RuntimeError('Failed to parse the ONNX file')

config = builder.create_builder_config()
config.max_workspace_size = 2 << 20

# ... 

context.set_binding_shape(0, (batch_size, 3, tile_size, tile_size))

# ...

outputs = []
for binding in engine:
    if engine.binding_is_input(binding):
        continue

    tensor = torch.empty(
        [i for i in engine.get_binding_shape(binding)],
        dtype=torch_dtypes[engine.get_binding_dtype(binding)],
        device='cuda'
    )
    buffers[engine[binding]] = tensor.data_ptr()

    output_tensors.append(tensor)
    print(tensor.size())

print(f'Engine loaded. Buffers size: {len(buffers)}.')

Output:

torch.Size([1, 147456, 2])
torch.Size([1, 36864, 2])
torch.Size([1, 9216, 2])
torch.Size([1, 2304, 2])
torch.Size([1, 576, 2])
torch.Size([1, 147456, 4])
torch.Size([1, 36864, 4])
torch.Size([1, 9216, 4])
torch.Size([1, 2304, 4])
torch.Size([1, 576, 4])
torch.Size([147456, 4])
torch.Size([36864, 4])
torch.Size([9216, 4])
torch.Size([2304, 4])
torch.Size([576, 4])
Engine loaded. Buffers size: 16. # I should have 21 here! This buffer is set for batch size of 1.

I'm also getting worst results on TensoRT than a typical PyTorch inference:

# Getting about 0.25 seconds per inference
with torch.no_grad():
    out_batch += model(torch.stack(batch))
# Getting about 0.36 seconds per inference
buffers[input_idx] = torch.stack(batch).data_ptr()
torch_stream = torch.cuda.Stream()
context.execute_async_v2(buffers, torch_stream.cuda_stream)
torch.cuda.synchronize()

Any ideas? 😄

Environment

TensorRT Version: 8.2.1.8
NVIDIA GPU: Jetson Xavier NX
NVIDIA Driver Version:
CUDA Version: 10.2.300
CUDNN Version: 8.2.1.32
Operating System: Jetpack L4T 32.7.1
Python Version (if applicable): 3.6.9
PyTorch Version (if applicable): 1.10.0

@ttyio
Copy link
Collaborator

ttyio commented Apr 7, 2022

Hello @aurelien-m ,
for the batch size problem, I am not fully understand the question here. are you confused on the batch size setting? or the buffer count? for the batch size we need export ONNX with dynamic batch to use trt with different batch size. for the buffer count, you can visualize your onnx model by tools like Netron.

for the performance issue, according to pytorch documentation, we need insert event to measure the performance, else we only get the API launch latency, but not the GPU workload, see https://pytorch.org/docs/stable/notes/cuda.html#asynchronous-execution

Thanks!

@ttyio ttyio added ONNX Issues relating to ONNX usage and import Performance General performance issues Framework: PyTorch triaged Issue has been triaged by maintainers labels Apr 7, 2022
@montmejat
Copy link
Author

Hey, thanks for trying to help me out! I think I was timing it incorrectly, and now it's indeed much better. I was also using the wrong batch size when initially converting to a ONNX model, my bad about that too.

However, I tried using a dynamic batch size but I'm having troubles getting it to work. I'm looking at this documentation.

I built my ONNX model using:

torch.onnx.export(
    retinanet,
    example, # (2, 3, 1024, 1024)
    onnx_model_path,
    verbose=False,
    opset_version=11,
    input_names=["input"],
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"}
    }
)

And i'm building my TensorRT engine like this:

onnx_model = onnx.load(onnx_model_path)
onnx.checker.check_model(onnx_model)

builder = trt.Builder(TRT_LOGGER)
builder.max_batch_size = batch_size

network = builder.create_network(
    1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, TRT_LOGGER)

network.add_input("input", trt.float32, (-1, 3, 1024, 1024)) # crashes here

success = parser.parse_from_file(onnx_model_path)
for idx in range(parser.num_errors):
    print(parser.get_error(idx))

if not success:
    raise RuntimeError('Failed to parse the ONNX file') # not able to read the model

config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
config.flags = 1 << int(trt.BuilderFlag.FP16)
# config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 30)

profile = builder.create_optimization_profile()
profile.set_shape("input", (1, 3, 1024, 1024), (2, 3, 1024, 1024), (2, 3, 1024, 1024))
config.add_optimization_profile(profile)

serialized_engine = builder.build_serialized_network(network, config)
with open(f'{PROJECT_PATH}/models/{model_filename}.engine', 'wb') as f:
    f.write(serialized_engine)

But I'm getting:

[04/07/2022-15:58:10] [TRT] [E] [network.cpp::addInput::1507] Error Code 3: API Usage Error (Parameter check failed at: optimizer/api/network.cpp::addInput::1507, condition: inName != knownInput->getName()
[04/07/2022-15:58:10] [TRT] [E] ModelImporter.cpp:779: ERROR: input:248 In function importInput:
[8] Assertion failed: (*tensor = ctx->network()->addInput(input.name().c_str(), trtDtype, trt_dims)) && "Failed to add input to the network."
In node -1 (importInput): UNSUPPORTED_NODE: Assertion failed: (*tensor = ctx->network()->addInput(input.name().c_str(), trtDtype, trt_dims)) && "Failed to add input to the network."

RuntimeError: Failed to parse the ONNX file

I'm not sure what I missed here

@ttyio
Copy link
Collaborator

ttyio commented Apr 19, 2022

Hello @aurelien-m , sorry for the late response.
There is no need to call addInput, because there is already input with name input, we cannot add two input with the same name. You could take a look at https://github.com/NVIDIA/TensorRT/blob/main/samples/sampleOnnxMNIST/sampleOnnxMNIST.cpp#L131 for how to using onnx parser in C++, thanks!

@montmejat
Copy link
Author

Hey, thanks for the help. I'm still looking at this documentation because I want to implement a dynamic batch size. If I'm not wrong, the example you have given doesn't implement a dynamic batch size.

Now, I removed this line:

network.add_input("input", trt.float32, (-1, 3, 1024, 1024))

And I'm now able to build the engine successfully. I also tried building with: trtexec --onnx=models/MyModel.onnx --minShapes='input':1x3x1024x1024 --optShapes='input':2x3x1024x1024 --maxShapes='input':2x3x1024x1024 --fp16 --saveEngine=models/MyModel.engine

I'm using the following code to load my engine:

torch_stream = torch.cuda.Stream()
runtime = trt.Runtime(TRT_LOGGER)

with open(tensorrt_model_path, "rb") as f:
    serialized_engine = f.read()

engine = runtime.deserialize_cuda_engine(serialized_engine)
context = engine.create_execution_context()

batch_size = 1 # changing it here, 2 works fine (which is what I've used for the max and optimal shape)

if -1 in engine.get_binding_shape(input_index):
    context.set_optimization_profile_async(0, torch_stream.cuda_stream)
    context.set_binding_shape(0, (batch_size, 3, 1024, 1024))

buffers = ...

But when inferering:

torch.cuda.synchronize()

...

buffers[input_idenx] = torch.stack(batch).data_ptr()
context.execute_async_v2(buffers, torch_stream.cuda_stream)

When using batch_size = 2 ✅, it works fine. But when I use batch_size = 1 💀 I get this error:

[12/10/2021-13:29:40] [TRT] [E] 7: [shapeMachine.cpp::execute::565] Error Code 7: Internal Error (Split_0_0: ISliceLayer has out of bounds access on axis 0
condition '<' violated
Instruction: CHECK_LESS 1 1
)
[12/10/2021-13:29:40] [TRT] [E] 2: [executionContext.cpp::enqueueInternal::366] Error Code 2: Internal Error (Could not resolve slots: )

Any ideas? 😄

@montmejat
Copy link
Author

montmejat commented Apr 21, 2022

After looking at my network, I think it's just because my network is not compatible with dynamic batch sizes... I'm using RetinaNet and bypassing some layers brake the dynamic batch size feature.

I'm closing it because it's not a problem from TensoRT but from the network and the ONNX graph.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ONNX Issues relating to ONNX usage and import Performance General performance issues triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

2 participants