-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch.compile] Request for shape ranges in torch.compile workflow #115137
Comments
The request here is reasonable, though I will note that if you do an assertion on the size inside the model, that will also appropriately update value ranges |
@ezyang Thanks for your response. Can you elaborate on the assertion part and how I can do this ? |
Literally assert input.size(1) <= 256 or something |
@ezyang I tried your suggestion but could not see these shape ranges generated via asserts. Can you please confirm if the following approach is correct ? class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()
def forward(self, x):
assert x.size()[0] >= 1
assert x.size()[0] <= 8
out = self.conv(x)
out = self.relu(out)
return out Code: model = MyModule().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")
compile_spec = {
"inputs": [input],
"device": torchtrt.Device("cuda:0"),
"enabled_precisions": {torch.float},
"pass_through_build_failures": True,
"optimization_level": 1,
"min_block_size": 1,
"debug": True,
}
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec) The graph that I see in the backend of torch.compile (in this case "tensorrt") is graph():
%l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]
%out : [num_users=1] = call_module[target=L__self___conv](args = (%l_x_,), kwargs = {})
%out_1 : [num_users=1] = call_module[target=L__self___relu](args = (%out,), kwargs = {})
return (out_1,)
nodes = list(gm.graph.nodes)
(Pdb) nodes[0]
l_x_
(Pdb) nodes[0].meta # input node
{'stack_trace': ' File "/work/TensorRT/tests/py/dynamo/models/test_dyn_compile.py", line 25, in forward\n assert x.size()[0] <= 8\n', 'example_value': FakeTensor(..., device='cuda:0', size=(1, 3, 224, 224)), 'tensor_dict': {}, 'grapharg': GraphArg(source=LocalSource(local_name='x', cell_or_freevar=False), _example=<torch.utils.weak.TensorWeakRef object at 0x7f36a1477b50>, is_unspecialized=False, fake_tensor=FakeTensor(..., device='cuda:0', size=(1, 3, 224, 224)), is_tensor=True, example_strong_ref=None)}
(Pdb) nodes[1].meta # conv node
{'nn_module_stack': {'L__self___conv': ("L['self'].conv", <class 'torch.nn.modules.conv.Conv2d'>)}, 'source_fn_stack': [('l__self___conv', <class 'torch.nn.modules.conv.Conv2d'>)], 'stack_trace': ' File "/work/TensorRT/tests/py/dynamo/models/test_dyn_compile.py", line 26, in forward\n out = self.conv(x)\n', 'example_value': FakeTensor(..., device='cuda:0', size=(1, 16, 222, 222),
grad_fn=<ConvolutionBackward0>)} The shape info in the above nodes has static shapes and does not have Expectation (this is coming from the behavior of torch.export) and with no input = torch.randn((2, 3, 224, 224)).to("cuda")
ep = torch.export.export(model, (input,), dynamic_shapes = {"x": {0: torch.export.Dim("B", min=2, max=8)}},)
gm = ep.module()
nodes = list(gm.graph.nodes)
(Pdb) nodes[2].meta # input node
{'val': FakeTensor(..., device='cuda:0', size=(s0, 3, 224, 224)), 'tensor_meta': TensorMetadata(shape=torch.Size([s0, 3, 224, 224]), dtype=torch.float32, requires_grad=False, stride=(150528, 50176, 224, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
(Pdb) nodes[3].meta # conv node
{'stack_trace': ' File "/work/TensorRT/tests/py/dynamo/models/test_dyn_compile.py", line 26, in forward\n out = self.conv(x)\n', 'nn_module_stack': {'L__self__': ('', <class '__main__.test_base_dynamic.<locals>.MyModule'>), 'L__self___conv': ('conv', <class 'torch.nn.modules.conv.Conv2d'>)}, 'source_fn_stack': [('l__self___conv', <class 'torch.nn.modules.conv.Conv2d'>)], 'original_aten': <OpOverload(op='aten.convolution', overload='default')>, 'from_node': [('out', 'L__self___conv'), ('convolution', <OpOverload(op='aten.convolution', overload='default')>)], 'seq_nr': 11, 'val': FakeTensor(..., device='cuda:0', size=(s0, 16, 222, 222)), 'tensor_meta': TensorMetadata(shape=torch.Size([s0, 16, 222, 222]), dtype=torch.float32, requires_grad=False, stride=(788544, 49284, 222, 1), memory_format=torch.contiguous_format, is_quantized=False, qparams={})}
(Pdb) nodes[3].meta['val'].size()[0]
s0
(Pdb) nodes[3].meta['val'].size()[0].node.shape_env.var_to_range
{s0: ValueRanges(lower=2, upper=8, is_bool=False)} The Question :
|
I'd suggest trying something like the following instead, to follow the suggestion by @ezyang :
Reasoning is that size |
Thanks @avikchaudhuri for the response. Here's what I did as per your comment model = MyModule().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")
torch._dynamo.mark_dynamic(input, 0)
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec)
trt_model(input) The graph that got generated within the backend is as follows (Pdb) print(gm.graph)
graph():
%s0 : torch.SymInt [num_users=0] = placeholder[target=s0]
%l_x_ : torch.Tensor [num_users=2] = placeholder[target=L_x_]
%size : [num_users=1] = call_method[target=size](args = (%l_x_,), kwargs = {})
%getitem : [num_users=1] = call_function[target=operator.getitem](args = (%size, 0), kwargs = {})
%le : [num_users=1] = call_function[target=operator.le](args = (%getitem, 8), kwargs = {})
%scalar_tensor : [num_users=1] = call_function[target=torch.scalar_tensor](args = (%le,), kwargs = {})
%_assert_async : [num_users=0] = call_function[target=torch._assert_async](args = (%scalar_tensor, assertion error), kwargs = {})
%out : [num_users=1] = call_module[target=L__self___conv](args = (%l_x_,), kwargs = {})
%out_1 : [num_users=1] = call_module[target=L__self___relu](args = (%out,), kwargs = {})
return (out_1,) And now I check the metadata for
The |
You didn't change the batch size.
|
@ezyang My bad. The snippet I posted might be misleading since it has my_model = torch.compile(model, backend="tensorrt") And with nodes = list(gm.graph.nodes)
nodes[7].meta['example_value'].size()[0].node.shape_env.var_to_range
{s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} These ranges are not the ones I defined via assert though. |
When I run my example with inductor backend, with a breakpoint in inductor compile_fx and I do
So there must be something wrong with how tensorrt is integrated with dynamo. |
Hello @ezyang, thanks for the info. My torch version : 1) Inductor repro: Unfortunately, I'm unable to repro what you are seeing. Can you please confirm the following steps are correct ? Model: class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
self.relu = torch.nn.ReLU()
def forward(self, x):
assert x.size()[0] >= 1
assert x.size()[0] <= 8
out = self.conv(x)
out = self.relu(out)
return out
model = MyModule().eval().cuda()
input = torch.randn((4, 3, 224, 224)).to("cuda")
torch._dynamo.mark_dynamic(input, 0)
ind_model = torch.compile(model, backend="inductor") I placed a breakpoint at https://github.com/pytorch/pytorch/blob/main/torch/_inductor/compile_fx.py#L952 (Pdb) nodes=list(model_.graph.nodes)
(Pdb) nodes
[s0, l_x_, size, getitem, ge, scalar_tensor, _assert_async, size_1, getitem_4, le, scalar_tensor_1, _assert_async_1, out, out_1, output]
# Node 12 is the conv node in the inductor case
(Pdb) nodes[12].meta['example_value'].size()[0].node.shape_env.var_to_range
{s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} Any suggestion as to where the difference in outputs (b/w your runs and mine) can be ? 2) Torch-TensorRT repro : Code: model = MyModule().eval().cuda()
input = torch.randn((4, 3, 224, 224)).to("cuda")
torch._dynamo.mark_dynamic(input, 0)
compile_spec = {
"inputs": [input],
"min_block_size": 1,
"debug": True,
}
trt_model = torch.compile(model, backend="tensorrt", options=compile_spec) Entry point for torch.compile within tensorrt backend is : https://github.com/pytorch/TensorRT/blob/main/py/torch_tensorrt/dynamo/backend/backends.py#L42 TensorRT does not come into picture before this and this is where I checked the (Pdb) nodes=list(gm.graph.nodes)
(Pdb) nodes[12].meta['example_value'].size()[0].node.shape_env.var_to_range
{s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} |
I think what is going on is sometimes asserts are being optimized away. Change the assert call to |
Thanks @ezyang. With |
🚀 The feature, motivation and pitch
Torch-TensorRT has two workflows (torch.export based and torch.compile based) for optimizing Pytorch models using TensorRT.
torch.export.export()
flow can accept dynamic shapes withmin
andmax
ranges specified for any input dimension. This range information can be accessible for intermediate nodes in the graph vianode.shape_env.var_to_range
. TensorRT uses this range information when building engines (with dynamic shapes).torch.compile
providestorch._dynamo.mark_dynamic(tensor, dim)
which is great. But can this be extended to accept ranges (or something similar totorch.export.Dim(name, min, max)
API ?Ultimately, the ask here is for the graph (provided by
torch.compile
to its backend) to have the range information (similar to torch.export.export()).Please let me know if there's a way to do this currently (or) any questions you have.
Alternatives
No response
Additional context
No response
cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519 @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4
The text was updated successfully, but these errors were encountered: