Skip to content

Commit

Permalink
fix: Fix unfold method in torch frontend with shape mismatch issues w…
Browse files Browse the repository at this point in the history
…ith the native fw
  • Loading branch information
hmahmood24 committed Jul 2, 2024
1 parent bd4e642 commit 4396c35
Showing 1 changed file with 1 addition and 6 deletions.
7 changes: 1 addition & 6 deletions ivy/functional/frontends/torch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,14 +767,9 @@ def unfold(self, dimension, size, step):
slicing[dimension] = slice(i, i + size)
slices.append(self.ivy_array[tuple(slicing)])
stacked = torch_frontend.stack(slices, dim=dimension)
new_shape = list(self.shape)
num_slices = (self.shape[dimension] - size) // step + 1
new_shape[dimension] = num_slices
new_shape.insert(dimension + 1, size)
reshaped = stacked.reshape(new_shape)
dims = list(range(len(stacked.shape)))
dims[-2], dims[-1] = dims[-1], dims[-2]
return reshaped.permute(*dims)
return stacked.permute(*dims)

def long(self, memory_format=None):
self.ivy_array = ivy.astype(self.ivy_array, ivy.int64, copy=False)
Expand Down

0 comments on commit 4396c35

Please sign in to comment.