Skip to content

Commit

Permalink
feat: change ivy.lstm to always return sequences and states (ivy-llc#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong authored Feb 1, 2024
1 parent 602a436 commit d7fc398
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 39 deletions.
15 changes: 1 addition & 14 deletions ivy/functional/backends/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,8 +883,6 @@ def lstm(
weights_transposed: bool = False,
has_ih_bias: bool = True,
has_hh_bias: bool = True,
return_sequences: bool = True,
return_states: bool = True,
):
if weights_transposed:
# transpose the weights if they are in the wrong format
Expand Down Expand Up @@ -924,15 +922,4 @@ def lstm(
batch_first,
)

if return_states:
if return_sequences:
return ret
else:
return tuple(
[ret[0][:, -1], ret[1], ret[2]]
) # TODO: this depends on batch_first
else:
if return_sequences:
return ret[0]
else:
return ret[0][:, -1]
return ret[0][:, -1], ret[0], (ret[1], ret[2])
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def _lstm_full(
bidirectional,
batch_first,
):
return ivy.lstm(
ret = ivy.lstm(
input,
hx,
params,
Expand All @@ -38,6 +38,7 @@ def _lstm_full(
has_ih_bias=has_biases,
has_hh_bias=has_biases,
)
return ret[1], ret[2][0], ret[2][1]


def _lstm_packed(
Expand All @@ -51,7 +52,7 @@ def _lstm_packed(
train,
bidirectional,
):
return ivy.lstm(
ret = ivy.lstm(
data,
hx,
params,
Expand All @@ -63,6 +64,7 @@ def _lstm_packed(
has_ih_bias=has_biases,
has_hh_bias=has_biases,
)
return ret[1], ret[2][0], ret[2][1]


# --- Main --- #
Expand Down
18 changes: 1 addition & 17 deletions ivy/functional/ivy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2397,8 +2397,6 @@ def lstm(
weights_transposed: bool = False,
has_ih_bias: bool = True,
has_hh_bias: bool = True,
return_sequences: bool = True,
return_states: bool = True,
):
"""Applies a multi-layer long-short term memory to an input sequence.
Expand Down Expand Up @@ -2442,11 +2440,6 @@ def lstm(
whether the `all_weights` argument includes a input-hidden bias
has_hh_bias
whether the `all_weights` argument includes a hidden-hidden bias
return_sequences
whether to return the last output in the output sequence,
or the full sequence
return_states
whether to return the final hidden and carry states in addition to the output
Returns
-------
Expand Down Expand Up @@ -2567,16 +2560,7 @@ def lstm(
if batch_sizes is not None:
output = _pack_padded_sequence(output, batch_sizes)[0]

if return_states:
if return_sequences:
return output, h_outs, c_outs
else:
return output[:, -1], h_outs, c_outs # TODO: this depends on batch_first
else:
if return_sequences:
return output
else:
return output[:, -1]
return output[:, -1], output, (h_outs, c_outs)


# Helpers #
Expand Down
6 changes: 0 additions & 6 deletions ivy_tests/test_ivy/test_functional/test_nn/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ def _lstm_helper(draw):
has_ih_bias = draw(st.booleans())
has_hh_bias = draw(st.booleans())
weights_transposed = draw(st.booleans())
return_sequences = draw(st.booleans())
return_states = draw(st.booleans())
bidirectional = draw(st.booleans())
dropout = draw(st.floats(min_value=0, max_value=0.99))
train = draw(st.booleans()) and not dropout
Expand Down Expand Up @@ -217,8 +215,6 @@ def _lstm_helper(draw):
"weights_transposed": weights_transposed,
"has_ih_bias": has_ih_bias,
"has_hh_bias": has_hh_bias,
"return_sequences": return_sequences,
"return_states": return_states,
}
else:
dtypes = dtype
Expand All @@ -234,8 +230,6 @@ def _lstm_helper(draw):
"weights_transposed": weights_transposed,
"has_ih_bias": has_ih_bias,
"has_hh_bias": has_hh_bias,
"return_sequences": return_sequences,
"return_states": return_states,
}
return dtypes, kwargs

Expand Down

0 comments on commit d7fc398

Please sign in to comment.