Skip to content

Latest commit

 

History

History
129 lines (85 loc) · 4.35 KB

README.md

File metadata and controls

129 lines (85 loc) · 4.35 KB

BlaGPT

Experimental playground for benchmarking language model (LM) architectures, layers, and tricks on smaller datasets. Designed for flexible experimentation and exploration.

BlaGPT Model

BlaGPT is a flexible Transformer implementation that you can turn on/off following things in the config.

Multi-token prediction - link

Weight tying - link

Grouped query attention - link

Capping logits - link

QKV bias - link

Zero-init projection layer - link

Post and pre-RMSNorm - link

Setting base theta to 1_000_000 - llama3 - increased the final validation loss - best 3.3324

Z-loss regularization - link - increased the final validation loss by 0.02 - loss: 3.3527

KV-Shifting attention - link - seems to improve performance - loss: 3.3310 -> 3.3138 - peak memory consumption: 42858 MiB

Dilated Attention (LongNet) - link

Other Models

MegaByte - link - loss: 3.810

FTP (heavily modified) - link - loss: 3.901

Rene - link - loss: 3.340

Rwkv7 - link - loss: 4.450

Zamba2 - link - Zamba2 > Rene > Rwkv7

Hourglass Transformer (modified) - link - Hourglass > MegaByte > FTP - loss: 3.710

Hymba - link - train step time is significantly slower than the transformers. Best validation loss so far: 4.7505

Tokenformer (in BlaGPT model) - link - loss: 3.390

Optimizers

PaLMForeachSOAP - link - almost 2 times slower than Adam but the best results

Ademamix - link - Unstable even after trying different learning rates.

Adopt - link - straight up Nan

CAdamW - link - loss: 3.3517

AdamW with independent weight decay - link - loss: 3.320

Adam - loss: 3.3224

AdamW - loss: 3.3310, peak VRAM: 42053 MiB, step_time: 533ms

DeMo - link - Saves 7 GB per GPU, loss is higher than baseline, step time is slower than Adam - loss: 3.4676, peak VRAM: 41534 MiB, step_time: 820ms

Adam-Mini - link - loss is higher than Adam and AdamW and also slower ??, saved a bit of VRAM - loss: 3.3324, peak VRAM: 41534 MiB, step_time: 610ms

Best Model So Far

BlaGPT with the following configurations:

{
    "params": {
      "norm_layer": "rmsnorm",
      "attention": "GQA",
      "activation": "swiglu",
      "tie_embed_weights": true,
      "zero_init_proj_layers": true,
      "use_rotary_emb": true,
      "rmsnorm_before_qk": true
    },
    "config": {
      "block_size": 1024,
      "vocab_size": 50304,
      "n_layer": 12,
      "n_head": 12,
      "n_embd": 768,
      "dropout": 0.0,
      "bias": true,
      "norm_layer": "rmsnorm",
      "attention": "GQA",
      "activation": "swiglu",
      "use_soft_logit_capping": false,
      "n_kv_head": 4,
      "tie_embed_weights": true,
      "zero_init_proj_layers": true,
      "rmsnorm_before_qk": true,
      "use_rotary_emb": true
    },
    "val_loss": 3.2993,
    "memory_usage": 49403,
  },

Adding a New Model

  • Implement the model
  • Return the loss in the forward function
  • Register the model
  • And start training

See one of the implementations for details.

Training

  • Get the data by running data/fineweb10B_cached.py

  • Start training with:

torchrun --standalone --nproc_per_node=8 train.py --run_name pre_post_norm --model_name blagpt

Acknowledgements

The initial code is based on

Nano GPT - link

Modded NanoGPT - link