tensor2tensor 是谷歌提出的一个transformer模型。其结构与end-to-end模型类似,但结构中不再使用RNN作为基础神经元,而是采用self-attention自注意力机制来实现上下文信息的传递
具体可以参考论文Attention Is All You Need
transfomer的优势在于它不再像RNN那样具有时序行,整个运算都是并行的。
- tensor2tensor
- tensorflow
- subword-nmt
安装
pip3.6 install -r requirements.txt
tensor2tensor训练模型分为以下两个步骤
1.数据准备 t2t-datagen
2.模型训练 t2t-trainer
首先需要准备好语料,这里使用的是nlp_chinese_corpus中的翻译语料,下载完成解压到项目根目录 tmp 下
使用 generate_trainset.py 生成训练数据 raw-train.zh-en 和 raw-dev.zh-en
def prepare_data(data_io, en_io, zh_io):
for line in data_io:
line = line.strip()
json_obj = json.loads(line)
en_io.write(json_obj["english"] + "\n")
zh_io.write(" ".join(list(json_obj["chinese"])) + "\n")
en_io.close()
zh_io.close()
data_io.close()
由于中文使用的是字符特征,所以保存的时候以空格隔开,方便以后读取,当然也可以使用分词工具,但都需要以空格分开。
python3.6 generate_trainset.py
接下来准备 vocab 英文通过 subword-nmt 生成词典,然后使用 generate_en_vocab.py 整理词典
subword-nmt get-vocab --input tmp/raw-train.zh-en.en --output en.vocab
python3.6 generate_en_vocab.py
python3.6 generate_zh_vocab.py
至此,所有的语料都准备好了,但如果使用自己的数据,就需要注册自己的problem 创建一个 TranslateEnzhSub50k 任务, 并在 user_dir/__init__.py 中导入
加载用户字典通过以下方法实现
def get_vocab(self, data_dir, is_target=False):
vocab_filename = os.path.join(data_dir, self.target_vocab_name if is_target else self.source_vocab_name)
if not tf.gfile.Exists(vocab_filename):
raise ValueError("Vocab %s not found" % vocab_filename)
return text_encoder.TokenTextEncoder(vocab_filename, replace_oov="UNK")
使用 t2t-datagen 生成训练数据
t2t-datagen --data_dir=data/ --problem=translate_enzh_sub50k --t2t_usr_dir=user_dir --tmp_dir=tmp/
参数说明
--data_dir 生成的训练数据的目录
--problem 自定义的problem名
--t2t_usr_dir problem目录
--tmp_dir 数据目录
t2t-trainer --data_dir=data --output_dir=model --problem=translate_enzh_sub50k --model=transformer --hparams_set=transformer_big --train_steps=200000 --eval_steps=100 --t2t_usr_dir=user_dir --tmp_dir=tmp/ --decode_hparams="batch_size=1024"
训练20w步后可以获得 30 的 BLEU
INFO:tensorflow:Saving dict for global step 207000:
global_step = 207000,
loss = 1.9731357,
metrics-translate_enzh_sub50k/targets/accuracy = 0.5956236,
metrics-translate_enzh_sub50k/targets/accuracy_per_sequence = 0.0,
metrics-translate_enzh_sub50k/targets/accuracy_top5 = 0.80242187,
metrics-translate_enzh_sub50k/targets/approx_bleu_score = 0.300216,
metrics-translate_enzh_sub50k/targets/neg_log_perplexity = -1.9725443,
metrics-translate_enzh_sub50k/targets/rouge_2_fscore = 0.36830086,
metrics-translate_enzh_sub50k/targets/rouge_L_fscore = 0.5878231
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 207000: model/model.ckpt-207000