JAX codebase demonstrating an application of ZeRO-style optimizer sharding using a combination of xmap
and pjit
. This codebase was used to train a 1.3B parameter transformer model on a TPU v3-32, something that would not be possible with standard data parallel training. I have a full post detailing my work which you can read here.
Add your model config to conf/model_config.yaml
:
model_name:
embedding_dim:
vocab_size:
num_head:
block_size: # maximum context length
dropout:
N:
alibi_attn: # boolean for using ALiBi attention
All other configuration is handled in conf/config.yaml
.
This assumes you have your data setup on a GCP bucket and .index files created for your datasets:
python main_zero.py
If resuming a run, pass the --resume
flag to your script.
The following three models are available for download:
Their performance is roughly summarized here:
Model Size (M) | Training Tokens (B) | LAMBADA (PPL) | LAMBADA (ACC) | PIQA (Acc) | Winogrande (Acc) | Hellaswag (Acc Norm) |
---|---|---|---|---|---|---|
417 | 300 | 13.1534 | 48.11% | 65.02% | 51.93% | 36.00% |
760 | 330 | 8.6189 | 55.52% | 67.63% | 55.01% | 41.46% |
1300 | 200 | 7.6880 | 57.15% | 69.48% | 55.09% | 45.21% |
Once you've downloaded the weihgts, the following code is sufficient to load and run the models. For example, to load the 1.3B param model:
from torch_compatability.GPT2 import model_getter
model = model_getter(
size = "1_3b,
model_checkpoint="path/to/weights"
)
model.to('cuda')
If you're interested in accessing the flax models including optimizer state, feel free to open an issue in the repo.
Tests are written in a combination of unittest
and pytest
(yes, I know this is kinda silly). All tests can be run with:
pytest
from the base directory.
git clone https://github.com/fattorib/transformer.git
cd transformer
bash prepareTPUVM.sh
TPU Development and training supported with Cloud TPUs from Google's TPU Research Cloud (TRC). Thank you to the excellent TRC team for granting me access to upgraded TPU VMs and for the extensions I received while working on this project!