Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Update get_model and NMTModel API
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Oct 19, 2019
1 parent 16f40cc commit 2b3c537
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 49 deletions.
34 changes: 16 additions & 18 deletions scripts/machine_translation/train_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,26 @@
# pylint:disable=redefined-outer-name,logging-format-interpolation

import argparse
import time
import random
import os
import logging
import math
import numpy as np
import os
import random
import time

import mxnet as mx
import numpy as np
from mxnet import gluon
import gluonnlp as nlp

from gluonnlp.loss import MaskedSoftmaxCELoss, LabelSmoothing
import dataprocessor
import gluonnlp as nlp
from bleu import _bpe_to_words, compute_bleu
from gluonnlp.loss import LabelSmoothing, MaskedSoftmaxCELoss
from gluonnlp.model.transformer import (ParallelTransformer,
get_transformer_encoder_decoder)
from gluonnlp.model.translation import NMTModel
from gluonnlp.model.transformer import get_transformer_encoder_decoder, ParallelTransformer
from gluonnlp.utils.parallel import Parallel
from translation import BeamSearchTranslator

from utils import logging_config
from bleu import _bpe_to_words, compute_bleu
import dataprocessor

np.random.seed(100)
random.seed(100)
Expand Down Expand Up @@ -174,15 +175,12 @@
tgt_max_len = args.tgt_max_len
else:
tgt_max_len = max_len[1]
encoder, decoder = get_transformer_encoder_decoder(units=args.num_units,
hidden_size=args.hidden_size,
dropout=args.dropout,
num_layers=args.num_layers,
num_heads=args.num_heads,
max_src_length=max(src_max_len, 500),
max_tgt_length=max(tgt_max_len, 500),
scaled=args.scaled)
encoder, decoder, one_step_ahead_decoder = get_transformer_encoder_decoder(
units=args.num_units, hidden_size=args.hidden_size, dropout=args.dropout,
num_layers=args.num_layers, num_heads=args.num_heads, max_src_length=max(src_max_len, 500),
max_tgt_length=max(tgt_max_len, 500), scaled=args.scaled)
model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder,
one_step_ahead_decoder=one_step_ahead_decoder,
share_embed=args.dataset not in ('TOY', 'IWSLT2015'), embed_size=args.num_units,
tie_weights=args.dataset not in ('TOY', 'IWSLT2015'), embed_initializer=None,
prefix='transformer_')
Expand Down
2 changes: 1 addition & 1 deletion src/gluonnlp/model/train/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def __init__(self, attention_cell='multi_head', num_layers=2, units=128, hidden_
self.dropout_layer = nn.Dropout(rate=dropout)
self.layer_norm = nn.LayerNorm()
encoding = _position_encoding_init(max_length, units)
self.position_weight = self.params.get_constant('const', encoding)
self.position_weight = self.params.get_constant('const', encoding.astype(np.float32))
self.transformer_cells = nn.HybridSequential()
for i in range(num_layers):
self.transformer_cells.add(
Expand Down
59 changes: 35 additions & 24 deletions src/gluonnlp/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@
__all__ = ['TransformerEncoder', 'PositionwiseFFN', 'TransformerEncoderCell',
'transformer_en_de_512']

import math
import os

import math
import numpy as np
import mxnet as mx
import numpy as np
from mxnet import cpu, gluon
from mxnet.gluon import nn
from mxnet.gluon.block import HybridBlock
from mxnet.gluon.model_zoo import model_store
from gluonnlp.utils.parallel import Parallelizable
from .seq2seq_encoder_decoder import Seq2SeqEncoder, _get_attention_cell
from .block import GELU
from .translation import NMTModel
from .utils import _load_vocab, _load_pretrained_params

from ..base import get_home_dir
from ..utils.parallel import Parallelizable
from .block import GELU
from .seq2seq_encoder_decoder import Seq2SeqEncoder, _get_attention_cell
from .train import transformer

from .translation import NMTModel
from .utils import _load_pretrained_params, _load_vocab

###############################################################################
# BASE ENCODER BLOCKS #
Expand Down Expand Up @@ -882,7 +882,7 @@ def forward(self, step_input, states):
states = states[1:]
else:
inputs = mx.nd.expand_dims(step_input, axis=1)
super().forward(inputs, states)
return super().forward(inputs, states)

def hybrid_forward(self, F, inputs, states, position_weight=None):
#pylint: disable=arguments-differ
Expand All @@ -909,9 +909,8 @@ def hybrid_forward(self, F, inputs, states, position_weight=None):
(batch_size, num_heads, length, mem_length)
"""
# One-step ahead decoder is implemented as extension case of train.TransformerDecoder
outputs, states, additional_outputs = super().hybrid_forward(inputs, states,
valid_length=None,
position_weight=position_weight)
outputs, states, additional_outputs = super().hybrid_forward(
F, inputs, states, valid_length=None, position_weight=position_weight)

# Append inputs to states: They are needed in the next one-step ahead decoding step
new_states = [inputs] + states
Expand Down Expand Up @@ -974,18 +973,30 @@ def get_transformer_encoder_decoder(num_layers=2,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix=prefix + 'enc_', params=params)
decoder = TransformerDecoder(num_layers=num_layers,
num_heads=num_heads,
max_length=max_tgt_length,
units=units,
hidden_size=hidden_size,
dropout=dropout,
scaled=scaled,
use_residual=use_residual,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix=prefix + 'dec_', params=params)
return encoder, decoder
decoder = transformer.TransformerDecoder(num_layers=num_layers,
num_heads=num_heads,
max_length=max_tgt_length,
units=units,
hidden_size=hidden_size,
dropout=dropout,
scaled=scaled,
use_residual=use_residual,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix=prefix + 'dec_', params=params)
one_step_ahead_decoder = TransformerDecoder(num_layers=num_layers,
num_heads=num_heads,
max_length=max_tgt_length,
units=units,
hidden_size=hidden_size,
dropout=dropout,
scaled=scaled,
use_residual=use_residual,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
prefix=prefix + 'dec_',
params=decoder.collect_params())
return encoder, decoder, one_step_ahead_decoder


def _get_transformer_model(model_cls, model_name, dataset_name, src_vocab, tgt_vocab,
Expand Down
13 changes: 7 additions & 6 deletions src/gluonnlp/model/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class NMTModel(Block):
Encoder that encodes the input sentence.
decoder : Seq2SeqDecoder
Decoder that generates the predictions based on the output of the encoder.
one_step_ahead_decoder : Seq2SeqOneStepDecoder
Decoder that generates the one-step ahead prediction based on the output of the encoder.
embed_size : int or None, default None
Size of the embedding vectors. It is used to generate the source and target embeddings
if src_embed and tgt_embed are None.
Expand All @@ -63,7 +65,7 @@ class NMTModel(Block):
params : ParameterDict or None
See document of `Block`.
"""
def __init__(self, src_vocab, tgt_vocab, encoder, decoder,
def __init__(self, src_vocab, tgt_vocab, encoder, decoder, one_step_ahead_decoder,
embed_size=None, embed_dropout=0.0, embed_initializer=mx.init.Uniform(0.1),
src_embed=None, tgt_embed=None, share_embed=False, tie_weights=False,
tgt_proj=None, prefix=None, params=None):
Expand All @@ -72,6 +74,7 @@ def __init__(self, src_vocab, tgt_vocab, encoder, decoder,
self.src_vocab = src_vocab
self.encoder = encoder
self.decoder = decoder
self.one_step_ahead_decoder = one_step_ahead_decoder
self._shared_embed = share_embed
if embed_dropout is None:
embed_dropout = 0.0
Expand Down Expand Up @@ -158,10 +161,8 @@ def decode_seq(self, inputs, states, valid_length=None):
additional_outputs : list
Additional outputs of the decoder, e.g, the attention weights
"""
outputs, states, additional_outputs =\
self.decoder.decode_seq(inputs=self.tgt_embed(inputs),
states=states,
valid_length=valid_length)
outputs, states, additional_outputs = self.decoder(self.tgt_embed(inputs), states,
valid_length)
outputs = self.tgt_proj(outputs)
return outputs, states, additional_outputs

Expand All @@ -183,7 +184,7 @@ def decode_step(self, step_input, states):
Additional outputs of the step, e.g, the attention weights
"""
step_output, states, step_additional_outputs =\
self.decoder(self.tgt_embed(step_input), states)
self.one_step_ahead_decoder(self.tgt_embed(step_input), states)
step_output = self.tgt_proj(step_output)
return step_output, states, step_additional_outputs

Expand Down

0 comments on commit 2b3c537

Please sign in to comment.