Skip to content

Jax/Flax re-write of @karpathy 🐐 NanoGPT using some of the common Jax libraries/features (shmap, pallas, jmp, optax, orbax)

License

Notifications You must be signed in to change notification settings

azzeddineCH/flash-nanoGPT

Repository files navigation

flash-nanoGPT (Under development)

a jax (flax) re-write of Andrej Karpathy NanoGPT, this repository will hold a collection of Jax/Flax new features like : Pallas kernel language for flashAttention on TPU, Data and tensor sharding with Jax on TPU

Todos

  • GPT2 alike model in flax
  • Mixed precision training with jmp
  • Gradient accumulation with optax
  • Data sharding across GPUs/TPUs using the new Jax shmap
  • Loading and Saving checkpoints
  • Reproduce the results on shakespear-char dataset
  • TF Record reader/writer with support for data sharding across hosts
  • Multi-host training
  • Reproducing results on OpenWebText dataset
  • Loading huggingface GPTs pre-trained models
  • Fine tuning GPT-2 weights on Shakespear dataset
  • Sampling
  • Estimating MFU (Model flops utilization)
  • Profiling training iteration
  • Flash attention with Pallas

data generation

in order to run training using TPU VM, copy the generated data files into a GCP bucket

Acknowledgement

Big thanks to TPU Research Cloud for providing v2-8/v3-8/v3-32 TPU instances on Google Cloud.

References

  • Original nanoGPT repositories [1]
  • jax based nanoGPT repositories [1] [2]
  • Nvidia mixed precision training [1]
  • Google Cloud documentation [1]

About

Jax/Flax re-write of @karpathy 🐐 NanoGPT using some of the common Jax libraries/features (shmap, pallas, jmp, optax, orbax)

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published