Skip to content

Commit

Permalink
fix: improved fix for making lstm weights contiguous (#28368)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Feb 21, 2024
1 parent c7e98b3 commit 44d3453
Showing 1 changed file with 1 addition and 6 deletions.
7 changes: 1 addition & 6 deletions ivy/functional/backends/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,7 +887,7 @@ def lstm(
if weights_transposed:
# transpose the weights if they are in the wrong format
all_weights = [
torch.transpose(weight, 1, 0) if weight.dim() == 2 else weight
torch.transpose(weight, 1, 0).contiguous() if weight.dim() == 2 else weight
for weight in all_weights
]
else:
Expand All @@ -910,11 +910,6 @@ def lstm(
if initial_states[1].dim() == 2:
initial_states[1] = ivy.expand_dims(initial_states[1])

# ensure all weights are contiguous, so they will work on gpu
for i, w in enumerate(all_weights):
if not w.is_contiguous():
all_weights[i] = w.contiguous()

ret = torch.lstm(
input,
initial_states,
Expand Down

0 comments on commit 44d3453

Please sign in to comment.