From 44d345362565e65305c0acba0bbc5f4147789157 Mon Sep 17 00:00:00 2001 From: Sam Armstrong <88863522+Sam-Armstrong@users.noreply.github.com> Date: Wed, 21 Feb 2024 10:51:39 +0000 Subject: [PATCH] fix: improved fix for making lstm weights contiguous (#28368) --- ivy/functional/backends/torch/layers.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/ivy/functional/backends/torch/layers.py b/ivy/functional/backends/torch/layers.py index f1b7fecce5367..c470384987513 100644 --- a/ivy/functional/backends/torch/layers.py +++ b/ivy/functional/backends/torch/layers.py @@ -887,7 +887,7 @@ def lstm( if weights_transposed: # transpose the weights if they are in the wrong format all_weights = [ - torch.transpose(weight, 1, 0) if weight.dim() == 2 else weight + torch.transpose(weight, 1, 0).contiguous() if weight.dim() == 2 else weight for weight in all_weights ] else: @@ -910,11 +910,6 @@ def lstm( if initial_states[1].dim() == 2: initial_states[1] = ivy.expand_dims(initial_states[1]) - # ensure all weights are contiguous, so they will work on gpu - for i, w in enumerate(all_weights): - if not w.is_contiguous(): - all_weights[i] = w.contiguous() - ret = torch.lstm( input, initial_states,