Skip to content
This repository has been archived by the owner on Jan 5, 2023. It is now read-only.

Commit

Permalink
Merge pull request #4 from lium-lst/wip
Browse files Browse the repository at this point in the history
Merge v1.1 changes
  • Loading branch information
ozancaglayan authored Jan 25, 2018
2 parents db4beb7 + 1b590e9 commit 7891452
Show file tree
Hide file tree
Showing 37 changed files with 1,979 additions and 874 deletions.
71 changes: 71 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ a sequence-to-sequence framework which was originally a fork of
[dl4mt-tutorial](https://github.com/nyu-dl/dl4mt-tutorial).

The core parts of `nmtpytorch` depends on `numpy`, `torch` and `tqdm`.
For multimodal architectures, you also need to install `torchvision` which
is used to integrate pre-trained CNN models.

`nmtpytorch` is developed and tested on Python 3.6 and will not support
Python 2.x whatsoever.
Expand Down Expand Up @@ -52,6 +54,75 @@ nmtpy train -C <config file> train.<opt>:<val> model.<opt>:<val> ...

## Release Notes

### v1.1 (25/01/2018)

- New experimental `Multi30kDataset` and `ImageFolderDataset` classes
- `torchvision` dependency added for CNN support
- `nmtpy-coco-metrics` now computes one METEOR without `norm=True`
- Mainloop mechanism is completely refactored with **backward-incompatible**
configuration option changes for `[train]` section:
- `patience_delta` option is removed
- Added `eval_batch_size` to define batch size for GPU beam-search during training
- `eval_freq` default is now `3000` which means per `3000` minibatches
- `eval_metrics` now defaults to `loss`. As before, you can provide a list
of metrics like `bleu,meteor,loss` to compute all of them and early-stop
based on the first
- Added `eval_zero (default: False)` which tells to evaluate the model
once on dev set right before the training starts. Useful for sanity
checking if you fine-tune a model initialized with pre-trained weights
- Removed `save_best_n`: we no longer save the best `N` models on dev set
w.r.t. early-stopping metric
- Added `save_best_metrics (default: True)` which will save best models
on dev set w.r.t each metric provided in `eval_metrics`. This kind of
remedies the removal of `save_best_n`
- `checkpoint_freq` now to defaults to `5000` which means per `5000`
minibatches.
- Added `n_checkpoints (default: 5)` to define the number of last
checkpoints that will be kept if `checkpoint_freq > 0` i.e. checkpointing enabled
- Added `ExtendedInterpolation` support to configuration files:
- You can now define intermediate variables in `.conf` files to avoid
typing same paths again and again. A variable can be referenced
from within its **section** using `tensorboard_dir: ${save_path}/tb` notation
Cross-section references are also possible: `${data:root}` will be replaced
by the value of the `root` variable defined in the `[data]` section.
- Added `-p/--pretrained` to `nmtpy train` to initialize the weights of
the model using another checkpoint `.ckpt`.
- Improved input/output handling for `nmtpy translate`:
- `-s` accepts a comma-separated test sets **defined** in the configuration
file of the experiment to translate them at once. Example: `-s val,newstest2016,newstest2017`
- The mutually exclusive counterpart of `-s` is `-S` which receives a
single input file of source sentences.
- For both cases, an output prefix **should now be** provided with `-o`.
In the case of multiple test sets, the output prefix will be appended
the name of the test set and the beam size. If you just provide a single file with `-S`
the final output name will only reflect the beam size information.
- Two new arguments for `nmtpy-build-vocab`:
- `-f`: Stores frequency counts as well inside the final `json` vocabulary
- `-x`: Does not add special markers `<eos>,<bos>,<unk>,<pad>` into the vocabulary

#### Layers/Architectures

- Added `Fusion()` layer to `concat,sum,mul` an arbitrary number of inputs
- Added *experimental* `ImageEncoder()` layer to seamlessly plug a VGG or ResNet
CNN using `torchvision` pretrained models
- `Attention` layer arguments improved. You can now select the bottleneck
dimensionality for MLP attention with `att_bottleneck`. The `dot`
attention is **still not tested** and probably broken.

New layers/architectures:

- Added **AttentiveMNMT** which implements modality-specific multimodal attention
from the paper [Multimodal Attention for Neural Machine Translation](https://arxiv.org/abs/1609.03976)
- Added **ShowAttendAndTell** [model](http://www.jmlr.org/proceedings/papers/v37/xuc15.pdf)

Changes in **NMT**:

- `dec_init` defaults to `mean_ctx`, i.e. the decoder will be initialized
with the mean context computed from the source encoder
- `enc_lnorm` which was just a placeholder is now removed since we do not
provided layer-normalization for now
- Beam Search is completely moved to GPU

### Initial Release v1.0 (18/12/2017)

The initial release aims to be (as much as) feature compatible with respect
Expand Down
23 changes: 14 additions & 9 deletions bin/nmtpy
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ if __name__ == '__main__':
help="Experiment configuration file")
parser_train.add_argument('-s', '--suffix', type=str, default="",
help="Optional experiment suffix.")
parser_train.add_argument('-p', '--pretrained', type=str, default=None,
help=".ckpt file for model initialization")
parser_train.add_argument('overrides', nargs="*", default=[],
help="(section).key:value overrides for config")

Expand Down Expand Up @@ -62,22 +64,20 @@ if __name__ == '__main__':

group = parser_trans.add_mutually_exclusive_group(required=True)
# You can translate a set of splits defined in the .conf file
# In this case, outputs will be automatically <model_file>.<split>.<beam>
group.add_argument('-s', '--splits', type=str,
help='Comma separated splits from config file')
# Or you can provide another input file with -S
group.add_argument('-S', '--source', type=str,
help='A text file to translate (pass - for stdin)')
parser_trans.add_argument('-o', '--output', type=str,
help='Explicit output file if using -S <input>')
help='A text file to translate')
parser_trans.add_argument('-o', '--output', type=str, required=True,
help='Output filename prefix')

# Parse command-line arguments first
args = parser.parse_args()
if args.cmd is None:
parser.print_help()
sys.exit(1)

opts, history, weights = {}, {}, None

# Mode selection
if args.cmd == 'train':
# Parse configuration file and merge with the rest
Expand All @@ -86,13 +86,19 @@ if __name__ == '__main__':
# Setup experiment folders
setup_experiment(opts, args.suffix)

weights, history = None, None
if args.pretrained:
weights, _, _ = load_pt_file(args.pretrained)
opts.train['pretrained'] = args.pretrained

elif args.cmd == 'resume':
# Load everything to CPU without messing with storage tags
weights, history, opts = load_pt_file(args.checkpoint)
opts = Options.from_dict(opts)

# Detect device_id
device_id = opts.train['device_id'] if opts else args.device_id
device_id = (opts.train['device_id'] if args.cmd != 'translate' else
args.device_id)

# Reserve GPUs
gpu_devs = GPUManager()(device_id, strict=True)
Expand All @@ -102,7 +108,6 @@ if __name__ == '__main__':
# translate entry point
#######################
if args.cmd == 'translate':
assert len(args.models) == 1, "Ensembling not implemented yet."
Translator(**args.__dict__)()
sys.exit(0)

Expand All @@ -123,5 +128,5 @@ if __name__ == '__main__':
log.info("PyTorch {} (CUDA: {}) on '{}' (GPUs: {})".format(
torch.__version__, torch.version.cuda, platform.node(), gpu_devs))
loop = MainLoop(model, log, opts.train, history, weights, mode=args.cmd)
loop.run()
loop()
sys.exit(0)
37 changes: 25 additions & 12 deletions bin/nmtpy-build-vocab
Original file line number Diff line number Diff line change
Expand Up @@ -13,33 +13,38 @@ from nmtpytorch.vocabulary import Vocabulary
from nmtpytorch.utils.misc import get_language, pbar


def freqs_to_dict(token_freqs, min_freq=0, max_items=0):
def freqs_to_dict(token_freqs, min_freq=0, max_items=0,
store_freqs=False, exclude_markers=False):
# Get list of tokens
tokens = list(token_freqs.keys())

# Collect their frequencies in a numpy array
freqs = np.array(list(token_freqs.values()))

tokendict = OrderedDict()
for key, value in Vocabulary.TOKENS.items():
tokendict[key] = value

offset = len(tokendict)
if not exclude_markers:
for key, value in Vocabulary.TOKENS.items():
tokendict[key] = value

# Sort in descending order of frequency
sorted_idx = np.argsort(freqs)
if min_freq > 0:
sorted_tokens = [tokens[ii] for ii in sorted_idx[::-1]
sorted_tokens = [(tokens[ii], freqs[ii]) for ii in sorted_idx[::-1]
if freqs[ii] >= min_freq]
else:
sorted_tokens = [tokens[ii] for ii in sorted_idx[::-1]]
sorted_tokens = [(tokens[ii], freqs[ii]) for ii in sorted_idx[::-1]]

if max_items > 0:
sorted_tokens = sorted_tokens[:max_items]

# Start inserting from index 2
for ii, ww in enumerate(sorted_tokens):
tokendict[ww] = ii + offset
# Start inserting from index offset
offset = len(tokendict)
if store_freqs:
for ii, (token, freq) in enumerate(sorted_tokens):
tokendict[token] = (ii + offset, int(freq))
else:
for ii, (token, freq) in enumerate(sorted_tokens):
tokendict[token] = ii + offset

return tokendict

Expand Down Expand Up @@ -84,10 +89,14 @@ if __name__ == '__main__':
help='Output directory')
parser.add_argument('-s', '--single', type=str, default=None,
help='Name of the combined vocabulary file')
parser.add_argument('-f', '--store-freqs', action='store_true',
help='Store occurrence counts inside .json.')
parser.add_argument('-m', '--min-freq', type=int, default=0,
help='Filter out tokens occuring < m times')
parser.add_argument('-M', '--max-items', type=int, default=0,
help='Keep the final vocabulary size less than this')
parser.add_argument('-x', '--exclude-symbols', action='store_true',
help='Do not add special <eos>, <bos>, <pad>, <unk>')
parser.add_argument('files', type=str, nargs='+',
help='Sentence files')
args = parser.parse_args()
Expand All @@ -110,7 +119,9 @@ if __name__ == '__main__':
# Get frequencies
freqs = get_freqs(filename)
# Build dictionary from frequencies
tokendict = freqs_to_dict(freqs, args.min_freq, args.max_items)
tokendict = freqs_to_dict(freqs, args.min_freq, args.max_items,
args.store_freqs,
args.exclude_symbols)

if args.min_freq > 0:
vocab_fname += "-min%d" % args.min_freq
Expand All @@ -120,5 +131,7 @@ if __name__ == '__main__':
write_dict(vocab_fname, tokendict)

if args.single:
tokendict = freqs_to_dict(all_freqs, args.min_freq, args.max_items)
tokendict = freqs_to_dict(all_freqs, args.min_freq, args.max_items,
args.store_freqs,
args.exclude_markers)
write_dict(args.single, tokendict)
1 change: 0 additions & 1 deletion bin/nmtpy-coco-metrics
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ if __name__ == '__main__':
scorers = [
(Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
(Meteor(args.language), ["METEOR"]),
(Meteor(args.language, norm=True), ["METEOR (norm)"]),
(Cider(), ["CIDEr"]),
(Rouge(), ["ROUGE_L"]),
]
Expand Down
2 changes: 1 addition & 1 deletion nmtpytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.0.0'
__version__ = '1.1.0'
52 changes: 27 additions & 25 deletions nmtpytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,39 @@

from collections import defaultdict

from configparser import ConfigParser
from configparser import ConfigParser, ExtendedInterpolation
from ast import literal_eval

# Default data types
INT = 'int64'
FLOAT = 'float32'


TRAIN_DEFAULTS = {
'seed': 1234, # RNG seed
'gclip': 5., # Clip gradients above clip_c
'l2_reg': 0., # L2 penalty factor
'patience': 10, # Early stopping patience
'optimizer': 'adam', # adadelta, sgd, rmsprop, adam
'lr': 0, # 0: Use default lr if not precised further
'device_id': 'auto_1', # auto_N for automatic N gpus
# 0,1,2 for manual N gpus
# 0 for 0th (single) GPU
'disp_freq': 30, # Training display frequency (/batch)
'batch_size': 32, # Training batch size
'max_epochs': 100, # Max number of epochs to train
'eval_beam': 6, # Validation beam_size
'eval_freq': 0, # 0: End of epochs
'eval_start': 1, # Epoch which validation will start
'save_best_n': 4, # Store a set of 4 best validation models
'eval_metrics': 'bleu', # comma sep. metrics, 1st -> earlystopping
'eval_filters': '', # comma sep. filters to apply to refs and hyps
'checkpoint_freq': 0, # Checkpoint frequency for resuming
'max_iterations': int(1e6), # Max number of updates to train
'patience_delta': 0., # Abs. difference that counts for metrics
'tensorboard_dir': '', # Enable TB and give global log folder
'device_id': 'auto_1', # auto_N for automatic N gpus
# 0,1,2 for manual N gpus
# 0 for 0th (single) GPU
'seed': 1234, # RNG seed. 0 -> Don't init
'gclip': 5., # Clip gradients above clip_c
'l2_reg': 0., # L2 penalty factor
'patience': 20, # Early stopping patience
'optimizer': 'adam', # adadelta, sgd, rmsprop, adam
'lr': 0.0004, # 0 -> Use default lr from Pytorch
'disp_freq': 30, # Training display frequency (/batch)
'batch_size': 32, # Training batch size
'max_epochs': 100, # Max number of epochs to train
'max_iterations': int(1e6), # Max number of updates to train
'eval_metrics': 'loss', # comma sep. metrics, 1st -> earlystopping
'eval_filters': '', # comma sep. filters to apply to refs and hyps
'eval_beam': 6, # Validation beam size
'eval_batch_size': 16, # batch_size for GPU beam-search
'eval_freq': 3000, # 0 means 'End of epochs'
'eval_start': 1, # Epoch which validation will start
'eval_zero': False, # Evaluate once before starting training
# Useful when using pretrained init.
'save_best_metrics': True, # Save best models for each eval_metric
'checkpoint_freq': 5000, # Periodic checkpoint frequency
'n_checkpoints': 5, # Number of checkpoints to keep
'tensorboard_dir': '', # Enable TB and give global log folder
}


Expand Down Expand Up @@ -80,7 +82,7 @@ def from_dict(cls, dict_):
return obj

def __init__(self, filename, overrides=None):
self.__parser = ConfigParser()
self.__parser = ConfigParser(interpolation=ExtendedInterpolation())
self.filename = filename
self.overrides = defaultdict(dict)
self.sections = []
Expand Down
2 changes: 2 additions & 0 deletions nmtpytorch/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .multi30k import Multi30kDataset
from .bitext import BitextDataset
from .imagefolder import ImageFolderDataset
11 changes: 9 additions & 2 deletions nmtpytorch/datasets/bitext.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def __init__(self, split, data_dict, vocabs, topology,
#######################
path = self.data_dict[self.txt_split][self.sl]
fnames = sorted(path.parent.glob(path.name))
assert len(fnames) == 1, "Multiple source files not supported."
if len(fnames) == 0:
raise RuntimeError('{} does not exist.'.format(path))
elif len(fnames) > 1:
raise RuntimeError("Multiple source files not supported.")

self.data[self.sl], self.lens[self.sl] = \
read_sentences(fnames[0], self.src_vocab)
Expand All @@ -57,7 +60,11 @@ def __init__(self, split, data_dict, vocabs, topology,
if self.tl in self.data_dict[self.txt_split]:
path = self.data_dict[self.txt_split][self.tl]
fnames = sorted(path.parent.glob(path.name))
assert len(fnames) == 1, "Multiple target files not supported."
if len(fnames) == 0:
raise RuntimeError('{} does not exist.'.format(path))
elif len(fnames) > 1:
raise RuntimeError("Multiple source files not supported.")

self.data[self.tl], self.lens[self.tl] = \
read_sentences(fnames[0], self.trg_vocab, bos=True)

Expand Down
Loading

0 comments on commit 7891452

Please sign in to comment.