From 94e7d732c5671429c4e6771b96107e502ac0886d Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Fri, 18 Oct 2019 17:16:43 +0000 Subject: [PATCH 1/3] Split up Seq2SeqDecoder into Seq2SeqDecoder and Seq2SeqOneStepDecoder In the current Gluon API, each HybridBlock has to serve one puropse and can only define a single callable interface. Previous Seq2SeqDecoder interface required each Seq2SeqDecoder Block to perform two functionalities (multi-step ahead and single-step ahead decoding). This means neither of the two functionalities can in practice be hybridized completely. Thus use two separate Blocks for the two functionalities. They may share parameters. Update the NMTModel API accordingly. Further refactor TransformerDecoder to make it completely hybridizable. TransformerOneStepDecoder still relies on a small hack but can be hybridized completely when we enable numpy shape semantics. --- docs/examples/machine_translation/gnmt.md | 10 +- scripts/machine_translation/gnmt.py | 265 +++++++---- .../inference_transformer.py | 17 +- scripts/machine_translation/train_gnmt.py | 14 +- .../machine_translation/train_transformer.py | 31 +- scripts/tests/test_encoder_decoder.py | 88 ++-- src/gluonnlp/model/seq2seq_encoder_decoder.py | 44 +- src/gluonnlp/model/train/__init__.py | 3 +- src/gluonnlp/model/transformer.py | 449 +++++++++--------- src/gluonnlp/model/translation.py | 13 +- tests/unittest/test_models.py | 39 +- 11 files changed, 515 insertions(+), 458 deletions(-) diff --git a/docs/examples/machine_translation/gnmt.md b/docs/examples/machine_translation/gnmt.md index f195317b5a..9c8e45a324 100644 --- a/docs/examples/machine_translation/gnmt.md +++ b/docs/examples/machine_translation/gnmt.md @@ -337,12 +337,12 @@ feed the encoder and decoder to the `NMTModel` to construct the GNMT model. `model.hybridize` allows computation to be done using the symbolic backend. To understand what it means to be "hybridized," please refer to [this](https://mxnet.incubator.apache.org/versions/master/tutorials/gluon/hybrid.html) page on MXNet hybridization and its advantages. ```{.python .input} -encoder, decoder = nmt.gnmt.get_gnmt_encoder_decoder(hidden_size=num_hidden, - dropout=dropout, - num_layers=num_layers, - num_bi_layers=num_bi_layers) +encoder, decoder, one_step_ahead_decoder = nmt.gnmt.get_gnmt_encoder_decoder( + hidden_size=num_hidden, dropout=dropout, num_layers=num_layers, + num_bi_layers=num_bi_layers) model = nlp.model.translation.NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, - decoder=decoder, embed_size=num_hidden, prefix='gnmt_') + decoder=decoder, one_step_ahead_decoder=one_step_ahead_decoder, + embed_size=num_hidden, prefix='gnmt_') model.initialize(init=mx.init.Uniform(0.1), ctx=ctx) static_alloc = True model.hybridize(static_alloc=static_alloc) diff --git a/scripts/machine_translation/gnmt.py b/scripts/machine_translation/gnmt.py index 6910ec3ac3..8c3240663e 100644 --- a/scripts/machine_translation/gnmt.py +++ b/scripts/machine_translation/gnmt.py @@ -22,7 +22,7 @@ from mxnet.gluon import nn, rnn from mxnet.gluon.block import HybridBlock from gluonnlp.model.seq2seq_encoder_decoder import Seq2SeqEncoder, Seq2SeqDecoder, \ - _get_attention_cell, _get_cell_type, _nested_sequence_last + Seq2SeqOneStepDecoder, _get_attention_cell, _get_cell_type, _nested_sequence_last class GNMTEncoder(Seq2SeqEncoder): @@ -158,48 +158,14 @@ def forward(self, inputs, states=None, valid_length=None): #pylint: disable=arg return [outputs, new_states], [] -class GNMTDecoder(HybridBlock, Seq2SeqDecoder): - """Structure of the RNN Encoder similar to that used in the - Google Neural Machine Translation paper. - - We use gnmt_v2 strategy in tensorflow/nmt - - Parameters - ---------- - cell_type : str or type - attention_cell : AttentionCell or str - Arguments of the attention cell. - Can be 'scaled_luong', 'normed_mlp', 'dot' - num_layers : int - hidden_size : int - dropout : float - use_residual : bool - output_attention: bool - Whether to output the attention weights - i2h_weight_initializer : str or Initializer - Initializer for the input weights matrix, used for the linear - transformation of the inputs. - h2h_weight_initializer : str or Initializer - Initializer for the recurrent weights matrix, used for the linear - transformation of the recurrent state. - i2h_bias_initializer : str or Initializer - Initializer for the bias vector. - h2h_bias_initializer : str or Initializer - Initializer for the bias vector. - prefix : str, default 'rnn_' - Prefix for name of `Block`s - (and name of weight if params is `None`). - params : Parameter or None - Container for weight sharing between cells. - Created if `None`. - """ +class _BaseGNMTDecoder(HybridBlock): def __init__(self, cell_type='lstm', attention_cell='scaled_luong', num_layers=2, hidden_size=128, dropout=0.0, use_residual=True, output_attention=False, i2h_weight_initializer=None, h2h_weight_initializer=None, i2h_bias_initializer='zeros', h2h_bias_initializer='zeros', prefix=None, params=None): - super(GNMTDecoder, self).__init__(prefix=prefix, params=params) + super().__init__(prefix=prefix, params=params) self._cell_type = _get_cell_type(cell_type) self._num_layers = num_layers self._hidden_size = hidden_size @@ -249,59 +215,7 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): decoder_states.append(mem_masks) return decoder_states - def decode_seq(self, inputs, states, valid_length=None): - """Decode the decoder inputs. This function is only used for training. - - Parameters - ---------- - inputs : NDArray, Shape (batch_size, length, C_in) - states : list of NDArrays or None - Initial states. The list of initial decoder states - valid_length : NDArray or None - Valid lengths of each sequence. This is usually used when part of sequence has - been padded. Shape (batch_size,) - - Returns - ------- - output : NDArray, Shape (batch_size, length, C_out) - states : list - The decoder states, includes: - - - rnn_states : NDArray - - attention_vec : NDArray - - mem_value : NDArray - - mem_masks : NDArray, optional - additional_outputs : list - Either be an empty list or contains the attention weights in this step. - The attention weights will have shape (batch_size, length, mem_length) or - (batch_size, num_heads, length, mem_length) - """ - length = inputs.shape[1] - output = [] - additional_outputs = [] - inputs = _as_list(mx.nd.split(inputs, num_outputs=length, axis=1, squeeze_axis=True)) - rnn_states_l = [] - attention_output_l = [] - fixed_states = states[2:] - for i in range(length): - ele_output, states, ele_additional_outputs = self.forward(inputs[i], states) - rnn_states_l.append(states[0]) - attention_output_l.append(states[1]) - output.append(ele_output) - additional_outputs.extend(ele_additional_outputs) - output = mx.nd.stack(*output, axis=1) - if valid_length is not None: - states = [_nested_sequence_last(rnn_states_l, valid_length), - _nested_sequence_last(attention_output_l, valid_length)] + fixed_states - output = mx.nd.SequenceMask(output, - sequence_length=valid_length, - use_sequence_length=True, - axis=1) - if self._output_attention: - additional_outputs = [mx.nd.concat(*additional_outputs, dim=-2)] - return output, states, additional_outputs - - def __call__(self, step_input, states): #pylint: disable=arguments-differ + def forward(self, step_input, states): # pylint: disable=arguments-differ """One-step-ahead decoding of the GNMT decoder. Parameters @@ -326,11 +240,7 @@ def __call__(self, step_input, states): #pylint: disable=arguments-differ The attention weights will have shape (batch_size, 1, mem_length) or (batch_size, num_heads, 1, mem_length) """ - return super(GNMTDecoder, self).__call__(step_input, states) - - def forward(self, step_input, states): #pylint: disable=arguments-differ, missing-docstring - step_output, new_states, step_additional_outputs =\ - super(GNMTDecoder, self).forward(step_input, states) + step_output, new_states, step_additional_outputs = super().forward(step_input, states) # In hybrid_forward, only the rnn_states and attention_vec are calculated. # We directly append the mem_value and mem_masks in the forward() function. # We apply this trick because the memory value/mask can be directly appended to the next @@ -402,6 +312,148 @@ def hybrid_forward(self, F, step_input, states): #pylint: disable=arguments-dif return rnn_out, new_states, step_additional_outputs +class GNMTOneStepDecoder(_BaseGNMTDecoder, Seq2SeqOneStepDecoder): + """RNN Encoder similar to that used in the Google Neural Machine Translation paper. + + One-step ahead decoder used during inference. + + We use gnmt_v2 strategy in tensorflow/nmt + + Parameters + ---------- + cell_type : str or type + Can be "lstm", "gru" or constructor functions that can be directly called, + like rnn.LSTMCell + attention_cell : AttentionCell or str + Arguments of the attention cell. + Can be 'scaled_luong', 'normed_mlp', 'dot' + num_layers : int + Total number of layers + hidden_size : int + Number of hidden units + dropout : float + The dropout rate + use_residual : bool + Whether to use residual connection. Residual connection will be added in the + uni-directional RNN layers + output_attention: bool + Whether to output the attention weights + i2h_weight_initializer : str or Initializer + Initializer for the input weights matrix, used for the linear + transformation of the inputs. + h2h_weight_initializer : str or Initializer + Initializer for the recurrent weights matrix, used for the linear + transformation of the recurrent state. + i2h_bias_initializer : str or Initializer + Initializer for the bias vector. + h2h_bias_initializer : str or Initializer + Initializer for the bias vector. + prefix : str, default 'rnn_' + Prefix for name of `Block`s + (and name of weight if params is `None`). + params : Parameter or None + Container for weight sharing between cells. + Created if `None`. + """ + + +class GNMTDecoder(_BaseGNMTDecoder, Seq2SeqDecoder): + """RNN Encoder similar to that used in the Google Neural Machine Translation paper. + + Multi-step decoder used during training with teacher forcing. + + We use gnmt_v2 strategy in tensorflow/nmt + + Parameters + ---------- + cell_type : str or type + Can be "lstm", "gru" or constructor functions that can be directly called, + like rnn.LSTMCell + attention_cell : AttentionCell or str + Arguments of the attention cell. + Can be 'scaled_luong', 'normed_mlp', 'dot' + num_layers : int + Total number of layers + hidden_size : int + Number of hidden units + dropout : float + The dropout rate + use_residual : bool + Whether to use residual connection. Residual connection will be added in the + uni-directional RNN layers + output_attention: bool + Whether to output the attention weights + i2h_weight_initializer : str or Initializer + Initializer for the input weights matrix, used for the linear + transformation of the inputs. + h2h_weight_initializer : str or Initializer + Initializer for the recurrent weights matrix, used for the linear + transformation of the recurrent state. + i2h_bias_initializer : str or Initializer + Initializer for the bias vector. + h2h_bias_initializer : str or Initializer + Initializer for the bias vector. + prefix : str, default 'rnn_' + Prefix for name of `Block`s + (and name of weight if params is `None`). + params : Parameter or None + Container for weight sharing between cells. + Created if `None`. + """ + + def forward(self, inputs, states, valid_length=None): # pylint: disable=arguments-differ + """Decode the decoder inputs. This function is only used for training. + + Parameters + ---------- + inputs : NDArray, Shape (batch_size, length, C_in) + states : list of NDArrays or None + Initial states. The list of initial decoder states + valid_length : NDArray or None + Valid lengths of each sequence. This is usually used when part of sequence has + been padded. Shape (batch_size,) + + Returns + ------- + output : NDArray, Shape (batch_size, length, C_out) + states : list + The decoder states, includes: + + - rnn_states : NDArray + - attention_vec : NDArray + - mem_value : NDArray + - mem_masks : NDArray, optional + additional_outputs : list + Either be an empty list or contains the attention weights in this step. + The attention weights will have shape (batch_size, length, mem_length) or + (batch_size, num_heads, length, mem_length) + """ + length = inputs.shape[1] + output = [] + additional_outputs = [] + inputs = _as_list(mx.nd.split(inputs, num_outputs=length, axis=1, squeeze_axis=True)) + rnn_states_l = [] + attention_output_l = [] + fixed_states = states[2:] + for i in range(length): + ele_output, states, ele_additional_outputs = super().forward(inputs[i], states) + rnn_states_l.append(states[0]) + attention_output_l.append(states[1]) + output.append(ele_output) + additional_outputs.extend(ele_additional_outputs) + output = mx.nd.stack(*output, axis=1) + if valid_length is not None: + states = [_nested_sequence_last(rnn_states_l, valid_length), + _nested_sequence_last(attention_output_l, valid_length)] + fixed_states + output = mx.nd.SequenceMask(output, + sequence_length=valid_length, + use_sequence_length=True, + axis=1) + if self._output_attention: + additional_outputs = [mx.nd.concat(*additional_outputs, dim=-2)] + return output, states, additional_outputs + + def get_gnmt_encoder_decoder(cell_type='lstm', attention_cell='scaled_luong', num_layers=2, num_bi_layers=1, hidden_size=128, dropout=0.0, use_residual=False, i2h_weight_initializer=None, h2h_weight_initializer=None, @@ -435,19 +487,24 @@ def get_gnmt_encoder_decoder(cell_type='lstm', attention_cell='scaled_luong', nu decoder : GNMTDecoder """ encoder = GNMTEncoder(cell_type=cell_type, num_layers=num_layers, num_bi_layers=num_bi_layers, - hidden_size=hidden_size, dropout=dropout, - use_residual=use_residual, + hidden_size=hidden_size, dropout=dropout, use_residual=use_residual, i2h_weight_initializer=i2h_weight_initializer, h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, - h2h_bias_initializer=h2h_bias_initializer, - prefix=prefix + 'enc_', params=params) + h2h_bias_initializer=h2h_bias_initializer, prefix=prefix + 'enc_', + params=params) decoder = GNMTDecoder(cell_type=cell_type, attention_cell=attention_cell, num_layers=num_layers, - hidden_size=hidden_size, dropout=dropout, - use_residual=use_residual, + hidden_size=hidden_size, dropout=dropout, use_residual=use_residual, i2h_weight_initializer=i2h_weight_initializer, h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, - h2h_bias_initializer=h2h_bias_initializer, - prefix=prefix + 'dec_', params=params) - return encoder, decoder + h2h_bias_initializer=h2h_bias_initializer, prefix=prefix + 'dec_', + params=params) + one_step_ahead_decoder = GNMTOneStepDecoder( + cell_type=cell_type, attention_cell=attention_cell, num_layers=num_layers, + hidden_size=hidden_size, dropout=dropout, use_residual=use_residual, + i2h_weight_initializer=i2h_weight_initializer, + h2h_weight_initializer=h2h_weight_initializer, i2h_bias_initializer=i2h_bias_initializer, + h2h_bias_initializer=h2h_bias_initializer, prefix=prefix + 'dec_', + params=decoder.collect_params()) + return encoder, decoder, one_step_ahead_decoder diff --git a/scripts/machine_translation/inference_transformer.py b/scripts/machine_translation/inference_transformer.py index 3128d5f64d..8c048c300b 100644 --- a/scripts/machine_translation/inference_transformer.py +++ b/scripts/machine_translation/inference_transformer.py @@ -154,17 +154,14 @@ else: tgt_max_len = max_len[1] -encoder, decoder = get_transformer_encoder_decoder(units=args.num_units, - hidden_size=args.hidden_size, - dropout=args.dropout, - num_layers=args.num_layers, - num_heads=args.num_heads, - max_src_length=max(src_max_len, 500), - max_tgt_length=max(tgt_max_len, 500), - scaled=args.scaled) +encoder, decoder, one_step_ahead_decoder = get_transformer_encoder_decoder( + units=args.num_units, hidden_size=args.hidden_size, dropout=args.dropout, + num_layers=args.num_layers, num_heads=args.num_heads, max_src_length=max(src_max_len, 500), + max_tgt_length=max(tgt_max_len, 500), scaled=args.scaled) model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder, - share_embed=args.dataset != 'TOY', embed_size=args.num_units, - tie_weights=args.dataset != 'TOY', embed_initializer=None, prefix='transformer_') + one_step_ahead_decoder=one_step_ahead_decoder, share_embed=args.dataset != 'TOY', + embed_size=args.num_units, tie_weights=args.dataset != 'TOY', + embed_initializer=None, prefix='transformer_') param_name = args.model_parameter if (not os.path.exists(param_name)): diff --git a/scripts/machine_translation/train_gnmt.py b/scripts/machine_translation/train_gnmt.py index 11dbe69c99..c9b9b8aa28 100644 --- a/scripts/machine_translation/train_gnmt.py +++ b/scripts/machine_translation/train_gnmt.py @@ -122,12 +122,12 @@ else: ctx = mx.gpu(args.gpu) -encoder, decoder = get_gnmt_encoder_decoder(hidden_size=args.num_hidden, - dropout=args.dropout, - num_layers=args.num_layers, - num_bi_layers=args.num_bi_layers) +encoder, decoder, one_step_ahead_decoder = get_gnmt_encoder_decoder( + hidden_size=args.num_hidden, dropout=args.dropout, num_layers=args.num_layers, + num_bi_layers=args.num_bi_layers) model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder, - embed_size=args.num_hidden, prefix='gnmt_') + one_step_ahead_decoder=one_step_ahead_decoder, embed_size=args.num_hidden, + prefix='gnmt_') model.initialize(init=mx.init.Uniform(0.1), ctx=ctx) static_alloc = True model.hybridize(static_alloc=static_alloc) @@ -175,8 +175,8 @@ def evaluate(data_loader): avg_loss += loss * (tgt_seq.shape[1] - 1) avg_loss_denom += (tgt_seq.shape[1] - 1) # Translate - samples, _, sample_valid_length =\ - translator.translate(src_seq=src_seq, src_valid_length=src_valid_length) + samples, _, sample_valid_length = translator.translate( + src_seq=src_seq, src_valid_length=src_valid_length) max_score_sample = samples[:, 0, :].asnumpy() sample_valid_length = sample_valid_length[:, 0].asnumpy() for i in range(max_score_sample.shape[0]): diff --git a/scripts/machine_translation/train_transformer.py b/scripts/machine_translation/train_transformer.py index 09646bc650..a366c7e36e 100644 --- a/scripts/machine_translation/train_transformer.py +++ b/scripts/machine_translation/train_transformer.py @@ -33,25 +33,25 @@ # pylint:disable=redefined-outer-name,logging-format-interpolation import argparse -import time -import random -import os import logging import math +import os +import random +import time + import numpy as np import mxnet as mx from mxnet import gluon -import gluonnlp as nlp -from gluonnlp.loss import MaskedSoftmaxCELoss, LabelSmoothing +import gluonnlp as nlp +from gluonnlp.loss import LabelSmoothing, MaskedSoftmaxCELoss +from gluonnlp.model.transformer import ParallelTransformer, get_transformer_encoder_decoder from gluonnlp.model.translation import NMTModel -from gluonnlp.model.transformer import get_transformer_encoder_decoder, ParallelTransformer from gluonnlp.utils.parallel import Parallel +import dataprocessor +from bleu import _bpe_to_words, compute_bleu from translation import BeamSearchTranslator - from utils import logging_config -from bleu import _bpe_to_words, compute_bleu -import dataprocessor np.random.seed(100) random.seed(100) @@ -174,15 +174,12 @@ tgt_max_len = args.tgt_max_len else: tgt_max_len = max_len[1] -encoder, decoder = get_transformer_encoder_decoder(units=args.num_units, - hidden_size=args.hidden_size, - dropout=args.dropout, - num_layers=args.num_layers, - num_heads=args.num_heads, - max_src_length=max(src_max_len, 500), - max_tgt_length=max(tgt_max_len, 500), - scaled=args.scaled) +encoder, decoder, one_step_ahead_decoder = get_transformer_encoder_decoder( + units=args.num_units, hidden_size=args.hidden_size, dropout=args.dropout, + num_layers=args.num_layers, num_heads=args.num_heads, max_src_length=max(src_max_len, 500), + max_tgt_length=max(tgt_max_len, 500), scaled=args.scaled) model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder, + one_step_ahead_decoder=one_step_ahead_decoder, share_embed=args.dataset not in ('TOY', 'IWSLT2015'), embed_size=args.num_units, tie_weights=args.dataset not in ('TOY', 'IWSLT2015'), embed_initializer=None, prefix='transformer_') diff --git a/scripts/tests/test_encoder_decoder.py b/scripts/tests/test_encoder_decoder.py index 9e358d4025..0888aa468a 100644 --- a/scripts/tests/test_encoder_decoder.py +++ b/scripts/tests/test_encoder_decoder.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. +import pytest + import numpy as np import mxnet as mx from mxnet.test_utils import assert_almost_equal @@ -75,9 +77,7 @@ def test_gnmt_encoder_decoder(): decoder_states = decoder.init_state_from_encoder(encoder_outputs, src_valid_length_nd) # Test multi step forwarding - output, new_states, additional_outputs = decoder.decode_seq(tgt_seq_nd, - decoder_states, - tgt_valid_length_nd) + output, new_states, additional_outputs = decoder(tgt_seq_nd, decoder_states, tgt_valid_length_nd) assert(output.shape == (batch_size, tgt_seq_length, num_hidden)) output_npy = output.asnumpy() for i in range(batch_size): @@ -136,51 +136,53 @@ def test_transformer_encoder(): else: assert(len(additional_outputs) == 0) -def test_transformer_encoder_decoder(): +@pytest.mark.parametrize('output_attention', [False, True]) +@pytest.mark.parametrize('use_residual', [False, True]) +@pytest.mark.parametrize('batch_size', [4]) +@pytest.mark.parametrize('src_tgt_seq_len', [(5, 10), (10, 5)]) +def test_transformer_encoder_decoder(output_attention, use_residual, batch_size, src_tgt_seq_len): ctx = mx.current_context() units = 16 encoder = TransformerEncoder(num_layers=3, units=units, hidden_size=32, num_heads=8, max_length=10, dropout=0.0, use_residual=True, prefix='transformer_encoder_') encoder.initialize(ctx=ctx) encoder.hybridize() - for output_attention in [True, False]: - for use_residual in [True, False]: - decoder = TransformerDecoder(num_layers=3, units=units, hidden_size=32, num_heads=8, max_length=10, dropout=0.0, - output_attention=output_attention, use_residual=use_residual, prefix='transformer_decoder_') - decoder.initialize(ctx=ctx) - decoder.hybridize() - for batch_size in [4]: - for src_seq_length, tgt_seq_length in [(5, 10), (10, 5)]: - src_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, src_seq_length, units), ctx=ctx) - tgt_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, tgt_seq_length, units), ctx=ctx) - src_valid_length_nd = mx.nd.array(np.random.randint(1, src_seq_length, size=(batch_size,)), ctx=ctx) - tgt_valid_length_nd = mx.nd.array(np.random.randint(1, tgt_seq_length, size=(batch_size,)), ctx=ctx) - src_valid_length_npy = src_valid_length_nd.asnumpy() - tgt_valid_length_npy = tgt_valid_length_nd.asnumpy() - encoder_outputs, _ = encoder(src_seq_nd, valid_length=src_valid_length_nd) - decoder_states = decoder.init_state_from_encoder(encoder_outputs, src_valid_length_nd) + decoder = TransformerDecoder(num_layers=3, units=units, hidden_size=32, + num_heads=8, max_length=10, dropout=0.0, + output_attention=output_attention, + use_residual=use_residual, + prefix='transformer_decoder_') + decoder.initialize(ctx=ctx) + decoder.hybridize() - # Test multi step forwarding - output, new_states, additional_outputs = decoder.decode_seq(tgt_seq_nd, - decoder_states, - tgt_valid_length_nd) - assert(output.shape == (batch_size, tgt_seq_length, units)) - output_npy = output.asnumpy() - for i in range(batch_size): - tgt_v_len = int(tgt_valid_length_npy[i]) - if tgt_v_len < tgt_seq_length - 1: - assert((output_npy[i, tgt_v_len:, :] == 0).all()) - if output_attention: - assert(len(additional_outputs) == 3) - attention_out = additional_outputs[0][1].asnumpy() - assert(attention_out.shape == (batch_size, 8, tgt_seq_length, src_seq_length)) - for i in range(batch_size): - mem_v_len = int(src_valid_length_npy[i]) - if mem_v_len < src_seq_length - 1: - assert((attention_out[i, :, :, mem_v_len:] == 0).all()) - if mem_v_len > 0: - assert_almost_equal(attention_out[i, :, :, :].sum(axis=-1), - np.ones(attention_out.shape[1:3])) - else: - assert(len(additional_outputs) == 0) + src_seq_length, tgt_seq_length = src_tgt_seq_len + src_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, src_seq_length, units), ctx=ctx) + tgt_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, tgt_seq_length, units), ctx=ctx) + src_valid_length_nd = mx.nd.array(np.random.randint(1, src_seq_length, size=(batch_size,)), ctx=ctx) + tgt_valid_length_nd = mx.nd.array(np.random.randint(1, tgt_seq_length, size=(batch_size,)), ctx=ctx) + src_valid_length_npy = src_valid_length_nd.asnumpy() + tgt_valid_length_npy = tgt_valid_length_nd.asnumpy() + encoder_outputs, _ = encoder(src_seq_nd, valid_length=src_valid_length_nd) + decoder_states = decoder.init_state_from_encoder(encoder_outputs, src_valid_length_nd) + # Test multi step forwarding + output, new_states, additional_outputs = decoder(tgt_seq_nd, decoder_states, tgt_valid_length_nd) + assert(output.shape == (batch_size, tgt_seq_length, units)) + output_npy = output.asnumpy() + for i in range(batch_size): + tgt_v_len = int(tgt_valid_length_npy[i]) + if tgt_v_len < tgt_seq_length - 1: + assert((output_npy[i, tgt_v_len:, :] == 0).all()) + if output_attention: + assert(len(additional_outputs) == 3) + attention_out = additional_outputs[0][1].asnumpy() + assert(attention_out.shape == (batch_size, 8, tgt_seq_length, src_seq_length)) + for i in range(batch_size): + mem_v_len = int(src_valid_length_npy[i]) + if mem_v_len < src_seq_length - 1: + assert((attention_out[i, :, :, mem_v_len:] == 0).all()) + if mem_v_len > 0: + assert_almost_equal(attention_out[i, :, :, :].sum(axis=-1), + np.ones(attention_out.shape[1:3])) + else: + assert(len(additional_outputs) == 0) diff --git a/src/gluonnlp/model/seq2seq_encoder_decoder.py b/src/gluonnlp/model/seq2seq_encoder_decoder.py index 586893a72e..28f8c8b9e3 100644 --- a/src/gluonnlp/model/seq2seq_encoder_decoder.py +++ b/src/gluonnlp/model/seq2seq_encoder_decoder.py @@ -18,11 +18,13 @@ __all__ = ['Seq2SeqEncoder'] from functools import partial + import mxnet as mx from mxnet.gluon import rnn from mxnet.gluon.block import Block -from gluonnlp.model import AttentionCell, MLPAttentionCell, DotProductAttentionCell, \ - MultiHeadAttentionCell + +from .attention_cell import (AttentionCell, DotProductAttentionCell, + MLPAttentionCell, MultiHeadAttentionCell) def _get_cell_type(cell_type): @@ -153,9 +155,11 @@ def forward(self, inputs, valid_length=None, states=None): #pylint: disable=arg class Seq2SeqDecoder(Block): - r"""Base class of the decoders in sequence to sequence learning models. + """Base class of the decoders in sequence to sequence learning models. - In the forward function, it generates the one-step-ahead decoding output. + Given the inputs and the context computed by the encoder, generate the new + states. This is usually used in the training phase where we set the inputs + to be the target sequence. """ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): @@ -172,8 +176,8 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): """ raise NotImplementedError - def decode_seq(self, inputs, states, valid_length=None): - r"""Given the inputs and the context computed by the encoder, + def forward(self, step_input, states, valid_length=None): #pylint: disable=arguments-differ + """Given the inputs and the context computed by the encoder, generate the new states. This is usually used in the training phase where we set the inputs to be the target sequence. @@ -196,8 +200,29 @@ def decode_seq(self, inputs, states, valid_length=None): """ raise NotImplementedError - def __call__(self, step_input, states): #pylint: disable=arguments-differ - r"""One-step decoding of the input + +class Seq2SeqOneStepDecoder(Block): + r"""Base class of the decoders in sequence to sequence learning models. + + In the forward function, it generates the one-step-ahead decoding output. + + """ + def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): + r"""Generates the initial decoder states based on the encoder outputs. + + Parameters + ---------- + encoder_outputs : list of NDArrays + encoder_valid_length : NDArray or None + + Returns + ------- + decoder_states : list + """ + raise NotImplementedError + + def forward(self, step_input, states): #pylint: disable=arguments-differ + """One-step decoding of the input Parameters ---------- @@ -213,7 +238,4 @@ def __call__(self, step_input, states): #pylint: disable=arguments-differ step_additional_outputs : list Additional outputs of the step, e.g, the attention weights """ - return super(Seq2SeqDecoder, self).__call__(step_input, states) - - def forward(self, step_input, states): #pylint: disable=arguments-differ raise NotImplementedError diff --git a/src/gluonnlp/model/train/__init__.py b/src/gluonnlp/model/train/__init__.py index 5f25621730..8b66c7e208 100644 --- a/src/gluonnlp/model/train/__init__.py +++ b/src/gluonnlp/model/train/__init__.py @@ -26,8 +26,7 @@ from .embedding import * from .language_model import * -__all__ = language_model.__all__ + cache.__all__ + embedding.__all__ + \ - ['get_cache_model'] +__all__ = language_model.__all__ + cache.__all__ + embedding.__all__ + ['get_cache_model'] def get_cache_model(name, dataset_name='wikitext-2', window=2000, diff --git a/src/gluonnlp/model/transformer.py b/src/gluonnlp/model/transformer.py index 73df2db59d..9209444518 100644 --- a/src/gluonnlp/model/transformer.py +++ b/src/gluonnlp/model/transformer.py @@ -20,22 +20,24 @@ __all__ = ['TransformerEncoder', 'PositionwiseFFN', 'TransformerEncoderCell', 'transformer_en_de_512'] +import math import os -import math import numpy as np import mxnet as mx from mxnet import cpu, gluon from mxnet.gluon import nn from mxnet.gluon.block import HybridBlock from mxnet.gluon.model_zoo import model_store -from gluonnlp.utils.parallel import Parallelizable -from .seq2seq_encoder_decoder import Seq2SeqEncoder, Seq2SeqDecoder, _get_attention_cell + +from ..base import get_home_dir +from ..utils.parallel import Parallelizable from .block import GELU +from .seq2seq_encoder_decoder import (Seq2SeqDecoder, Seq2SeqEncoder, + Seq2SeqOneStepDecoder, + _get_attention_cell) from .translation import NMTModel -from .utils import _load_vocab, _load_pretrained_params -from ..base import get_home_dir - +from .utils import _load_pretrained_params, _load_vocab ############################################################################### # BASE ENCODER BLOCKS # @@ -758,7 +760,9 @@ class TransformerDecoderCell(HybridBlock): Whether to scale the softmax input by the sqrt of the input dimension in multi-head attention dropout : float + Dropout probability. use_residual : bool + Whether to use residual connection. output_attention: bool Whether to output the attention weights weight_initializer : str or Initializer @@ -866,50 +870,12 @@ def hybrid_forward(self, F, inputs, mem_value, mask=None, mem_mask=None): #pyli return outputs, additional_outputs -class TransformerDecoder(HybridBlock, Seq2SeqDecoder): - """Structure of the Transformer Decoder. - - Parameters - ---------- - attention_cell : AttentionCell or str, default 'multi_head' - Arguments of the attention cell. - Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp' - num_layers : int - units : int - hidden_size : int - number of units in the hidden layer of position-wise feed-forward networks - max_length : int - Maximum length of the input sequence. This is used for constructing position encoding - num_heads : int - Number of heads in multi-head attention - scaled : bool - Whether to scale the softmax input by the sqrt of the input dimension - in multi-head attention - dropout : float - use_residual : bool - output_attention: bool - Whether to output the attention weights - weight_initializer : str or Initializer - Initializer for the input weights matrix, used for the linear - transformation of the inputs. - bias_initializer : str or Initializer - Initializer for the bias vector. - scale_embed : bool, default True - Scale the input embeddings by sqrt(embed_size). - prefix : str, default 'rnn_' - Prefix for name of `Block`s - (and name of weight if params is `None`). - params : Parameter or None - Container for weight sharing between cells. - Created if `None`. - """ - def __init__(self, attention_cell='multi_head', num_layers=2, - units=128, hidden_size=2048, max_length=50, - num_heads=4, scaled=True, dropout=0.0, - use_residual=True, output_attention=False, - weight_initializer=None, bias_initializer='zeros', +class _BaseTransformerDecoder(HybridBlock): + def __init__(self, attention_cell='multi_head', num_layers=2, units=128, hidden_size=2048, + max_length=50, num_heads=4, scaled=True, dropout=0.0, use_residual=True, + output_attention=False, weight_initializer=None, bias_initializer='zeros', scale_embed=True, prefix=None, params=None): - super(TransformerDecoder, self).__init__(prefix=prefix, params=params) + super().__init__(prefix=prefix, params=params) assert units % num_heads == 0, 'In TransformerDecoder, the units should be divided ' \ 'exactly by the number of heads. Received units={}, ' \ 'num_heads={}'.format(units, num_heads) @@ -928,22 +894,17 @@ def __init__(self, attention_cell='multi_head', num_layers=2, self.dropout_layer = nn.Dropout(rate=dropout) self.layer_norm = nn.LayerNorm() encoding = _position_encoding_init(max_length, units) - self.position_weight = self.params.get_constant('const', encoding) + self.position_weight = self.params.get_constant('const', encoding.astype(np.float32)) self.transformer_cells = nn.HybridSequential() for i in range(num_layers): self.transformer_cells.add( - TransformerDecoderCell( - units=units, - hidden_size=hidden_size, - num_heads=num_heads, - attention_cell=attention_cell, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - dropout=dropout, - scaled=scaled, - use_residual=use_residual, - output_attention=output_attention, - prefix='transformer%d_' % i)) + TransformerDecoderCell(units=units, hidden_size=hidden_size, + num_heads=num_heads, attention_cell=attention_cell, + weight_initializer=weight_initializer, + bias_initializer=bias_initializer, dropout=dropout, + scaled=scaled, use_residual=use_residual, + output_attention=output_attention, + prefix='transformer%d_' % i)) def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): """Initialize the state from the encoder outputs. @@ -959,7 +920,7 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): The decoder states, includes: - mem_value : NDArray - - mem_masks : NDArray, optional + - mem_masks : NDArray or None """ mem_value = encoder_outputs decoder_states = [mem_value] @@ -971,10 +932,12 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): mx.nd.arange(mem_length, ctx=ctx, dtype=dtype).reshape((1, -1)), encoder_valid_length.reshape((-1, 1))) decoder_states.append(mem_masks) - self._encoder_valid_length = encoder_valid_length + else: + decoder_states.append(None) return decoder_states - def decode_seq(self, inputs, states, valid_length=None): + def hybrid_forward(self, F, inputs, states, valid_length=None, position_weight=None): + #pylint: disable=arguments-differ """Decode the decoder inputs. This function is only used for training. Parameters @@ -990,58 +953,180 @@ def decode_seq(self, inputs, states, valid_length=None): ------- output : NDArray, Shape (batch_size, length, C_out) states : list - The decoder states, includes: - + The decoder states: - mem_value : NDArray - - mem_masks : NDArray, optional + - mem_masks : NDArray or None additional_outputs : list of list Either be an empty list or contains the attention weights in this step. The attention weights will have shape (batch_size, length, mem_length) or (batch_size, num_heads, length, mem_length) """ - batch_size = inputs.shape[0] - length = inputs.shape[1] - length_array = mx.nd.arange(length, ctx=inputs.context, dtype=inputs.dtype) - mask = mx.nd.broadcast_lesser_equal( - length_array.reshape((1, -1)), - length_array.reshape((-1, 1))) + + length_array = F.contrib.arange_like(inputs, axis=1) + mask = F.broadcast_lesser_equal(length_array.reshape((1, -1)), + length_array.reshape((-1, 1))) if valid_length is not None: - arange = mx.nd.arange(length, ctx=valid_length.context, dtype=valid_length.dtype) - batch_mask = mx.nd.broadcast_lesser( - arange.reshape((1, -1)), - valid_length.reshape((-1, 1))) - mask = mx.nd.broadcast_mul(mx.nd.expand_dims(batch_mask, -1), - mx.nd.expand_dims(mask, 0)) + batch_mask = F.broadcast_lesser(length_array.reshape((1, -1)), + valid_length.reshape((-1, 1))) + batch_mask = F.expand_dims(batch_mask, -1) + mask = F.broadcast_mul(batch_mask, F.expand_dims(mask, 0)) else: - mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), axis=0, size=batch_size) - states = [None] + states - output, states, additional_outputs = self.forward(inputs, states, mask) - states = states[1:] + mask = F.expand_dims(mask, axis=0) + mask = F.broadcast_like(mask, inputs, lhs_axes=(0, ), rhs_axes=(0, )) + + mem_value, mem_mask = states + if mem_mask is not None: + mem_mask = F.expand_dims(mem_mask, axis=1) + mem_mask = F.broadcast_like(mem_mask, inputs, lhs_axes=(1, ), rhs_axes=(1, )) + + if self._scale_embed: + # XXX: input.shape[-1] and self._units are expected to be the same + inputs = inputs * math.sqrt(self._units) + + # Positional Encoding + steps = F.contrib.arange_like(inputs, axis=1) + positional_embed = F.Embedding(steps, position_weight, self._max_length, self._units) + inputs = F.broadcast_add(inputs, F.expand_dims(positional_embed, axis=0)) + + if self._dropout: + inputs = self.dropout_layer(inputs) + inputs = self.layer_norm(inputs) + additional_outputs = [] + attention_weights_l = [] + outputs = inputs + for cell in self.transformer_cells: + outputs, attention_weights = cell(outputs, mem_value, mask, mem_mask) + if self._output_attention: + attention_weights_l.append(attention_weights) + if self._output_attention: + additional_outputs.extend(attention_weights_l) + if valid_length is not None: - output = mx.nd.SequenceMask(output, - sequence_length=valid_length, - use_sequence_length=True, - axis=1) - return output, states, additional_outputs + outputs = F.SequenceMask(outputs, sequence_length=valid_length, + use_sequence_length=True, axis=1) + return outputs, states, additional_outputs + + +class TransformerDecoder(_BaseTransformerDecoder, Seq2SeqDecoder): + """Transformer Decoder. + + Multi-step ahead decoder for use during training with teacher forcing. + + Parameters + ---------- + attention_cell : AttentionCell or str, default 'multi_head' + Arguments of the attention cell. + Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp' + num_layers : int + Number of attention layers. + units : int + Number of units for the output. + hidden_size : int + number of units in the hidden layer of position-wise feed-forward networks + max_length : int + Maximum length of the input sequence. This is used for constructing position encoding + num_heads : int + Number of heads in multi-head attention + scaled : bool + Whether to scale the softmax input by the sqrt of the input dimension + in multi-head attention + dropout : float + Dropout probability. + use_residual : bool + Whether to use residual connection. + output_attention: bool + Whether to output the attention weights + weight_initializer : str or Initializer + Initializer for the input weights matrix, used for the linear + transformation of the inputs. + bias_initializer : str or Initializer + Initializer for the bias vector. + scale_embed : bool, default True + Scale the input embeddings by sqrt(embed_size). + prefix : str, default 'rnn_' + Prefix for name of `Block`s + (and name of weight if params is `None`). + params : Parameter or None + Container for weight sharing between cells. + Created if `None`. + """ + + +class TransformerOneStepDecoder(_BaseTransformerDecoder, Seq2SeqOneStepDecoder): + """Transformer Decoder. - def __call__(self, step_input, states): #pylint: disable=arguments-differ + One-step ahead decoder for use during inference. + + Parameters + ---------- + attention_cell : AttentionCell or str, default 'multi_head' + Arguments of the attention cell. + Can be 'multi_head', 'scaled_luong', 'scaled_dot', 'dot', 'cosine', 'normed_mlp', 'mlp' + num_layers : int + Number of attention layers. + units : int + Number of units for the output. + hidden_size : int + number of units in the hidden layer of position-wise feed-forward networks + max_length : int + Maximum length of the input sequence. This is used for constructing position encoding + num_heads : int + Number of heads in multi-head attention + scaled : bool + Whether to scale the softmax input by the sqrt of the input dimension + in multi-head attention + dropout : float + Dropout probability. + use_residual : bool + Whether to use residual connection. + output_attention: bool + Whether to output the attention weights + weight_initializer : str or Initializer + Initializer for the input weights matrix, used for the linear + transformation of the inputs. + bias_initializer : str or Initializer + Initializer for the bias vector. + scale_embed : bool, default True + Scale the input embeddings by sqrt(embed_size). + prefix : str, default 'rnn_' + Prefix for name of `Block`s + (and name of weight if params is `None`). + params : Parameter or None + Container for weight sharing between cells. + Created if `None`. + """ + + def forward(self, step_input, states): # pylint: disable=arguments-differ + # We implement forward, as the number of states changes between the + # first and later calls of the one-step ahead Transformer decoder. This + # is due to the lack of numpy shape semantics. Once we enable numpy + # shape semantic in the GluonNLP code-base, the number of states should + # stay constant, but the first state element will be an array of shape + # (batch_size, 0, C_in) at the first call. + if len(states) == 3: # step_input from prior call is included + last_embeds, _, _ = states + inputs = mx.nd.concat(last_embeds, mx.nd.expand_dims(step_input, axis=1), dim=1) + states = states[1:] + else: + inputs = mx.nd.expand_dims(step_input, axis=1) + return super().forward(inputs, states) + + def hybrid_forward(self, F, inputs, states, position_weight): + # pylint: disable=arguments-differ """One-step-ahead decoding of the Transformer decoder. Parameters ---------- - step_input : NDArray + step_input : NDArray, Shape (batch_size, C_in) states : list of NDArray Returns ------- step_output : NDArray - The output of the decoder. - In the train mode, Shape is (batch_size, length, C_out) - In the test mode, Shape is (batch_size, C_out) + The output of the decoder. Shape is (batch_size, C_out) new_states: list Includes - last_embeds : NDArray or None - It is only given during testing - mem_value : NDArray - mem_masks : NDArray, optional @@ -1050,107 +1135,15 @@ def __call__(self, step_input, states): #pylint: disable=arguments-differ The attention weights will have shape (batch_size, length, mem_length) or (batch_size, num_heads, length, mem_length) """ - return super(TransformerDecoder, self).__call__(step_input, states) - - def forward(self, step_input, states, mask=None): #pylint: disable=arguments-differ, missing-docstring - input_shape = step_input.shape - mem_mask = None - # If it is in testing, transform input tensor to a tensor with shape NTC - # Otherwise remove the None in states. - if len(input_shape) == 2: - if self._encoder_valid_length is not None: - has_last_embeds = len(states) == 3 - else: - has_last_embeds = len(states) == 2 - if has_last_embeds: - last_embeds = states[0] - step_input = mx.nd.concat(last_embeds, - mx.nd.expand_dims(step_input, axis=1), - dim=1) - states = states[1:] - else: - step_input = mx.nd.expand_dims(step_input, axis=1) - elif states[0] is None: - states = states[1:] - has_mem_mask = (len(states) == 2) - if has_mem_mask: - _, mem_mask = states - augmented_mem_mask = mx.nd.expand_dims(mem_mask, axis=1)\ - .broadcast_axes(axis=1, size=step_input.shape[1]) - states[-1] = augmented_mem_mask - if mask is None: - length_array = mx.nd.arange(step_input.shape[1], ctx=step_input.context, - dtype=step_input.dtype) - mask = mx.nd.broadcast_lesser_equal( - length_array.reshape((1, -1)), - length_array.reshape((-1, 1))) - mask = mx.nd.broadcast_axes(mx.nd.expand_dims(mask, axis=0), - axis=0, size=step_input.shape[0]) - steps = mx.nd.arange(step_input.shape[1], ctx=step_input.context) - states.append(steps) - if self._scale_embed: - scaled_step_input = step_input * math.sqrt(step_input.shape[-1]) - # pylint: disable=too-many-function-args - step_output, step_additional_outputs = \ - super(TransformerDecoder, self).forward(scaled_step_input, states, mask) - states = states[:-1] - if has_mem_mask: - states[-1] = mem_mask - new_states = [step_input] + states - # If it is in testing, only output the last one - if len(input_shape) == 2: - step_output = step_output[:, -1, :] - return step_output, new_states, step_additional_outputs - - def hybrid_forward(self, F, step_input, states, mask=None, position_weight=None): - #pylint: disable=arguments-differ - """ + outputs, states, additional_outputs = super().hybrid_forward( + F, inputs, states, valid_length=None, position_weight=position_weight) - Parameters - ---------- - step_input : NDArray or Symbol, Shape (batch_size, length, C_in) - states : list of NDArray or Symbol - mask : NDArray or Symbol - position_weight : NDArray or Symbol + # Append inputs to states: They are needed in the next one-step ahead decoding step + new_states = [inputs] + states + # Only return one-step ahead + step_output = F.slice_axis(outputs, axis=1, begin=-1, end=None).reshape((0, -1)) - Returns - ------- - step_output : NDArray or Symbol - The output of the decoder. Shape is (batch_size, length, C_out) - step_additional_outputs : list - Either be an empty list or contains the attention weights in this step. - The attention weights will have shape (batch_size, length, mem_length) or - (batch_size, num_heads, length, mem_length) - - """ - has_mem_mask = (len(states) == 3) - if has_mem_mask: - mem_value, mem_mask, steps = states - else: - mem_value, steps = states - mem_mask = None - # Positional Encoding - step_input = F.broadcast_add(step_input, - F.expand_dims(F.Embedding(steps, - position_weight, - self._max_length, - self._units), - axis=0)) - if self._dropout: - step_input = self.dropout_layer(step_input) - step_input = self.layer_norm(step_input) - inputs = step_input - outputs = inputs - step_additional_outputs = [] - attention_weights_l = [] - for cell in self.transformer_cells: - outputs, attention_weights = cell(inputs, mem_value, mask, mem_mask) - if self._output_attention: - attention_weights_l.append(attention_weights) - inputs = outputs - if self._output_attention: - step_additional_outputs.extend(attention_weights_l) - return outputs, step_additional_outputs + return step_output, new_states, additional_outputs @@ -1193,40 +1186,35 @@ def get_transformer_encoder_decoder(num_layers=2, Returns ------- encoder : TransformerEncoder - decoder :TransformerDecoder + decoder : TransformerDecoder + one_step_ahead_decoder : TransformerOneStepDecoder """ - encoder = TransformerEncoder(num_layers=num_layers, - num_heads=num_heads, - max_length=max_src_length, - units=units, - hidden_size=hidden_size, - dropout=dropout, - scaled=scaled, - use_residual=use_residual, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - prefix=prefix + 'enc_', params=params) - decoder = TransformerDecoder(num_layers=num_layers, - num_heads=num_heads, - max_length=max_tgt_length, - units=units, - hidden_size=hidden_size, - dropout=dropout, - scaled=scaled, - use_residual=use_residual, - weight_initializer=weight_initializer, - bias_initializer=bias_initializer, - prefix=prefix + 'dec_', params=params) - return encoder, decoder - - -def _get_transformer_model(model_cls, model_name, dataset_name, src_vocab, tgt_vocab, - encoder, decoder, share_embed, embed_size, tie_weights, + encoder = TransformerEncoder( + num_layers=num_layers, num_heads=num_heads, max_length=max_src_length, units=units, + hidden_size=hidden_size, dropout=dropout, scaled=scaled, use_residual=use_residual, + weight_initializer=weight_initializer, bias_initializer=bias_initializer, + prefix=prefix + 'enc_', params=params) + decoder = TransformerDecoder( + num_layers=num_layers, num_heads=num_heads, max_length=max_tgt_length, units=units, + hidden_size=hidden_size, dropout=dropout, scaled=scaled, use_residual=use_residual, + weight_initializer=weight_initializer, bias_initializer=bias_initializer, + prefix=prefix + 'dec_', params=params) + one_step_ahead_decoder = TransformerOneStepDecoder( + num_layers=num_layers, num_heads=num_heads, max_length=max_tgt_length, units=units, + hidden_size=hidden_size, dropout=dropout, scaled=scaled, use_residual=use_residual, + weight_initializer=weight_initializer, bias_initializer=bias_initializer, + prefix=prefix + 'dec_', params=decoder.collect_params()) + return encoder, decoder, one_step_ahead_decoder + + +def _get_transformer_model(model_cls, model_name, dataset_name, src_vocab, tgt_vocab, encoder, + decoder, one_step_ahead_decoder, share_embed, embed_size, tie_weights, embed_initializer, pretrained, ctx, root, **kwargs): src_vocab = _load_vocab(dataset_name + '_src', src_vocab, root) tgt_vocab = _load_vocab(dataset_name + '_tgt', tgt_vocab, root) kwargs['encoder'] = encoder kwargs['decoder'] = decoder + kwargs['one_step_ahead_decoder'] = one_step_ahead_decoder kwargs['src_vocab'] = src_vocab kwargs['tgt_vocab'] = tgt_vocab kwargs['share_embed'] = share_embed @@ -1279,16 +1267,13 @@ def transformer_en_de_512(dataset_name=None, src_vocab=None, tgt_vocab=None, pre assert all((k not in kwargs or k in mutable_args) for k in predefined_args), \ 'Cannot override predefined model settings.' predefined_args.update(kwargs) - encoder, decoder = get_transformer_encoder_decoder(units=predefined_args['num_units'], - hidden_size=predefined_args['hidden_size'], - dropout=predefined_args['dropout'], - num_layers=predefined_args['num_layers'], - num_heads=predefined_args['num_heads'], - max_src_length=530, - max_tgt_length=549, - scaled=predefined_args['scaled']) - return _get_transformer_model(NMTModel, 'transformer_en_de_512', dataset_name, - src_vocab, tgt_vocab, encoder, decoder, + encoder, decoder, one_step_ahead_decoder = get_transformer_encoder_decoder( + units=predefined_args['num_units'], hidden_size=predefined_args['hidden_size'], + dropout=predefined_args['dropout'], num_layers=predefined_args['num_layers'], + num_heads=predefined_args['num_heads'], max_src_length=530, max_tgt_length=549, + scaled=predefined_args['scaled']) + return _get_transformer_model(NMTModel, 'transformer_en_de_512', dataset_name, src_vocab, + tgt_vocab, encoder, decoder, one_step_ahead_decoder, predefined_args['share_embed'], predefined_args['embed_size'], predefined_args['tie_weights'], predefined_args['embed_initializer'], pretrained, ctx, root) diff --git a/src/gluonnlp/model/translation.py b/src/gluonnlp/model/translation.py index 940e43bdec..e56e9721cf 100644 --- a/src/gluonnlp/model/translation.py +++ b/src/gluonnlp/model/translation.py @@ -38,6 +38,8 @@ class NMTModel(Block): Encoder that encodes the input sentence. decoder : Seq2SeqDecoder Decoder that generates the predictions based on the output of the encoder. + one_step_ahead_decoder : Seq2SeqOneStepDecoder + Decoder that generates the one-step ahead prediction based on the output of the encoder. embed_size : int or None, default None Size of the embedding vectors. It is used to generate the source and target embeddings if src_embed and tgt_embed are None. @@ -63,7 +65,7 @@ class NMTModel(Block): params : ParameterDict or None See document of `Block`. """ - def __init__(self, src_vocab, tgt_vocab, encoder, decoder, + def __init__(self, src_vocab, tgt_vocab, encoder, decoder, one_step_ahead_decoder, embed_size=None, embed_dropout=0.0, embed_initializer=mx.init.Uniform(0.1), src_embed=None, tgt_embed=None, share_embed=False, tie_weights=False, tgt_proj=None, prefix=None, params=None): @@ -72,6 +74,7 @@ def __init__(self, src_vocab, tgt_vocab, encoder, decoder, self.src_vocab = src_vocab self.encoder = encoder self.decoder = decoder + self.one_step_ahead_decoder = one_step_ahead_decoder self._shared_embed = share_embed if embed_dropout is None: embed_dropout = 0.0 @@ -158,10 +161,8 @@ def decode_seq(self, inputs, states, valid_length=None): additional_outputs : list Additional outputs of the decoder, e.g, the attention weights """ - outputs, states, additional_outputs =\ - self.decoder.decode_seq(inputs=self.tgt_embed(inputs), - states=states, - valid_length=valid_length) + outputs, states, additional_outputs = self.decoder(self.tgt_embed(inputs), states, + valid_length) outputs = self.tgt_proj(outputs) return outputs, states, additional_outputs @@ -183,7 +184,7 @@ def decode_step(self, step_input, states): Additional outputs of the step, e.g, the attention weights """ step_output, states, step_additional_outputs =\ - self.decoder(self.tgt_embed(step_input), states) + self.one_step_ahead_decoder(self.tgt_embed(step_input), states) step_output = self.tgt_proj(step_output) return step_output, states, step_additional_outputs diff --git a/tests/unittest/test_models.py b/tests/unittest/test_models.py index 26ad55dbfa..e446199db4 100644 --- a/tests/unittest/test_models.py +++ b/tests/unittest/test_models.py @@ -73,30 +73,27 @@ def test_big_text_models(wikitext2_val_and_counter): @pytest.mark.serial @pytest.mark.remote_required -def test_transformer_models(): - models = ['transformer_en_de_512'] - pretrained_to_test = {'transformer_en_de_512': 'WMT2014'} - dropout_rates = [0.1, 0.0] +@pytest.mark.parametrize('dropout_rate', [0.1, 0.0]) +@pytest.mark.parametrize('model_dataset', [('transformer_en_de_512', 'WMT2014')]) +def test_transformer_models(dropout_rate, model_dataset): + model_name, pretrained_dataset = model_dataset src = mx.nd.ones((2, 10)) tgt = mx.nd.ones((2, 8)) valid_len = mx.nd.ones((2,)) - for model_name in models: - for rate in dropout_rates: - eprint('testing forward for %s, dropout rate %f' % (model_name, rate)) - pretrained_dataset = pretrained_to_test.get(model_name) - with warnings.catch_warnings(): # TODO https://github.com/dmlc/gluon-nlp/issues/978 - warnings.simplefilter("ignore") - model, _, _ = nlp.model.get_model(model_name, dataset_name=pretrained_dataset, - pretrained=pretrained_dataset is not None, - dropout=rate) - - print(model) - if not pretrained_dataset: - model.initialize() - output, state = model(src, tgt, src_valid_length=valid_len, tgt_valid_length=valid_len) - output.wait_to_read() - del model - mx.nd.waitall() + eprint('testing forward for %s, dropout rate %f' % (model_name, dropout_rate)) + with warnings.catch_warnings(): # TODO https://github.com/dmlc/gluon-nlp/issues/978 + warnings.simplefilter("ignore") + model, _, _ = nlp.model.get_model(model_name, dataset_name=pretrained_dataset, + pretrained=pretrained_dataset is not None, + dropout=dropout_rate) + + print(model) + if not pretrained_dataset: + model.initialize() + output, state = model(src, tgt, src_valid_length=valid_len, tgt_valid_length=valid_len) + output.wait_to_read() + del model + mx.nd.waitall() @pytest.mark.serial From 46f5b64dc09a30ec77ccc79bec7049d744de8dd7 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Mon, 28 Oct 2019 18:35:51 +0000 Subject: [PATCH 2/3] Extend unit tests to include one-step decoding --- scripts/machine_translation/gnmt.py | 2 +- scripts/tests/test_encoder_decoder.py | 50 ++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/scripts/machine_translation/gnmt.py b/scripts/machine_translation/gnmt.py index 8c3240663e..cf3c82aa13 100644 --- a/scripts/machine_translation/gnmt.py +++ b/scripts/machine_translation/gnmt.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. """Encoder and decoder usded in sequence-to-sequence learning.""" -__all__ = ['GNMTEncoder', 'GNMTDecoder', 'get_gnmt_encoder_decoder'] +__all__ = ['GNMTEncoder', 'GNMTDecoder', 'GNMTOneStepDecoder', 'get_gnmt_encoder_decoder'] import mxnet as mx from mxnet.base import _as_list diff --git a/scripts/tests/test_encoder_decoder.py b/scripts/tests/test_encoder_decoder.py index 0888aa468a..695785a619 100644 --- a/scripts/tests/test_encoder_decoder.py +++ b/scripts/tests/test_encoder_decoder.py @@ -22,7 +22,7 @@ from mxnet.test_utils import assert_almost_equal from ..machine_translation.gnmt import * from gluonnlp.model.transformer import * -from gluonnlp.model.transformer import TransformerDecoder +from gluonnlp.model.transformer import TransformerDecoder, TransformerOneStepDecoder def test_gnmt_encoder(): @@ -65,6 +65,11 @@ def test_gnmt_encoder_decoder(): output_attention=output_attention, use_residual=use_residual, prefix='gnmt_decoder_') decoder.initialize(ctx=ctx) decoder.hybridize() + one_step_decoder = GNMTOneStepDecoder(cell_type="lstm", num_layers=3, hidden_size=num_hidden, + dropout=0.0, output_attention=output_attention, + use_residual=use_residual, prefix='gnmt_decoder_', + params=decoder.collect_params()) + one_step_decoder.hybridize() for batch_size in [4]: for src_seq_length, tgt_seq_length in [(5, 10), (10, 5)]: src_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, src_seq_length, 4), ctx=ctx) @@ -98,6 +103,25 @@ def test_gnmt_encoder_decoder(): else: assert(len(additional_outputs) == 0) + # Test one-step forwarding + output, new_states, additional_outputs = one_step_decoder( + tgt_seq_nd[:, 0, :], decoder_states) + assert(output.shape == (batch_size, num_hidden)) + if output_attention: + assert(len(additional_outputs) == 1) + attention_out = additional_outputs[0].asnumpy() + assert(attention_out.shape == (batch_size, 1, src_seq_length)) + for i in range(batch_size): + mem_v_len = int(src_valid_length_npy[i]) + if mem_v_len < src_seq_length - 1: + assert((attention_out[i, :, mem_v_len:] == 0).all()) + if mem_v_len > 0: + assert_almost_equal(attention_out[i, :, :].sum(axis=-1), + np.ones(attention_out.shape[1])) + else: + assert(len(additional_outputs) == 0) + + def test_transformer_encoder(): ctx = mx.current_context() for num_layers in range(1, 3): @@ -186,3 +210,27 @@ def test_transformer_encoder_decoder(output_attention, use_residual, batch_size, np.ones(attention_out.shape[1:3])) else: assert(len(additional_outputs) == 0) + + # Test one step forwarding + decoder = TransformerOneStepDecoder(num_layers=3, units=units, hidden_size=32, + num_heads=8, max_length=10, dropout=0.0, + output_attention=output_attention, + use_residual=use_residual, + prefix='transformer_decoder_', + params=decoder.collect_params()) + decoder.hybridize() + output, new_states, additional_outputs = decoder(tgt_seq_nd[:, 0, :], decoder_states) + assert(output.shape == (batch_size, units)) + if output_attention: + assert(len(additional_outputs) == 3) + attention_out = additional_outputs[0][1].asnumpy() + assert(attention_out.shape == (batch_size, 8, 1, src_seq_length)) + for i in range(batch_size): + mem_v_len = int(src_valid_length_npy[i]) + if mem_v_len < src_seq_length - 1: + assert((attention_out[i, :, :, mem_v_len:] == 0).all()) + if mem_v_len > 0: + assert_almost_equal(attention_out[i, :, :, :].sum(axis=-1), + np.ones(attention_out.shape[1:3])) + else: + assert(len(additional_outputs) == 0) From 74fcfce933055491da6c6239ec4fe9b9428b1247 Mon Sep 17 00:00:00 2001 From: Leonard Lausen Date: Mon, 28 Oct 2019 23:43:01 +0000 Subject: [PATCH 3/3] Improve doc --- src/gluonnlp/model/seq2seq_encoder_decoder.py | 17 +++++++++++------ src/gluonnlp/model/translation.py | 6 ++++-- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/gluonnlp/model/seq2seq_encoder_decoder.py b/src/gluonnlp/model/seq2seq_encoder_decoder.py index 28f8c8b9e3..fc2c7a1da9 100644 --- a/src/gluonnlp/model/seq2seq_encoder_decoder.py +++ b/src/gluonnlp/model/seq2seq_encoder_decoder.py @@ -130,6 +130,7 @@ def _nested_sequence_last(data, valid_length): class Seq2SeqEncoder(Block): r"""Base class of the encoders in sequence to sequence learning models. """ + def __call__(self, inputs, valid_length=None, states=None): #pylint: disable=arguments-differ """Encode the input sequence. @@ -155,13 +156,14 @@ def forward(self, inputs, valid_length=None, states=None): #pylint: disable=arg class Seq2SeqDecoder(Block): - """Base class of the decoders in sequence to sequence learning models. + """Base class of the decoders for sequence to sequence learning models. Given the inputs and the context computed by the encoder, generate the new - states. This is usually used in the training phase where we set the inputs - to be the target sequence. + states. Used in the training phase where we set the inputs to be the target + sequence. """ + def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): r"""Generates the initial decoder states based on the encoder outputs. @@ -177,9 +179,10 @@ def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): raise NotImplementedError def forward(self, step_input, states, valid_length=None): #pylint: disable=arguments-differ - """Given the inputs and the context computed by the encoder, - generate the new states. This is usually used in the training phase where we set the inputs - to be the target sequence. + """Given the inputs and the context computed by the encoder, generate the new states. + + Used in the training phase where we set the inputs to be the target + sequence. Parameters ---------- @@ -197,6 +200,7 @@ def forward(self, step_input, states, valid_length=None): #pylint: disable=argu The new states of the decoder additional_outputs : list Additional outputs of the decoder, e.g, the attention weights + """ raise NotImplementedError @@ -207,6 +211,7 @@ class Seq2SeqOneStepDecoder(Block): In the forward function, it generates the one-step-ahead decoding output. """ + def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): r"""Generates the initial decoder states based on the encoder outputs. diff --git a/src/gluonnlp/model/translation.py b/src/gluonnlp/model/translation.py index e56e9721cf..f0ac7754e9 100644 --- a/src/gluonnlp/model/translation.py +++ b/src/gluonnlp/model/translation.py @@ -37,9 +37,11 @@ class NMTModel(Block): encoder : Seq2SeqEncoder Encoder that encodes the input sentence. decoder : Seq2SeqDecoder - Decoder that generates the predictions based on the output of the encoder. + Decoder used during training phase. The decoder generates predictions + based on the output of the encoder. one_step_ahead_decoder : Seq2SeqOneStepDecoder - Decoder that generates the one-step ahead prediction based on the output of the encoder. + One-step ahead decoder used during inference phase. The decoder + generates predictions based on the output of the encoder. embed_size : int or None, default None Size of the embedding vectors. It is used to generate the source and target embeddings if src_embed and tgt_embed are None.