Skip to content

A clean PyTorch implementation of the original Transformer model + A German -> English translation example

Notifications You must be signed in to change notification settings

arxyzan/vanilla-transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

52 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Vanilla Transformer (PyTorch)

My PyTorch implementation of the original Transformer model from the paper Attention Is All You Need inspired by all the codes and blogs I've read on this topic. There's nothing really special going on here except the fact that I tried to make it as barebone as possible. There is also a training code prepared for a simple German -> English translator written in pure PyTorch using Torchtext library.

My Inspirations

And probably a couple more which I don't remember ...

Prerequisites

  1. Install the required pip packages:
pip install -r requirements.txt
  1. Install spacy models :
python -m spacy download de_core_news_sm
python -m spacy download en_core_web_sm

Note: This code uses Torchtext's new API (v0.10.0+) and the dataset.py contains a custom text dataset class inherited from torch.utils.data.Dataset and is different from the classic methods using Field and BucketIterator (which are now moved to torchtext.legacy). Nevertheless torchtext library is still under heavy development so this code will probably break with the upcoming versions.

Train

In train.py we train a simple German -> English translation model on Multi30k dataset using the Transformer model. Make sure you configure the necessary paths for weights, logs, etc in config.py. Then you can simply run the file as below:

python train.py
Epoch: 1/10     100%|######################################################################| 227/227 [00:10<00:00, 21.61batch/s, loss=4.33]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 45.25batch/s, loss=3.13]
Saved Model at weights/1.pt

Epoch: 2/10     100%|######################################################################| 227/227 [00:10<00:00, 22.64batch/s, loss=2.82]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 51.68batch/s, loss=2.55]
Saved Model at weights/2.pt

Epoch: 3/10     100%|######################################################################| 227/227 [00:10<00:00, 22.56batch/s, loss=2.22]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 51.98batch/s, loss=2.22]
Saved Model at weights/3.pt

Epoch: 4/10     100%|######################################################################| 227/227 [00:10<00:00, 22.64batch/s, loss=1.83]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 52.20batch/s, loss=2.07]
Saved Model at weights/4.pt

Epoch: 5/10     100%|######################################################################| 227/227 [00:10<00:00, 22.64batch/s, loss=1.55]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 52.12batch/s, loss=2]   
Saved Model at weights/5.pt

Epoch: 6/10     100%|######################################################################| 227/227 [00:10<00:00, 22.25batch/s, loss=1.34]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 51.45batch/s, loss=1.95]
Saved Model at weights/6.pt

Epoch: 7/10     100%|######################################################################| 227/227 [00:10<00:00, 22.55batch/s, loss=1.17]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 51.34batch/s, loss=1.95]
Saved Model at weights/7.pt

Epoch: 8/10     100%|######################################################################| 227/227 [00:10<00:00, 22.46batch/s, loss=1.03]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 51.43batch/s, loss=1.96]
Saved Model at weights/8.pt

Epoch: 9/10     100%|######################################################################| 227/227 [00:10<00:00, 22.45batch/s, loss=0.91] 
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 52.84batch/s, loss=1.99]
Saved Model at weights/9.pt

Epoch: 10/10    100%|######################################################################| 227/227 [00:10<00:00, 22.50batch/s, loss=0.808]
Evaluating...   100%|######################################################################| 8/8 [00:00<00:00, 51.74batch/s, loss=2.01]
Saved Model at weights/10.pt

Inference

Given the sentence Eine Gruppe von Menschen steht vor einem Iglu as input in predict.py we get the following output which is pretty decent even though our dataset is somewhat naive & simple.

python predict.py
"Translation:  A group of people standing in front of a warehouse ."

TODO

  • predict.py for inference
  • Add pretrained weights
  • Visualize attentions
  • An in-depth notebook

About

A clean PyTorch implementation of the original Transformer model + A German -> English translation example

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Languages