Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] Request for shape ranges in torch.compile workflow #115137

Closed
peri044 opened this issue Dec 5, 2023 · 12 comments
Closed

[torch.compile] Request for shape ranges in torch.compile workflow #115137

peri044 opened this issue Dec 5, 2023 · 12 comments
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@peri044
Copy link
Contributor

peri044 commented Dec 5, 2023

🚀 The feature, motivation and pitch

Torch-TensorRT has two workflows (torch.export based and torch.compile based) for optimizing Pytorch models using TensorRT.

  • torch.export.export() flow can accept dynamic shapes with min and max ranges specified for any input dimension. This range information can be accessible for intermediate nodes in the graph via node.shape_env.var_to_range. TensorRT uses this range information when building engines (with dynamic shapes).

  • torch.compile provides torch._dynamo.mark_dynamic(tensor, dim) which is great. But can this be extended to accept ranges (or something similar to torch.export.Dim(name, min, max) API ?

Ultimately, the ask here is for the graph (provided by torch.compile to its backend) to have the range information (similar to torch.export.export()).

Please let me know if there's a way to do this currently (or) any questions you have.

Alternatives

No response

Additional context

No response

cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@ezyang
Copy link
Contributor

ezyang commented Dec 31, 2023

The request here is reasonable, though I will note that if you do an assertion on the size inside the model, that will also appropriately update value ranges

@peri044
Copy link
Contributor Author

peri044 commented Jan 9, 2024

if you do an assertion on the size inside the model, that will also appropriately update value ranges

@ezyang Thanks for your response. Can you elaborate on the assertion part and how I can do this ?

@ezyang
Copy link
Contributor

ezyang commented Jan 9, 2024

Literally assert input.size(1) <= 256 or something

@anijain2305 anijain2305 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 12, 2024
@peri044
Copy link
Contributor Author

peri044 commented Jan 17, 2024

@ezyang I tried your suggestion but could not see these shape ranges generated via asserts. Can you please confirm if the following approach is correct ?
I want the batch size to be within 1<= BS <= 8
My input graph (added assert on batch size):

class MyModule(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
            self.relu = torch.nn.ReLU()

        def forward(self, x):
            assert x.size()[0] >= 1
            assert x.size()[0] <= 8
            out = self.conv(x)
            out = self.relu(out)
            return out

Code:

model = MyModule().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

compile_spec = {
    "inputs": [input],
    "device": torchtrt.Device("cuda:0"),
    "enabled_precisions": {torch.float},
    "pass_through_build_failures": True,
    "optimization_level": 1,
    "min_block_size": 1,
    "debug": True,
}
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)

The graph that I see in the backend of torch.compile (in this case "tensorrt") is

graph():
    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
    %out : [num_users=1] = call_module[target=L__self___conv](args = (%l_x_,), kwargs = {})
    %out_1 : [num_users=1] = call_module[target=L__self___relu](args = (%out,), kwargs = {})
    return (out_1,)
    
nodes = list(gm.graph.nodes)
(Pdb) nodes[0]
l_x_
(Pdb) nodes[0].meta # input node
{'stack_trace': '  File "/work/TensorRT/tests/py/dynamo/models/test_dyn_compile.py", line 25, in forward\n    assert x.size()[0] <= 8\n', 'example_value': FakeTensor(..., device='cuda:0', size=(1, 3, 224, 224)), 'tensor_dict': {}, 'grapharg': GraphArg(source=LocalSource(local_name='x', cell_or_freevar=False), _example=<torch.utils.weak.TensorWeakRef object at 0x7f36a1477b50>, is_unspecialized=False, fake_tensor=FakeTensor(..., device='cuda:0', size=(1, 3, 224, 224)), is_tensor=True, example_strong_ref=None)}

(Pdb) nodes[1].meta # conv node 
{'nn_module_stack': {'L__self___conv': ("L['self'].conv", <class 'torch.nn.modules.conv.Conv2d'>)}, 'source_fn_stack': [('l__self___conv', <class 'torch.nn.modules.conv.Conv2d'>)], 'stack_trace': '  File "/work/TensorRT/tests/py/dynamo/models/test_dyn_compile.py", line 26, in forward\n    out = self.conv(x)\n', 'example_value': FakeTensor(..., device='cuda:0', size=(1, 16, 222, 222),
           grad_fn=<ConvolutionBackward0>)}

The shape info in the above nodes has static shapes and does not have sym_int nodes which captured the max value of 8.

Expectation (this is coming from the behavior of torch.export) and with no assert in the input graph:
Code in export:

input = torch.randn((2, 3, 224, 224)).to("cuda")
ep = torch.export.export(model, (input,), dynamic_shapes = {"x": {0: torch.export.Dim("B", min=2, max=8)}},)
gm = ep.module()
nodes = list(gm.graph.nodes)

(Pdb) nodes[2].meta # input node
{'val': FakeTensor(..., device='cuda:0', size=(s0, 3, 224, 224)), 'tensor_meta': TensorMetadata(shape=torch.Size([s0, 3, 224, 224]), dtype=torch.float32, requires_grad=False, stride=(150528, 50176, 224, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}

(Pdb) nodes[3].meta # conv node
{'stack_trace': '  File "/work/TensorRT/tests/py/dynamo/models/test_dyn_compile.py", line 26, in forward\n    out = self.conv(x)\n', 'nn_module_stack': {'L__self__': ('', <class '__main__.test_base_dynamic.<locals>.MyModule'>), 'L__self___conv': ('conv', <class 'torch.nn.modules.conv.Conv2d'>)}, 'source_fn_stack': [('l__self___conv', <class 'torch.nn.modules.conv.Conv2d'>)], 'original_aten': <OpOverload(op='aten.convolution', overload='default')>, 'from_node': [('out', 'L__self___conv'), ('convolution', <OpOverload(op='aten.convolution', overload='default')>)], 'seq_nr': 11, 'val': FakeTensor(..., device='cuda:0', size=(s0, 16, 222, 222)), 'tensor_meta': TensorMetadata(shape=torch.Size([s0, 16, 222, 222]), dtype=torch.float32, requires_grad=False, stride=(788544, 49284, 222, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}

(Pdb) nodes[3].meta['val'].size()[0]
s0
(Pdb) nodes[3].meta['val'].size()[0].node.shape_env.var_to_range
{s0: ValueRanges(lower=2, upper=8, is_bool=False)}

The sym_int node s0 has the shape ranges specified in the dynamic_shapes argument.

Question :

  • With the asserts, I was expecting torch.compile to add sym_int nodes on the dynamic dimensions (similar to export behavior). Is this the right understanding ? If not, how do I achieve this ? Thanks !!

@avikchaudhuri
Copy link
Contributor

avikchaudhuri commented Jan 17, 2024

I'd suggest trying something like the following instead, to follow the suggestion by @ezyang :

            # assert x.size()[0] >= 1
            assert x.size()[0] <= 8
...
inp = torch.randn((4, 3, 224, 224))
torch._dynamo.mark_dynamic(inp, 0)

Reasoning is that size 1 is always specialized and we assume at compile time that all dynamic sizes are >= 2 (even though at run time passing size 1 might be OK). Using a different batch size, then, allows symbols in the generated code. And also, you do need to mark_dynamic what size you need to be dynamic.

@peri044
Copy link
Contributor Author

peri044 commented Jan 17, 2024

Thanks @avikchaudhuri for the response. Here's what I did as per your comment

model = MyModule().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

torch._dynamo.mark_dynamic(input, 0)
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
trt_model(input)

The graph that got generated within the backend is as follows

(Pdb) print(gm.graph)
graph():
    %s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
    %l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
    %size : [num_users=1] = call_method[target=size](args = (%l_x_,), kwargs = {})
    %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%size, 0), kwargs = {})
    %le : [num_users=1] = call_function[target=operator.le](args = (%getitem, 8), kwargs = {})
    %scalar_tensor : [num_users=1] = call_function[target=torch.scalar_tensor](args = (%le,), kwargs = {})
    %_assert_async : [num_users=0] = call_function[target=torch._assert_async](args = (%scalar_tensor, assertion error), kwargs = {})
    %out : [num_users=1] = call_module[target=L__self___conv](args = (%l_x_,), kwargs = {})
    %out_1 : [num_users=1] = call_module[target=L__self___relu](args = (%out,), kwargs = {})
    return (out_1,)

And now I check the metadata for %out node which is the conv layer (this has symbolic int's now but the values aren't reflecting the max values set by the assert statement.

(Pdb) nodes=list(gm.graph.nodes)
(Pdb) nodes[7].meta['example_value'].size()
torch.Size([s0, 16, 222, 222])
(Pdb) nodes[7].meta['example_value'].size()[0].node.shape_env.var_to_range 
{s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}

The s0 doesn't seem to have the max value as 8 in this case (as set by the assert). Any thoughts ?

@ezyang
Copy link
Contributor

ezyang commented Jan 18, 2024

You didn't change the batch size.

import torch

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        assert x.size()[0] >= 1
        assert x.size()[0] <= 8
        out = self.conv(x)
        out = self.relu(out)
        return out

model = MyModule().eval()
input = torch.randn((5, 3, 224, 224))
        
torch._dynamo.mark_dynamic(input, 0)

my_model = torch.compile(model, backend="eager")
my_model(input)

@peri044
Copy link
Contributor Author

peri044 commented Jan 19, 2024

@ezyang My bad. The snippet I posted might be misleading since it has 1. I actually tried different batch sizes (eg: 2, 5 etc)..
I used the backend as tensorrt as follows :

my_model = torch.compile(model, backend="tensorrt")

And with batch_size=2 (even for batch_size=5, the lower value is 2), I print the conv layer meta data once we get the graph to tensorrt backend here,

nodes = list(gm.graph.nodes) 
nodes[7].meta['example_value'].size()[0].node.shape_env.var_to_range
{s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}

These ranges are not the ones I defined via assert though.

@ezyang
Copy link
Contributor

ezyang commented Jan 20, 2024

When I run my example with inductor backend, with a breakpoint in inductor compile_fx and I do

(Pdb) p list(model_.graph.nodes)[7].meta['example_value'][0].node.shape_env.var_to_range
{s0: ValueRanges(lower=2, upper=15, is_bool=False)}

So there must be something wrong with how tensorrt is integrated with dynamo.

@peri044
Copy link
Contributor Author

peri044 commented Jan 23, 2024

Hello @ezyang, thanks for the info.

My torch version : 2.3.0.dev20240122+cu121

1) Inductor repro: Unfortunately, I'm unable to repro what you are seeing. Can you please confirm the following steps are correct ?

Model:

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        assert x.size()[0] >= 1
        assert x.size()[0] <= 8
        out = self.conv(x)
        out = self.relu(out)
        return out

model = MyModule().eval().cuda()
input = torch.randn((4, 3, 224, 224)).to("cuda")
torch._dynamo.mark_dynamic(input, 0)
ind_model = torch.compile(model, backend="inductor")

I placed a breakpoint at https://github.com/pytorch/pytorch/blob/main/torch/_inductor/compile_fx.py#L952

(Pdb) nodes=list(model_.graph.nodes)
(Pdb) nodes
[s0, l_x_, size, getitem, ge, scalar_tensor, _assert_async, size_1, getitem_4, le, scalar_tensor_1, _assert_async_1, out, out_1, output]
# Node 12 is the conv node in the inductor case
(Pdb) nodes[12].meta['example_value'].size()[0].node.shape_env.var_to_range 
{s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}

Any suggestion as to where the difference in outputs (b/w your runs and mine) can be ?

2) Torch-TensorRT repro :

Code:

model = MyModule().eval().cuda()
input = torch.randn((4, 3, 224, 224)).to("cuda")
torch._dynamo.mark_dynamic(input, 0)
compile_spec = {
    "inputs": [input],
    "min_block_size": 1,
    "debug": True,
}
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)

Entry point for torch.compile within tensorrt backend is : https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/backend/backends.py#L42

TensorRT does not come into picture before this and this is where I checked the conv node meta data. which is same as the inductor's case

(Pdb) nodes=list(gm.graph.nodes)
(Pdb) nodes[12].meta['example_value'].size()[0].node.shape_env.var_to_range
{s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)}

@ezyang
Copy link
Contributor

ezyang commented Jan 27, 2024

I think what is going on is sometimes asserts are being optimized away. Change the assert call to torch._check(...) instead.

@peri044
Copy link
Contributor Author

peri044 commented Jan 29, 2024

Thanks @ezyang. With torch._check, I can see the ranges properly now. I'm testing this out on other models as well and I shall update this issue accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dynamic shapes oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants