Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Encountered bug when using Torch-TensorRT #2328

Closed
balazon opened this issue Sep 20, 2023 · 1 comment · Fixed by #2193
Closed

🐛 [Bug] Encountered bug when using Torch-TensorRT #2328

balazon opened this issue Sep 20, 2023 · 1 comment · Fixed by #2193
Assignees
Labels
bug Something isn't working component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths

Comments

@balazon
Copy link

balazon commented Sep 20, 2023

Bug Description

This might not be a bug, maybe it's a feature request, not sure.
I wanted to compile torch.einsum with torch_tensorrt and I get back an error

I was reading this tutorial about compiling transformers:
https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/torch_compile_transformers_example.html
Based on this I created a small example module containing einsum, and I get this error:

torch._dynamo.exc.BackendCompilerFailed: backend='torch_tensorrt' raised:
RuntimeError: Autograd has not been implemented for operator

While executing %einsum_1 : [num_users=1] = call_function[target=torch.ops.tensorrt.einsum](args = (i,ij->i, (%l_x_, %l__self___b)), kwargs = {})

torch.einsum is either not a supported op yet, or if it is, it's buggy I think
It's not listed under supported ops here:
https://github.com/pytorch/TensorRT/blob/8ebb5991f8bc46fea6179593b882d5c160bc1a53/docs/_sources/indices/supported_ops.rst.txt

TensorRT supports it according to this: (IEinsumLayer)
https://docs.nvidia.com/deeplearning/tensorrt/archives/tensorrt-861/operators/index.html
So I don't see why it wouldn't be supported in torch-tensorrt.

I see some issues/PR-s that relate to einsum, but I don't know if they, closest issue I found is
#277
But it's closed due to inactivity
Other issues/PRs:
#1385
#1985
#1420
#1005

To Reproduce

Steps to reproduce the behavior:

  1. Run this script
import torch
import torch_tensorrt
import time

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.B = torch.nn.Parameter(torch.tensor([
            [0, 1, 2, 3],
            [4, 5, 6, 7],
            [8, 9, 10, 11]
        ], dtype=torch.float32)).cuda()
    
    def forward(self, x):
        return torch.einsum("i,ij->i", x, self.B)

def compile_my_model():
    model = MyModule().eval()

    a = torch.tensor([0, 1, 2], dtype=torch.float32).cuda().to('cuda')
    inputs = [a]


    # Enabled precision for TensorRT optimization
    enabled_precisions = {torch.float}

    # Whether to print verbose logs
    debug = True

    # Workspace size for TensorRT
    workspace_size = 8 << 30

    # Maximum number of TRT Engines
    # (Lower value allows more graph segmentation)
    min_block_size = 7

    # Operations to Run in Torch, regardless of converter support
    torch_executed_ops = {}

    # Define backend compilation keyword arguments
    compilation_kwargs = {
        "enabled_precisions": enabled_precisions,
        "debug": debug,
        "workspace_size": workspace_size,
        "min_block_size": min_block_size,
        "torch_executed_ops": torch_executed_ops,
    }

    # compile option 1
    optimized_model = torch.compile(
        model,
        backend="torch_tensorrt",
        options=compilation_kwargs,
    )
    
    # compile option 2: is it the same as option 1? still fails
    # optimized_model = torch_tensorrt.compile(model, ir="torch_compile", inputs=inputs, **compilation_kwargs)

    # compile option 3: success
    # optimized_model = torch_tensorrt.compile(model, inputs=inputs, **compilation_kwargs)

    res = optimized_model(*inputs)
    print("res:", res)

    torch._dynamo.reset()

if __name__ == "__main__":
    compile_my_model()
    print("done")

  1. run it with # compile option 2
  2. run it with # compile option 3

Only compile option 3 works, but I don't know what the difference is between any of these 3 options, can somebody clear that up? option 1 and 2 I think are the same, but option 3?

Expected behavior

I expect all 3 options to work, but only the 3rd compile option seems to work.

Environment

I'm using this docker image:
nvcr.io/nvidia/pytorch:23.08-py3
https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-23-08.html

  • Torch-TensorRT Version: 2.0.0.dev0
  • PyTorch Version: 2.1.0a0+29c30b1
  • Python version: 3.10
@balazon balazon added the bug Something isn't working label Sep 20, 2023
@gs-olive gs-olive self-assigned this Sep 20, 2023
@gs-olive gs-olive added the component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths label Sep 20, 2023
@gs-olive
Copy link
Collaborator

Hello - thanks for sharing this - I am able to reproduce this issue. A quick workaround is to wrap the function compile_my_model() in with torch.no_grad():, for options 1 and 2. We are in the process of fixing the system which results in the einsum converter, so this context wrapper should not be necessary very soon.

To answer the second question - yes, methods 1 and 2 are the same, and both use the torch.compile functionality with the torch_tensorrt backend to compile the model in a JIT fashion. Method 3 uses the default ir which is now ir="dynamo". This uses the Torch Dynamo export path, which is an ahead-of-time export system for models which need to be serialized or otherwise exported. Currently, einsum is one of very few operators which is converted differently in Methods 1/2 vs. Method 3.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants