Skip to content

Commit

Permalink
Fix compute_mask for Birectional with return_state=True
Browse files Browse the repository at this point in the history
Fix `compute_mask` to properly support `return_state` introduced in Birectional with #8977
  • Loading branch information
nisargjhaveri committed Feb 26, 2018
1 parent 9b52f74 commit 991d0bd
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions keras/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,11 +490,20 @@ def build(self, input_shape):
def compute_mask(self, inputs, mask):
if self.return_sequences:
if not self.merge_mode:
return [mask, mask]
output_mask = [mask, mask]
else:
return mask
output_mask = mask
else:
return None
output_mask = None

if self.return_state:
states = self.forward_layer.states
state_mask = [None for _ in states]
if isinstance(output_mask, list):
return output_mask + state_mask * 2
return [output_mask] + state_mask * 2

return output_mask

@property
def trainable_weights(self):
Expand Down

0 comments on commit 991d0bd

Please sign in to comment.