Skip to content

Commit

Permalink
Add check that dshape[dim] % multiplication of dimensions in unflatte…
Browse files Browse the repository at this point in the history
…ned_size == 0
  • Loading branch information
mshr-h committed Nov 16, 2023
1 parent bca8a84 commit 252a58b
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1552,6 +1552,13 @@ def unflatten(self, inputs, input_types):
unflattened_size = tuple(inputs[2])
dshape = get_const_tuple(self.infer_shape_with_prelude(data))
assert len(dshape) > dim

mult = 1
for s in unflattened_size:
if s is not -1:
mult *= s
assert dshape[dim] % mult == 0

new_shape = dshape[:dim] + unflattened_size + dshape[dim + 1 :]
out = _op.reshape(data, new_shape)
return out
Expand Down

0 comments on commit 252a58b

Please sign in to comment.