My PyTorch implementation of the original Transformer model from the paper Attention Is All You Need inspired by all the codes and blogs I've read on this topic. There's nothing really special going on here except the fact that I tried to make it as barebone as possible. There is also a training code prepared for a simple German -> English translator written in pure PyTorch using Torchtext library.
- The Illustrated Transformer by Jay Alammar
- The Original Transformer (PyTorch) by Aleksa Gordic
- Attention is all you need from scratch by Aladdin Persson
- PyTorch Seq2Seq by Ben Trevett
- Transformers: Attention in Disguise by Mihail Eric
- The Annotated Transformer by Harvard NLP
And probably a couple more which I don't remember ...
- Install the required pip packages:
pip install -r requirements.txt
- Install
spacy
models :
python -m spacy download de_core_news_sm
python -m spacy download en_core_web_sm
Note: This code uses Torchtext's new API (v0.10.0+) and the dataset.py
contains a custom text dataset class inherited from torch.utils.data.Dataset
and is different from the classic methods using Field
and BucketIterator
(which are now moved to torchtext.legacy
). Nevertheless torchtext
library is still under heavy development so this code will probably break with the upcoming versions.
In train.py
we train a simple German -> English translation model on Multi30k dataset using the Transformer model. Make sure you configure the necessary paths for weights, logs, etc in config.py
. Then you can simply run the file as below:
python train.py
Epoch: 1/10 100%|######################################################################| 227/227 [00:10<00:00, 21.61batch/s, loss=4.33]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 45.25batch/s, loss=3.13]
Saved Model at weights/1.pt
Epoch: 2/10 100%|######################################################################| 227/227 [00:10<00:00, 22.64batch/s, loss=2.82]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 51.68batch/s, loss=2.55]
Saved Model at weights/2.pt
Epoch: 3/10 100%|######################################################################| 227/227 [00:10<00:00, 22.56batch/s, loss=2.22]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 51.98batch/s, loss=2.22]
Saved Model at weights/3.pt
Epoch: 4/10 100%|######################################################################| 227/227 [00:10<00:00, 22.64batch/s, loss=1.83]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 52.20batch/s, loss=2.07]
Saved Model at weights/4.pt
Epoch: 5/10 100%|######################################################################| 227/227 [00:10<00:00, 22.64batch/s, loss=1.55]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 52.12batch/s, loss=2]
Saved Model at weights/5.pt
Epoch: 6/10 100%|######################################################################| 227/227 [00:10<00:00, 22.25batch/s, loss=1.34]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 51.45batch/s, loss=1.95]
Saved Model at weights/6.pt
Epoch: 7/10 100%|######################################################################| 227/227 [00:10<00:00, 22.55batch/s, loss=1.17]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 51.34batch/s, loss=1.95]
Saved Model at weights/7.pt
Epoch: 8/10 100%|######################################################################| 227/227 [00:10<00:00, 22.46batch/s, loss=1.03]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 51.43batch/s, loss=1.96]
Saved Model at weights/8.pt
Epoch: 9/10 100%|######################################################################| 227/227 [00:10<00:00, 22.45batch/s, loss=0.91]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 52.84batch/s, loss=1.99]
Saved Model at weights/9.pt
Epoch: 10/10 100%|######################################################################| 227/227 [00:10<00:00, 22.50batch/s, loss=0.808]
Evaluating... 100%|######################################################################| 8/8 [00:00<00:00, 51.74batch/s, loss=2.01]
Saved Model at weights/10.pt
Given the sentence Eine Gruppe von Menschen steht vor einem Iglu
as input in predict.py
we get the following output which is pretty decent even though our dataset is somewhat naive & simple.
python predict.py
"Translation: A group of people standing in front of a warehouse ."
-
predict.py
for inference - Add pretrained weights
- Visualize attentions
- An in-depth notebook