From e73199d413395d7ede0dd86c06f7a9ad2d1fee17 Mon Sep 17 00:00:00 2001 From: Gabriel de Marmiesse Date: Wed, 11 Apr 2018 07:21:51 +0200 Subject: [PATCH] Removed generate dropout ones from recurrent. (#9892) * Removed generate dropout ones from recurrent. * Fixed index issue. --- keras/layers/recurrent.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/keras/layers/recurrent.py b/keras/layers/recurrent.py index 952c6c77ce8..cbe6ce9a7e3 100644 --- a/keras/layers/recurrent.py +++ b/keras/layers/recurrent.py @@ -888,13 +888,13 @@ def call(self, inputs, states, training=None): prev_output = states[0] if 0 < self.dropout < 1 and self._dropout_mask is None: self._dropout_mask = _generate_dropout_mask( - _generate_dropout_ones(inputs, K.shape(inputs)[-1]), + K.ones_like(inputs), self.dropout, training=training) if (0 < self.recurrent_dropout < 1 and self._recurrent_dropout_mask is None): self._recurrent_dropout_mask = _generate_dropout_mask( - _generate_dropout_ones(inputs, self.units), + K.ones_like(prev_output), self.recurrent_dropout, training=training) @@ -1329,14 +1329,14 @@ def call(self, inputs, states, training=None): if 0 < self.dropout < 1 and self._dropout_mask is None: self._dropout_mask = _generate_dropout_mask( - _generate_dropout_ones(inputs, K.shape(inputs)[-1]), + K.ones_like(inputs), self.dropout, training=training, count=3) if (0 < self.recurrent_dropout < 1 and self._recurrent_dropout_mask is None): self._recurrent_dropout_mask = _generate_dropout_mask( - _generate_dropout_ones(inputs, self.units), + K.ones_like(h_tm1), self.recurrent_dropout, training=training, count=3) @@ -1887,14 +1887,14 @@ def bias_initializer(_, *args, **kwargs): def call(self, inputs, states, training=None): if 0 < self.dropout < 1 and self._dropout_mask is None: self._dropout_mask = _generate_dropout_mask( - _generate_dropout_ones(inputs, K.shape(inputs)[-1]), + K.ones_like(inputs), self.dropout, training=training, count=4) if (0 < self.recurrent_dropout < 1 and self._recurrent_dropout_mask is None): self._recurrent_dropout_mask = _generate_dropout_mask( - _generate_dropout_ones(inputs, self.units), + K.ones_like(states[0]), self.recurrent_dropout, training=training, count=4) @@ -2249,16 +2249,6 @@ def from_config(cls, config): return cls(**config) -def _generate_dropout_ones(inputs, dims): - # Currently, CNTK can't instantiate `ones` with symbolic shapes. - # Will update workaround once CNTK supports it. - if K.backend() == 'cntk': - ones = K.ones_like(K.reshape(inputs[:, 0], (-1, 1))) - return K.tile(ones, (1, dims)) - else: - return K.ones((K.shape(inputs)[0], dims)) - - def _generate_dropout_mask(ones, rate, training=None, count=1): def dropped_inputs(): return K.dropout(ones, rate)