Skip to content

Commit

Permalink
Add support for constants in Bidirectional wrapper (#9260)
Browse files Browse the repository at this point in the history
* Add support fot `constants` in Bidirectional wrapper

* Add more tests for Bidirectional wrapper

* Fix `compute_mask` for Birectional with return_state=True

Fix `compute_mask` to properly support `return_state` introduced in Birectional with #8977

* Add test for Bidirectional with unknown timestamps

* Skip test for CNTK for unknown timestamps with Bidirectional

* avoid override the input constant when need broadcast sequential axis on rnn's constant

* Move _standardize_args to recurrent, remove duplication

* Fix  for Birectional when multiple masks are passed
  • Loading branch information
nisargjhaveri authored and fchollet committed Apr 14, 2018
1 parent 083a41c commit e246250
Show file tree
Hide file tree
Showing 5 changed files with 342 additions and 92 deletions.
38 changes: 22 additions & 16 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1353,31 +1353,37 @@ def rnn(step_function, inputs, initial_states,
'variable-length sequences. Please specify a '
'static length for your sequences.')

rnn_inputs = inputs
if need_convert:
if go_backwards:
inputs = reverse(inputs, 1)
rnn_inputs = reverse(rnn_inputs, 1)

inputs = C.to_sequence(inputs)
rnn_inputs = C.to_sequence(rnn_inputs)

j = 0
while j < len(constants):
if isinstance(constants[j], list):
i = 0
while i < len(constants[j]):
if _get_dynamic_axis_num(constants[j][i]) == 1:
constants[j][i] = C.sequence.broadcast_as(constants[j][i], inputs)
i += 1
rnn_constants = []
for constant in constants:
if isinstance(constant, list):
new_c = []
for c in constant:
if _get_dynamic_axis_num(c) == 1:
new_c.append(C.sequence.broadcast_as(c, rnn_inputs))
else:
new_c.append(c)
rnn_constants.append(new_c)
else:
if _get_dynamic_axis_num(constants[j]) == 1:
constants[j] = C.sequence.broadcast_as(constants[j], inputs)
j += 1
if _get_dynamic_axis_num(constant) == 1:
rnn_constants.append(C.sequence.broadcast_as(constant, rnn_inputs))
else:
rnn_constants.append(constant)
else:
rnn_constants = constants

if mask is not None and not has_seq_axis(mask):
if go_backwards:
mask = reverse(mask, 1)
if len(int_shape(mask)) == 2:
mask = expand_dims(mask)
mask = C.to_sequence_like(mask, inputs)
mask = C.to_sequence_like(mask, rnn_inputs)

states = tuple(initial)

Expand All @@ -1389,7 +1395,7 @@ def _recurrence(x, states, m):
for s, p in zip(states, place_holders):
past_values.append(C.sequence.past_value(p, s))
new_output, new_states = step_function(
x, tuple(past_values) + tuple(constants))
x, tuple(past_values) + tuple(rnn_constants))

if getattr(new_output, '_uses_learning_phase', False):
global uses_learning_phase
Expand All @@ -1404,7 +1410,7 @@ def _recurrence(x, states, m):
new_output = n_s[0]
return new_output, n_s

final_output, final_states = _recurrence(inputs, states, mask)
final_output, final_states = _recurrence(rnn_inputs, states, mask)
last_output = C.sequence.last(final_output)
last_states = [C.sequence.last(s) for s in final_states]

Expand Down
5 changes: 3 additions & 2 deletions keras/layers/convolutional_recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .. import regularizers
from .. import constraints
from .recurrent import _generate_dropout_mask
from .recurrent import _standardize_args

import numpy as np
import warnings
Expand Down Expand Up @@ -270,8 +271,8 @@ def get_initial_state(self, inputs):
return [initial_state]

def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = self._standardize_args(
inputs, initial_state, constants)
inputs, initial_state, constants = _standardize_args(
inputs, initial_state, constants, self._num_constants)

if initial_state is None and constants is None:
return super(ConvRNN2D, self).__call__(inputs, **kwargs)
Expand Down
85 changes: 43 additions & 42 deletions keras/layers/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,8 @@ def get_initial_state(self, inputs):
return [K.tile(initial_state, [1, self.cell.state_size])]

def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = self._standardize_args(
inputs, initial_state, constants)
inputs, initial_state, constants = _standardize_args(
inputs, initial_state, constants, self._num_constants)

if initial_state is None and constants is None:
return super(RNN, self).__call__(inputs, **kwargs)
Expand Down Expand Up @@ -633,46 +633,6 @@ def step(inputs, states):
else:
return output

def _standardize_args(self, inputs, initial_state, constants):
"""Standardize `__call__` to a single list of tensor inputs.
When running a model loaded from file, the input tensors
`initial_state` and `constants` can be passed to `RNN.__call__` as part
of `inputs` instead of by the dedicated keyword arguments. This method
makes sure the arguments are separated and that `initial_state` and
`constants` are lists of tensors (or None).
# Arguments
inputs: tensor or list/tuple of tensors
initial_state: tensor or list of tensors or None
constants: tensor or list of tensors or None
# Returns
inputs: tensor
initial_state: list of tensors or None
constants: list of tensors or None
"""
if isinstance(inputs, list):
assert initial_state is None and constants is None
if self._num_constants is not None:
constants = inputs[-self._num_constants:]
inputs = inputs[:-self._num_constants]
if len(inputs) > 1:
initial_state = inputs[1:]
inputs = inputs[0]

def to_list_or_none(x):
if x is None or isinstance(x, list):
return x
if isinstance(x, tuple):
return list(x)
return [x]

initial_state = to_list_or_none(initial_state)
constants = to_list_or_none(constants)

return inputs, initial_state, constants

def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
Expand Down Expand Up @@ -2262,3 +2222,44 @@ def dropped_inputs():
dropped_inputs,
ones,
training=training)


def _standardize_args(inputs, initial_state, constants, num_constants):
"""Standardize `__call__` to a single list of tensor inputs.
When running a model loaded from file, the input tensors
`initial_state` and `constants` can be passed to `RNN.__call__` as part
of `inputs` instead of by the dedicated keyword arguments. This method
makes sure the arguments are separated and that `initial_state` and
`constants` are lists of tensors (or None).
# Arguments
inputs: tensor or list/tuple of tensors
initial_state: tensor or list of tensors or None
constants: tensor or list of tensors or None
# Returns
inputs: tensor
initial_state: list of tensors or None
constants: list of tensors or None
"""
if isinstance(inputs, list):
assert initial_state is None and constants is None
if num_constants is not None:
constants = inputs[-num_constants:]
inputs = inputs[:-num_constants]
if len(inputs) > 1:
initial_state = inputs[1:]
inputs = inputs[0]

def to_list_or_none(x):
if x is None or isinstance(x, list):
return x
if isinstance(x, tuple):
return list(x)
return [x]

initial_state = to_list_or_none(initial_state)
constants = to_list_or_none(constants)

return inputs, initial_state, constants
105 changes: 74 additions & 31 deletions keras/layers/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from ..utils.generic_utils import has_arg
from .. import backend as K

from . import recurrent


class Wrapper(Layer):
"""Abstract wrapper base class.
Expand Down Expand Up @@ -276,6 +278,7 @@ def __init__(self, layer, merge_mode='concat', weights=None, **kwargs):
self._trainable = True
super(Bidirectional, self).__init__(layer, **kwargs)
self.input_spec = layer.input_spec
self._num_constants = None

@property
def trainable(self):
Expand Down Expand Up @@ -314,36 +317,45 @@ def compute_output_shape(self, input_shape):
return [output_shape] + state_shape + copy.copy(state_shape)
return output_shape

def __call__(self, inputs, initial_state=None, **kwargs):
if isinstance(inputs, list):
if len(inputs) > 1:
initial_state = inputs[1:]
inputs = inputs[0]
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
inputs, initial_state, constants = recurrent._standardize_args(
inputs, initial_state, constants, self._num_constants)

if initial_state is None:
if initial_state is None and constants is None:
return super(Bidirectional, self).__call__(inputs, **kwargs)

# Standardize `initial_state` into list
if isinstance(initial_state, tuple):
initial_state = list(initial_state)
elif not isinstance(initial_state, list):
initial_state = [initial_state]

# Check if `initial_state` can be splitted into half
num_states = len(initial_state)
if num_states % 2 > 0:
raise ValueError(
'When passing `initial_state` to a Bidirectional RNN, the state '
'should be a list containing the states of the underlying RNNs. '
'Found: ' + str(initial_state))

# Applies the same workaround as in `RNN.__call__`, without handling constants
kwargs['initial_state'] = initial_state
additional_inputs = initial_state
additional_specs = [InputSpec(shape=K.int_shape(state))
for state in initial_state]
self.forward_layer.state_spec = additional_specs[:num_states // 2]
self.backward_layer.state_spec = additional_specs[num_states // 2:]
# Applies the same workaround as in `RNN.__call__`
additional_inputs = []
additional_specs = []
if initial_state is not None:
# Check if `initial_state` can be splitted into half
num_states = len(initial_state)
if num_states % 2 > 0:
raise ValueError(
'When passing `initial_state` to a Bidirectional RNN, '
'the state should be a list containing the states of '
'the underlying RNNs. '
'Found: ' + str(initial_state))

kwargs['initial_state'] = initial_state
additional_inputs += initial_state
state_specs = [InputSpec(shape=K.int_shape(state))
for state in initial_state]
self.forward_layer.state_spec = state_specs[:num_states // 2]
self.backward_layer.state_spec = state_specs[num_states // 2:]
additional_specs += state_specs
if constants is not None:
kwargs['constants'] = constants
additional_inputs += constants
constants_spec = [InputSpec(shape=K.int_shape(constant))
for constant in constants]
self.forward_layer.constants_spec = constants_spec
self.backward_layer.constants_spec = constants_spec
additional_specs += constants_spec

self._num_constants = len(constants)
self.forward_layer._num_constants = self._num_constants
self.backward_layer._num_constants = self._num_constants

is_keras_tensor = K.is_keras_tensor(additional_inputs[0])
for tensor in additional_inputs:
Expand All @@ -368,12 +380,19 @@ def __call__(self, inputs, initial_state=None, **kwargs):
else:
return super(Bidirectional, self).__call__(inputs, **kwargs)

def call(self, inputs, training=None, mask=None, initial_state=None):
def call(self,
inputs,
mask=None,
training=None,
initial_state=None,
constants=None):
kwargs = {}
if has_arg(self.layer.call, 'training'):
kwargs['training'] = training
if has_arg(self.layer.call, 'mask'):
kwargs['mask'] = mask
if has_arg(self.layer.call, 'constants'):
kwargs['constants'] = constants

if initial_state is not None and has_arg(self.layer.call, 'initial_state'):
forward_state = initial_state[:len(initial_state) // 2]
Expand Down Expand Up @@ -429,13 +448,24 @@ def build(self, input_shape):
self.built = True

def compute_mask(self, inputs, mask):
if isinstance(mask, list):
mask = mask[0]
if self.return_sequences:
if not self.merge_mode:
return [mask, mask]
output_mask = [mask, mask]
else:
return mask
output_mask = mask
else:
return None
output_mask = [None, None] if not self.merge_mode else None

if self.return_state:
states = self.forward_layer.states
state_mask = [None for _ in states]
if isinstance(output_mask, list):
return output_mask + state_mask * 2
return [output_mask] + state_mask * 2

return output_mask

@property
def trainable_weights(self):
Expand Down Expand Up @@ -473,5 +503,18 @@ def constraints(self):

def get_config(self):
config = {'merge_mode': self.merge_mode}
if self._num_constants is not None:
config['num_constants'] = self._num_constants

base_config = super(Bidirectional, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

@classmethod
def from_config(cls, config, custom_objects=None):
from . import deserialize as deserialize_layer
rnn_layer = deserialize_layer(config.pop('layer'),
custom_objects=custom_objects)
num_constants = config.pop('num_constants', None)
layer = cls(rnn_layer, **config)
layer._num_constants = num_constants
return layer
Loading

0 comments on commit e246250

Please sign in to comment.