diff --git a/docs/examples/machine_translation/gnmt.md b/docs/examples/machine_translation/gnmt.md index f195317b5a..9c8e45a324 100644 --- a/docs/examples/machine_translation/gnmt.md +++ b/docs/examples/machine_translation/gnmt.md @@ -337,12 +337,12 @@ feed the encoder and decoder to the `NMTModel` to construct the GNMT model. `model.hybridize` allows computation to be done using the symbolic backend. To understand what it means to be "hybridized," please refer to [this](https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/hybrid.html) page on MXNet hybridization and its advantages. ```{.python .input} -encoder, decoder = nmt.gnmt.get_gnmt_encoder_decoder(hidden_size=num_hidden, - dropout=dropout, - num_layers=num_layers, - num_bi_layers=num_bi_layers) +encoder, decoder, one_step_ahead_decoder = nmt.gnmt.get_gnmt_encoder_decoder( + hidden_size=num_hidden, dropout=dropout, num_layers=num_layers, + num_bi_layers=num_bi_layers) model = nlp.model.translation.NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, - decoder=decoder, embed_size=num_hidden, prefix='gnmt_') + decoder=decoder, one_step_ahead_decoder=one_step_ahead_decoder, + embed_size=num_hidden, prefix='gnmt_') model.initialize(init=mx.init.Uniform(0.1), ctx=ctx) static_alloc = True model.hybridize(static_alloc=static_alloc) diff --git a/scripts/machine_translation/gnmt.py b/scripts/machine_translation/gnmt.py index 6910ec3ac3..cf3c82aa13 100644 --- a/scripts/machine_translation/gnmt.py +++ b/scripts/machine_translation/gnmt.py @@ -15,14 +15,14 @@ # specific language governing permissions and limitations # under the License. """Encoder and decoder usded in sequence-to-sequence learning.""" -__all__ = ['GNMTEncoder', 'GNMTDecoder', 'get_gnmt_encoder_decoder'] +__all__ = ['GNMTEncoder', 'GNMTDecoder', 'GNMTOneStepDecoder', 'get_gnmt_encoder_decoder'] import mxnet as mx from mxnet.base import _as_list from mxnet.gluon import nn, rnn from mxnet.gluon.block import HybridBlock from gluonnlp.model.seq2seq_encoder_decoder import Seq2SeqEncoder, Seq2SeqDecoder, \ - _get_attention_cell, _get_cell_type, _nested_sequence_last + Seq2SeqOneStepDecoder, _get_attention_cell, _get_cell_type, _nested_sequence_last class GNMTEncoder(Seq2SeqEncoder): @@ -158,48 +158,14 @@ def forward(self, inputs, states=None, valid_length=None): #pylint: disable=arg return [outputs, new_states], [] -class GNMTDecoder(HybridBlock, Seq2SeqDecoder): - """Structure of the RNN Encoder similar to that used in the - Google Neural Machine Translation paper. - - We use gnmt_v2 strategy in tensorflow/nmt - - Parameters - ---------- - cell_type : str or type - attention_cell : AttentionCell or str - Arguments of the attention cell. - Can be 'scaled_luong', 'normed_mlp', 'dot' - num_layers : int - hidden_size : int - dropout : float - use_residual : bool - output_attention: bool - Whether to output the attention weights - i2h_weight_initializer : str or Initializer - Initializer for the input weights matrix, used for the linear - transformation of the inputs. - h2h_weight_initializer : str or Initializer - Initializer for the recurrent weights matrix, used for the linear - transformation of the recurrent state. - i2h_bias_initializer : str or Initializer - Initializer for the bias vector. - h2h_bias_initializer : str or Initializer - Initializer for the bias vector. - prefix : str, default 'rnn_' - Prefix for name of `Block`s - (and name of weight if params is `None`). - params : Parameter or None - Container for weight sharing between cells. - Created if `None`. - """ +class _BaseGNMTDecoder(HybridBlock): def __init__(self, cell_type='lstm', attention_cell='scaled_luong', num_layers=2, hidden_size=128, dropout=0.0, use_residual=True, output_attention=False, i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', prefix=None, params=None): - super(GNMTDecoder, self).__init__(prefix=prefix, params=params) + super().__init__(prefix=prefix, params=params) self._cell_type = _get_cell_type(cell_type) self._num_layers = num_layers self._hidden_size = hidden_size @@ -249,59 +215,7 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): decoder_states.append(mem_masks) return decoder_states - def decode_seq(self, inputs, states, valid_length=None): - """Decode the decoder inputs. This function is only used for training. - - Parameters - ---------- - inputs : NDArray, Shape (batch_size, length, C_in) - states : list of NDArrays or None - Initial states. The list of initial decoder states - valid_length : NDArray or None - Valid lengths of each sequence. This is usually used when part of sequence has - been padded. Shape (batch_size,) - - Returns - ------- - output : NDArray, Shape (batch_size, length, C_out) - states : list - The decoder states, includes: - - - rnn_states : NDArray - - attention_vec : NDArray - - mem_value : NDArray - - mem_masks : NDArray, optional - additional_outputs : list - Either be an empty list or contains the attention weights in this step. - The attention weights will have shape (batch_size, length, mem_length) or - (batch_size, num_heads, length, mem_length) - """ - length = inputs.shape[1] - output = [] - additional_outputs = [] - inputs = _as_list(mx.nd.split(inputs, num_outputs=length, axis=1, squeeze_axis=True)) - rnn_states_l = [] - attention_output_l = [] - fixed_states = states[2:] - for i in range(length): - ele_output, states, ele_additional_outputs = self.forward(inputs[i], states) - rnn_states_l.append(states[0]) - attention_output_l.append(states[1]) - output.append(ele_output) - additional_outputs.extend(ele_additional_outputs) - output = mx.nd.stack(*output, axis=1) - if valid_length is not None: - states = [_nested_sequence_last(rnn_states_l, valid_length), - _nested_sequence_last(attention_output_l, valid_length)] + fixed_states - output = mx.nd.SequenceMask(output, - sequence_length=valid_length, - use_sequence_length=True, - axis=1) - if self._output_attention: - additional_outputs = [mx.nd.concat(*additional_outputs, dim=-2)] - return output, states, additional_outputs - - def __call__(self, step_input, states): #pylint: disable=arguments-differ + def forward(self, step_input, states): # pylint: disable=arguments-differ """One-step-ahead decoding of the GNMT decoder. Parameters @@ -326,11 +240,7 @@ def __call__(self, step_input, states): #pylint: disable=arguments-differ The attention weights will have shape (batch_size, 1, mem_length) or (batch_size, num_heads, 1, mem_length) """ - return super(GNMTDecoder, self).__call__(step_input, states) - - def forward(self, step_input, states): #pylint: disable=arguments-differ, missing-docstring - step_output, new_states, step_additional_outputs =\ - super(GNMTDecoder, self).forward(step_input, states) + step_output, new_states, step_additional_outputs = super().forward(step_input, states) # In hybrid_forward, only the rnn_states and attention_vec are calculated. # We directly append the mem_value and mem_masks in the forward() function. # We apply this trick because the memory value/mask can be directly appended to the next @@ -402,6 +312,148 @@ def hybrid_forward(self, F, step_input, states): #pylint: disable=arguments-dif return rnn_out, new_states, step_additional_outputs +class GNMTOneStepDecoder(_BaseGNMTDecoder, Seq2SeqOneStepDecoder): + """RNN Encoder similar to that used in the Google Neural Machine Translation paper. + + One-step ahead decoder used during inference. + + We use gnmt_v2 strategy in tensorflow/nmt + + Parameters + ---------- + cell_type : str or type + Can be "lstm", "gru" or constructor functions that can be directly called, + like rnn.LSTMCell + attention_cell : AttentionCell or str + Arguments of the attention cell. + Can be 'scaled_luong', 'normed_mlp', 'dot' + num_layers : int + Total number of layers + hidden_size : int + Number of hidden units + dropout : float + The dropout rate + use_residual : bool + Whether to use residual connection. Residual connection will be added in the + uni-directional RNN layers + output_attention: bool + Whether to output the attention weights + i2h_weight_initializer : str or Initializer + Initializer for the input weights matrix, used for the linear + transformation of the inputs. + h2h_weight_initializer : str or Initializer + Initializer for the recurrent weights matrix, used for the linear + transformation of the recurrent state. + i2h_bias_initializer : str or Initializer + Initializer for the bias vector. + h2h_bias_initializer : str or Initializer + Initializer for the bias vector. + prefix : str, default 'rnn_' + Prefix for name of `Block`s + (and name of weight if params is `None`). + params : Parameter or None + Container for weight sharing between cells. + Created if `None`. + """ + + +class GNMTDecoder(_BaseGNMTDecoder, Seq2SeqDecoder): + """RNN Encoder similar to that used in the Google Neural Machine Translation paper. + + Multi-step decoder used during training with teacher forcing. + + We use gnmt_v2 strategy in tensorflow/nmt + + Parameters + ---------- + cell_type : str or type + Can be "lstm", "gru" or constructor functions that can be directly called, + like rnn.LSTMCell + attention_cell : AttentionCell or str + Arguments of the attention cell. + Can be 'scaled_luong', 'normed_mlp', 'dot' + num_layers : int + Total number of layers + hidden_size : int + Number of hidden units + dropout : float + The dropout rate + use_residual : bool + Whether to use residual connection. Residual connection will be added in the + uni-directional RNN layers + output_attention: bool + Whether to output the attention weights + i2h_weight_initializer : str or Initializer + Initializer for the input weights matrix, used for the linear + transformation of the inputs. + h2h_weight_initializer : str or Initializer + Initializer for the recurrent weights matrix, used for the linear + transformation of the recurrent state. + i2h_bias_initializer : str or Initializer + Initializer for the bias vector. + h2h_bias_initializer : str or Initializer + Initializer for the bias vector. + prefix : str, default 'rnn_' + Prefix for name of `Block`s + (and name of weight if params is `None`). + params : Parameter or None + Container for weight sharing between cells. + Created if `None`. + """ + + def forward(self, inputs, states, valid_length=None): # pylint: disable=arguments-differ + """Decode the decoder inputs. This function is only used for training. + + Parameters + ---------- + inputs : NDArray, Shape (batch_size, length, C_in) + states : list of NDArrays or None + Initial states. The list of initial decoder states + valid_length : NDArray or None + Valid lengths of each sequence. This is usually used when part of sequence has + been padded. Shape (batch_size,) + + Returns + ------- + output : NDArray, Shape (batch_size, length, C_out) + states : list + The decoder states, includes: + + - rnn_states : NDArray + - attention_vec : NDArray + - mem_value : NDArray + - mem_masks : NDArray, optional + additional_outputs : list + Either be an empty list or contains the attention weights in this step. + The attention weights will have shape (batch_size, length, mem_length) or + (batch_size, num_heads, length, mem_length) + """ + length = inputs.shape[1] + output = [] + additional_outputs = [] + inputs = _as_list(mx.nd.split(inputs, num_outputs=length, axis=1, squeeze_axis=True)) + rnn_states_l = [] + attention_output_l = [] + fixed_states = states[2:] + for i in range(length): + ele_output, states, ele_additional_outputs = super().forward(inputs[i], states) + rnn_states_l.append(states[0]) + attention_output_l.append(states[1]) + output.append(ele_output) + additional_outputs.extend(ele_additional_outputs) + output = mx.nd.stack(*output, axis=1) + if valid_length is not None: + states = [_nested_sequence_last(rnn_states_l, valid_length), + _nested_sequence_last(attention_output_l, valid_length)] + fixed_states + output = mx.nd.SequenceMask(output, + sequence_length=valid_length, + use_sequence_length=True, + axis=1) + if self._output_attention: + additional_outputs = [mx.nd.concat(*additional_outputs, dim=-2)] + return output, states, additional_outputs + + def get_gnmt_encoder_decoder(cell_type='lstm', attention_cell='scaled_luong', num_layers=2, num_bi_layers=1, hidden_size=128, dropout=0.0, use_residual=False, i2h_weight_initializer=None, h2h_weight_initializer=None, @@ -435,19 +487,24 @@ def get_gnmt_encoder_decoder(cell_type='lstm', attention_cell='scaled_luong', nu decoder : GNMTDecoder """ encoder = GNMTEncoder(cell_type=cell_type, num_layers=num_layers, num_bi_layers=num_bi_layers, - hidden_size=hidden_size, dropout=dropout, - use_residual=use_residual, + hidden_size=hidden_size, dropout=dropout, use_residual=use_residual, i2h_weight_initializer=i2h_weight_initializer, h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, - h2h_bias_initializer=h2h_bias_initializer, - prefix=prefix + 'enc_', params=params) + h2h_bias_initializer=h2h_bias_initializer, prefix=prefix + 'enc_', + params=params) decoder = GNMTDecoder(cell_type=cell_type, attention_cell=attention_cell, num_layers=num_layers, - hidden_size=hidden_size, dropout=dropout, - use_residual=use_residual, + hidden_size=hidden_size, dropout=dropout, use_residual=use_residual, i2h_weight_initializer=i2h_weight_initializer, h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, - h2h_bias_initializer=h2h_bias_initializer, - prefix=prefix + 'dec_', params=params) - return encoder, decoder + h2h_bias_initializer=h2h_bias_initializer, prefix=prefix + 'dec_', + params=params) + one_step_ahead_decoder = GNMTOneStepDecoder( + cell_type=cell_type, attention_cell=attention_cell, num_layers=num_layers, + hidden_size=hidden_size, dropout=dropout, use_residual=use_residual, + i2h_weight_initializer=i2h_weight_initializer, + h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, + h2h_bias_initializer=h2h_bias_initializer, prefix=prefix + 'dec_', + params=decoder.collect_params()) + return encoder, decoder, one_step_ahead_decoder diff --git a/scripts/machine_translation/inference_transformer.py b/scripts/machine_translation/inference_transformer.py index 3128d5f64d..8c048c300b 100644 --- a/scripts/machine_translation/inference_transformer.py +++ b/scripts/machine_translation/inference_transformer.py @@ -154,17 +154,14 @@ else: tgt_max_len = max_len[1] -encoder, decoder = get_transformer_encoder_decoder(units=args.num_units, - hidden_size=args.hidden_size, - dropout=args.dropout, - num_layers=args.num_layers, - num_heads=args.num_heads, - max_src_length=max(src_max_len, 500), - max_tgt_length=max(tgt_max_len, 500), - scaled=args.scaled) +encoder, decoder, one_step_ahead_decoder = get_transformer_encoder_decoder( + units=args.num_units, hidden_size=args.hidden_size, dropout=args.dropout, + num_layers=args.num_layers, num_heads=args.num_heads, max_src_length=max(src_max_len, 500), + max_tgt_length=max(tgt_max_len, 500), scaled=args.scaled) model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder, - share_embed=args.dataset != 'TOY', embed_size=args.num_units, - tie_weights=args.dataset != 'TOY', embed_initializer=None, prefix='transformer_') + one_step_ahead_decoder=one_step_ahead_decoder, share_embed=args.dataset != 'TOY', + embed_size=args.num_units, tie_weights=args.dataset != 'TOY', + embed_initializer=None, prefix='transformer_') param_name = args.model_parameter if (not os.path.exists(param_name)): diff --git a/scripts/machine_translation/train_gnmt.py b/scripts/machine_translation/train_gnmt.py index 11dbe69c99..c9b9b8aa28 100644 --- a/scripts/machine_translation/train_gnmt.py +++ b/scripts/machine_translation/train_gnmt.py @@ -122,12 +122,12 @@ else: ctx = mx.gpu(args.gpu) -encoder, decoder = get_gnmt_encoder_decoder(hidden_size=args.num_hidden, - dropout=args.dropout, - num_layers=args.num_layers, - num_bi_layers=args.num_bi_layers) +encoder, decoder, one_step_ahead_decoder = get_gnmt_encoder_decoder( + hidden_size=args.num_hidden, dropout=args.dropout, num_layers=args.num_layers, + num_bi_layers=args.num_bi_layers) model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder, - embed_size=args.num_hidden, prefix='gnmt_') + one_step_ahead_decoder=one_step_ahead_decoder, embed_size=args.num_hidden, + prefix='gnmt_') model.initialize(init=mx.init.Uniform(0.1), ctx=ctx) static_alloc = True model.hybridize(static_alloc=static_alloc) @@ -175,8 +175,8 @@ def evaluate(data_loader): avg_loss += loss * (tgt_seq.shape[1] - 1) avg_loss_denom += (tgt_seq.shape[1] - 1) # Translate - samples, _, sample_valid_length =\ - translator.translate(src_seq=src_seq, src_valid_length=src_valid_length) + samples, _, sample_valid_length = translator.translate( + src_seq=src_seq, src_valid_length=src_valid_length) max_score_sample = samples[:, 0, :].asnumpy() sample_valid_length = sample_valid_length[:, 0].asnumpy() for i in range(max_score_sample.shape[0]): diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index 09646bc650..a366c7e36e 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -33,25 +33,25 @@ # pylint:disable=redefined-outer-name,logging-format-interpolation import argparse -import time -import random -import os import logging import math +import os +import random +import time + import numpy as np import mxnet as mx from mxnet import gluon -import gluonnlp as nlp -from gluonnlp.loss import MaskedSoftmaxCELoss, LabelSmoothing +import gluonnlp as nlp +from gluonnlp.loss import LabelSmoothing, MaskedSoftmaxCELoss +from gluonnlp.model.transformer import ParallelTransformer, get_transformer_encoder_decoder from gluonnlp.model.translation import NMTModel -from gluonnlp.model.transformer import get_transformer_encoder_decoder, ParallelTransformer from gluonnlp.utils.parallel import Parallel +import dataprocessor +from bleu import _bpe_to_words, compute_bleu from translation import BeamSearchTranslator - from utils import logging_config -from bleu import _bpe_to_words, compute_bleu -import dataprocessor np.random.seed(100) random.seed(100) @@ -174,15 +174,12 @@ tgt_max_len = args.tgt_max_len else: tgt_max_len = max_len[1] -encoder, decoder = get_transformer_encoder_decoder(units=args.num_units, - hidden_size=args.hidden_size, - dropout=args.dropout, - num_layers=args.num_layers, - num_heads=args.num_heads, - max_src_length=max(src_max_len, 500), - max_tgt_length=max(tgt_max_len, 500), - scaled=args.scaled) +encoder, decoder, one_step_ahead_decoder = get_transformer_encoder_decoder( + units=args.num_units, hidden_size=args.hidden_size, dropout=args.dropout, + num_layers=args.num_layers, num_heads=args.num_heads, max_src_length=max(src_max_len, 500), + max_tgt_length=max(tgt_max_len, 500), scaled=args.scaled) model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder, + one_step_ahead_decoder=one_step_ahead_decoder, share_embed=args.dataset not in ('TOY', 'IWSLT2015'), embed_size=args.num_units, tie_weights=args.dataset not in ('TOY', 'IWSLT2015'), embed_initializer=None, prefix='transformer_') diff --git a/scripts/tests/test_encoder_decoder.py b/scripts/tests/test_encoder_decoder.py index 9e358d4025..695785a619 100644 --- a/scripts/tests/test_encoder_decoder.py +++ b/scripts/tests/test_encoder_decoder.py @@ -15,12 +15,14 @@ # specific language governing permissions and limitations # under the License. +import pytest + import numpy as np import mxnet as mx from mxnet.test_utils import assert_almost_equal from ..machine_translation.gnmt import * from gluonnlp.model.transformer import * -from gluonnlp.model.transformer import TransformerDecoder +from gluonnlp.model.transformer import TransformerDecoder, TransformerOneStepDecoder def test_gnmt_encoder(): @@ -63,6 +65,11 @@ def test_gnmt_encoder_decoder(): output_attention=output_attention, use_residual=use_residual, prefix='gnmt_decoder_') decoder.initialize(ctx=ctx) decoder.hybridize() + one_step_decoder = GNMTOneStepDecoder(cell_type="lstm", num_layers=3, hidden_size=num_hidden, + dropout=0.0, output_attention=output_attention, + use_residual=use_residual, prefix='gnmt_decoder_', + params=decoder.collect_params()) + one_step_decoder.hybridize() for batch_size in [4]: for src_seq_length, tgt_seq_length in [(5, 10), (10, 5)]: src_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, src_seq_length, 4), ctx=ctx) @@ -75,9 +82,7 @@ def test_gnmt_encoder_decoder(): decoder_states = decoder.init_state_from_encoder(encoder_outputs, src_valid_length_nd) # Test multi step forwarding - output, new_states, additional_outputs = decoder.decode_seq(tgt_seq_nd, - decoder_states, - tgt_valid_length_nd) + output, new_states, additional_outputs = decoder(tgt_seq_nd, decoder_states, tgt_valid_length_nd) assert(output.shape == (batch_size, tgt_seq_length, num_hidden)) output_npy = output.asnumpy() for i in range(batch_size): @@ -98,6 +103,25 @@ def test_gnmt_encoder_decoder(): else: assert(len(additional_outputs) == 0) + # Test one-step forwarding + output, new_states, additional_outputs = one_step_decoder( + tgt_seq_nd[:, 0, :], decoder_states) + assert(output.shape == (batch_size, num_hidden)) + if output_attention: + assert(len(additional_outputs) == 1) + attention_out = additional_outputs[0].asnumpy() + assert(attention_out.shape == (batch_size, 1, src_seq_length)) + for i in range(batch_size): + mem_v_len = int(src_valid_length_npy[i]) + if mem_v_len < src_seq_length - 1: + assert((attention_out[i, :, mem_v_len:] == 0).all()) + if mem_v_len > 0: + assert_almost_equal(attention_out[i, :, :].sum(axis=-1), + np.ones(attention_out.shape[1])) + else: + assert(len(additional_outputs) == 0) + + def test_transformer_encoder(): ctx = mx.current_context() for num_layers in range(1, 3): @@ -136,51 +160,77 @@ def test_transformer_encoder(): else: assert(len(additional_outputs) == 0) -def test_transformer_encoder_decoder(): +@pytest.mark.parametrize('output_attention', [False, True]) +@pytest.mark.parametrize('use_residual', [False, True]) +@pytest.mark.parametrize('batch_size', [4]) +@pytest.mark.parametrize('src_tgt_seq_len', [(5, 10), (10, 5)]) +def test_transformer_encoder_decoder(output_attention, use_residual, batch_size, src_tgt_seq_len): ctx = mx.current_context() units = 16 encoder = TransformerEncoder(num_layers=3, units=units, hidden_size=32, num_heads=8, max_length=10, dropout=0.0, use_residual=True, prefix='transformer_encoder_') encoder.initialize(ctx=ctx) encoder.hybridize() - for output_attention in [True, False]: - for use_residual in [True, False]: - decoder = TransformerDecoder(num_layers=3, units=units, hidden_size=32, num_heads=8, max_length=10, dropout=0.0, - output_attention=output_attention, use_residual=use_residual, prefix='transformer_decoder_') - decoder.initialize(ctx=ctx) - decoder.hybridize() - for batch_size in [4]: - for src_seq_length, tgt_seq_length in [(5, 10), (10, 5)]: - src_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, src_seq_length, units), ctx=ctx) - tgt_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, tgt_seq_length, units), ctx=ctx) - src_valid_length_nd = mx.nd.array(np.random.randint(1, src_seq_length, size=(batch_size,)), ctx=ctx) - tgt_valid_length_nd = mx.nd.array(np.random.randint(1, tgt_seq_length, size=(batch_size,)), ctx=ctx) - src_valid_length_npy = src_valid_length_nd.asnumpy() - tgt_valid_length_npy = tgt_valid_length_nd.asnumpy() - encoder_outputs, _ = encoder(src_seq_nd, valid_length=src_valid_length_nd) - decoder_states = decoder.init_state_from_encoder(encoder_outputs, src_valid_length_nd) + decoder = TransformerDecoder(num_layers=3, units=units, hidden_size=32, + num_heads=8, max_length=10, dropout=0.0, + output_attention=output_attention, + use_residual=use_residual, + prefix='transformer_decoder_') + decoder.initialize(ctx=ctx) + decoder.hybridize() - # Test multi step forwarding - output, new_states, additional_outputs = decoder.decode_seq(tgt_seq_nd, - decoder_states, - tgt_valid_length_nd) - assert(output.shape == (batch_size, tgt_seq_length, units)) - output_npy = output.asnumpy() - for i in range(batch_size): - tgt_v_len = int(tgt_valid_length_npy[i]) - if tgt_v_len < tgt_seq_length - 1: - assert((output_npy[i, tgt_v_len:, :] == 0).all()) - if output_attention: - assert(len(additional_outputs) == 3) - attention_out = additional_outputs[0][1].asnumpy() - assert(attention_out.shape == (batch_size, 8, tgt_seq_length, src_seq_length)) - for i in range(batch_size): - mem_v_len = int(src_valid_length_npy[i]) - if mem_v_len < src_seq_length - 1: - assert((attention_out[i, :, :, mem_v_len:] == 0).all()) - if mem_v_len > 0: - assert_almost_equal(attention_out[i, :, :, :].sum(axis=-1), - np.ones(attention_out.shape[1:3])) - else: - assert(len(additional_outputs) == 0) + src_seq_length, tgt_seq_length = src_tgt_seq_len + src_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, src_seq_length, units), ctx=ctx) + tgt_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, tgt_seq_length, units), ctx=ctx) + src_valid_length_nd = mx.nd.array(np.random.randint(1, src_seq_length, size=(batch_size,)), ctx=ctx) + tgt_valid_length_nd = mx.nd.array(np.random.randint(1, tgt_seq_length, size=(batch_size,)), ctx=ctx) + src_valid_length_npy = src_valid_length_nd.asnumpy() + tgt_valid_length_npy = tgt_valid_length_nd.asnumpy() + encoder_outputs, _ = encoder(src_seq_nd, valid_length=src_valid_length_nd) + decoder_states = decoder.init_state_from_encoder(encoder_outputs, src_valid_length_nd) + + # Test multi step forwarding + output, new_states, additional_outputs = decoder(tgt_seq_nd, decoder_states, tgt_valid_length_nd) + assert(output.shape == (batch_size, tgt_seq_length, units)) + output_npy = output.asnumpy() + for i in range(batch_size): + tgt_v_len = int(tgt_valid_length_npy[i]) + if tgt_v_len < tgt_seq_length - 1: + assert((output_npy[i, tgt_v_len:, :] == 0).all()) + if output_attention: + assert(len(additional_outputs) == 3) + attention_out = additional_outputs[0][1].asnumpy() + assert(attention_out.shape == (batch_size, 8, tgt_seq_length, src_seq_length)) + for i in range(batch_size): + mem_v_len = int(src_valid_length_npy[i]) + if mem_v_len < src_seq_length - 1: + assert((attention_out[i, :, :, mem_v_len:] == 0).all()) + if mem_v_len > 0: + assert_almost_equal(attention_out[i, :, :, :].sum(axis=-1), + np.ones(attention_out.shape[1:3])) + else: + assert(len(additional_outputs) == 0) + # Test one step forwarding + decoder = TransformerOneStepDecoder(num_layers=3, units=units, hidden_size=32, + num_heads=8, max_length=10, dropout=0.0, + output_attention=output_attention, + use_residual=use_residual, + prefix='transformer_decoder_', + params=decoder.collect_params()) + decoder.hybridize() + output, new_states, additional_outputs = decoder(tgt_seq_nd[:, 0, :], decoder_states) + assert(output.shape == (batch_size, units)) + if output_attention: + assert(len(additional_outputs) == 3) + attention_out = additional_outputs[0][1].asnumpy() + assert(attention_out.shape == (batch_size, 8, 1, src_seq_length)) + for i in range(batch_size): + mem_v_len = int(src_valid_length_npy[i]) + if mem_v_len < src_seq_length - 1: + assert((attention_out[i, :, :, mem_v_len:] == 0).all()) + if mem_v_len > 0: + assert_almost_equal(attention_out[i, :, :, :].sum(axis=-1), + np.ones(attention_out.shape[1:3])) + else: + assert(len(additional_outputs) == 0) diff --git a/src/gluonnlp/model/seq2seq_encoder_decoder.py b/src/gluonnlp/model/seq2seq_encoder_decoder.py index 586893a72e..fc2c7a1da9 100644 --- a/src/gluonnlp/model/seq2seq_encoder_decoder.py +++ b/src/gluonnlp/model/seq2seq_encoder_decoder.py @@ -18,11 +18,13 @@ __all__ = ['Seq2SeqEncoder'] from functools import partial + import mxnet as mx from mxnet.gluon import rnn from mxnet.gluon.block import Block -from gluonnlp.model import AttentionCell, MLPAttentionCell, DotProductAttentionCell, \ - MultiHeadAttentionCell + +from .attention_cell import (AttentionCell, DotProductAttentionCell, + MLPAttentionCell, MultiHeadAttentionCell) def _get_cell_type(cell_type): @@ -128,6 +130,7 @@ def _nested_sequence_last(data, valid_length): class Seq2SeqEncoder(Block): r"""Base class of the encoders in sequence to sequence learning models. """ + def __call__(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ """Encode the input sequence. @@ -153,11 +156,14 @@ def forward(self, inputs, valid_length=None, states=None): #pylint: disable=arg class Seq2SeqDecoder(Block): - r"""Base class of the decoders in sequence to sequence learning models. + """Base class of the decoders for sequence to sequence learning models. - In the forward function, it generates the one-step-ahead decoding output. + Given the inputs and the context computed by the encoder, generate the new + states. Used in the training phase where we set the inputs to be the target + sequence. """ + def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): r"""Generates the initial decoder states based on the encoder outputs. @@ -172,10 +178,11 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): """ raise NotImplementedError - def decode_seq(self, inputs, states, valid_length=None): - r"""Given the inputs and the context computed by the encoder, - generate the new states. This is usually used in the training phase where we set the inputs - to be the target sequence. + def forward(self, step_input, states, valid_length=None): #pylint: disable=arguments-differ + """Given the inputs and the context computed by the encoder, generate the new states. + + Used in the training phase where we set the inputs to be the target + sequence. Parameters ---------- @@ -193,11 +200,34 @@ def decode_seq(self, inputs, states, valid_length=None): The new states of the decoder additional_outputs : list Additional outputs of the decoder, e.g, the attention weights + + """ + raise NotImplementedError + + +class Seq2SeqOneStepDecoder(Block): + r"""Base class of the decoders in sequence to sequence learning models. + + In the forward function, it generates the one-step-ahead decoding output. + + """ + + def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): + r"""Generates the initial decoder states based on the encoder outputs. + + Parameters + ---------- + encoder_outputs : list of NDArrays + encoder_valid_length : NDArray or None + + Returns + ------- + decoder_states : list """ raise NotImplementedError - def __call__(self, step_input, states): #pylint: disable=arguments-differ - r"""One-step decoding of the input + def forward(self, step_input, states): #pylint: disable=arguments-differ + """One-step decoding of the input Parameters ---------- @@ -213,7 +243,4 @@ def __call__(self, step_input, states): #pylint: disable=arguments-differ step_additional_outputs : list Additional outputs of the step, e.g, the attention weights """ - return super(Seq2SeqDecoder, self).__call__(step_input, states) - - def forward(self, step_input, states): #pylint: disable=arguments-differ raise NotImplementedError diff --git a/src/gluonnlp/model/train/__init__.py b/src/gluonnlp/model/train/__init__.py index 5f25621730..8b66c7e208 100644 --- a/src/gluonnlp/model/train/__init__.py +++ b/src/gluonnlp/model/train/__init__.py @@ -26,8 +26,7 @@ from .embedding import * from .language_model import * -__all__ = language_model.__all__ + cache.__all__ + embedding.__all__ + \ - ['get_cache_model'] +__all__ = language_model.__all__ + cache.__all__ + embedding.__all__ + ['get_cache_model'] def get_cache_model(name, dataset_name='wikitext-2', window=2000, diff --git a/src/gluonnlp/model/transformer.py b/src/gluonnlp/model/transformer.py index 73df2db59d..9209444518 100644 --- a/src/gluonnlp/model/transformer.py +++ b/src/gluonnlp/model/transformer.py @@ -20,22 +20,24 @@ __all__ = ['TransformerEncoder', 'PositionwiseFFN', 'TransformerEncoderCell', 'transformer_en_de_512'] +import math import os -import math import numpy as np import mxnet as mx from mxnet import cpu, gluon from mxnet.gluon import nn from mxnet.gluon.block import HybridBlock from mxnet.gluon.model_zoo import model_store -from gluonnlp.utils.parallel import Parallelizable -from .seq2seq_encoder_decoder import Seq2SeqEncoder, Seq2SeqDecoder, _get_attention_cell + +from ..base import get_home_dir +from ..utils.parallel import Parallelizable from .block import GELU +from .seq2seq_encoder_decoder import (Seq2SeqDecoder, Seq2SeqEncoder, + Seq2SeqOneStepDecoder, + _get_attention_cell) from .translation import NMTModel -from .utils import _load_vocab, _load_pretrained_params -from ..base import get_home_dir - +from .utils import _load_pretrained_params, _load_vocab ############################################################################### # BASE ENCODER BLOCKS # @@ -758,7 +760,9 @@ class TransformerDecoderCell(HybridBlock): Whether to scale the softmax input by the sqrt of the input dimension in multi-head attention dropout : float + Dropout probability. use_residual : bool + Whether to use residual connection. output_attention: bool Whether to output the attention weights weight_initializer : str or Initializer @@ -866,50 +870,12 @@ def hybrid_forward(self, F, inputs, mem_value, mask=None, mem_mask=None): #pyli return outputs, additional_outputs -class TransformerDecoder(HybridBlock, Seq2SeqDecoder): - """Structure of the Transformer Decoder. - - Parameters - ---------- - attention_cell : AttentionCell or str, default 'multi_head' - Arguments of the attention cell. - Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp' - num_layers : int - units : int - hidden_size : int - number of units in the hidden layer of position-wise feed-forward networks - max_length : int - Maximum length of the input sequence. This is used for constructing position encoding - num_heads : int - Number of heads in multi-head attention - scaled : bool - Whether to scale the softmax input by the sqrt of the input dimension - in multi-head attention - dropout : float - use_residual : bool - output_attention: bool - Whether to output the attention weights - weight_initializer : str or Initializer - Initializer for the input weights matrix, used for the linear - transformation of the inputs. - bias_initializer : str or Initializer - Initializer for the bias vector. - scale_embed : bool, default True - Scale the input embeddings by sqrt(embed_size). - prefix : str, default 'rnn_' - Prefix for name of `Block`s - (and name of weight if params is `None`). - params : Parameter or None - Container for weight sharing between cells. - Created if `None`. - """ - def __init__(self, attention_cell='multi_head', num_layers=2, - units=128, hidden_size=2048, max_length=50, - num_heads=4, scaled=True, dropout=0.0, - use_residual=True, output_attention=False, - weight_initializer=None, bias_initializer='zeros', +class _BaseTransformerDecoder(HybridBlock): + def __init__(self, attention_cell='multi_head', num_layers=2, units=128, hidden_size=2048, + max_length=50, num_heads=4, scaled=True, dropout=0.0, use_residual=True, + output_attention=False, weight_initializer=None, bias_initializer='zeros', scale_embed=True, prefix=None, params=None): - super(TransformerDecoder, self).__init__(prefix=prefix, params=params) + super().__init__(prefix=prefix, params=params) assert units % num_heads == 0, 'In TransformerDecoder, the units should be divided ' \ 'exactly by the number of heads. Received units={}, ' \ 'num_heads={}'.format(units, num_heads) @@ -928,22 +894,17 @@ def __init__(self, attention_cell='multi_head', num_layers=2, self.dropout_layer = nn.Dropout(rate=dropout) self.layer_norm = nn.LayerNorm() encoding = _position_encoding_init(max_length, units) - self.position_weight = self.params.get_constant('const', encoding) + self.position_weight = self.params.get_constant('const', encoding.astype(np.float32)) self.transformer_cells = nn.HybridSequential() for i in range(num_layers): self.transformer_cells.add( - TransformerDecoderCell( - units=units, - hidden_size=hidden_size, - num_heads=num_heads, - attention_cell=attention_cell, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - dropout=dropout, - scaled=scaled, - use_residual=use_residual, - output_attention=output_attention, - prefix='transformer%d_' % i)) + TransformerDecoderCell(units=units, hidden_size=hidden_size, + num_heads=num_heads, attention_cell=attention_cell, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, dropout=dropout, + scaled=scaled, use_residual=use_residual, + output_attention=output_attention, + prefix='transformer%d_' % i)) def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): """Initialize the state from the encoder outputs. @@ -959,7 +920,7 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): The decoder states, includes: - mem_value : NDArray - - mem_masks : NDArray, optional + - mem_masks : NDArray or None """ mem_value = encoder_outputs decoder_states = [mem_value] @@ -971,10 +932,12 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): mx.nd.arange(mem_length, ctx=ctx, dtype=dtype).reshape((1, -1)), encoder_valid_length.reshape((-1, 1))) decoder_states.append(mem_masks) - self._encoder_valid_length = encoder_valid_length + else: + decoder_states.append(None) return decoder_states - def decode_seq(self, inputs, states, valid_length=None): + def hybrid_forward(self, F, inputs, states, valid_length=None, position_weight=None): + #pylint: disable=arguments-differ """Decode the decoder inputs. This function is only used for training. Parameters @@ -990,58 +953,180 @@ def decode_seq(self, inputs, states, valid_length=None): ------- output : NDArray, Shape (batch_size, length, C_out) states : list - The decoder states, includes: - + The decoder states: - mem_value : NDArray - - mem_masks : NDArray, optional + - mem_masks : NDArray or None additional_outputs : list of list Either be an empty list or contains the attention weights in this step. The attention weights will have shape (batch_size, length, mem_length) or (batch_size, num_heads, length, mem_length) """ - batch_size = inputs.shape[0] - length = inputs.shape[1] - length_array = mx.nd.arange(length, ctx=inputs.context, dtype=inputs.dtype) - mask = mx.nd.broadcast_lesser_equal( - length_array.reshape((1, -1)), - length_array.reshape((-1, 1))) + + length_array = F.contrib.arange_like(inputs, axis=1) + mask = F.broadcast_lesser_equal(length_array.reshape((1, -1)), + length_array.reshape((-1, 1))) if valid_length is not None: - arange = mx.nd.arange(length, ctx=valid_length.context, dtype=valid_length.dtype) - batch_mask = mx.nd.broadcast_lesser( - arange.reshape((1, -1)), - valid_length.reshape((-1, 1))) - mask = mx.nd.broadcast_mul(mx.nd.expand_dims(batch_mask, -1), - mx.nd.expand_dims(mask, 0)) + batch_mask = F.broadcast_lesser(length_array.reshape((1, -1)), + valid_length.reshape((-1, 1))) + batch_mask = F.expand_dims(batch_mask, -1) + mask = F.broadcast_mul(batch_mask, F.expand_dims(mask, 0)) else: - mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0, size=batch_size) - states = [None] + states - output, states, additional_outputs = self.forward(inputs, states, mask) - states = states[1:] + mask = F.expand_dims(mask, axis=0) + mask = F.broadcast_like(mask, inputs, lhs_axes=(0, ), rhs_axes=(0, )) + + mem_value, mem_mask = states + if mem_mask is not None: + mem_mask = F.expand_dims(mem_mask, axis=1) + mem_mask = F.broadcast_like(mem_mask, inputs, lhs_axes=(1, ), rhs_axes=(1, )) + + if self._scale_embed: + # XXX: input.shape[-1] and self._units are expected to be the same + inputs = inputs * math.sqrt(self._units) + + # Positional Encoding + steps = F.contrib.arange_like(inputs, axis=1) + positional_embed = F.Embedding(steps, position_weight, self._max_length, self._units) + inputs = F.broadcast_add(inputs, F.expand_dims(positional_embed, axis=0)) + + if self._dropout: + inputs = self.dropout_layer(inputs) + inputs = self.layer_norm(inputs) + additional_outputs = [] + attention_weights_l = [] + outputs = inputs + for cell in self.transformer_cells: + outputs, attention_weights = cell(outputs, mem_value, mask, mem_mask) + if self._output_attention: + attention_weights_l.append(attention_weights) + if self._output_attention: + additional_outputs.extend(attention_weights_l) + if valid_length is not None: - output = mx.nd.SequenceMask(output, - sequence_length=valid_length, - use_sequence_length=True, - axis=1) - return output, states, additional_outputs + outputs = F.SequenceMask(outputs, sequence_length=valid_length, + use_sequence_length=True, axis=1) + return outputs, states, additional_outputs + + +class TransformerDecoder(_BaseTransformerDecoder, Seq2SeqDecoder): + """Transformer Decoder. + + Multi-step ahead decoder for use during training with teacher forcing. + + Parameters + ---------- + attention_cell : AttentionCell or str, default 'multi_head' + Arguments of the attention cell. + Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp' + num_layers : int + Number of attention layers. + units : int + Number of units for the output. + hidden_size : int + number of units in the hidden layer of position-wise feed-forward networks + max_length : int + Maximum length of the input sequence. This is used for constructing position encoding + num_heads : int + Number of heads in multi-head attention + scaled : bool + Whether to scale the softmax input by the sqrt of the input dimension + in multi-head attention + dropout : float + Dropout probability. + use_residual : bool + Whether to use residual connection. + output_attention: bool + Whether to output the attention weights + weight_initializer : str or Initializer + Initializer for the input weights matrix, used for the linear + transformation of the inputs. + bias_initializer : str or Initializer + Initializer for the bias vector. + scale_embed : bool, default True + Scale the input embeddings by sqrt(embed_size). + prefix : str, default 'rnn_' + Prefix for name of `Block`s + (and name of weight if params is `None`). + params : Parameter or None + Container for weight sharing between cells. + Created if `None`. + """ + + +class TransformerOneStepDecoder(_BaseTransformerDecoder, Seq2SeqOneStepDecoder): + """Transformer Decoder. - def __call__(self, step_input, states): #pylint: disable=arguments-differ + One-step ahead decoder for use during inference. + + Parameters + ---------- + attention_cell : AttentionCell or str, default 'multi_head' + Arguments of the attention cell. + Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp' + num_layers : int + Number of attention layers. + units : int + Number of units for the output. + hidden_size : int + number of units in the hidden layer of position-wise feed-forward networks + max_length : int + Maximum length of the input sequence. This is used for constructing position encoding + num_heads : int + Number of heads in multi-head attention + scaled : bool + Whether to scale the softmax input by the sqrt of the input dimension + in multi-head attention + dropout : float + Dropout probability. + use_residual : bool + Whether to use residual connection. + output_attention: bool + Whether to output the attention weights + weight_initializer : str or Initializer + Initializer for the input weights matrix, used for the linear + transformation of the inputs. + bias_initializer : str or Initializer + Initializer for the bias vector. + scale_embed : bool, default True + Scale the input embeddings by sqrt(embed_size). + prefix : str, default 'rnn_' + Prefix for name of `Block`s + (and name of weight if params is `None`). + params : Parameter or None + Container for weight sharing between cells. + Created if `None`. + """ + + def forward(self, step_input, states): # pylint: disable=arguments-differ + # We implement forward, as the number of states changes between the + # first and later calls of the one-step ahead Transformer decoder. This + # is due to the lack of numpy shape semantics. Once we enable numpy + # shape semantic in the GluonNLP code-base, the number of states should + # stay constant, but the first state element will be an array of shape + # (batch_size, 0, C_in) at the first call. + if len(states) == 3: # step_input from prior call is included + last_embeds, _, _ = states + inputs = mx.nd.concat(last_embeds, mx.nd.expand_dims(step_input, axis=1), dim=1) + states = states[1:] + else: + inputs = mx.nd.expand_dims(step_input, axis=1) + return super().forward(inputs, states) + + def hybrid_forward(self, F, inputs, states, position_weight): + # pylint: disable=arguments-differ """One-step-ahead decoding of the Transformer decoder. Parameters ---------- - step_input : NDArray + step_input : NDArray, Shape (batch_size, C_in) states : list of NDArray Returns ------- step_output : NDArray - The output of the decoder. - In the train mode, Shape is (batch_size, length, C_out) - In the test mode, Shape is (batch_size, C_out) + The output of the decoder. Shape is (batch_size, C_out) new_states: list Includes - last_embeds : NDArray or None - It is only given during testing - mem_value : NDArray - mem_masks : NDArray, optional @@ -1050,107 +1135,15 @@ def __call__(self, step_input, states): #pylint: disable=arguments-differ The attention weights will have shape (batch_size, length, mem_length) or (batch_size, num_heads, length, mem_length) """ - return super(TransformerDecoder, self).__call__(step_input, states) - - def forward(self, step_input, states, mask=None): #pylint: disable=arguments-differ, missing-docstring - input_shape = step_input.shape - mem_mask = None - # If it is in testing, transform input tensor to a tensor with shape NTC - # Otherwise remove the None in states. - if len(input_shape) == 2: - if self._encoder_valid_length is not None: - has_last_embeds = len(states) == 3 - else: - has_last_embeds = len(states) == 2 - if has_last_embeds: - last_embeds = states[0] - step_input = mx.nd.concat(last_embeds, - mx.nd.expand_dims(step_input, axis=1), - dim=1) - states = states[1:] - else: - step_input = mx.nd.expand_dims(step_input, axis=1) - elif states[0] is None: - states = states[1:] - has_mem_mask = (len(states) == 2) - if has_mem_mask: - _, mem_mask = states - augmented_mem_mask = mx.nd.expand_dims(mem_mask, axis=1)\ - .broadcast_axes(axis=1, size=step_input.shape[1]) - states[-1] = augmented_mem_mask - if mask is None: - length_array = mx.nd.arange(step_input.shape[1], ctx=step_input.context, - dtype=step_input.dtype) - mask = mx.nd.broadcast_lesser_equal( - length_array.reshape((1, -1)), - length_array.reshape((-1, 1))) - mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), - axis=0, size=step_input.shape[0]) - steps = mx.nd.arange(step_input.shape[1], ctx=step_input.context) - states.append(steps) - if self._scale_embed: - scaled_step_input = step_input * math.sqrt(step_input.shape[-1]) - # pylint: disable=too-many-function-args - step_output, step_additional_outputs = \ - super(TransformerDecoder, self).forward(scaled_step_input, states, mask) - states = states[:-1] - if has_mem_mask: - states[-1] = mem_mask - new_states = [step_input] + states - # If it is in testing, only output the last one - if len(input_shape) == 2: - step_output = step_output[:, -1, :] - return step_output, new_states, step_additional_outputs - - def hybrid_forward(self, F, step_input, states, mask=None, position_weight=None): - #pylint: disable=arguments-differ - """ + outputs, states, additional_outputs = super().hybrid_forward( + F, inputs, states, valid_length=None, position_weight=position_weight) - Parameters - ---------- - step_input : NDArray or Symbol, Shape (batch_size, length, C_in) - states : list of NDArray or Symbol - mask : NDArray or Symbol - position_weight : NDArray or Symbol + # Append inputs to states: They are needed in the next one-step ahead decoding step + new_states = [inputs] + states + # Only return one-step ahead + step_output = F.slice_axis(outputs, axis=1, begin=-1, end=None).reshape((0, -1)) - Returns - ------- - step_output : NDArray or Symbol - The output of the decoder. Shape is (batch_size, length, C_out) - step_additional_outputs : list - Either be an empty list or contains the attention weights in this step. - The attention weights will have shape (batch_size, length, mem_length) or - (batch_size, num_heads, length, mem_length) - - """ - has_mem_mask = (len(states) == 3) - if has_mem_mask: - mem_value, mem_mask, steps = states - else: - mem_value, steps = states - mem_mask = None - # Positional Encoding - step_input = F.broadcast_add(step_input, - F.expand_dims(F.Embedding(steps, - position_weight, - self._max_length, - self._units), - axis=0)) - if self._dropout: - step_input = self.dropout_layer(step_input) - step_input = self.layer_norm(step_input) - inputs = step_input - outputs = inputs - step_additional_outputs = [] - attention_weights_l = [] - for cell in self.transformer_cells: - outputs, attention_weights = cell(inputs, mem_value, mask, mem_mask) - if self._output_attention: - attention_weights_l.append(attention_weights) - inputs = outputs - if self._output_attention: - step_additional_outputs.extend(attention_weights_l) - return outputs, step_additional_outputs + return step_output, new_states, additional_outputs @@ -1193,40 +1186,35 @@ def get_transformer_encoder_decoder(num_layers=2, Returns ------- encoder : TransformerEncoder - decoder :TransformerDecoder + decoder : TransformerDecoder + one_step_ahead_decoder : TransformerOneStepDecoder """ - encoder = TransformerEncoder(num_layers=num_layers, - num_heads=num_heads, - max_length=max_src_length, - units=units, - hidden_size=hidden_size, - dropout=dropout, - scaled=scaled, - use_residual=use_residual, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - prefix=prefix + 'enc_', params=params) - decoder = TransformerDecoder(num_layers=num_layers, - num_heads=num_heads, - max_length=max_tgt_length, - units=units, - hidden_size=hidden_size, - dropout=dropout, - scaled=scaled, - use_residual=use_residual, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - prefix=prefix + 'dec_', params=params) - return encoder, decoder - - -def _get_transformer_model(model_cls, model_name, dataset_name, src_vocab, tgt_vocab, - encoder, decoder, share_embed, embed_size, tie_weights, + encoder = TransformerEncoder( + num_layers=num_layers, num_heads=num_heads, max_length=max_src_length, units=units, + hidden_size=hidden_size, dropout=dropout, scaled=scaled, use_residual=use_residual, + weight_initializer=weight_initializer, bias_initializer=bias_initializer, + prefix=prefix + 'enc_', params=params) + decoder = TransformerDecoder( + num_layers=num_layers, num_heads=num_heads, max_length=max_tgt_length, units=units, + hidden_size=hidden_size, dropout=dropout, scaled=scaled, use_residual=use_residual, + weight_initializer=weight_initializer, bias_initializer=bias_initializer, + prefix=prefix + 'dec_', params=params) + one_step_ahead_decoder = TransformerOneStepDecoder( + num_layers=num_layers, num_heads=num_heads, max_length=max_tgt_length, units=units, + hidden_size=hidden_size, dropout=dropout, scaled=scaled, use_residual=use_residual, + weight_initializer=weight_initializer, bias_initializer=bias_initializer, + prefix=prefix + 'dec_', params=decoder.collect_params()) + return encoder, decoder, one_step_ahead_decoder + + +def _get_transformer_model(model_cls, model_name, dataset_name, src_vocab, tgt_vocab, encoder, + decoder, one_step_ahead_decoder, share_embed, embed_size, tie_weights, embed_initializer, pretrained, ctx, root, **kwargs): src_vocab = _load_vocab(dataset_name + '_src', src_vocab, root) tgt_vocab = _load_vocab(dataset_name + '_tgt', tgt_vocab, root) kwargs['encoder'] = encoder kwargs['decoder'] = decoder + kwargs['one_step_ahead_decoder'] = one_step_ahead_decoder kwargs['src_vocab'] = src_vocab kwargs['tgt_vocab'] = tgt_vocab kwargs['share_embed'] = share_embed @@ -1279,16 +1267,13 @@ def transformer_en_de_512(dataset_name=None, src_vocab=None, tgt_vocab=None, pre assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \ 'Cannot override predefined model settings.' predefined_args.update(kwargs) - encoder, decoder = get_transformer_encoder_decoder(units=predefined_args['num_units'], - hidden_size=predefined_args['hidden_size'], - dropout=predefined_args['dropout'], - num_layers=predefined_args['num_layers'], - num_heads=predefined_args['num_heads'], - max_src_length=530, - max_tgt_length=549, - scaled=predefined_args['scaled']) - return _get_transformer_model(NMTModel, 'transformer_en_de_512', dataset_name, - src_vocab, tgt_vocab, encoder, decoder, + encoder, decoder, one_step_ahead_decoder = get_transformer_encoder_decoder( + units=predefined_args['num_units'], hidden_size=predefined_args['hidden_size'], + dropout=predefined_args['dropout'], num_layers=predefined_args['num_layers'], + num_heads=predefined_args['num_heads'], max_src_length=530, max_tgt_length=549, + scaled=predefined_args['scaled']) + return _get_transformer_model(NMTModel, 'transformer_en_de_512', dataset_name, src_vocab, + tgt_vocab, encoder, decoder, one_step_ahead_decoder, predefined_args['share_embed'], predefined_args['embed_size'], predefined_args['tie_weights'], predefined_args['embed_initializer'], pretrained, ctx, root) diff --git a/src/gluonnlp/model/translation.py b/src/gluonnlp/model/translation.py index 940e43bdec..f0ac7754e9 100644 --- a/src/gluonnlp/model/translation.py +++ b/src/gluonnlp/model/translation.py @@ -37,7 +37,11 @@ class NMTModel(Block): encoder : Seq2SeqEncoder Encoder that encodes the input sentence. decoder : Seq2SeqDecoder - Decoder that generates the predictions based on the output of the encoder. + Decoder used during training phase. The decoder generates predictions + based on the output of the encoder. + one_step_ahead_decoder : Seq2SeqOneStepDecoder + One-step ahead decoder used during inference phase. The decoder + generates predictions based on the output of the encoder. embed_size : int or None, default None Size of the embedding vectors. It is used to generate the source and target embeddings if src_embed and tgt_embed are None. @@ -63,7 +67,7 @@ class NMTModel(Block): params : ParameterDict or None See document of `Block`. """ - def __init__(self, src_vocab, tgt_vocab, encoder, decoder, + def __init__(self, src_vocab, tgt_vocab, encoder, decoder, one_step_ahead_decoder, embed_size=None, embed_dropout=0.0, embed_initializer=mx.init.Uniform(0.1), src_embed=None, tgt_embed=None, share_embed=False, tie_weights=False, tgt_proj=None, prefix=None, params=None): @@ -72,6 +76,7 @@ def __init__(self, src_vocab, tgt_vocab, encoder, decoder, self.src_vocab = src_vocab self.encoder = encoder self.decoder = decoder + self.one_step_ahead_decoder = one_step_ahead_decoder self._shared_embed = share_embed if embed_dropout is None: embed_dropout = 0.0 @@ -158,10 +163,8 @@ def decode_seq(self, inputs, states, valid_length=None): additional_outputs : list Additional outputs of the decoder, e.g, the attention weights """ - outputs, states, additional_outputs =\ - self.decoder.decode_seq(inputs=self.tgt_embed(inputs), - states=states, - valid_length=valid_length) + outputs, states, additional_outputs = self.decoder(self.tgt_embed(inputs), states, + valid_length) outputs = self.tgt_proj(outputs) return outputs, states, additional_outputs @@ -183,7 +186,7 @@ def decode_step(self, step_input, states): Additional outputs of the step, e.g, the attention weights """ step_output, states, step_additional_outputs =\ - self.decoder(self.tgt_embed(step_input), states) + self.one_step_ahead_decoder(self.tgt_embed(step_input), states) step_output = self.tgt_proj(step_output) return step_output, states, step_additional_outputs diff --git a/tests/unittest/test_models.py b/tests/unittest/test_models.py index 26ad55dbfa..e446199db4 100644 --- a/tests/unittest/test_models.py +++ b/tests/unittest/test_models.py @@ -73,30 +73,27 @@ def test_big_text_models(wikitext2_val_and_counter): @pytest.mark.serial @pytest.mark.remote_required -def test_transformer_models(): - models = ['transformer_en_de_512'] - pretrained_to_test = {'transformer_en_de_512': 'WMT2014'} - dropout_rates = [0.1, 0.0] +@pytest.mark.parametrize('dropout_rate', [0.1, 0.0]) +@pytest.mark.parametrize('model_dataset', [('transformer_en_de_512', 'WMT2014')]) +def test_transformer_models(dropout_rate, model_dataset): + model_name, pretrained_dataset = model_dataset src = mx.nd.ones((2, 10)) tgt = mx.nd.ones((2, 8)) valid_len = mx.nd.ones((2,)) - for model_name in models: - for rate in dropout_rates: - eprint('testing forward for %s, dropout rate %f' % (model_name, rate)) - pretrained_dataset = pretrained_to_test.get(model_name) - with warnings.catch_warnings(): # TODO https://github.com/dmlc/gluon-nlp/issues/978 - warnings.simplefilter("ignore") - model, _, _ = nlp.model.get_model(model_name, dataset_name=pretrained_dataset, - pretrained=pretrained_dataset is not None, - dropout=rate) - - print(model) - if not pretrained_dataset: - model.initialize() - output, state = model(src, tgt, src_valid_length=valid_len, tgt_valid_length=valid_len) - output.wait_to_read() - del model - mx.nd.waitall() + eprint('testing forward for %s, dropout rate %f' % (model_name, dropout_rate)) + with warnings.catch_warnings(): # TODO https://github.com/dmlc/gluon-nlp/issues/978 + warnings.simplefilter("ignore") + model, _, _ = nlp.model.get_model(model_name, dataset_name=pretrained_dataset, + pretrained=pretrained_dataset is not None, + dropout=dropout_rate) + + print(model) + if not pretrained_dataset: + model.initialize() + output, state = model(src, tgt, src_valid_length=valid_len, tgt_valid_length=valid_len) + output.wait_to_read() + del model + mx.nd.waitall() @pytest.mark.serial