Skip to content

Commit

Permalink
Fixed cat uint8 lowering (pytorch#112753)
Browse files Browse the repository at this point in the history
Description:
- Fixed cat uint8 lowering

Otherwise, it gives the following issue on the repro code:
```python
def func(x):
    batch_shape = x.shape[:1]
    out = torch.cat([x.new_zeros(1).expand(batch_shape + (1,)), x], dim=-1)
    return out

cfunc = torch.compile(func)

x = torch.randint(0, 256, size=(3, 255), dtype=torch.uint8)
out = cfunc(x)
```
Error message:
```
  File "/pytorch/torch/_inductor/lowering.py", line 1037, in <genexpr>
    if all(len(input.layout.size) == 4 for input in inputs):
  File "/pytorch/torch/_inductor/ir.py", line 5795, in __getattr__
    fn = getattr(self.data, name)
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
LoweringException: AttributeError: 'ExpandView' object has no attribute 'layout'
  target: aten.cat.default
  args[0]: [TensorBox(
    ExpandView(data=StorageBox(
      ComputedBuffer(name='buf0', layout=FlexibleLayout('cpu', torch.uint8, size=[1], stride=[1]), data=Pointwise(
        'cpu',
        torch.uint8,
        def inner_fn(index):
            _ = index
            tmp0 = ops.constant(0, torch.uint8)
            return tmp0
        ,
        ranges=[1],
        origin_node=full,
        origins={full}
      ))
    ), size=[3, 1])
  ), TensorBox(StorageBox(
    InputBuffer(name='arg0_1', layout=FixedLayout('cpu', torch.uint8, size=[3, 255], stride=[255, 1]))
  ))]
  args[1]: 1

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
```

Context: compiling is not working for torchvision's `F.equalize` op: pytorch/vision#8056

Pull Request resolved: pytorch#112753
Approved by: https://github.com/peterbell10
  • Loading branch information
vfdev-5 authored and xuhancn committed Nov 8, 2023
1 parent 933a54e commit a216d23
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 1 deletion.
11 changes: 11 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3455,6 +3455,17 @@ def fn(a):
(torch.randn([1, 3, 3, 16]).to(memory_format=torch.channels_last),),
)

def test_cat_uint8(self):
def fn(x):
batch_shape = x.shape[:1]
out = torch.cat([x.new_zeros(1).expand(batch_shape + (1,)), x], dim=-1)
return out

self.common(
fn,
(torch.randint(0, 256, size=(3, 255), dtype=torch.uint8),),
)

def test_cat_empty(self):
def fn_2(*tensors):
return torch.cat(tensors)
Expand Down
3 changes: 3 additions & 0 deletions test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ def run(*ex, **kwargs):
("cpu", "cuda")
),
"test_zero_element_mutation_dynamic_shapes": TestFailure(("cpu", "cuda")),
"test_cat_uint8_dynamic_shapes": TestFailure(
("cpu",)
), # cat on uint8 input is using aten fallback on cpu
#
# Tests not using 'common' or directly calling 'assertEqual':
#
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ def cat(inputs, dim=0):
# code gen with uint8 data type directly.
for input in inputs:
input.realize()
if all(len(input.layout.size) == 4 for input in inputs):
if all(len(input.get_size()) == 4 for input in inputs):
inputs, _ = require_channels_last(aten.cat, *inputs)
return fallback_handler(aten.cat.default)(inputs, dim)

Expand Down

0 comments on commit a216d23

Please sign in to comment.