diff --git a/ivy/functional/frontends/torch/tensor.py b/ivy/functional/frontends/torch/tensor.py index da9c3a2dc243..33179310b6dc 100644 --- a/ivy/functional/frontends/torch/tensor.py +++ b/ivy/functional/frontends/torch/tensor.py @@ -767,14 +767,9 @@ def unfold(self, dimension, size, step): slicing[dimension] = slice(i, i + size) slices.append(self.ivy_array[tuple(slicing)]) stacked = torch_frontend.stack(slices, dim=dimension) - new_shape = list(self.shape) - num_slices = (self.shape[dimension] - size) // step + 1 - new_shape[dimension] = num_slices - new_shape.insert(dimension + 1, size) - reshaped = stacked.reshape(new_shape) dims = list(range(len(stacked.shape))) dims[-2], dims[-1] = dims[-1], dims[-2] - return reshaped.permute(*dims) + return stacked.permute(*dims) def long(self, memory_format=None): self.ivy_array = ivy.astype(self.ivy_array, ivy.int64, copy=False)