Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[Refactor] TransformerDecoder #976

Merged
merged 3 commits into from
Oct 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/examples/machine_translation/gnmt.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,12 @@ feed the encoder and decoder to the `NMTModel` to construct the GNMT model.
`model.hybridize` allows computation to be done using the symbolic backend. To understand what it means to be "hybridized," please refer to [this](https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/hybrid.html) page on MXNet hybridization and its advantages.

```{.python .input}
encoder, decoder = nmt.gnmt.get_gnmt_encoder_decoder(hidden_size=num_hidden,
dropout=dropout,
num_layers=num_layers,
num_bi_layers=num_bi_layers)
encoder, decoder, one_step_ahead_decoder = nmt.gnmt.get_gnmt_encoder_decoder(
hidden_size=num_hidden, dropout=dropout, num_layers=num_layers,
num_bi_layers=num_bi_layers)
model = nlp.model.translation.NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder,
decoder=decoder, embed_size=num_hidden, prefix='gnmt_')
decoder=decoder, one_step_ahead_decoder=one_step_ahead_decoder,
embed_size=num_hidden, prefix='gnmt_')
model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
static_alloc = True
model.hybridize(static_alloc=static_alloc)
Expand Down
267 changes: 162 additions & 105 deletions scripts/machine_translation/gnmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
# specific language governing permissions and limitations
# under the License.
"""Encoder and decoder usded in sequence-to-sequence learning."""
__all__ = ['GNMTEncoder', 'GNMTDecoder', 'get_gnmt_encoder_decoder']
__all__ = ['GNMTEncoder', 'GNMTDecoder', 'GNMTOneStepDecoder', 'get_gnmt_encoder_decoder']

import mxnet as mx
from mxnet.base import _as_list
from mxnet.gluon import nn, rnn
from mxnet.gluon.block import HybridBlock
from gluonnlp.model.seq2seq_encoder_decoder import Seq2SeqEncoder, Seq2SeqDecoder, \
_get_attention_cell, _get_cell_type, _nested_sequence_last
Seq2SeqOneStepDecoder, _get_attention_cell, _get_cell_type, _nested_sequence_last


class GNMTEncoder(Seq2SeqEncoder):
Expand Down Expand Up @@ -158,48 +158,14 @@ def forward(self, inputs, states=None, valid_length=None): #pylint: disable=arg
return [outputs, new_states], []


class GNMTDecoder(HybridBlock, Seq2SeqDecoder):
"""Structure of the RNN Encoder similar to that used in the
Google Neural Machine Translation paper.

We use gnmt_v2 strategy in tensorflow/nmt

Parameters
----------
cell_type : str or type
attention_cell : AttentionCell or str
Arguments of the attention cell.
Can be 'scaled_luong', 'normed_mlp', 'dot'
num_layers : int
hidden_size : int
dropout : float
use_residual : bool
output_attention: bool
Whether to output the attention weights
i2h_weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
h2h_weight_initializer : str or Initializer
Initializer for the recurrent weights matrix, used for the linear
transformation of the recurrent state.
i2h_bias_initializer : str or Initializer
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""
class _BaseGNMTDecoder(HybridBlock):
def __init__(self, cell_type='lstm', attention_cell='scaled_luong',
num_layers=2, hidden_size=128,
dropout=0.0, use_residual=True, output_attention=False,
i2h_weight_initializer=None, h2h_weight_initializer=None,
i2h_bias_initializer='zeros', h2h_bias_initializer='zeros',
prefix=None, params=None):
super(GNMTDecoder, self).__init__(prefix=prefix, params=params)
super().__init__(prefix=prefix, params=params)
self._cell_type = _get_cell_type(cell_type)
self._num_layers = num_layers
self._hidden_size = hidden_size
Expand Down Expand Up @@ -249,59 +215,7 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None):
decoder_states.append(mem_masks)
return decoder_states

def decode_seq(self, inputs, states, valid_length=None):
"""Decode the decoder inputs. This function is only used for training.

Parameters
----------
inputs : NDArray, Shape (batch_size, length, C_in)
states : list of NDArrays or None
Initial states. The list of initial decoder states
valid_length : NDArray or None
Valid lengths of each sequence. This is usually used when part of sequence has
been padded. Shape (batch_size,)

Returns
-------
output : NDArray, Shape (batch_size, length, C_out)
states : list
The decoder states, includes:

- rnn_states : NDArray
- attention_vec : NDArray
- mem_value : NDArray
- mem_masks : NDArray, optional
additional_outputs : list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, mem_length) or
(batch_size, num_heads, length, mem_length)
"""
length = inputs.shape[1]
output = []
additional_outputs = []
inputs = _as_list(mx.nd.split(inputs, num_outputs=length, axis=1, squeeze_axis=True))
rnn_states_l = []
attention_output_l = []
fixed_states = states[2:]
for i in range(length):
ele_output, states, ele_additional_outputs = self.forward(inputs[i], states)
rnn_states_l.append(states[0])
attention_output_l.append(states[1])
output.append(ele_output)
additional_outputs.extend(ele_additional_outputs)
output = mx.nd.stack(*output, axis=1)
if valid_length is not None:
states = [_nested_sequence_last(rnn_states_l, valid_length),
_nested_sequence_last(attention_output_l, valid_length)] + fixed_states
output = mx.nd.SequenceMask(output,
sequence_length=valid_length,
use_sequence_length=True,
axis=1)
if self._output_attention:
additional_outputs = [mx.nd.concat(*additional_outputs, dim=-2)]
return output, states, additional_outputs

def __call__(self, step_input, states): #pylint: disable=arguments-differ
def forward(self, step_input, states): # pylint: disable=arguments-differ
"""One-step-ahead decoding of the GNMT decoder.

Parameters
Expand All @@ -326,11 +240,7 @@ def __call__(self, step_input, states): #pylint: disable=arguments-differ
The attention weights will have shape (batch_size, 1, mem_length) or
(batch_size, num_heads, 1, mem_length)
"""
return super(GNMTDecoder, self).__call__(step_input, states)

def forward(self, step_input, states): #pylint: disable=arguments-differ, missing-docstring
step_output, new_states, step_additional_outputs =\
super(GNMTDecoder, self).forward(step_input, states)
step_output, new_states, step_additional_outputs = super().forward(step_input, states)
# In hybrid_forward, only the rnn_states and attention_vec are calculated.
# We directly append the mem_value and mem_masks in the forward() function.
# We apply this trick because the memory value/mask can be directly appended to the next
Expand Down Expand Up @@ -402,6 +312,148 @@ def hybrid_forward(self, F, step_input, states): #pylint: disable=arguments-dif
return rnn_out, new_states, step_additional_outputs


class GNMTOneStepDecoder(_BaseGNMTDecoder, Seq2SeqOneStepDecoder):
"""RNN Encoder similar to that used in the Google Neural Machine Translation paper.

One-step ahead decoder used during inference.

We use gnmt_v2 strategy in tensorflow/nmt

Parameters
----------
cell_type : str or type
Can be "lstm", "gru" or constructor functions that can be directly called,
like rnn.LSTMCell
attention_cell : AttentionCell or str
Arguments of the attention cell.
Can be 'scaled_luong', 'normed_mlp', 'dot'
num_layers : int
Total number of layers
hidden_size : int
Number of hidden units
dropout : float
The dropout rate
use_residual : bool
szha marked this conversation as resolved.
Show resolved Hide resolved
Whether to use residual connection. Residual connection will be added in the
uni-directional RNN layers
output_attention: bool
Whether to output the attention weights
i2h_weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
h2h_weight_initializer : str or Initializer
Initializer for the recurrent weights matrix, used for the linear
transformation of the recurrent state.
i2h_bias_initializer : str or Initializer
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""


class GNMTDecoder(_BaseGNMTDecoder, Seq2SeqDecoder):
"""RNN Encoder similar to that used in the Google Neural Machine Translation paper.

Multi-step decoder used during training with teacher forcing.

We use gnmt_v2 strategy in tensorflow/nmt

Parameters
----------
cell_type : str or type
Can be "lstm", "gru" or constructor functions that can be directly called,
like rnn.LSTMCell
attention_cell : AttentionCell or str
Arguments of the attention cell.
Can be 'scaled_luong', 'normed_mlp', 'dot'
num_layers : int
Total number of layers
hidden_size : int
Number of hidden units
dropout : float
The dropout rate
use_residual : bool
Whether to use residual connection. Residual connection will be added in the
uni-directional RNN layers
output_attention: bool
Whether to output the attention weights
i2h_weight_initializer : str or Initializer
Initializer for the input weights matrix, used for the linear
transformation of the inputs.
h2h_weight_initializer : str or Initializer
Initializer for the recurrent weights matrix, used for the linear
transformation of the recurrent state.
i2h_bias_initializer : str or Initializer
Initializer for the bias vector.
h2h_bias_initializer : str or Initializer
Initializer for the bias vector.
prefix : str, default 'rnn_'
Prefix for name of `Block`s
(and name of weight if params is `None`).
params : Parameter or None
Container for weight sharing between cells.
Created if `None`.
"""

def forward(self, inputs, states, valid_length=None): # pylint: disable=arguments-differ
"""Decode the decoder inputs. This function is only used for training.

Parameters
----------
inputs : NDArray, Shape (batch_size, length, C_in)
states : list of NDArrays or None
Initial states. The list of initial decoder states
valid_length : NDArray or None
Valid lengths of each sequence. This is usually used when part of sequence has
been padded. Shape (batch_size,)

Returns
-------
output : NDArray, Shape (batch_size, length, C_out)
states : list
The decoder states, includes:

- rnn_states : NDArray
- attention_vec : NDArray
- mem_value : NDArray
- mem_masks : NDArray, optional
additional_outputs : list
Either be an empty list or contains the attention weights in this step.
The attention weights will have shape (batch_size, length, mem_length) or
(batch_size, num_heads, length, mem_length)
"""
length = inputs.shape[1]
output = []
additional_outputs = []
inputs = _as_list(mx.nd.split(inputs, num_outputs=length, axis=1, squeeze_axis=True))
rnn_states_l = []
attention_output_l = []
fixed_states = states[2:]
for i in range(length):
ele_output, states, ele_additional_outputs = super().forward(inputs[i], states)
rnn_states_l.append(states[0])
attention_output_l.append(states[1])
output.append(ele_output)
additional_outputs.extend(ele_additional_outputs)
output = mx.nd.stack(*output, axis=1)
if valid_length is not None:
states = [_nested_sequence_last(rnn_states_l, valid_length),
_nested_sequence_last(attention_output_l, valid_length)] + fixed_states
output = mx.nd.SequenceMask(output,
sequence_length=valid_length,
use_sequence_length=True,
axis=1)
if self._output_attention:
additional_outputs = [mx.nd.concat(*additional_outputs, dim=-2)]
return output, states, additional_outputs


def get_gnmt_encoder_decoder(cell_type='lstm', attention_cell='scaled_luong', num_layers=2,
num_bi_layers=1, hidden_size=128, dropout=0.0, use_residual=False,
i2h_weight_initializer=None, h2h_weight_initializer=None,
Expand Down Expand Up @@ -435,19 +487,24 @@ def get_gnmt_encoder_decoder(cell_type='lstm', attention_cell='scaled_luong', nu
decoder : GNMTDecoder
"""
encoder = GNMTEncoder(cell_type=cell_type, num_layers=num_layers, num_bi_layers=num_bi_layers,
hidden_size=hidden_size, dropout=dropout,
use_residual=use_residual,
hidden_size=hidden_size, dropout=dropout, use_residual=use_residual,
i2h_weight_initializer=i2h_weight_initializer,
h2h_weight_initializer=h2h_weight_initializer,
i2h_bias_initializer=i2h_bias_initializer,
h2h_bias_initializer=h2h_bias_initializer,
prefix=prefix + 'enc_', params=params)
h2h_bias_initializer=h2h_bias_initializer, prefix=prefix + 'enc_',
params=params)
decoder = GNMTDecoder(cell_type=cell_type, attention_cell=attention_cell, num_layers=num_layers,
hidden_size=hidden_size, dropout=dropout,
use_residual=use_residual,
hidden_size=hidden_size, dropout=dropout, use_residual=use_residual,
i2h_weight_initializer=i2h_weight_initializer,
h2h_weight_initializer=h2h_weight_initializer,
i2h_bias_initializer=i2h_bias_initializer,
h2h_bias_initializer=h2h_bias_initializer,
prefix=prefix + 'dec_', params=params)
return encoder, decoder
h2h_bias_initializer=h2h_bias_initializer, prefix=prefix + 'dec_',
params=params)
one_step_ahead_decoder = GNMTOneStepDecoder(
cell_type=cell_type, attention_cell=attention_cell, num_layers=num_layers,
hidden_size=hidden_size, dropout=dropout, use_residual=use_residual,
i2h_weight_initializer=i2h_weight_initializer,
h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer,
h2h_bias_initializer=h2h_bias_initializer, prefix=prefix + 'dec_',
params=decoder.collect_params())
return encoder, decoder, one_step_ahead_decoder
17 changes: 7 additions & 10 deletions scripts/machine_translation/inference_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,14 @@
else:
tgt_max_len = max_len[1]

encoder, decoder = get_transformer_encoder_decoder(units=args.num_units,
hidden_size=args.hidden_size,
dropout=args.dropout,
num_layers=args.num_layers,
num_heads=args.num_heads,
max_src_length=max(src_max_len, 500),
max_tgt_length=max(tgt_max_len, 500),
scaled=args.scaled)
encoder, decoder, one_step_ahead_decoder = get_transformer_encoder_decoder(
units=args.num_units, hidden_size=args.hidden_size, dropout=args.dropout,
num_layers=args.num_layers, num_heads=args.num_heads, max_src_length=max(src_max_len, 500),
max_tgt_length=max(tgt_max_len, 500), scaled=args.scaled)
model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder,
share_embed=args.dataset != 'TOY', embed_size=args.num_units,
tie_weights=args.dataset != 'TOY', embed_initializer=None, prefix='transformer_')
one_step_ahead_decoder=one_step_ahead_decoder, share_embed=args.dataset != 'TOY',
embed_size=args.num_units, tie_weights=args.dataset != 'TOY',
embed_initializer=None, prefix='transformer_')

param_name = args.model_parameter
if (not os.path.exists(param_name)):
Expand Down
14 changes: 7 additions & 7 deletions scripts/machine_translation/train_gnmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@
else:
ctx = mx.gpu(args.gpu)

encoder, decoder = get_gnmt_encoder_decoder(hidden_size=args.num_hidden,
dropout=args.dropout,
num_layers=args.num_layers,
num_bi_layers=args.num_bi_layers)
encoder, decoder, one_step_ahead_decoder = get_gnmt_encoder_decoder(
hidden_size=args.num_hidden, dropout=args.dropout, num_layers=args.num_layers,
num_bi_layers=args.num_bi_layers)
model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder,
embed_size=args.num_hidden, prefix='gnmt_')
one_step_ahead_decoder=one_step_ahead_decoder, embed_size=args.num_hidden,
prefix='gnmt_')
model.initialize(init=mx.init.Uniform(0.1), ctx=ctx)
static_alloc = True
model.hybridize(static_alloc=static_alloc)
Expand Down Expand Up @@ -175,8 +175,8 @@ def evaluate(data_loader):
avg_loss += loss * (tgt_seq.shape[1] - 1)
avg_loss_denom += (tgt_seq.shape[1] - 1)
# Translate
samples, _, sample_valid_length =\
translator.translate(src_seq=src_seq, src_valid_length=src_valid_length)
samples, _, sample_valid_length = translator.translate(
src_seq=src_seq, src_valid_length=src_valid_length)
max_score_sample = samples[:, 0, :].asnumpy()
sample_valid_length = sample_valid_length[:, 0].asnumpy()
for i in range(max_score_sample.shape[0]):
Expand Down
Loading