Skip to content

Commit

Permalink
🐛Fixed RNN when bidirectional is True
Browse files Browse the repository at this point in the history
  • Loading branch information
carefree0910 committed Mar 24, 2021
1 parent aec1846 commit be974df
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion cflearn/modules/extractors/rnn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
rnn_base = rnn_dict[cell]
input_dimensions = [self.in_dim]
self.hidden_size = cell_config["hidden_size"]
self.bidirectional = cell_config.setdefault("bidirectional", False)
input_dimensions += [self.hidden_size] * (num_layers - 1)
rnn_list = []
for dim in input_dimensions:
Expand All @@ -51,7 +52,9 @@ def flatten_ts(self) -> bool:

@property
def out_dim(self) -> int:
return self.hidden_size
if not self.bidirectional:
return self.hidden_size
return 2 * self.hidden_size

def forward(self, net: torch.Tensor) -> torch.Tensor:
for rnn in self.rnn_list:
Expand Down

0 comments on commit be974df

Please sign in to comment.