Skip to content

Commit

Permalink
Removed generate dropout ones from recurrent. (#9892)
Browse files Browse the repository at this point in the history
* Removed generate dropout ones from recurrent.

* Fixed index issue.
  • Loading branch information
gabrieldemarmiesse authored and taehoonlee committed Apr 11, 2018
1 parent af804d0 commit e73199d
Showing 1 changed file with 6 additions and 16 deletions.
22 changes: 6 additions & 16 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,13 +888,13 @@ def call(self, inputs, states, training=None):
prev_output = states[0]
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
_generate_dropout_ones(inputs, K.shape(inputs)[-1]),

This comment has been minimized.

Copy link
@wangwangsuibinbin

wangwangsuibinbin Apr 13, 2019

K.ones_like()

K.ones_like(inputs),
self.dropout,
training=training)
if (0 < self.recurrent_dropout < 1 and
self._recurrent_dropout_mask is None):
self._recurrent_dropout_mask = _generate_dropout_mask(
_generate_dropout_ones(inputs, self.units),
K.ones_like(prev_output),
self.recurrent_dropout,
training=training)

Expand Down Expand Up @@ -1329,14 +1329,14 @@ def call(self, inputs, states, training=None):

if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
_generate_dropout_ones(inputs, K.shape(inputs)[-1]),
K.ones_like(inputs),
self.dropout,
training=training,
count=3)
if (0 < self.recurrent_dropout < 1 and
self._recurrent_dropout_mask is None):
self._recurrent_dropout_mask = _generate_dropout_mask(
_generate_dropout_ones(inputs, self.units),
K.ones_like(h_tm1),
self.recurrent_dropout,
training=training,
count=3)
Expand Down Expand Up @@ -1887,14 +1887,14 @@ def bias_initializer(_, *args, **kwargs):
def call(self, inputs, states, training=None):
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
_generate_dropout_ones(inputs, K.shape(inputs)[-1]),
K.ones_like(inputs),
self.dropout,
training=training,
count=4)
if (0 < self.recurrent_dropout < 1 and
self._recurrent_dropout_mask is None):
self._recurrent_dropout_mask = _generate_dropout_mask(
_generate_dropout_ones(inputs, self.units),
K.ones_like(states[0]),
self.recurrent_dropout,
training=training,
count=4)
Expand Down Expand Up @@ -2249,16 +2249,6 @@ def from_config(cls, config):
return cls(**config)


def _generate_dropout_ones(inputs, dims):
# Currently, CNTK can't instantiate `ones` with symbolic shapes.
# Will update workaround once CNTK supports it.
if K.backend() == 'cntk':
ones = K.ones_like(K.reshape(inputs[:, 0], (-1, 1)))
return K.tile(ones, (1, dims))
else:
return K.ones((K.shape(inputs)[0], dims))


def _generate_dropout_mask(ones, rate, training=None, count=1):
def dropped_inputs():
return K.dropout(ones, rate)
Expand Down

0 comments on commit e73199d

Please sign in to comment.