forked from PaddlePaddle/Paddle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#52 from guoshengCS/add-hapi-seq2seq-new
Add seq2seq example
- Loading branch information
Showing
24 changed files
with
1,617 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
运行本目录下的范例模型需要安装PaddlePaddle Fluid 1.7版。如果您的 PaddlePaddle 安装版本低于此要求,请按照[安装文档](https://www.paddlepaddle.org.cn/#quick-start)中的说明更新 PaddlePaddle 安装版本。 | ||
|
||
# Sequence to Sequence (Seq2Seq) | ||
|
||
以下是本范例模型的简要目录结构及说明: | ||
|
||
``` | ||
. | ||
├── README.md # 文档,本文件 | ||
├── args.py # 训练、预测以及模型参数配置程序 | ||
├── reader.py # 数据读入程序 | ||
├── download.py # 数据下载程序 | ||
├── train.py # 训练主程序 | ||
├── predict.py # 预测主程序 | ||
├── seq2seq_attn.py # 带注意力机制的翻译模型程序 | ||
└── seq2seq_base.py # 无注意力机制的翻译模型程序 | ||
``` | ||
|
||
## 简介 | ||
|
||
Sequence to Sequence (Seq2Seq),使用编码器-解码器(Encoder-Decoder)结构,用编码器将源序列编码成vector,再用解码器将该vector解码为目标序列。Seq2Seq 广泛应用于机器翻译,自动对话机器人,文档摘要自动生成,图片描述自动生成等任务中。 | ||
|
||
本目录包含Seq2Seq的一个经典样例:机器翻译,实现了一个base model(不带attention机制),一个带attention机制的翻译模型。Seq2Seq翻译模型,模拟了人类在进行翻译类任务时的行为:先解析源语言,理解其含义,再根据该含义来写出目标语言的语句。更多关于机器翻译的具体原理和数学表达式,我们推荐参考飞桨官网[机器翻译案例](https://www.paddlepaddle.org.cn/documentation/docs/zh/user_guides/nlp_case/machine_translation/README.cn.html)。 | ||
|
||
## 模型概览 | ||
|
||
本模型中,在编码器方面,我们采用了基于LSTM的多层的RNN encoder;在解码器方面,我们使用了带注意力(Attention)机制的RNN decoder,并同时提供了一个不带注意力机制的解码器实现作为对比。在预测时我们使用柱搜索(beam search)算法来生成翻译的目标语句。 | ||
|
||
## 数据介绍 | ||
|
||
本教程使用[IWSLT'15 English-Vietnamese data ](https://nlp.stanford.edu/projects/nmt/)数据集中的英语到越南语的数据作为训练语料,tst2012的数据作为开发集,tst2013的数据作为测试集 | ||
|
||
### 数据获取 | ||
|
||
``` | ||
python download.py | ||
``` | ||
|
||
## 模型训练 | ||
|
||
执行以下命令即可训练带有注意力机制的Seq2Seq机器翻译模型: | ||
|
||
```sh | ||
export CUDA_VISIBLE_DEVICES=0 | ||
|
||
python train.py \ | ||
--src_lang en --tar_lang vi \ | ||
--attention True \ | ||
--num_layers 2 \ | ||
--hidden_size 512 \ | ||
--src_vocab_size 17191 \ | ||
--tar_vocab_size 7709 \ | ||
--batch_size 128 \ | ||
--dropout 0.2 \ | ||
--init_scale 0.1 \ | ||
--max_grad_norm 5.0 \ | ||
--train_data_prefix data/en-vi/train \ | ||
--eval_data_prefix data/en-vi/tst2012 \ | ||
--test_data_prefix data/en-vi/tst2013 \ | ||
--vocab_prefix data/en-vi/vocab \ | ||
--use_gpu True \ | ||
--model_path ./attention_models | ||
``` | ||
|
||
可以通过修改 `attention` 参数为False来训练不带注意力机制的Seq2Seq模型,各参数的具体说明请参阅 `args.py` 。训练程序会在每个epoch训练结束之后,save一次模型。 | ||
|
||
默认使用动态图模式进行训练,可以通过设置 `eager_run` 参数为False来以静态图模式进行训练,如下: | ||
|
||
```sh | ||
export CUDA_VISIBLE_DEVICES=0 | ||
|
||
python train.py \ | ||
--src_lang en --tar_lang vi \ | ||
--attention True \ | ||
--num_layers 2 \ | ||
--hidden_size 512 \ | ||
--src_vocab_size 17191 \ | ||
--tar_vocab_size 7709 \ | ||
--batch_size 128 \ | ||
--dropout 0.2 \ | ||
--init_scale 0.1 \ | ||
--max_grad_norm 5.0 \ | ||
--train_data_prefix data/en-vi/train \ | ||
--eval_data_prefix data/en-vi/tst2012 \ | ||
--test_data_prefix data/en-vi/tst2013 \ | ||
--vocab_prefix data/en-vi/vocab \ | ||
--use_gpu True \ | ||
--model_path ./attention_models \ | ||
--eager_run False | ||
``` | ||
|
||
## 模型预测 | ||
|
||
训练完成之后,可以使用保存的模型(由 `--reload_model` 指定)对test的数据集(由 `--infer_file` 指定)进行beam search解码,命令如下: | ||
|
||
```sh | ||
export CUDA_VISIBLE_DEVICES=0 | ||
|
||
python infer.py \ | ||
--attention True \ | ||
--src_lang en --tar_lang vi \ | ||
--num_layers 2 \ | ||
--hidden_size 512 \ | ||
--src_vocab_size 17191 \ | ||
--tar_vocab_size 7709 \ | ||
--batch_size 128 \ | ||
--dropout 0.2 \ | ||
--init_scale 0.1 \ | ||
--max_grad_norm 5.0 \ | ||
--vocab_prefix data/en-vi/vocab \ | ||
--infer_file data/en-vi/tst2013.en \ | ||
--reload_model attention_models/10 \ | ||
--infer_output_file infer_output.txt \ | ||
--beam_size 10 \ | ||
--use_gpu True | ||
``` | ||
|
||
各参数的具体说明请参阅 `args.py` ,注意预测时所用模型超参数需和训练时一致。和训练类似,预测时同样可以以静态图模式进行,如下: | ||
|
||
```sh | ||
export CUDA_VISIBLE_DEVICES=0 | ||
|
||
python infer.py \ | ||
--attention True \ | ||
--src_lang en --tar_lang vi \ | ||
--num_layers 2 \ | ||
--hidden_size 512 \ | ||
--src_vocab_size 17191 \ | ||
--tar_vocab_size 7709 \ | ||
--batch_size 128 \ | ||
--dropout 0.2 \ | ||
--init_scale 0.1 \ | ||
--max_grad_norm 5.0 \ | ||
--vocab_prefix data/en-vi/vocab \ | ||
--infer_file data/en-vi/tst2013.en \ | ||
--reload_model attention_models/10 \ | ||
--infer_output_file infer_output.txt \ | ||
--beam_size 10 \ | ||
--use_gpu True \ | ||
--eager_run False | ||
``` | ||
|
||
## 效果评价 | ||
|
||
使用 [*multi-bleu.perl*](https://github.com/moses-smt/mosesdecoder.git) 工具来评价模型预测的翻译质量,使用方法如下: | ||
|
||
```sh | ||
mosesdecoder/scripts/generic/multi-bleu.perl tst2013.vi < infer_output.txt | ||
``` | ||
|
||
每个模型分别训练了10次,单次取第10个epoch保存的模型进行预测,取beam_size=10。效果如下(为了便于观察,对10次结果按照升序进行了排序): | ||
|
||
``` | ||
> no attention | ||
tst2012 BLEU: | ||
[10.75 10.85 10.9 10.94 10.97 11.01 11.01 11.04 11.13 11.4] | ||
tst2013 BLEU: | ||
[10.71 10.71 10.74 10.76 10.91 10.94 11.02 11.16 11.21 11.44] | ||
> with attention | ||
tst2012 BLEU: | ||
[21.14 22.34 22.54 22.65 22.71 22.71 23.08 23.15 23.3 23.4] | ||
tst2013 BLEU: | ||
[23.41 24.79 25.11 25.12 25.19 25.24 25.39 25.61 25.61 25.63] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import argparse | ||
import distutils.util | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description=__doc__) | ||
parser.add_argument( | ||
"--train_data_prefix", type=str, help="file prefix for train data") | ||
parser.add_argument( | ||
"--eval_data_prefix", type=str, help="file prefix for eval data") | ||
parser.add_argument( | ||
"--test_data_prefix", type=str, help="file prefix for test data") | ||
parser.add_argument( | ||
"--vocab_prefix", type=str, help="file prefix for vocab") | ||
parser.add_argument("--src_lang", type=str, help="source language suffix") | ||
parser.add_argument("--tar_lang", type=str, help="target language suffix") | ||
|
||
parser.add_argument( | ||
"--attention", | ||
type=eval, | ||
default=False, | ||
help="Whether use attention model") | ||
|
||
parser.add_argument( | ||
"--optimizer", | ||
type=str, | ||
default='adam', | ||
help="optimizer to use, only supprt[sgd|adam]") | ||
|
||
parser.add_argument( | ||
"--learning_rate", | ||
type=float, | ||
default=0.001, | ||
help="learning rate for optimizer") | ||
|
||
parser.add_argument( | ||
"--num_layers", | ||
type=int, | ||
default=1, | ||
help="layers number of encoder and decoder") | ||
parser.add_argument( | ||
"--hidden_size", | ||
type=int, | ||
default=100, | ||
help="hidden size of encoder and decoder") | ||
parser.add_argument("--src_vocab_size", type=int, help="source vocab size") | ||
parser.add_argument("--tar_vocab_size", type=int, help="target vocab size") | ||
|
||
parser.add_argument( | ||
"--batch_size", type=int, help="batch size of each step") | ||
|
||
parser.add_argument( | ||
"--max_epoch", type=int, default=12, help="max epoch for the training") | ||
|
||
parser.add_argument( | ||
"--max_len", | ||
type=int, | ||
default=50, | ||
help="max length for source and target sentence") | ||
parser.add_argument( | ||
"--dropout", type=float, default=0.0, help="drop probability") | ||
parser.add_argument( | ||
"--init_scale", | ||
type=float, | ||
default=0.0, | ||
help="init scale for parameter") | ||
parser.add_argument( | ||
"--max_grad_norm", | ||
type=float, | ||
default=5.0, | ||
help="max grad norm for global norm clip") | ||
|
||
parser.add_argument( | ||
"--log_freq", | ||
type=int, | ||
default=100, | ||
help="The frequency to print training logs") | ||
|
||
parser.add_argument( | ||
"--model_path", | ||
type=str, | ||
default='model', | ||
help="model path for model to save") | ||
|
||
parser.add_argument( | ||
"--reload_model", type=str, help="reload model to inference") | ||
|
||
parser.add_argument( | ||
"--infer_file", type=str, help="file name for inference") | ||
parser.add_argument( | ||
"--infer_output_file", | ||
type=str, | ||
default='infer_output', | ||
help="file name for inference output") | ||
parser.add_argument( | ||
"--beam_size", type=int, default=10, help="file name for inference") | ||
|
||
parser.add_argument( | ||
'--use_gpu', | ||
type=eval, | ||
default=False, | ||
help='Whether using gpu [True|False]') | ||
|
||
parser.add_argument( | ||
'--eager_run', type=eval, default=False, help='Whether to use dygraph') | ||
|
||
parser.add_argument( | ||
"--enable_ce", | ||
action='store_true', | ||
help="The flag indicating whether to run the task " | ||
"for continuous evaluation.") | ||
|
||
parser.add_argument( | ||
"--profile", action='store_true', help="Whether enable the profile.") | ||
# NOTE: profiler args, used for benchmark | ||
parser.add_argument( | ||
"--profiler_path", | ||
type=str, | ||
default='./seq2seq.profile', | ||
help="the profiler output file path. (used for benchmark)") | ||
args = parser.parse_args() | ||
return args |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the 'License'); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an 'AS IS' BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
''' | ||
Script for downloading training data. | ||
''' | ||
import os | ||
import urllib | ||
import sys | ||
|
||
if sys.version_info >= (3, 0): | ||
import urllib.request | ||
import zipfile | ||
|
||
URLLIB = urllib | ||
if sys.version_info >= (3, 0): | ||
URLLIB = urllib.request | ||
|
||
remote_path = 'https://nlp.stanford.edu/projects/nmt/data/iwslt15.en-vi' | ||
base_path = 'data' | ||
tar_path = os.path.join(base_path, 'en-vi') | ||
filenames = [ | ||
'train.en', 'train.vi', 'tst2012.en', 'tst2012.vi', 'tst2013.en', | ||
'tst2013.vi', 'vocab.en', 'vocab.vi' | ||
] | ||
|
||
|
||
def main(arguments): | ||
print("Downloading data......") | ||
|
||
if not os.path.exists(tar_path): | ||
if not os.path.exists(base_path): | ||
os.mkdir(base_path) | ||
os.mkdir(tar_path) | ||
|
||
for filename in filenames: | ||
url = remote_path + '/' + filename | ||
tar_file = os.path.join(tar_path, filename) | ||
URLLIB.urlretrieve(url, tar_file) | ||
print("Downloaded sucess......") | ||
|
||
|
||
if __name__ == '__main__': | ||
sys.exit(main(sys.argv[1:])) |
Oops, something went wrong.