diff --git a/.gitignore b/.gitignore
index c852262..c12b2d2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -2,6 +2,7 @@
__pycache__/
*.py[cod]
*$py.class
+wandb/
*.DS_Store
@@ -160,3 +161,6 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
+
+# uv
+uv.lock
\ No newline at end of file
diff --git a/README.md b/README.md
index 35b1c34..3b7ed79 100644
--- a/README.md
+++ b/README.md
@@ -1,61 +1,170 @@
+
-Attempting to balance ergonomics and simplicity. This is meant to be easily hackable for research purposes.
+## 💡Purpose
-```
-torch, lightning-fabric (or) accelerate, datasets, rich (eyecandy) ~~tokenizers will be removed~~
-```
+Rafale provides an opinionated scaffolding for training transformers. It is solely built to be an efficient
+learning/research tool. It is **not** a fully fledged library for large scale training.
-@TODO :: (check out this stream on HF accelerate)[https://www.youtube.com/watch?v=X-Jx5-YskKY]
+It should be thought of as a starting point for research projects to bootstrap experiments on small LMs. The best way to
+use rafale is to simply fork it and build on top of it for your specific purposes.
+### Core dependencies
-## Purpose
+Attempting to balance ergonomics and simplicity. This is meant to be easily hackable for research purposes.
-This package is solely built to be an efficient research tool. It will not support data preprocessing/handling
-pipelines. It should be thought of as a starting point for research projects to bootstrap experiments on small LMs.
+```
+torch, composer, datasets, tokenizers
+```
-Should be used pip installable via git and setup to be easily hackable to build on top of it.
+## 🚀 Installation & Usage
-Datasets should be preshuffled and pretokenized, only load it from disk and feed it to the dataloader with the collator
-function.
+Setup with ```uv``` ([install uv](https://github.com/astral-sh/uv)).
+```sh
+$ git clone
+$ cd rafale
+$ uv venv
+$ . .venv/bin/activate
+$ uv pip install -r cuda-requirements.txt (or cpu-requirements.txt)
+$ uv pip install -e .
+```
+
+Launch a run with a configuration.
-## Usage
+```sh
+$ python rafale/main -c test/pythia_tinystories.yaml
+```
-Mostly optimized for SLURM clusters.
+What if I just want to prepare my dataset? ```DATA=1``` will run the data preparation and caching pipeline without
+launching the training run.
```sh
+$ DATA=1 python rafale/main -c test/pythia_tinystories.yaml
+```
-rafale run -c config.yaml # set DEBUG=1 for a sanity check
+What if I want to test my model to make sure that its learning? ```DEBUG=1``` will run 10 epochs on a single training
+batch (same for train/eval), the model should fit quickly if there are no bugs in the implementation.
+```sh
+$ DEBUG=1 python rafale/main -c test/pythia_tinystories.yaml
```
-## Roadmap
-
-v0.1
-- [ ] Local model weight loading
-- [ ] load weights from safetensors and include it in the config (BERT and GPT2)
-- [ ] integration with lighteval (?)
-- [ ] Logging/Progress/Debug outputs with Rich library
-- ~~RoBERTa BPE tokenizer with TikToken (compare w/ HF), on the fly tokenization to be handled by dataloader's
- collator (for MLM)~~
- - ~~model will be tied to the tokenizer, so dataloader will be defined after the model and use it's tokenizer~~
- - We don't want anything to do with preprocessing, all data should be split/augmented/shuffled/tokenized/etc. All we
- do with this tool is load it from disk, turn it to a tensor and send it to the model
-- [ ] Local dataloader
-- [ ] ```debug``` single batch debug
-- [ ] ```main.py``` handles both training and evaluation (together or separately)
-- [-] BERT/RoBERTa support (MLM objective)
- + [ ] move the testing in the notebook to a debug file in the modeling folder
- + **layerwise decay** for fine-tuning (https://kozodoi.me/blog/20220329/discriminative-lr)
- + optimizations : flash attn2, xformers layer_norm (triton) or RMSNorm, xformers fused_linear_layer
- + RMSNorm
-- [ ] simple trainer (see lightning-fabric simple trainer example and start from there)
- + bf16/fp16, gradient clipping, and gradient accumulation
-
-v0.2
-- DeepSpeed ZeRO
- - Streaming dataloader
+
+### 🔧 Under the hood
+
+The goal of rafale is to provide a single entry point for data preparation and training. You configure the model and
+dataset. Then call the training job.
+
+When calling a run, first we run the datapipepline. If the dataset has already been processed (tokenized, padded,
+chunked, etc.), it will be loaded from the cache (default location is ```~/.rafale_cache```.
+
+> [!NOTE]
+> #### Adding a new model
+> To add a new model, you need write a new configuration to ```rafale/models/configurations.py```, and add it's key to
+> ```model_config_dict``` in ```rafale/main.py```.
+>
+> Look at the ```ComposerLM``` wrapper class in ```rafale/models/decoder.py``` to check if all your building blocks are
+> there. Otherwise you may need to modify/write a new wrapper.
+>
+> #### Adding a new datapipeline
+>
+> If the dataset is hosted on huggingface, simply use git lfs to clone the repo locally or use the repo name as the
+> dataset path. Same goes for tokenizers since we use their tokenizer implementation.
+>
+> You will need to add a new datapipeline class in ```rafale/datapipes.py``` where the ```_prepare``` method all data
+> preprocessing (tokenization, chunking, truncation, etc.) **EXCEPT** padding. Padding will be performed by the datacollator.
+
+### 📕 Docs
+
+Append this file ```llm-docprompt.txt``` to your favorite LLM and ask away.
+
+### 🦾 Supported models
+
+
+| Name | Implemented | Inference test | Training test |
+|:------------|:------------|:---------------|:--------------|
+| BERT | ✅ | | |
+| RoBERTa | ✅ | | |
+| Pythia | ✅ | ✅ | ✅ |
+| CLIP/SigLIP | ⏳ | | |
+
+
+## 🔮 Roadmap
+
+
+ v0.1
+
+
+### v0.1 - initial release
+- [x] single entrypoint CLI
+- [ ] simple deploy/build
+ - [x] CPU macos build - Ok, uv run works with this
+ - [x] local linux machine - for now uv for venv + requirements.txt
+ - [ ] SLURM compute-canada - TBD
+ - NOTE: because uv still does not fully play well with pytorch recommend semi-manual setup*
+- [ ] load weights from safetensors and include it in the config (BERT/RoBERTa and Pythia)
+ - [x] pythia
+ - [ ] BERT/RoBERTa (need to move from HF to safetensors)
+ - [ ] MLM
+ - [ ] Classification
+- [x] Pythia KV-cache implementation
+- [x] greedy generation
+- [ ] datapipes for CLM and MLM
+ - local dataloader for now
+ - [x] CLM tinystories
+ - [ ] MLM tinystories
+ - [ ] Imdb classification
+- [x] ```main.py``` handles both training and evaluation (together or separately)
+- [x] Mosaic Composer/Trainer
+ + [x] fp16
+ + [x] gradient clipping
+ + [x] gradient accumulation (automatically handled by composer)
+ + [x] building blocks are nn.Modules, specific models are ComposerModel classes with methods to load safetensor weights
+ automatically (keep in a single separate file for each model)
+ + [x] set DEBUG=1 for 1 batch sanity check before launching a run
+
+Datapipelines
+1. [x] tokenize
+2. [x] concat and split w/ block size (pad w/ collator)
+3. [x] save to disk {source}_{tokname}_bs{int}_len{int}
+4. [x] data_collator: *next* pad (if desired), label shift right and return torch tensor # HF: does this in the model...
+5. [x] test with model training
+6. [ ] tiny stories but for MLM also
+
+
+
+ v1.0
+
+### path to v1.0
+cleanup and additional features
+- [ ] clean up ```tests``` for pythia and bert models on tinystories
+- [ ] move the testing in the notebook to a debug file in the modeling folder
+- [ ] optimizations : flash attn2, xformers layer_norm (triton) or RMSNorm, xformers fused_linear_layer
+- [ ] try out schedulefree, SOAP, and other optimizers
+- [ ] **layerwise decay** for fine-tuning (https://kozodoi.me/blog/20220329/discriminative-lr)
+- [ ] multimodality CLIP
+- [ ] integration with lm-eval-harness (guide)[https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage]
+
+
+
+## I am GPU-rich what do I use?
+
+For large scale experiments other frameworks/libraries exist:
+- lingua (Facebookresearch)
+- torchtitan (Pytorch)
+- torchtune (Pytorch)
+- litGPT (LightningAI)
+- GPT-NeoX (EleutherAI)
+- nanotron (Huggingface)
+- llm-foundry (MosaicML)
diff --git a/cpu-requirements.txt b/cpu-requirements.txt
new file mode 100644
index 0000000..e69de29
diff --git a/cuda-requirements.txt b/cuda-requirements.txt
new file mode 100644
index 0000000..3338534
--- /dev/null
+++ b/cuda-requirements.txt
@@ -0,0 +1,98 @@
+aiohappyeyeballs==2.4.3
+aiohttp==3.10.10
+aiosignal==1.3.1
+anyio==4.6.2.post1
+argcomplete==3.5.1
+arrow==1.3.0
+attrs==24.2.0
+backoff==2.2.1
+certifi==2024.8.30
+charset-normalizer==3.4.0
+click==8.1.7
+composer==0.26.0
+coolname==2.2.0
+datasets==3.0.2
+dill==0.3.8
+docker-pycreds==0.4.0
+filelock==3.16.1
+frozenlist==1.5.0
+fsspec==2024.9.0
+gitdb==4.0.11
+gitpython==3.1.43
+gql==3.5.0
+graphql-core==3.2.5
+huggingface-hub==0.26.2
+idna==3.10
+importlib-metadata==8.5.0
+jinja2==3.1.4
+lightning-utilities==0.11.8
+markdown-it-py==3.0.0
+markupsafe==3.0.2
+mdurl==0.1.2
+mosaicml-cli==0.6.42
+mpmath==1.3.0
+multidict==6.1.0
+multiprocess==0.70.16
+networkx==3.4.2
+numpy==2.1.2
+nvidia-cublas-cu12==12.1.3.1
+nvidia-cuda-cupti-cu12==12.1.105
+nvidia-cuda-nvrtc-cu12==12.1.105
+nvidia-cuda-runtime-cu12==12.1.105
+nvidia-cudnn-cu12==9.1.0.70
+nvidia-cufft-cu12==11.0.2.54
+nvidia-curand-cu12==10.3.2.106
+nvidia-cusolver-cu12==11.4.5.107
+nvidia-cusparse-cu12==12.1.0.106
+nvidia-nccl-cu12==2.20.5
+nvidia-nvjitlink-cu12==12.6.77
+nvidia-nvtx-cu12==12.1.105
+packaging==24.1
+pandas==2.2.3
+pillow==10.4.0
+platformdirs==4.3.6
+prompt-toolkit==3.0.48
+propcache==0.2.0
+protobuf==5.28.3
+psutil==6.1.0
+py-cpuinfo==9.0.0
+pyarrow==18.0.0
+pygments==2.18.0
+python-dateutil==2.9.0.post0
+pytorch-ranger==0.1.1
+pytz==2024.2
+pyyaml==6.0.2
+questionary==1.10.0
+-e file:///home/max/code/rafale
+requests==2.32.3
+rich==13.9.3
+ruamel-yaml==0.18.6
+ruamel-yaml-clib==0.2.12
+safetensors==0.4.5
+sentry-sdk==2.17.0
+setproctitle==1.3.3
+setuptools==75.3.0
+six==1.16.0
+smmap==5.0.1
+sniffio==1.3.1
+sympy==1.13.3
+tabulate==0.9.0
+termcolor==2.5.0
+tokenizers==0.20.1
+torch==2.4.0
+torch-optimizer==0.3.0
+torchmetrics==1.4.0.post0
+torchvision==0.19.0
+tqdm==4.66.6
+triton==3.0.0
+types-python-dateutil==2.9.0.20241003
+typing-extensions==4.12.2
+tzdata==2024.2
+urllib3==2.2.3
+validators==0.34.0
+wandb==0.18.5
+wcwidth==0.2.13
+websockets==11.0.3
+xxhash==3.5.0
+yarl==1.17.1
+zipp==3.20.2
diff --git a/generate_llm_docprompt.sh b/generate_llm_docprompt.sh
new file mode 100755
index 0000000..c07b362
--- /dev/null
+++ b/generate_llm_docprompt.sh
@@ -0,0 +1,43 @@
+#!/bin/bash
+
+# Output file
+OUTPUT_FILE="llm-docprompt.txt"
+
+# Clear the output file if it exists
+> "$OUTPUT_FILE"
+
+# Add repo structure, excluding unwanted directories like .venv
+echo "### Repository Structure ###" >> "$OUTPUT_FILE"
+tree . -I 'wandb|*__pycache__|media|*-requirements.txt|.venv' >> "$OUTPUT_FILE"
+echo -e "\n\n" >> "$OUTPUT_FILE"
+
+# Include README.md content
+if [[ -f "README.md" ]]; then
+ echo "### README.md ###" >> "$OUTPUT_FILE"
+ cat README.md >> "$OUTPUT_FILE"
+ echo -e "\n\n" >> "$OUTPUT_FILE"
+fi
+
+# Function to include content of a given file type, excluding .venv directory
+include_files() {
+ local pattern="$1"
+ local header="$2"
+
+ find . -type f -name "$pattern" ! -path "./.venv/*" | while read -r file; do
+ echo "### $file ###" >> "$OUTPUT_FILE"
+ cat "$file" >> "$OUTPUT_FILE"
+ echo -e "\n\n" >> "$OUTPUT_FILE"
+ done
+}
+
+# Include Python files, excluding those in .venv
+include_files "*.py" "Python File"
+
+# Include YAML files only from the 'test' folder
+find ./test -type f -name "*.yaml" | while read -r yaml_file; do
+ echo "### $yaml_file ###" >> "$OUTPUT_FILE"
+ cat "$yaml_file" >> "$OUTPUT_FILE"
+ echo -e "\n\n" >> "$OUTPUT_FILE"
+done
+
+echo "Documentation prompt has been generated in $OUTPUT_FILE"
diff --git a/llm-docprompt.txt b/llm-docprompt.txt
new file mode 100644
index 0000000..b631d96
--- /dev/null
+++ b/llm-docprompt.txt
@@ -0,0 +1,2926 @@
+### Repository Structure ###
+.
+├── generate_llm_docprompt.sh
+├── LICENSE
+├── llm-docprompt.txt
+├── pyproject.toml
+├── rafale
+│ ├── caches.py
+│ ├── datapipe.py
+│ ├── __init__.py
+│ ├── main.py
+│ └── models
+│ ├── configurations.py
+│ ├── convert_hf_weights.py
+│ ├── decoder.py
+│ ├── decoding_strategies.py
+│ ├── encoder.py
+│ ├── __init__.py
+│ ├── model_utils.py
+│ ├── roberta_config
+│ │ └── tokenizer.json
+│ └── roberta.py
+├── README.md
+└── test
+ ├── pythia_tinystories.yaml
+ ├── test_bert.py
+ ├── test_pythia_generation.py
+ ├── test_pythia.py
+ └── test.yaml
+
+5 directories, 23 files
+
+
+
+### README.md ###
+
+
+## 💡Purpose
+
+Rafale provides an opinionated scaffolding for training transformers. It is solely built to be an efficient
+learning/research tool. It is **not** a fully fledged library for large scale training.
+
+It should be thought of as a starting point for research projects to bootstrap experiments on small LMs. The best way to
+use rafale is to simply fork it and build on top of it for your specific purposes.
+
+### Core dependencies
+
+Attempting to balance ergonomics and simplicity. This is meant to be easily hackable for research purposes.
+
+```
+torch, composer, datasets, tokenizers
+```
+
+## 🚀 Installation & Usage
+
+Setup with ```uv``` ([install uv](https://github.com/astral-sh/uv)).
+```sh
+$ git clone
+$ cd rafale
+$ uv venv
+$ . .venv/bin/activate
+$ uv pip install -r cuda-requirements.txt (or cpu-requirements.txt)
+$ uv pip install -e .
+```
+
+Launch a run with a configuration.
+
+```sh
+$ python rafale/main -c test/pythia_tinystories.yaml
+```
+
+What if I just want to prepare my dataset? ```DATA=1``` will run the data preparation and caching pipeline without
+launching the training run.
+
+```sh
+$ DATA=1 python rafale/main -c test/pythia_tinystories.yaml
+```
+
+What if I want to test my model to make sure that its learning? ```DEBUG=1``` will run 10 epochs on a single training
+batch (same for train/eval), the model should fit quickly if there are no bugs in the implementation.
+
+```sh
+$ DEBUG=1 python rafale/main -c test/pythia_tinystories.yaml
+```
+
+
+### 🔧 Under the hood
+
+The goal of rafale is to provide a single entry point for data preparation and training. You configure the model and
+dataset. Then call the training job.
+
+When calling a run, first we run the datapipepline. If the dataset has already been processed (tokenized, padded,
+chunked, etc.), it will be loaded from the cache (default location is ```~/.rafale_cache```.
+
+> [!NOTE]
+> #### Adding a new model
+> To add a new model, you need write a new configuration to ```rafale/models/configurations.py```, and add it's key to
+> ```model_config_dict``` in ```rafale/main.py```.
+>
+> Look at the ```ComposerLM``` wrapper class in ```rafale/models/decoder.py``` to check if all your building blocks are
+> there. Otherwise you may need to modify/write a new wrapper.
+>
+> #### Adding a new datapipeline
+>
+> If the dataset is hosted on huggingface, simply use git lfs to clone the repo locally or use the repo name as the
+> dataset path. Same goes for tokenizers since we use their tokenizer implementation.
+>
+> You will need to add a new datapipeline class in ```rafale/datapipes.py``` where the ```_prepare``` method all data
+> preprocessing (tokenization, chunking, truncation, etc.) **EXCEPT** padding. Padding will be performed by the datacollator.
+
+### 📕 Docs
+
+Append this file ```llm-docprompt.txt``` to your favorite LLM and ask away.
+
+### 🦾 Supported models
+
+
+| Name | Implemented | Inference test | Training test |
+|:------------|:------------|:---------------|:--------------|
+| BERT | ✅ | | |
+| RoBERTa | ✅ | | |
+| Pythia | ✅ | ✅ | ✅ |
+| CLIP/SigLIP | ⏳ | | |
+
+
+## 🔮 Roadmap
+
+
+ v0.1
+
+
+### v0.1 - initial release
+- [x] single entrypoint CLI
+- [ ] simple deploy/build
+ - [x] CPU macos build - Ok, uv run works with this
+ - [x] local linux machine - for now uv for venv + requirements.txt
+ - [ ] SLURM compute-canada - TBD
+ - NOTE: because uv still does not fully play well with pytorch recommend semi-manual setup*
+- [ ] load weights from safetensors and include it in the config (BERT/RoBERTa and Pythia)
+ - [x] pythia
+ - [ ] BERT/RoBERTa (need to move from HF to safetensors)
+ - [ ] MLM
+ - [ ] Classification
+- [x] Pythia KV-cache implementation
+- [x] greedy generation
+- [ ] datapipes for CLM and MLM
+ - local dataloader for now
+ - [x] CLM tinystories
+ - [ ] MLM tinystories
+ - [ ] Imdb classification
+- [x] ```main.py``` handles both training and evaluation (together or separately)
+- [x] Mosaic Composer/Trainer
+ + [x] fp16
+ + [x] gradient clipping
+ + [x] gradient accumulation (automatically handled by composer)
+ + [x] building blocks are nn.Modules, specific models are ComposerModel classes with methods to load safetensor weights
+ automatically (keep in a single separate file for each model)
+ + [x] set DEBUG=1 for 1 batch sanity check before launching a run
+
+Datapipelines
+1. [x] tokenize
+2. [x] concat and split w/ block size (pad w/ collator)
+3. [x] save to disk {source}_{tokname}_bs{int}_len{int}
+4. [x] data_collator: *next* pad (if desired), label shift right and return torch tensor # HF: does this in the model...
+5. [x] test with model training
+6. [ ] tiny stories but for MLM also
+
+
+
+ v1.0
+
+### path to v1.0
+cleanup and additional features
+- [ ] clean up ```tests``` for pythia and bert models on tinystories
+- [ ] move the testing in the notebook to a debug file in the modeling folder
+- [ ] optimizations : flash attn2, xformers layer_norm (triton) or RMSNorm, xformers fused_linear_layer
+- [ ] try out schedulefree, SOAP, and other optimizers
+- [ ] **layerwise decay** for fine-tuning (https://kozodoi.me/blog/20220329/discriminative-lr)
+- [ ] multimodality CLIP
+- [ ] integration with lm-eval-harness (guide)[https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage]
+
+
+
+## I am GPU-rich what do I use?
+
+For large scale experiments other frameworks/libraries exist:
+- lingua (Facebookresearch)
+- torchtitan (Pytorch)
+- torchtune (Pytorch)
+- litGPT (LightningAI)
+- GPT-NeoX (EleutherAI)
+- nanotron (Huggingface)
+- llm-foundry (MosaicML)
+
+
+
+### ./rafale/models/decoder.py ###
+#!/usr/bin/env python
+from typing import Optional
+import torch
+
+from torch import nn
+from torch import Tensor
+
+import torch.nn.functional as F
+
+from torch.nn.functional import scaled_dot_product_attention
+
+from torchmetrics import Metric
+from torchmetrics.collections import MetricCollection
+
+from composer.metrics import LossMetric, LanguagePerplexity
+from composer.models import ComposerModel
+
+
+###############################################################################
+# simple implementation of GPT building
+class NeoXRoPE(nn.Module):
+ @classmethod
+ def precompute_sin_cos_cache(cls, dim=None, seq_len=None, base=10000, device=None):
+ """Computes the cos and sin angles to be applied to the token vectors.
+
+ We begin by computing thetas (freqs) across each dimension pair (P=D/2) for the whole sequence length (L).
+ Then we convert this matrix of shape LP into complex numbers of the same shape.
+ Finally the real and imaginary parts of these complex numbers are stored in a stacked matrix and returned.
+
+ Args:
+ dim (int): number of features dimension per token to apply rotations to (d*rotary_pct)
+ seq_len (int): sequence length of the input (use the maximum sequence length)
+ base (int): default 10000
+
+ Returns:
+ Tensor # of shape [1,1,L,R]
+
+ Tensor dimension names:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+ - R rotary dimensions (d*rotary_pct)
+ """
+
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)
+ )
+ t = torch.arange(seq_len, dtype=torch.int64).type_as(inv_freq)
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos_cached = emb.cos()[None, None, :, :] # shape is [1, 1, L, R]
+ sin_cached = emb.sin()[None, None, :, :]
+
+ return cos_cached, sin_cached
+
+ @staticmethod
+ def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
+
+ @classmethod
+ def apply_rotary_pos_emb(cls, q_BHLR, k_BHLR, cos, sin):
+ """Applies the rotation to the input queries and key features."""
+
+ tensor_device = q_BHLR.get_device()
+
+ cos = cos.to(device=tensor_device)
+ sin = sin.to(device=tensor_device)
+
+ return (q_BHLR * cos) + (cls.rotate_half(q_BHLR) * sin), (k_BHLR * cos) + (
+ cls.rotate_half(k_BHLR) * sin
+ )
+
+ @classmethod
+ def apply_rotary_pos_emb_offset(cls, q, k, cos, sin, offset: int = 0):
+ """
+ q and k are shape: BHLR
+ cos, sin are shape: 11LR
+ """
+ cos, sin = (
+ cos[:, :, offset : q.shape[2] + offset, :],
+ sin[:, :, offset : q.shape[2] + offset, :],
+ )
+
+ tensor_device = q.get_device()
+
+ cos = cos.to(device=tensor_device)
+ sin = sin.to(device=tensor_device)
+
+ return (q * cos) + (cls.rotate_half(q) * sin), (k * cos) + (
+ cls.rotate_half(k) * sin
+ )
+
+
+class DecoderEmbedding(nn.Module):
+ """Simply an input projection of the tokens here, since rotary position encodings are used.
+
+ Tensor dimension names:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.input_embeddings = nn.Embedding(
+ num_embeddings=config.vocab_size,
+ embedding_dim=config.embed_dim,
+ )
+ self.dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(self, x_BL):
+ x_BLD = self.input_embeddings(x_BL)
+ return self.dropout(x_BLD)
+
+
+class DecoderAttentionRotary(nn.Module):
+ """
+ Attention with rotary position embedding.
+
+ Tensor dimension names:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.head_dim = config.embed_dim // config.num_heads
+ self.num_heads = config.num_heads
+ self.embed_dim = config.embed_dim
+ self.rotary_ndims = int(self.head_dim * config.rotary_pct)
+
+ self.attention_bias = True # @TODO: set bias to True or False from config.
+ self.query_key_value = nn.Linear(config.embed_dim, 3 * config.embed_dim)
+ self.dense = nn.Linear(config.embed_dim, config.embed_dim)
+ self.dropout = nn.Dropout(p=config.attention_dropout)
+ self.dropout_p = config.attention_dropout
+ self.norm_factor = self.head_dim**-0.5
+
+ def _split_heads(self, tensor: Tensor):
+ """
+ Splits hidden dim into attn_head_size and num_attention_heads
+
+ # input tensor: [bs, seq_len, hidden_size]
+ # returns: [bs, num_attention_heads, seq_len, attn_head_size]
+ """
+ batch_size = tensor.size(0)
+
+ return (
+ tensor.view(batch_size, -1, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ .contiguous()
+ )
+
+ def _merge_heads(self, tensor: Tensor):
+ """
+ input tensor: [bs. num_attention_heads, seq_len, attn_head_size]
+ returns: [bs, seq_len, hidden_size]
+ """
+ # tensor [bs, num_attention_heads, seq_len, attn_head_size]
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ # -> [bs, seq_len, num_attention_heads, attn_head_size]
+ tensor = tensor.view(
+ tensor.size(0),
+ tensor.size(1),
+ self.num_heads * self.head_dim,
+ ).contiguous()
+ # -> [bs, seq_len, hidden_size]
+ return tensor
+
+ def forward(self, x_BLD, freqs_cis):
+ if not self.training:
+ self.dropout_p = 0
+
+ bsz, seq_len, _ = x_BLD.size()
+
+ assert freqs_cis is not None
+
+ # projections
+ qkv = self.query_key_value(x_BLD)
+
+ # [batch, seq_len, (num_heads * 3 * head_size)]
+ # --> [batch, seq_len, num_heads, 3 * head_size]
+ new_qkv_shape = qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
+ qkv = qkv.view(*new_qkv_shape)
+
+ # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
+ q_BHLd = qkv[..., : self.head_dim].permute(0, 2, 1, 3)
+ k_BHLd = qkv[..., self.head_dim : 2 * self.head_dim].permute(0, 2, 1, 3)
+ v_BHLd = qkv[..., 2 * self.head_dim :].permute(0, 2, 1, 3)
+
+ # Slice the precomputed freqs_cis based on actual seq_len --> [1, 1, seq_len, R]
+ cos = freqs_cis[0][:, :, :seq_len, :]
+ sin = freqs_cis[1][:, :, :seq_len, :]
+
+ q_rot = q_BHLd[..., : self.rotary_ndims]
+ q_pass = q_BHLd[..., self.rotary_ndims :]
+ k_rot = k_BHLd[..., : self.rotary_ndims]
+ k_pass = k_BHLd[..., self.rotary_ndims :]
+
+ q_rot, k_rot = NeoXRoPE.apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
+
+ q_BHLd = torch.cat((q_rot, q_pass), dim=-1)
+ k_BHLd = torch.cat((k_rot, k_pass), dim=-1)
+
+ # compute attention
+ attn_out_BHLd = scaled_dot_product_attention(
+ q_BHLd,
+ k_BHLd,
+ v_BHLd,
+ is_causal=True,
+ scale=self.norm_factor,
+ dropout_p=self.dropout_p,
+ )
+
+ attn_out_BLD = self._merge_heads(attn_out_BHLd)
+
+ attn_out_BLD = self.dense(attn_out_BLD)
+
+ return attn_out_BLD
+
+
+class DecoderAttentionRotaryKVCache(DecoderAttentionRotary):
+ """implements the KV cache mechanism"""
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ def forward(self, x_BLD, freqs_cis, causal_mask=None, past_kv=None):
+ # A) figure out if we passed a cached KV
+ assert freqs_cis is not None
+ has_past_kv = past_kv is not None and past_kv[0].numel() > 0
+
+ if not self.training:
+ self.dropout_p = 0
+
+ bsz, seq_len, _ = x_BLD.size()
+ # B) if we have a cached KV, apply the offset to the sequence length for RoPE
+ if has_past_kv:
+ offset = past_kv[0].shape[
+ 2
+ ] # we want the lenght here, our kv shape is BHLd
+ seq_len += offset
+ else:
+ offset = 0
+
+ # projections
+ qkv = self.query_key_value(x_BLD)
+
+ # [batch, seq_len, (num_heads * 3 * head_size)]
+ # --> [batch, seq_len, num_heads, 3 * head_size]
+ new_qkv_shape = qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
+ qkv = qkv.view(*new_qkv_shape)
+
+ # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
+ q_BHLd = qkv[..., : self.head_dim].permute(0, 2, 1, 3)
+ k_BHLd = qkv[..., self.head_dim : 2 * self.head_dim].permute(0, 2, 1, 3)
+ v_BHLd = qkv[..., 2 * self.head_dim :].permute(0, 2, 1, 3)
+
+ # Slice the precomputed freqs_cis based on actual seq_len --> [1, 1, seq_len, R]
+ cos = freqs_cis[0][:, :, :seq_len, :]
+ sin = freqs_cis[1][:, :, :seq_len, :]
+
+ q_rot = q_BHLd[..., : self.rotary_ndims]
+ q_pass = q_BHLd[..., self.rotary_ndims :]
+ k_rot = k_BHLd[..., : self.rotary_ndims]
+ k_pass = k_BHLd[..., self.rotary_ndims :]
+
+ q_rot, k_rot = NeoXRoPE.apply_rotary_pos_emb_offset(
+ q_rot, k_rot, cos, sin, offset=offset
+ )
+
+ q_BHLd = torch.cat((q_rot, q_pass), dim=-1)
+ k_BHLd = torch.cat((k_rot, k_pass), dim=-1)
+
+ # C) before scaled_dot_product_attention we are going to
+ # Cache QKV values
+ if has_past_kv:
+ past_key, past_value = past_kv
+
+ k_BHLd = torch.cat((past_key.type_as(k_BHLd), k_BHLd), dim=2)
+ v_BHLd = torch.cat((past_value.type_as(v_BHLd), v_BHLd), dim=2)
+
+ # kv_cache = torch.stack((k_BHLd, v_BHLd))
+ kv_cache = (k_BHLd, v_BHLd) # let's keep them as a list of tuples
+ # ####################################################################
+
+ # print(f"key device {k_BHLd.get_device()}")
+ # print(f"query device {q_BHLd.get_device()}")
+ # print(f"value device {v_BHLd.get_device()}")
+ # print(f"causal mask device {causal_mask.get_device()}")
+
+ tensor_device = q_BHLd.get_device()
+ if causal_mask.get_device() != tensor_device:
+ causal_mask = causal_mask.to(device=tensor_device)
+
+ # compute attention here
+ attn_out_BHLd = scaled_dot_product_attention(
+ q_BHLd,
+ k_BHLd,
+ v_BHLd,
+ is_causal=False,
+ attn_mask=causal_mask,
+ scale=self.norm_factor,
+ dropout_p=self.dropout_p,
+ ) # even with rectangular matrices scaled_dot_product_attention will handle the causal mask by apply left bias
+ # causal mask which is exactly what we need.
+
+ attn_out_BLD = self._merge_heads(attn_out_BHLd)
+
+ attn_out_BLD = self.dense(attn_out_BLD)
+
+ return attn_out_BLD, kv_cache
+
+
+# ^^^^^ #######################################################################
+
+
+class DecoderFeedForward(nn.Module):
+ """
+ Tensor dimension names:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.ff_in = nn.Linear(config.embed_dim, config.ff_dim)
+ self.gelu = nn.GELU()
+ self.ff_out = nn.Linear(config.ff_dim, config.embed_dim)
+
+ def forward(self, x_BLD):
+ x_BLF = self.ff_in(x_BLD)
+ x_BLF = self.gelu(x_BLF)
+ out_BLD = self.ff_out(x_BLF)
+ return out_BLD
+
+
+class DecoderBlock(nn.Module):
+ """A single trasnformer decoder block/layer.
+
+ Tensor dimension names:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.attention = DecoderAttentionRotary(config)
+ self.feed_forward = DecoderFeedForward(config)
+ self.ffn_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
+ self.attention_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
+
+ def forward(self, x_BLD, freqs_cis, parallel_residual=True, use_cache=True):
+ assert freqs_cis is not None
+
+ if parallel_residual:
+ out_BLD = (
+ x_BLD
+ + self.attention(self.attention_norm(x_BLD), freqs_cis)
+ + self.feed_forward(self.ffn_norm(x_BLD))
+ )
+ else:
+ h_BLD = x_BLD + self.attention(self.attention_norm(x_BLD), freqs_cis)
+ out_BLD = h_BLD + self.feed_forward(self.ffn_norm(h_BLD))
+
+ return out_BLD
+
+
+# handle KV cache state
+class DecoderBlockKVcache(DecoderBlock):
+ """A single trasnformer decoder block/layer.
+
+ Tensor dimension names:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.attention = DecoderAttentionRotaryKVCache(config)
+ self.feed_forward = DecoderFeedForward(config)
+ self.ffn_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
+ self.attention_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
+
+ def forward(
+ self, x_BLD, freqs_cis, causal_mask, layer_kv_cache, parallel_residual=True
+ ):
+ assert freqs_cis is not None
+
+ if parallel_residual:
+ attn_out_BLD, layer_kv_cache = self.attention(
+ self.attention_norm(x_BLD), freqs_cis, causal_mask, layer_kv_cache
+ )
+ out_BLD = x_BLD + attn_out_BLD + self.feed_forward(self.ffn_norm(x_BLD))
+
+ else:
+ attn_out_BLD, layer_kv_cache = self.attention(
+ self.attention_norm(x_BLD), freqs_cis, layer_kv_cache
+ )
+ h_BLD = x_BLD + attn_out_BLD
+ out_BLD = h_BLD + self.feed_forward(self.ffn_norm(h_BLD))
+
+ return out_BLD, layer_kv_cache
+
+
+class DecoderWrapper(nn.Module):
+ """Full model wrapper for causal language modelling."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.token_embeddings = DecoderEmbedding(config)
+
+ if self.config.use_cache:
+ self.layers = nn.ModuleList(
+ DecoderBlockKVcache(config) for _ in range(config.num_blocks)
+ )
+ else:
+ self.layers = nn.ModuleList(
+ DecoderBlock(config) for _ in range(config.num_blocks)
+ )
+
+ self.final_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
+ self.output = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
+
+ self.freqs_cis: Optional[Tensor] = None
+
+ self.max_batch_size = -1
+ self.max_seq_length = config.max_pos_embedding
+
+ self.rotary_pct = 0.25
+
+ self.vocab_size = config.vocab_size
+
+ def setup_caches(self):
+ head_dim = self.config.embed_dim // self.config.num_heads
+ dtype = self.output.weight.dtype
+
+ head_size = self.config.embed_dim // self.config.num_heads
+ rotary_ndims = int(head_size * self.rotary_pct)
+ self.cos, self.sin = NeoXRoPE.precompute_sin_cos_cache(
+ dim=rotary_ndims,
+ seq_len=self.config.max_pos_embedding,
+ )
+
+ self.freqs_cis = (self.cos, self.sin)
+
+ def _generate_causal_mask(self, cache_length: int, new_length: int) -> torch.Tensor:
+ """
+ Creates a causal mask for autoregressive attention with KV caching.
+
+ Args:
+ cache_length (int): Number of cached tokens (K).
+ new_length (int): Number of new tokens to generate (T).
+ device (torch.device): The device on which to create the mask.
+ dtype (torch.dtype): The data type of the mask tensor.
+
+ Returns:
+ torch.Tensor: A (K + T) x (K + T) mask tensor where:
+ - Cached tokens can attend to all cached and new tokens.
+ - New tokens can attend to all cached tokens and up to their current position in new tokens.
+ - Future new tokens are masked to prevent attention.
+ """
+
+ causal_mask = torch.tril(torch.ones(new_length, new_length)).bool()
+
+ if cache_length > 0:
+ cache_tokens_mask = torch.ones((new_length, cache_length)).bool()
+ causal_mask = torch.cat((cache_tokens_mask, causal_mask), dim=1)
+
+ return causal_mask
+
+ def forward(self, batch: Tensor, past_kv_cache=None):
+ # if self.freqs_cis is None or self.causal_mask is None:
+ if self.freqs_cis is None:
+ self.setup_caches() # Caches must be initialized first
+
+ freqs_cis = self.freqs_cis
+
+ idx = batch["input_ids"]
+ num_new_tokens = idx.size(1)
+
+ x = self.token_embeddings(idx)
+
+ if self.config.use_cache and past_kv_cache is None:
+ past_kv_cache = [None] * self.config.num_blocks
+
+ if past_kv_cache[0] is not None:
+ num_cache_tokens = past_kv_cache[0][0].size(2)
+ else:
+ num_cache_tokens = 0
+ causal_mask = self._generate_causal_mask(num_cache_tokens, num_new_tokens)
+
+ kv_cache_list = []
+
+ for i, layer in enumerate(self.layers):
+ if self.config.use_cache:
+ x, layer_kv_cache = layer(x, freqs_cis, causal_mask, past_kv_cache[i])
+ kv_cache_list.append(layer_kv_cache)
+ else:
+ x = layer(x, freqs_cis)
+ x = self.final_norm(x)
+ logits = self.output(x)
+
+ return logits, kv_cache_list
+
+
+class ComposerLM(ComposerModel):
+ """wrapper with nice properties for simple training and evaluation"""
+
+ def __init__(self, config):
+ "docstring"
+ super().__init__()
+ self.model = DecoderWrapper(config)
+ self.ce_loss = nn.CrossEntropyLoss()
+ self.train_metrics = MetricCollection(
+ [LossMetric(self.ce_loss), LanguagePerplexity()]
+ )
+ self.eval_metrics = MetricCollection(
+ [LossMetric(self.ce_loss), LanguagePerplexity()]
+ )
+
+ def forward(self, batch): # batch is the output of the dataloader
+ """batch is a dict with "input_ids" key, model also takes past_kv"""
+ # specify how batches are passed through the model
+ return self.model(batch)
+
+ def eval_forward(self, batch, outputs=False):
+ if outputs:
+ if type(outputs) is tuple:
+ outputs, _ = outputs
+ return outputs
+
+ outputs = self.model(batch)
+ if type(outputs) is tuple:
+ outputs, _ = outputs
+
+ return outputs
+
+ def update_metric(self, batch, outputs, metric) -> None:
+ targets = batch["labels"]
+ metric.update(outputs.view(-1, self.model.vocab_size), targets.view(-1))
+
+ def get_metrics(self, is_train=False) -> dict[str, Metric]:
+ # defines which metrics to use in each phase of training
+ return self.train_metrics if is_train else self.eval_metrics
+
+ def loss(self, outputs, batch):
+ targets = batch["labels"]
+
+ if type(outputs) is tuple:
+ outputs, _ = outputs
+
+ return self.ce_loss(outputs.view(-1, self.model.vocab_size), targets.view(-1))
+
+
+
+### ./rafale/models/convert_hf_weights.py ###
+import argparse
+import torch
+import torch.nn.functional as F
+
+# from encoder import EncoderWrapper
+from transformers import AutoTokenizer, AutoModelForMaskedLM
+
+
+# helpers
+def list_params(model):
+ for k, v in model.items():
+ print(k)
+
+
+def get_name_params(model):
+ all_params = {}
+ for name, param in model.named_parameters():
+ all_params[name] = param
+
+ return all_params
+
+
+def convert_bert_params_dict(target, source):
+ conversion = [
+ ("bert.embeddings", "embedding_layer"),
+ ("bert.encoder.layer", "blocks"),
+ ("attention.self", "attention.self_attn"),
+ ("attention.output.dense", "attention.out"),
+ ("attention.output.LayerNorm", "add_norm_1.ln"),
+ ("intermediate.dense", "ff.ff_in"),
+ ("output.dense", "ff.ff_out"),
+ ("output.LayerNorm", "add_norm_2.ln"),
+ ("cls.predictions.bias", "mlm_head.bias"),
+ ("cls.predictions.transform.dense", "mlm_head.dense"),
+ ("cls.predictions.transform.LayerNorm", "mlm_head.ln"),
+ ("cls.predictions.decoder", "mlm_head.decoder"),
+ ]
+
+ source_parameters = source.state_dict()
+
+ updated_parameters = {}
+ for k, v in source_parameters.items():
+ for hf_term, my_term in conversion:
+ if hf_term in k:
+ k = k.replace(hf_term, my_term)
+
+ updated_parameters[k] = v
+
+ # return updated_parameters
+ # assert new_dict.keys() == target.keys(), was Ok but different for the state dict
+
+ # here we transfer weights for all layers
+ target.load_state_dict(updated_parameters, strict=True)
+
+ return target
+
+
+def convert_pythia_params_dict(target, source):
+ """
+ Source safetensors dict to our rafale model class.
+ """
+
+ # not needed for our implementation
+ unused = ["rotary_emb.inv_freq", "masked_bias", "attention.bias"]
+ for k, v in list(source.items()):
+ if True in [x in k for x in unused]:
+ del source[k]
+
+ conversion = [
+ ("gpt_neox.embed_in", "token_embeddings.input_embeddings"),
+ ("gpt_neox.layers", "layers"),
+ ("input_layernorm", "attention_norm"),
+ ("post_attention_layernorm", "ffn_norm"),
+ ("mlp", "feed_forward"),
+ ("dense_4h_to_h", "ff_out"),
+ ("dense_h_to_4h", "ff_in"),
+ ("embed_out", "output"),
+ ("gpt_neox.final_layer_norm", "final_norm"),
+ ]
+
+ updated_parameters = {}
+ for k, v in source.items():
+ for hf_term, my_term in conversion:
+ if hf_term in k:
+ k = k.replace(hf_term, my_term)
+
+ updated_parameters[k] = v
+
+ # here we transfer weights for all layers
+ target.load_state_dict(updated_parameters, strict=True)
+
+ return target
+
+
+def convert_roberta_params_dict(target, source):
+ conversion = [
+ ("roberta.embeddings", "embedding_layer"),
+ ("roberta.encoder.layer", "blocks"),
+ ("attention.self", "attention.self_attn"),
+ ("attention.output.dense", "attention.out"),
+ ("attention.output.LayerNorm", "add_norm_1.ln"),
+ ("intermediate.dense", "ff.ff_in"),
+ ("output.dense", "ff.ff_out"),
+ ("output.LayerNorm", "add_norm_2.ln"),
+ ("lm_head.bias", "mlm_head.bias"),
+ ("lm_head.dense", "mlm_head.dense"),
+ ("lm_head.layer_norm", "mlm_head.ln"),
+ ("lm_head.decoder", "mlm_head.decoder"),
+ ]
+
+ source_parameters = source.state_dict()
+
+ updated_parameters = {}
+ for k, v in source_parameters.items():
+ for hf_term, my_term in conversion:
+ if hf_term in k:
+ k = k.replace(hf_term, my_term)
+
+ updated_parameters[k] = v
+
+ # return updated_parameters
+ # assert new_dict.keys() == target.keys(), was Ok but different for the state dict
+
+ # here we transfer weights for all layers
+ target.load_state_dict(updated_parameters, strict=True)
+
+ return target
+
+
+def test_conversion():
+ """Convert the weights modify the dictionary for the blank and choice tokens, export the resulting tokenizer and
+ model checkpoints into assets
+ """
+ torch.manual_seed(0)
+
+ tokenizer = AutoTokenizer.from_pretrained(
+ "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
+ )
+ input_tokens = tokenizer("paris is the [MASK] of France.", return_tensors="pt")
+ hf_model = AutoModelForMaskedLM.from_pretrained(
+ "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
+ )
+ extra_tokens = [
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ "",
+ ]
+
+ # check if the tokens are already in the vocabulary
+ extra_tokens = set(extra_tokens) - set(tokenizer.vocab.keys())
+
+ # add the tokens to the tokenizer vocabulary
+ tokenizer.add_tokens(list(extra_tokens))
+
+ # add new, random embeddings for the new tokens
+ hf_model.resize_token_embeddings(len(tokenizer))
+
+ cfg = Config
+ slam = SlamEncoder(cfg)
+ slam = convert_bert_params_dict(slam, hf_model)
+
+ def get_test_preds(model, sample, tokenizer, hf=False):
+ model.eval()
+ output = model(**sample)
+
+ if hf:
+ output = output.logits
+
+ probs = F.softmax(output, dim=-1)
+ tokens = torch.argmax(probs, dim=-1)
+ sequence = tokenizer.batch_decode(tokens)
+ print(tokens)
+
+ return (
+ output,
+ tokens,
+ probs,
+ sequence,
+ )
+
+ _, _, _, hf_tokens = get_test_preds(hf_model, input_tokens, tokenizer, hf=True)
+
+ _, _, _, my_tokens = get_test_preds(slam, input_tokens, tokenizer)
+ # the output logits are different, however the output tokens predicted seem to be almost always the same
+ print(f"my implementation: {my_tokens}\n hf implementation: {hf_tokens}")
+
+
+def add_new_tokens(model, tokenizer, token_list):
+ token_list = set(token_list) - set(tokenizer.vocab.keys())
+
+ # add the tokens to the tokenizer vocabulary
+ tokenizer.add_tokens(list(token_list))
+
+ # add new, random embeddings for the new tokens
+ model.resize_token_embeddings(len(tokenizer))
+
+ return model, tokenizer
+
+
+def convert_and_save(dump_dir):
+ """Convert the weights modify the dictionary for the blank and choice tokens, export the resulting tokenizer and
+ model checkpoints into assets
+ """
+ tokenizer = AutoTokenizer.from_pretrained(
+ "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
+ )
+ hf_model = AutoModelForMaskedLM.from_pretrained(
+ "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext"
+ )
+
+ cfg = Config
+ slam = SlamEncoder(cfg)
+ slam = convert_bert_params_dict(slam, hf_model)
+
+ # save model
+ tokenizer.save_pretrained(f"{dump_dir}/tokenizer/")
+ torch.save(slam.state_dict(), f"{dump_dir}/model.pt")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Convert bert weights and tokenizer for Slamminnnnn"
+ )
+
+ # Adding arguments
+ parser.add_argument(
+ "--dump_dir",
+ required=True,
+ type=str,
+ help="Where to put the files",
+ )
+
+ # Parse arguments
+ args = parser.parse_args()
+ convert_and_save(args.dump_dir)
+
+
+if __name__ == "__main__":
+ main()
+
+
+
+### ./rafale/models/decoding_strategies.py ###
+import torch
+import torch.nn.functional as F
+
+
+def repeat_ngram(input_ids, ngram_max_length=4):
+ """
+ Clean up output by checking for repeated n-grams of length less than ngram_max_length.
+ If it finds a repeated n-gram, it returns the n-gram; otherwise, it returns 0.
+
+ Args:
+ input_ids: Tensor of shape (seq_length,), the sequence of input token IDs.
+ ngram_max_length: The maximum length of the n-gram to check for repetition.
+
+ Returns:
+ A list of tokens representing the repeated n-gram if found, otherwise 0.
+ """
+ ngram_found = False
+ ngram_tokens = None
+
+ # Create a set to store seen n-grams
+ seen_ngrams = set()
+
+ print(input_ids)
+ # Iterate through possible n-gram lengths
+ for n in range(1, ngram_max_length + 1):
+ for i in range(len(input_ids) - n + 1):
+ ngram = tuple(input_ids[i : i + n].tolist())
+ if ngram in seen_ngrams:
+ ngram_found = True
+ ngram_tokens = list(ngram)
+ break
+ seen_ngrams.add(ngram)
+ if ngram_found:
+ break
+
+ if ngram_found:
+ return ngram_tokens
+ else:
+ return 0
+
+
+def greedy_decode(model, batch, max_length, eos_token_id, check_repeat_ngrams=True):
+ """
+ Implements greedy decoding for the rafale transformer model.
+
+ Args:
+ model: The decoder model (e.g., DecoderWrapper).
+ batch: Dictionary containing input_ids of shape (batch_size, seq_length), the input prompt.
+ max_length: The maximum length of the generated sequence.
+ eos_token_id: The ID of the end-of-sequence token.
+
+ Returns:
+ Dictionary containing input_ids of shape (batch_size, max_length) with the generated tokens.
+ """
+ batch_size = batch["input_ids"].size(0)
+ if batch_size != 1:
+ raise ValueError(
+ "greedy_decode currently only supports batch_size=1. Provided batch_size: {}".format(
+ batch_size
+ )
+ )
+
+ input_seq_len = batch["input_ids"].size(1)
+ kv_cache_list = None
+
+ # Generate tokens until max_length or eos_token is generated
+ for _ in range(max_length - input_seq_len):
+ # Forward pass through the model
+ outputs, kv_cache_list = model(batch, kv_cache_list)
+ logits = outputs[:, -1, :] # Get the logits for the last generated token
+
+ # Greedily select the token with the highest probability
+ next_token = torch.argmax(logits, dim=-1).unsqueeze(-1) # Shape: (1, 1)
+
+ # Append the predicted token to the generated sequence
+ batch["input_ids"] = torch.cat((batch["input_ids"], next_token), dim=1)
+
+ # Check for repeated n-grams and stop if detected
+ if check_repeat_ngrams:
+ repeated_ngram = repeat_ngram(
+ batch["input_ids"].squeeze(), ngram_max_length=4
+ )
+ if repeated_ngram != 0:
+ print(repeated_ngram)
+ break
+
+ # Check if the sequence has generated the eos_token_id
+ if next_token.item() == eos_token_id:
+ break
+
+ return batch
+
+
+
+### ./rafale/models/model_utils.py ###
+def get_tokens_from_logits(logits, tokenizer=None):
+ """
+ return the prediced tokens for all of the inputs
+ """
+ # Apply softmax to convert logits to probabilities
+ probabilities = F.softmax(logits, dim=-1)
+
+ # Get the predicted token IDs
+ predicted_token_ids = torch.argmax(probabilities, dim=-1)
+
+ predicted_tokens = [
+ tokenizer.convert_ids_to_tokens(seq.numpy())
+ for seq in torch.unbind(predicted_token_ids, dim=0)
+ ]
+ return predicted_tokens
+
+
+
+### ./rafale/models/encoder.py ###
+from dataclasses import dataclass
+
+import torch
+import torch.nn.functional as F
+
+from torch import nn
+from torch.nn.functional import scaled_dot_product_attention
+
+
+###############################################################################
+# SIMPLE BERT-like BUILDING BLOCKS #
+###############################################################################
+
+
+# @TODO :: Refactor, improve documentation and add tensor dimension keys for the names
+
+
+class Embedding(nn.Module):
+ """Embeddings
+
+ In addition to the word embedding, BERT uses learned absolute position embeddings. We also have token type embedding for one the BERT pretraining
+ objectives.
+
+ Tensor dimension keys:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+ """
+
+ def __init__(
+ self,
+ vocab_size=None,
+ hidden_size=None,
+ pad_token_id=None,
+ max_sequence_length=512,
+ num_token_type=None, # technically should be 2, in HF they use type_vocab_size
+ layer_norm_eps=None,
+ hidden_dropout_prob=None,
+ ):
+ super().__init__()
+ # nn.Embedding is just a lookup table,
+ self.word_embeddings = nn.Embedding(
+ num_embeddings=vocab_size,
+ embedding_dim=hidden_size,
+ padding_idx=pad_token_id,
+ )
+ self.position_embeddings = nn.Embedding(
+ num_embeddings=max_sequence_length,
+ embedding_dim=hidden_size,
+ padding_idx=pad_token_id, # ROBERTA only?
+ )
+ self.token_type_embeddings = nn.Embedding(
+ num_token_type, hidden_size
+ ) # NOTE :: these are actually the segment embeddings
+ # from the original BERT paper... maybe rename?
+ self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
+ self.dropout = nn.Dropout(hidden_dropout_prob)
+
+ # not considered as model parameter
+ self.register_buffer(
+ "position_ids",
+ torch.arange(max_sequence_length).expand((1, -1)),
+ persistent=False,
+ )
+
+ def forward(self, input_ids, token_type_ids):
+ seq_length = input_ids.size(1)
+
+ position_ids = torch.index_select(
+ self.position_ids, 1, torch.arange(seq_length)
+ )
+ position_ids = position_ids.expand_as(input_ids)
+
+ # we assume absolute positional encoding here like in the original BERT and sum everything up
+ W = self.word_embeddings(input_ids)
+ P = self.position_embeddings(position_ids)
+ T = self.token_type_embeddings(token_type_ids)
+
+ E = W + P + T
+ E = self.LayerNorm(E)
+ E = self.dropout(E)
+
+ return E
+
+
+class EncoderSelfAttention(nn.Module):
+ """Bidirectional multi-head self attention.
+
+ Tensor dimension keys:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+ """
+
+ def __init__(self, n_heads, embed_dim, dropout_p=0.1, fast_attn=False):
+ super().__init__()
+ self.dropout_p = dropout_p
+ self.fast_attn = fast_attn
+ assert embed_dim % n_heads == 0
+
+ # We assume d_v always equals d_k
+ self.head_dim = embed_dim // n_heads
+ self.n_heads = n_heads
+ self.embed_dim = embed_dim
+ self.all_head_size = n_heads * self.head_dim
+
+ # get linear projections
+ self.query = nn.Linear(embed_dim, embed_dim)
+ self.key = nn.Linear(embed_dim, embed_dim)
+ self.value = nn.Linear(embed_dim, embed_dim)
+
+ def forward(self, q, k, v):
+ """"""
+ batch_size = q.size(0)
+ if not self.training:
+ self.dropout_p = 0
+
+ # check transformation again here....
+ q = (
+ self.query(q)
+ .view(batch_size, -1, self.n_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ k = (
+ self.key(k)
+ .view(batch_size, -1, self.n_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+ v = (
+ self.value(v)
+ .view(batch_size, -1, self.n_heads, self.head_dim)
+ .transpose(1, 2)
+ )
+
+ attn_output = scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.dropout_p,
+ )
+
+ # concatenate heads and put through final linear layer
+ attn_output = (
+ attn_output.transpose(1, 2)
+ .contiguous()
+ .view(batch_size, -1, self.embed_dim)
+ )
+
+ return attn_output
+
+
+class AttentionModule(nn.Module):
+ """the actual block with the output projections"""
+
+ # output
+ def __init__(self, n_heads, embed_dim, dropout_p=None, fast_attn=False):
+ super().__init__()
+ self.self_attn = EncoderSelfAttention(
+ n_heads, embed_dim, dropout_p=dropout_p, fast_attn=fast_attn
+ )
+ self.out = nn.Linear(embed_dim, embed_dim)
+
+ def forward(self, x):
+ attn_output = self.self_attn(x, x, x)
+ out = self.out(attn_output)
+ return out
+
+
+class FeedForward(nn.Module):
+ def __init__(self, embed_dim, ff_dim):
+ super().__init__()
+ self.ff_in = nn.Linear(embed_dim, ff_dim)
+ self.gelu = nn.GELU()
+ self.ff_out = nn.Linear(ff_dim, embed_dim)
+
+ def forward(self, x):
+ x = self.ff_in(x)
+ x = self.gelu(x)
+ x = self.ff_out(x)
+ return x
+
+
+class AddNorm(nn.Module):
+ def __init__(self, embed_dim, eps=None, dropout_p=None):
+ super().__init__()
+ self.ln = nn.LayerNorm(embed_dim, eps=eps)
+ self.dropout = nn.Dropout(dropout_p)
+
+ def forward(self, x, residual):
+ x = self.dropout(x) # @TODO :: make sure this should be here...
+ x = self.ln(x + residual)
+
+ return x
+
+
+class EncoderBlock(nn.Module):
+ def __init__(self, embed_dim, n_heads, ff_dim, eps=None, dropout_p=None):
+ super().__init__()
+
+ self.attention = AttentionModule(
+ n_heads=n_heads, embed_dim=embed_dim, dropout_p=dropout_p
+ )
+ self.add_norm_1 = AddNorm(embed_dim, eps=eps, dropout_p=dropout_p)
+ self.ff = FeedForward(embed_dim, ff_dim=ff_dim)
+ self.add_norm_2 = AddNorm(embed_dim, eps=eps, dropout_p=dropout_p)
+
+ def forward(self, x):
+ residual_1 = x
+ x = self.attention(x)
+ x = self.add_norm_1(x, residual_1)
+
+ residual_2 = x
+ x = self.ff(x)
+ x = self.add_norm_2(x, residual_2)
+
+ return x
+
+
+class MLMHead(nn.Module):
+ def __init__(self, embed_dim, vocab_size, eps=None):
+ super().__init__()
+ self.dense = nn.Linear(embed_dim, embed_dim)
+ self.gelu = nn.GELU()
+ self.ln = nn.LayerNorm(embed_dim, eps=eps)
+
+ self.decoder = nn.Linear(embed_dim, vocab_size, bias=False)
+ self.bias = nn.Parameter(torch.zeros(vocab_size))
+
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
+ self.decoder.bias = self.bias
+
+ def forward(self, x):
+ x = self.dense(x)
+ x = self.gelu(x)
+ x = self.ln(x)
+ x = self.decoder(x)
+
+ return x
+
+
+class EncoderWrapper(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.embedding_layer = Embedding(
+ vocab_size=config.vocab_size,
+ hidden_size=config.embed_dim,
+ pad_token_id=config.pad_token_id,
+ max_sequence_length=config.max_pos_embedding,
+ num_token_type=config.num_token_type, # technically should be 2, in HF they use type_vocab_size
+ layer_norm_eps=config.layer_norm_eps,
+ hidden_dropout_prob=config.hidden_dropout,
+ )
+
+ self.blocks = nn.ModuleList()
+ for i in range(config.num_blocks):
+ self.blocks.append(
+ EncoderBlock(
+ config.embed_dim,
+ config.num_heads,
+ config.ff_dim,
+ eps=config.layer_norm_eps,
+ dropout_p=config.hidden_dropout,
+ )
+ )
+
+ self.mlm_head = MLMHead(
+ embed_dim=config.embed_dim,
+ eps=config.layer_norm_eps,
+ vocab_size=config.vocab_size,
+ )
+
+ # Tie the weights
+ self.mlm_head.decoder.weight = self.embedding_layer.word_embeddings.weight
+ # @NOTE :: bias are tied too with the HF model
+
+ # no bias for MLM head (?), let's keep it since the HF implementation keeps it as well
+ # self.mlm_head.mlm[-1].bias = None
+ def forward(self, **kwargs):
+ input_ids = kwargs["input_ids"]
+ token_type_ids = kwargs["token_type_ids"]
+
+ x = self.embedding_layer(input_ids, token_type_ids)
+ # x = self.encoder_blocks(x)
+ for block in self.blocks:
+ x = block(x)
+ x = self.mlm_head(x)
+
+ return x
+
+ def compute_loss(self, logits, labels):
+ """ """
+ ce_loss = nn.CrossEntropyLoss(ignore_index=-100)
+
+ # Flatten the logits and labels
+ logits = logits.view(
+ -1, self.config.vocab_size
+ ) # Adjust vocab_size as per your config
+ labels = labels.view(-1)
+
+ # Compute and return the loss
+ return ce_loss(logits, labels)
+
+
+
+### ./rafale/models/__init__.py ###
+
+
+
+### ./rafale/models/configurations.py ###
+import requests
+import os
+
+from dataclasses import dataclass
+
+from safetensors import safe_open
+
+from ..caches import MODEL_CACHE_DIR
+
+"""
+to simplify model loading add a configuration for the pre-trained weight loading using safetensors instead of loading
+the full model.
+> then save to a folder named ".pretrained/" in this directory
+"""
+
+
+def download_file(url, save_path):
+ """
+ Downloads a file from the specified URL to the given save path.
+
+ :param url: The URL of the file to download.
+ :param save_path: The local path where the file will be saved.
+ """
+ try:
+ # Make sure the directory exists
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+ with requests.get(url, stream=True) as response:
+ response.raise_for_status() # Check for HTTP errors
+ total_size = int(response.headers.get("content-length", 0))
+ block_size = 1024 # 1 Kilobyte
+ progress = 0
+
+ with open(save_path, "wb") as file:
+ for data in response.iter_content(block_size):
+ file.write(data)
+ progress += len(data)
+ print(
+ f"Downloaded {progress} of {total_size} bytes ({(progress/total_size)*100:.2f}%)",
+ end="\r",
+ )
+
+ print(f"\nDownload completed successfully! File saved to: {save_path}")
+
+ except requests.exceptions.HTTPError as http_err:
+ print(f"HTTP error occurred: {http_err}") # HTTP error
+ except Exception as err:
+ print(f"An error occurred: {err}") # Other errors
+
+
+def load_safetensors(rafale_model, model_config):
+ """Transfer the pretrained safetensors to rafale model"""
+ tensors = {}
+
+ safetensors_path = os.path.join(MODEL_CACHE_DIR, model_config.name + ".safetensors")
+
+ if os.path.isfile(safetensors_path):
+ pass
+ else:
+ download_file(model_config.safetensors_url, safetensors_path)
+
+ with safe_open(safetensors_path, framework="pt") as f:
+ for k in f.keys():
+ tensors[k] = f.get_tensor(k)
+
+ rafale_model = model_config.convert_params_dict(rafale_model, tensors)
+
+ return rafale_model
+
+
+@dataclass
+class BertConfig:
+ embed_dim: int = 768
+ vocab_size: int = 30522
+ attention_dropout: float = 0.1
+ hidden_dropout: float = 0.1
+ num_heads: int = 12
+ ff_dim: int = 3072
+ max_pos_embedding: int = 512
+ layer_norm_eps: float = 1e-12
+ num_blocks: int = 12
+ pad_token_id: int = 0
+ num_token_type: int = 2
+ fast_attention: bool = (
+ False # use xformers (todo: add FlashAttention2), NOT IMPLEMENTED*
+ )
+
+
+@dataclass
+class BertTinyConfig:
+ embed_dim: int = 128
+ vocab_size: int = 30522 # could usage would be to 30522 + num_extra_tokens
+ attention_dropout: float = 0.1
+ hidden_dropout: float = 0.1
+ num_heads: int = 2
+ ff_dim: int = 512
+ max_pos_embedding: int = 512
+ layer_norm_eps: float = 1e-12
+ num_blocks: int = 2
+ pad_token_id: int = 0
+ num_token_type: int = 2
+ fast_attention: bool = False # use xformers (todo: add FlashAttention2)
+
+
+@dataclass
+class RobertaConfig:
+ embed_dim: int = 768
+ vocab_size: int = 50265
+ attention_dropout: float = 0.1
+ hidden_dropout: float = 0.1
+ num_heads: int = 12
+ ff_dim: int = 3072
+ max_pos_embedding: int = 514
+ layer_norm_eps: float = 1e-05
+ num_blocks: int = 12
+ pad_token_id: int = 1
+ num_token_type: int = 1
+ bos_token_id: int = 0
+ eos_token_id: int = 2
+ fast_attention: bool = False
+
+
+@dataclass
+class Pythia14MConfig:
+ name: str = "pythia14m"
+ safetensors_url: str = (
+ "https://huggingface.co/EleutherAI/pythia-14m/resolve/main/model.safetensors"
+ )
+
+ embed_dim: int = 128
+ num_heads: int = 4
+ ff_dim: int = 512
+ hidden_act: str = "gelu"
+ max_pos_embedding: int = 2048
+ vocab_size: int = 50304
+ use_cache: bool = True
+
+ parallel_residual: bool = True
+
+ attention_dropout: float = 0.1
+ hidden_dropout: float = 0.1
+
+ layer_norm_eps: float = 1e-05
+ num_blocks: int = 6
+
+ bos_token_id: int = 0
+ eos_token_id: int = 0
+ fast_attention: bool = False
+
+ rotary_emb_base: int = 10000
+ rotary_pct: float = 0.25
+
+ tie_word_embeddings: bool = False
+
+ @classmethod
+ def convert_params_dict(cls, target, source):
+ """
+ Source safetensors dict to our rafale model class.
+ """
+ # not needed for our implementation
+ unused = ["rotary_emb.inv_freq", "masked_bias", "attention.bias"]
+ for k, v in list(source.items()):
+ if True in [x in k for x in unused]:
+ del source[k]
+
+ conversion = [
+ ("gpt_neox.embed_in", "token_embeddings.input_embeddings"),
+ ("gpt_neox.layers", "layers"),
+ ("input_layernorm", "attention_norm"),
+ ("post_attention_layernorm", "ffn_norm"),
+ ("mlp", "feed_forward"),
+ ("dense_4h_to_h", "ff_out"),
+ ("dense_h_to_4h", "ff_in"),
+ ("embed_out", "output"),
+ ("gpt_neox.final_layer_norm", "final_norm"),
+ ]
+
+ updated_parameters = {}
+ for k, v in source.items():
+ for hf_term, my_term in conversion:
+ if hf_term in k:
+ k = k.replace(hf_term, my_term)
+
+ updated_parameters[k] = v
+
+ # here we transfer weights for all layers
+ target.load_state_dict(updated_parameters, strict=True)
+
+ return target
+
+
+
+### ./rafale/models/roberta.py ###
+import torch
+
+from encoder import EncoderWrapper
+from dataclasses import dataclass
+
+# from composer.models import ComposerModel
+
+
+@dataclass
+class RobertaConfig:
+ embed_dim: int = 768
+ vocab_size: int = 50265
+ attention_dropout: float = 0.1
+ hidden_dropout: float = 0.1
+ num_heads: int = 12
+ ff_dim: int = 3072
+ max_pos_embedding: int = 514
+ layer_norm_eps: float = 1e-05
+ num_blocks: int = 12
+ pad_token_id: int = 1
+ num_token_type: int = 1
+ bos_token_id: int = 0
+ eos_token_id: int = 2
+ fast_attention: bool = False # use xformers (todo: add FlashAttention2)
+
+
+class RobertaMLM(EncoderWrapper):
+ def __init__(self, config):
+ super().__init__(config)
+ self.embedding_layer.forward = self.roberta_embedding_forward
+ # monkey patched forward method to fix the position_id embedding without changing the original encoder embedding class.
+
+ def mlm_hook(self):
+ """TBD"""
+ None
+
+ def roberta_embedding_forward(self, input_ids, token_type_ids):
+ position_ids = self.create_position_ids_from_input_ids(input_ids, 1)
+
+ # we assume absolute positional encoding here like in the original BERT and sum everything up
+ W = self.embedding_layer.word_embeddings(input_ids)
+ P = self.embedding_layer.position_embeddings(position_ids)
+ T = self.embedding_layer.token_type_embeddings(token_type_ids)
+
+ E = W + P + T
+ E = self.embedding_layer.LayerNorm(E)
+ E = self.embedding_layer.dropout(E)
+
+ return E
+
+ def create_position_ids_from_input_ids(
+ self, input_ids, padding_idx, past_key_values_length=0
+ ):
+ """
+ MAX NOTE: from the huggingface implementation, they use a different method to create the positon_ids in roberta than
+ bert. whithout this the model breaks... simply modifies the method used to cast the array.
+
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
+ are ignored. This is modified from fairseq's `utils.make_positions`.
+
+ Args:
+ x: torch.Tensor x:
+
+ Returns: torch.Tensor
+ """
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
+ mask = input_ids.ne(padding_idx).int()
+ incremental_indices = (
+ torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length
+ ) * mask
+ return incremental_indices.long() + padding_idx
+
+
+
+### ./rafale/datapipe.py ###
+import os
+import warnings
+from abc import ABC, abstractmethod
+
+from datasets import load_dataset, DatasetDict
+
+from tokenizers import Tokenizer
+from tokenizers.processors import TemplateProcessing
+
+import torch
+from torch.utils.data import DataLoader
+from torch.nn.utils.rnn import pad_sequence
+
+from rafale.caches import DATA_CACHE_DIR
+
+
+class DataPipeline(ABC):
+ """
+ Base class
+
+ A data pipeline is initiated with a path and some configurations parameters:
+
+ Args:
+ - dataset_path : local path to the dataset
+ - collator : function to be applied when sending a batch through the dataloader
+
+
+ Datasets should be saved using the following format:
+ /.json
+
+ Returns:
+ """
+
+ def __init__(self, **kwargs):
+ self.name: str = kwargs["name"]
+ self.tokenizer_name: str = kwargs["tokenizer_name"]
+ self.is_prepared: bool = kwargs["is_prepared"]
+ self.dataset_path: str = os.path.expanduser(kwargs["dataset_path"])
+
+ self.max_sequence_length: int = kwargs["max_sequence_length"]
+ self.train_batch_size: int = kwargs["train_batch_size"]
+ self.eval_batch_size: int = kwargs["eval_batch_size"]
+ self.pad_inputs: bool = kwargs["pad_inputs"]
+ self.pad_token_id: int = kwargs["pad_token_id"]
+ self.input_id_key: str = kwargs["input_id_key"]
+ self.shuffle_train: bool = kwargs["shuffle_train"]
+
+ self.num_processes: int = kwargs["num_processes"]
+ self.tokenizer = Tokenizer.from_pretrained(kwargs["tokenizer_path"])
+
+ self.data_collator = None
+
+ self.use_cached = False
+
+ def _load(self):
+ # either load directly from disk or from an hf dataset repo
+ try:
+ self.dataset = DatasetDict.load_from_disk(self.dataset_path)
+ except:
+ try:
+ self.dataset = load_dataset(self.dataset_path)
+ pass
+ except:
+ raise OSError(
+ f"Wrong dataset file and/or path configuration! path: {self.dataset_path}"
+ )
+
+ @abstractmethod
+ def _prepare(self):
+ """Perform all data preprocessing here: tokenization, chunking, truncation, etc. (EXCEPT padding!). Padding will be performed by
+ the datacollator.
+ """
+ pass
+
+ def _check_path(self):
+ """make sure that the dataset has not already been parsed at location"""
+ output_path = f"{self.name}_{self.tokenizer_name}_bs{self.train_batch_size}_len{self.max_sequence_length}"
+
+ assert DATA_CACHE_DIR[-1] == "/"
+
+ save_path = os.path.abspath(os.path.join(DATA_CACHE_DIR, output_path))
+
+ if os.path.isdir(DATA_CACHE_DIR):
+ pass
+ else:
+ os.makedirs(DATA_CACHE_DIR)
+
+ if os.path.isdir(save_path):
+ warnings.warn(
+ f"Dataset already exists at location:\n\t {save_path} \n ABORTING PREPARATION, USING CACHED DATASET!"
+ )
+
+ self.is_prepared = True
+ self.use_cached = True
+
+ return save_path
+
+ def __call__(self):
+ # returns a or multiple dataloaders
+ self._load()
+
+ self.dataloaders = {}
+
+ if not self.is_prepared:
+ cache_dataset_path = self._check_path()
+
+ if type(self.dataset) == DatasetDict:
+ for subset in self.dataset:
+ if subset == "train":
+ shuffle = self.shuffle_train
+ batch_size = self.train_batch_size
+ else:
+ shuffle = False
+ batch_size = self.eval_batch_size
+
+ # if the data is not ready to be passed to the dataloader
+ if not self.is_prepared:
+ print(f"preparing subset {subset}")
+ self.dataset[subset] = self._prepare(self.dataset[subset])
+
+ if self.use_cached:
+ self.dataset = DatasetDict.load_from_disk(cache_dataset_path)
+
+ self.dataloaders[subset] = DataLoader(
+ self.dataset[subset],
+ collate_fn=self.data_collator,
+ batch_size=batch_size,
+ )
+ print(f"✅ Dataloader ready for subset {subset}.")
+
+ if not self.is_prepared:
+ self.dataset.save_to_disk(cache_dataset_path)
+ print(f"✅ Saved prepared dataset at {cache_dataset_path}.")
+ else:
+ raise TypeError(
+ f"self.dataset is type {type(self.dataset)}, but should be DatasetDict."
+ )
+
+ return self.dataloaders
+
+
+class InferenceDatapipeline:
+ def __init__(self, tokenizer_path):
+ self.tokenizer = Tokenizer.from_pretrained(tokenizer_path)
+
+ def _tokenizer_templating(self, tokenizer, add_eos=True):
+ if add_eos:
+ tokenizer.post_processor = TemplateProcessing(
+ single="<|endoftext|> $A",
+ special_tokens=[
+ ("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")),
+ ],
+ )
+
+ return tokenizer
+
+ def __call__(self, input_str, use_template: bool = True):
+ """
+ tokenize input_str
+ convert to torch tensor (batch_size=1)
+ add the endoftext token
+ """
+ if use_template:
+ self.tokenizer = self._tokenizer_templating(self.tokenizer)
+
+ tokenized_inputs = {
+ "input_ids": torch.LongTensor(
+ self.tokenizer.encode(input_str).ids
+ ).unsqueeze(dim=0)
+ }
+
+ return tokenized_inputs
+
+ def ids_to_str(self, tensor):
+ return ifdp.tokenizer.decode(tensor.squeeze().detach().numpy())
+
+
+class CausalCollator:
+ def __init__(self, pad_token_id: int, input_id_key: str, pad_inputs: bool):
+ self.pad_token_id = pad_token_id
+ self.input_id_key = input_id_key
+ self.pad_inputs = pad_inputs
+
+ def __call__(self, features):
+ # Extract the input IDs from the batch
+ input_ids = [torch.tensor(example[self.input_id_key]) for example in features]
+
+ # Pad the inputs if required
+ if self.pad_inputs:
+ input_ids = pad_sequence(
+ input_ids, batch_first=True, padding_value=self.pad_token_id
+ )
+
+ # Set the last token of each label sequence to pad_token_id to ignore loss for the last prediction
+ # shift left the ids
+ labels = [
+ torch.cat([ids[1:], torch.tensor([self.pad_token_id])]) for ids in input_ids
+ ]
+ labels = torch.stack(labels, dim=0)
+
+ return {
+ "input_ids": input_ids,
+ "labels": labels,
+ }
+
+
+class MLMCollator:
+ def __init__(
+ self,
+ mask_p: float = 0.15,
+ whole_word_mask: bool = False,
+ mask_span: bool = False,
+ pad_token_id: int = -100,
+ input_id_key: str = "input_ids",
+ pad_inputs: bool = True,
+ ):
+ """masks some % of tokens for MLM objective"""
+ pass
+
+
+class DefaultCollator:
+ def __init__(
+ self,
+ pad_token_id: int = -100,
+ input_id_key: str = "input_ids",
+ pad_inputs: bool = True,
+ ):
+ """for task data where labels are already set"""
+ pass
+
+
+class TinyStoriesCausalNeoX(DataPipeline):
+ """This is sample datapipelin for the TinyStories dataset.
+
+
+ This dataset is prepared for causal language modelling using the gpt neox tokenizer (eleutherai). We
+
+ Usage:
+ ts_dict = {
+ "name": "tinystories_testing",
+ "tokenizer_name": "neox",
+ "is_prepared": False,
+ "input_id_key": "input_ids",
+ "batch_size": 16,
+ "shuffle_train": False,
+ "dataset_path": "~/code/data/micro_tinystories",
+ "tokenizer_path": "EleutherAI/pythia-14m",
+ "max_sequence_length": 128,
+ "pad_token_id": -100,
+ "pad_inputs": True,
+ "is_prepared": False,
+ }
+ ts_dpipe = TinyStoriesCausalNeoX(**ts_dict)
+ dataloaders = ts_causal()
+
+ Args:
+
+ Returns:
+
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.data_collator = CausalCollator(
+ pad_token_id=self.pad_token_id,
+ input_id_key=self.input_id_key,
+ pad_inputs=self.pad_inputs,
+ )
+
+ # TODO: figure out if they really only use endoftext for everything...
+ def _tokenizer_templating(self, tokenizer, add_eos=True):
+ if add_eos:
+ tokenizer.post_processor = TemplateProcessing(
+ single="$A <|endoftext|>",
+ special_tokens=[
+ ("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")),
+ ],
+ )
+
+ return tokenizer
+
+ def _tokenize(self, example, tokenizer, key="text"):
+ return {self.input_id_key: tokenizer.encode(example[key]).ids}
+
+ def _group_and_chunk(self, examples, key="input_ids", block_size=None, pad=False):
+ concatenated_tokens = sum(examples[key], [])
+ total_length = len(concatenated_tokens)
+
+ if total_length >= block_size:
+ total_length = (total_length // block_size) * block_size
+
+ result = {
+ "input_ids": [
+ concatenated_tokens[i : i + block_size]
+ for i in range(0, total_length, block_size)
+ ]
+ }
+
+ return result
+
+ def _prepare(self, data):
+ """"""
+
+ # apply functions above to dataset
+ self.tokenizer = self._tokenizer_templating(self.tokenizer)
+
+ data = data.map(
+ lambda example: self._tokenize(example, self.tokenizer),
+ remove_columns=data.column_names,
+ num_proc=self.num_processes,
+ )
+
+ data = data.map(
+ lambda example: self._group_and_chunk(
+ example, block_size=self.max_sequence_length
+ ),
+ batched=True,
+ num_proc=self.num_processes,
+ )
+
+ return data
+
+
+class TinyStoriesMLM(DataPipeline):
+ """ """
+
+ pass
+
+
+class ImdbCLS(DataPipeline):
+ pass
+
+
+'''
+class ImdbClsPipe(DataPipeline):
+ """A pipeline for the imdb dataset for """
+
+ def __init__(self, **kwargs):
+ self.path = kwargs["path"]
+ self.name = kwargs["name"] # name
+ self.is_tokenized = kwargs["is_tokenized"]
+
+ self.padding = kwargs["padding"] # "max_length"
+ self.max_sequence_length = kwargs["max_sequence_length"] # 512
+
+ self.shuffle_train = kwargs["shuffle_train"] # False
+ self.batch_size = kwargs["batch_size"]
+ self.tokenizer = kwargs["tokenizer"]
+ self.collator_fn = DataCollatorWithPadding(
+ tokenizer=self.tokenizer,
+ padding=self.padding,
+ max_length=self.max_sequence_length,
+ return_tensors='pt'
+ )
+
+ self.data = datasets.DatasetDict.load_from_disk(self.path)
+
+ def _post_tokenize(self, dataset):
+ return dataset.remove_columns(["text"])
+
+ def _tokenize(
+ self,
+ examples,
+ ):
+ source_tokenized = self.tokenizer(
+ examples["text"],
+ truncation=True,
+ max_length=self.max_sequence_length,
+ return_token_type_ids=True,
+ return_tensors='pt'
+ )
+
+ batch = {k: v for k, v in source_tokenized.items()}
+
+ return batch
+
+ def _map_tokenize(self, subsets=None):
+ # tokenize
+ print("tokenizing training")
+ self.data["train"] = self.data["train"].map(self._tokenize, batched=True)
+ self.data["train"] = self.data["train"].remove_columns("text")
+
+ print("tokenizing test")
+ self.data["test"] = self.data["test"].map(self._tokenize, batched=True)
+ self.data["test"] = self.data["test"].remove_columns("text")
+
+
+ def _save_tokenized(self):
+ # preprocess
+ self.path += "_tokenized"
+ print(f"saving tokenized data to disk at location : {self.path}")
+ assert os.path.isdir(self.path) == False
+ self.data.save_to_disk(self.path)
+
+ def string_to_tokens(self, input_str):
+ # tokenize
+ tensor = self._tokenize({"text": input_str})
+
+ return self.collator_fn(tensor)
+
+ def __call__(self, subsets = ["train", "test"]):
+ dataloaders = {}
+
+ if self.is_tokenized:
+ print("data tokenized")
+ else:
+ self._map_tokenize()
+ self._save_tokenized()
+
+ for _set in subsets:
+ if _set == "train":
+ shuffle = self.shuffle_train
+ else:
+ shuffle = False
+
+ dataloaders[_set] = DataLoader(
+ self.data[_set],
+ collate_fn=self.collator_fn, # DEBUG
+ batch_size=self.batch_size,
+ shuffle=shuffle, # DEBUG
+ )
+ return dataloaders
+'''
+
+
+
+### ./rafale/main.py ###
+import os
+import argparse
+import yaml
+import time
+from datetime import datetime
+
+import torch
+import torch.utils.data
+
+import composer
+from composer import Trainer, Time
+from composer.loggers import InMemoryLogger, WandBLogger, FileLogger
+from composer.algorithms import GradientClipping
+from composer.optim.scheduler import (
+ CosineAnnealingWithWarmupScheduler,
+ CosineAnnealingScheduler,
+ LinearScheduler,
+ LinearWithWarmupScheduler,
+)
+
+from rafale.models.decoder import ComposerLM
+from rafale.models.encoder import EncoderWrapper
+from rafale.models.configurations import (
+ load_safetensors,
+ Pythia14MConfig,
+ BertConfig,
+ RobertaConfig,
+)
+
+from rafale.caches import CHECKPOINT_CACHE_DIR
+from rafale.datapipe import TinyStoriesCausalNeoX
+
+
+ENV_VARS = {key: value for key, value in os.environ.items()}
+
+parser = argparse.ArgumentParser(description="launch a training run")
+
+parser.add_argument(
+ "-c",
+ "--training_config",
+ type=str,
+ help="path to yaml run configuration file",
+ required=True,
+)
+args = parser.parse_args()
+
+model_config_dict = {
+ "pythia14m": Pythia14MConfig,
+ "bert": BertConfig,
+ "roberta": RobertaConfig,
+}
+
+data_pipeline_dict = {
+ "tinystories_neox": TinyStoriesCausalNeoX,
+}
+
+
+def main():
+ # CONFIG ##################################################################
+ with open(args.training_config, "r") as f:
+ config = yaml.safe_load(f)
+
+ run_name = config["run"]["name"]
+ run_n_epochs = config["run"]["n_epochs"]
+ run_seed = config["run"]["seed"]
+ run_save_interval = config["run"]["save_interval"]
+
+ run_clip_type = config["run"]["clip_type"]
+ run_clip_value = float(config["run"]["clip_value"])
+
+ device_bs = config["run"]["device_bs"] # int or "auto"
+
+ # schedule
+ run_schedule_type = config["run"]["schedule"]
+ run_max_lr = float(config["run"]["max_lr"]) # learning rate
+ run_warmup_pct = float(config["run"]["warmup_pct"])
+ if run_schedule_type == "cosine-warmup":
+ run_scheduler = CosineAnnealingWithWarmupScheduler(
+ t_warmup=Time(run_warmup_pct, "dur"), alpha_f=0.1
+ )
+ else:
+ raise TypeError(
+ f"Model type {model_type} is not valid! Supports: cosine-warmup.\nlinear, cosine, and linear-warmup planned"
+ )
+
+ model_config_key = config["model"]["config"]
+ model_type = config["model"]["type"]
+ model_use_pretrained = config["model"]["use_pretrained"]
+
+ data_pipeline_key = config["data"]["pipeline"]
+ dataset_config = config["data"]["config"]
+ print(dataset_config)
+
+ # DATALOADERS #############################################################
+ data_pipeline = data_pipeline_dict[data_pipeline_key](**dataset_config)
+ dataloaders = data_pipeline()
+ if "DATA" in ENV_VARS.keys() and ENV_VARS["DATA"] == "1":
+ print("Data processing complete, exiting...")
+ return 0 # just do data preprocessing if we pass DATA=1
+
+ # MODEL #######################################################################
+ model_config = model_config_dict[model_config_key]
+
+ if model_type == "decoder":
+ rafale_model = ComposerLM(model_config)
+ elif model_type == "encoder":
+ rafale_model = EncoderWrapper(model_config)
+ else:
+ raise TypeError(
+ f"Model type {model_type} is not valid! Supports: encoder, decoder."
+ )
+
+ if model_use_pretrained:
+ rafale_model.model = load_safetensors(rafale_model.model, model_config)
+
+ # LOGGING #################################################################
+ # mem_logger = InMemoryLogger()
+ # @TODO :: add some logging options in the yaml
+ wandb_logger = WandBLogger(project="rafale", name=run_name)
+ # file_logger = FileLogger(filename=f"{run_name}-{time}".txt)
+
+ # GRADIENT CLIPPING #######################################################
+ clipping_type = "norm" # can also be 'adaptive' or 'value'
+ gradient_clip = GradientClipping(
+ clipping_type=clipping_type, clipping_threshold=0.1
+ )
+
+ # DEVICES #################################################################
+ device = "gpu" if torch.cuda.is_available() else "cpu" # select the device
+ if device == "gpu":
+ run_precision = "amp_fp16"
+ else:
+ run_precision = "fp32"
+
+ # DEBUG RUN ###############################################################
+ if "DEBUG" in ENV_VARS.keys() and ENV_VARS["DEBUG"] == "1":
+ from torch.utils.data import Subset, DataLoader, default_collate
+ from datasets import Dataset
+
+ # single batch, same for train and test 10 epochs
+ bs = 4
+ debug_batch = next(iter(dataloaders["train"]))
+ debug_batch = Dataset.from_dict(
+ {k: v[:bs] for k, v in debug_batch.items()}
+ ).with_format("torch")
+ debug_batch = DataLoader(
+ debug_batch,
+ batch_size=bs,
+ shuffle=False,
+ collate_fn=default_collate,
+ )
+
+ trainer = Trainer(
+ model=rafale_model,
+ seed=run_seed,
+ train_dataloader=debug_batch,
+ eval_dataloader=debug_batch,
+ optimizers=torch.optim.AdamW(rafale_model.parameters(), lr=1e-4),
+ max_duration=10, # num epochs
+ device=device,
+ loggers=None,
+ precision=run_precision,
+ progress_bar=True,
+ )
+
+ return 0
+
+ # TRAIN ###################################################################
+ # training subset must have key "train" then whatever is called the validation subset (i.e. test, val, validation,
+ # eval, etc) as long as there is only 1 other subset, we call it
+ dl_keys = list(dataloaders.keys())
+ assert "train" in dl_keys
+ dl_keys.remove("train")
+ assert len(dl_keys) == 1
+ eval_subset_key = dl_keys[0]
+
+ # get datetime for checkpoint, directories are created by composer
+ now = datetime.now()
+ formatted_date = now.strftime(
+ "%d" + "d" + "%m" + "m" + "%Y" + "y" + "_%H" + "h" + "%M" + "m"
+ ) # Format it as DDdMMmYYYYy_HHhMMm
+ checkpoint_folder = os.path.abspath(
+ os.path.join(CHECKPOINT_CACHE_DIR, f"{run_name}-{formatted_date}/")
+ )
+
+ trainer = Trainer(
+ model=rafale_model,
+ seed=run_seed,
+ train_dataloader=dataloaders["train"],
+ eval_dataloader=dataloaders[eval_subset_key],
+ optimizers=torch.optim.AdamW(rafale_model.parameters(), lr=run_max_lr),
+ max_duration=run_n_epochs, # num epochs
+ eval_interval="50ba", # default is 1ep !
+ device_train_microbatch_size=device_bs, # will handle gradient accumulation automatically
+ device=device,
+ loggers=[wandb_logger],
+ precision=run_precision,
+ progress_bar=True,
+ schedulers=run_scheduler,
+ algorithms=[gradient_clip],
+ save_folder=checkpoint_folder,
+ save_latest_filename="latest",
+ save_interval=run_save_interval,
+ )
+
+ # launch
+ trainer.fit()
+ print(f"🍻 TRAINING COMPLETE\n💾 CHECKPOINTS SAVED AT LOCATION: {checkpoint_folder}")
+
+
+if __name__ == "__main__":
+ main()
+
+
+
+### ./rafale/caches.py ###
+import os
+
+DATA_CACHE_DIR = os.path.expanduser("~/.rafale_cache/data/")
+MODEL_CACHE_DIR = os.path.expanduser("~/.rafale_cache/models/")
+CHECKPOINT_CACHE_DIR = os.path.expanduser("~/.rafale_cache/checkpoints/")
+
+
+
+### ./rafale/__init__.py ###
+
+
+
+### ./test/test_pythia.py ###
+import torch
+import numpy as np
+
+from safetensors import safe_open
+
+from rafale.models.decoder import DecoderWrapper
+from rafale.models.configurations import Pythia14MConfig, load_safetensors
+
+from transformers import AutoTokenizer, GPTNeoXForCausalLM
+
+
+def test_layer_and_outputs(rafale_model, hf_model, tokenizer, layer=0, tol=1e-05):
+ """
+ # @NOTE :: this currently on evaluates with KV cache enabled, write the test to run this function without the KV cache
+ """
+ hf_activation = {}
+ hf_input_activation = {}
+ rafale_activation = {}
+ rafale_input_activation = {}
+
+ # tuple of shape num_layers, 2 (keys, values), tensor BHLd
+ # make a fake kv-cache of length 4
+ kv_cache = []
+ n_layers = 6
+ cache_len = 4
+ n_heads = 4
+ head_dim = 32
+ for i in range(n_layers):
+ k = torch.randn(1, n_heads, cache_len, 32)
+ v = torch.randn(1, n_heads, cache_len, 32)
+ kv_cache.append((k, v))
+
+ def get_hf_activation(name):
+ def hook(model, input, output):
+ hf_activation[name] = output.detach()
+
+ return hook
+
+ def get_hf_input_activation(name):
+ def hook(model, _input, output):
+ hf_input_activation[name] = _input[0].detach()
+
+ return hook
+
+ def get_rafale_activation(name):
+ def hook(model, input, output):
+ rafale_activation[name] = output.detach()
+
+ return hook
+
+ def get_rafale_input_activation(name):
+ def hook(model, _input, output):
+ rafale_input_activation[name] = _input[0].detach()
+
+ return hook
+
+ # embeddings
+ rafale_model.token_embeddings.register_forward_hook(
+ get_rafale_activation("input_embeddings")
+ )
+ hf_model.gpt_neox.embed_in.register_forward_hook(
+ get_hf_activation("input_embeddings")
+ )
+
+ # input layernorm
+ rafale_model.layers[layer].attention_norm.register_forward_hook(
+ get_rafale_activation("attn_norm")
+ )
+ hf_model.gpt_neox.layers[layer].input_layernorm.register_forward_hook(
+ get_hf_activation("attn_norm")
+ )
+
+ rafale_model.layers[layer].attention_norm.register_forward_hook(
+ get_rafale_input_activation("input_attn_norm")
+ )
+ hf_model.gpt_neox.layers[layer].input_layernorm.register_forward_hook(
+ get_hf_input_activation("input_attn_norm")
+ )
+
+ # attention projection query_key_values (pre RoPE)
+ rafale_model.layers[layer].attention.query_key_value.register_forward_hook(
+ get_rafale_activation("attn_inproj")
+ )
+
+ hf_model.gpt_neox.layers[layer].attention.query_key_value.register_forward_hook(
+ get_hf_activation("attn_inproj")
+ )
+
+ # out proj
+ rafale_model.layers[layer].attention.dense.register_forward_hook(
+ get_rafale_activation("attn_dense")
+ )
+
+ hf_model.gpt_neox.layers[layer].attention.dense.register_forward_hook(
+ get_hf_activation("attn_dense")
+ )
+
+ # INPUT check before attention dense* (if this fails then RoPE is probably the problem...)
+ rafale_model.layers[layer].attention.dense.register_forward_hook(
+ get_rafale_input_activation("attn_dense")
+ )
+ hf_model.gpt_neox.layers[layer].attention.dense.register_forward_hook(
+ get_hf_input_activation("attn_dense")
+ )
+
+ # feed forward out
+ rafale_model.layers[layer].feed_forward.ff_out.register_forward_hook(
+ get_rafale_activation("ffout")
+ )
+ hf_model.gpt_neox.layers[layer].mlp.dense_4h_to_h.register_forward_hook(
+ get_hf_activation("ffout")
+ )
+
+ # hf_model.gpt_neox.layers[0].attention(tensor)
+
+ input_str = "Hello World from pythia!"
+ tokens = tokenizer(input_str, return_tensors="pt")
+
+ hf_model.eval()
+ rafale_model.eval()
+
+ with torch.no_grad():
+ # hf_out = hf_model(tokens["input_ids"])["logits"].detach().numpy()
+
+ hf_out = hf_model(tokens["input_ids"], use_cache=True, past_key_values=kv_cache)
+ hf_out = hf_out["logits"].detach().numpy()
+
+ # rafale_out = rafale_model(tokens)[0].detach().numpy()
+
+ rafale_out = rafale_model(tokens, past_kv_cache=kv_cache)[0].detach().numpy()
+
+ print(f"Dropout p should be 0: {rafale_model.layers[layer].attention.dropout_p}")
+ print(f"Testing layer {layer}")
+ # EMBEDDING ###################################################################
+ try:
+ np.testing.assert_allclose(
+ rafale_activation["input_embeddings"].numpy(),
+ hf_activation["input_embeddings"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+
+ print(f"✅ embeddings OK!")
+ except:
+ print("⚠️ Embedding difference!")
+
+ # PRE-ATTENTION NORM ######################################################
+ try:
+ np.testing.assert_allclose(
+ rafale_input_activation["input_attn_norm"].numpy(),
+ hf_input_activation["input_attn_norm"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+ print(f"✅ INPUTS of pre-attention norm OK!")
+ except:
+ print("⚠️ INPUTS pre-attention norm difference")
+
+ try:
+ np.testing.assert_allclose(
+ rafale_activation["attn_norm"].numpy(),
+ hf_activation["attn_norm"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+ print(f"✅ pre-attention norm OK!")
+ except:
+ print("⚠️ pre-attention norm difference")
+
+ # LINEAR PROJECTION FOR ATTENTION ######################################################
+
+ try:
+ np.testing.assert_allclose(
+ rafale_activation["attn_inproj"].numpy(),
+ hf_activation["attn_inproj"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+ print(f"✅ attention in-projection OK!")
+
+ except:
+ print(f"⚠️ attention in-projection difference")
+
+ # INPUTS OF ATTENTION DENSE LAYER ########################################
+ try:
+ np.testing.assert_allclose(
+ rafale_input_activation["attn_dense"].numpy(),
+ hf_input_activation["attn_dense"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+ print(f"✅ inputs of attention dense OK")
+
+ except:
+ r = rafale_input_activation["attn_dense"].numpy()
+ h = hf_input_activation["attn_dense"].numpy()
+
+ print(r.shape)
+ print(h.shape)
+ print("⚠️ inputs of attention dense difference")
+
+ np.testing.assert_allclose(
+ rafale_input_activation["attn_dense"].numpy(),
+ hf_input_activation["attn_dense"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+
+ # OUTPUT OF ATTENTION DENSE LAYER ########################################
+ try:
+ np.testing.assert_allclose(
+ rafale_activation["attn_dense"].numpy(),
+ hf_activation["attn_dense"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+ print(f"✅ attention out dense OK!")
+ except:
+ print("⚠️ attention out dense difference")
+
+ try:
+ np.testing.assert_allclose(
+ rafale_activation["ffout"].numpy(),
+ hf_activation["ffout"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+ print(f"✅ feedforward out dense OK!")
+ except:
+ print("⚠️ ff out dense difference")
+
+ # Final model output ########################################
+ try:
+ np.testing.assert_allclose(rafale_out, hf_out, rtol=tol, atol=tol)
+ print(f"🎉 Model outputs match reference implementation!")
+ except:
+ print("❌ Model outputs do not match")
+
+
+def main():
+ """ """
+ torch.manual_seed(0)
+ np.random.seed(0)
+
+ hf_pythia = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-14m")
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
+
+ rafale_pythia = DecoderWrapper(Pythia14MConfig) # check forward pass OK
+ rafale_pythia = load_safetensors(rafale_pythia, Pythia14MConfig)
+
+ test_layer_and_outputs(rafale_pythia, hf_pythia, tokenizer)
+
+
+if __name__ == "__main__":
+ main()
+
+
+
+### ./test/test_bert.py ###
+#!/usr/bin/env python
+import os
+
+from datapipe import WikiMLMPipe
+from encoder import EncoderWrapper, BertConfig
+from roberta import RobertaConfig, RobertaMLM
+from convert_hf_weights import convert_bert_params_dict, convert_roberta_params_dict
+
+from transformers import AutoTokenizer, AutoModelForMaskedLM
+
+import torch
+
+import numpy as np
+import random
+
+torch.set_deterministic_debug_mode(1)
+torch.use_deterministic_algorithms(True)
+
+SEED = 42
+torch.manual_seed(SEED)
+np.random.seed(SEED)
+random.seed(SEED)
+
+# dump modeling debugging code here...
+# @TODO :: call from cli, layer by layer check, add proper logging
+
+
+# roberta
+def test_roberta():
+ # roberta base
+ roberta_cfg = RobertaConfig()
+ r_roberta = RobertaMLM(roberta_cfg)
+
+ roberta_tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base")
+ hf_roberta = AutoModelForMaskedLM.from_pretrained("FacebookAI/roberta-base")
+
+ r_roberta = convert_roberta_params_dict(r_roberta, hf_roberta)
+
+ args = {
+ "path": "~/code/data/enwiki1m",
+ "truncation": True,
+ "max_sequence_length": 128,
+ "shuffle_train": False,
+ "batch_size": 1,
+ "padding": "max_length",
+ "tokenizer": roberta_tokenizer,
+ }
+
+ wikipipe = WikiMLMPipe(**args)
+ dl = wikipipe()
+
+ batch = next(iter(dl["train"]))
+
+ with torch.no_grad():
+ r_roberta.eval()
+ hf_roberta.eval()
+ r_output = r_roberta(**batch)
+ r_output2 = r_roberta(**batch)
+ hf_output = hf_roberta(**batch)
+ hf_output2 = hf_roberta(**batch)
+ np.testing.assert_allclose(
+ hf_output.logits.detach().numpy(),
+ hf_output2.logits.detach().numpy(),
+ atol=5e-4,
+ rtol=5e-4,
+ )
+ print("hf deterministic")
+ np.testing.assert_allclose(
+ r_output.detach().numpy(), r_output2.detach().numpy(), atol=5e-4, rtol=5e-4
+ )
+ print("rafale deterministic")
+ np.testing.assert_allclose(
+ r_output.detach().numpy(),
+ hf_output.logits.detach().numpy(),
+ atol=5e-4,
+ rtol=5e-4,
+ )
+
+
+def test_bert():
+ bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
+ hf_bert = AutoModelForMaskedLM.from_pretrained("bert-base-uncased")
+ bert_cfg = BertConfig()
+ r_bert = EncoderWrapper(bert_cfg)
+ r_bert = convert_bert_params_dict(r_bert, hf_bert)
+
+ args = {
+ "path": "~/code/data/enwiki1m",
+ "truncation": True,
+ "max_sequence_length": 128,
+ "shuffle_train": False,
+ "batch_size": 1,
+ "padding": "max_length",
+ "tokenizer": bert_tokenizer,
+ }
+
+ wikipipe = WikiMLMPipe(**args)
+ dl = wikipipe()
+
+ batch = next(iter(dl["train"]))
+
+ with torch.no_grad():
+ r_bert.eval()
+ hf_bert.eval()
+ r_output = r_bert(**batch)
+ r_output2 = r_bert(**batch)
+ hf_output = hf_bert(**batch)
+ hf_output2 = hf_bert(**batch)
+ np.testing.assert_allclose(
+ hf_output.logits.detach().numpy(),
+ hf_output2.logits.detach().numpy(),
+ atol=5e-4,
+ rtol=5e-4,
+ )
+ print("hf deterministic")
+ np.testing.assert_allclose(
+ r_output.detach().numpy(), r_output2.detach().numpy(), atol=5e-4, rtol=5e-4
+ )
+ print("rafale deterministic")
+ np.testing.assert_allclose(
+ r_output.detach().numpy(),
+ hf_output.logits.detach().numpy(),
+ atol=5e-4,
+ rtol=5e-4,
+ )
+
+
+if __name__ == "__main__":
+ main()
+
+
+
+### ./test/test_pythia_generation.py ###
+import torch
+import numpy as np
+
+# from safetensors import safe_open
+# from tokenizers import Tokenizer
+
+# from rafale.datapipe import InferenceDatapipeline
+# from rafale.models.decoding_strategies import greedy_decode
+# from rafale.models.decoder import DecoderWrapper
+# from rafale.models.configurations import Pythia14MConfig, load_safetensors
+
+
+# Example usage
+# Initialize the rafale_pythia model
+rafale_pythia = DecoderWrapper(Pythia14MConfig)
+rafale_pythia = load_safetensors(rafale_pythia, Pythia14MConfig)
+ifdp = InferenceDatapipeline("EleutherAI/pythia-14m")
+test_str = "Once upon a time,"
+
+# Define input_ids (e.g., starting with a token)
+# input_ids = torch.tensor([[Pythia14MConfig.bos_token_id]]) # Shape: (1, 1)
+
+# Define maximum sequence length for generation
+max_length = 32
+
+# Generate sequence using greedy decoding
+generated_sequence = greedy_decode(
+ rafale_pythia,
+ ifdp(test_str),
+ max_length,
+ Pythia14MConfig.eos_token_id,
+ check_repeat_ngrams=True,
+)
+
+generated_str = ifdp.ids_to_str(generated_sequence["input_ids"])
+
+print("Generated sequence:", generated_str)
+
+
+
+### ./test/pythia_tinystories.yaml ###
+run:
+ name: "pythia14m-tinystories" # name of your experiment, used for checkpointing
+ seed: 42
+ n_epochs: 1
+ max_lr: 6e-04
+ warmup_pct: 0.01
+ schedule: "cosine-warmup" # linear, linear-warmup, cosine, cosine-warmup
+ optimizer: "AdamW"
+ eval_interval: "50ba"
+ clip_type: "norm"
+ clip_value: 1.0
+ device_bs: "auto"
+ save_interval: "200ba"
+
+model:
+ config: "pythia14m" # config key
+ type: "decoder"
+ use_pretrained: True
+
+data:
+ pipeline: "tinystories_neox" # the preprocessing/tokenization pipeline
+ config:
+ name: "tinystories"
+ num_processes: 8
+ tokenizer_name: "neox"
+ is_prepared: False
+ input_id_key: "input_ids"
+ train_batch_size: 1024
+ eval_batch_size: 16
+ shuffle_train: False
+ dataset_path: "~/code/data/TinyStories"
+ tokenizer_path: "EleutherAI/pythia-14m"
+ max_sequence_length: 512
+ pad_token_id: -100
+ pad_inputs: True
+ is_prepared: False
+
+logging: # @TODO :: not implemented
+ use_wandb: True
+ use_file: False
+ eval_interval: "10ba"
+ log_dir: "./run_logs"
+ checkpoint_dir: "./checkpoints"
+
+
+
+### ./test/test.yaml ###
+# we want data and model configurations to be in files rather than in yaml, leave training hyperparams to yaml config only
+
+run:
+ name: "test-ministories" # name of your experiment, used for checkpointing
+ seed: 42
+ n_epochs: 1
+ max_lr: 3e-04
+ warmup_pct: 0.01
+ schedule: "cosine-warmup" # linear, linear-warmup, cosine, cosine-warmup
+ optimizer: "AdamW"
+ eval_interval: "50ba"
+ clip_type: "norm"
+ clip_value: 1.0
+ device_bs: "auto"
+ save_interval: "1ep"
+
+model:
+ config: "pythia14m" # config key
+ type: "decoder"
+ use_pretrained: True
+
+data:
+ pipeline: "tinystories_neox" # the preprocessing/tokenization pipeline
+ config:
+ name: "tinystories_testing"
+ num_processes: 1
+ tokenizer_name: "neox"
+ is_prepared: False
+ input_id_key: "input_ids"
+ train_batch_size: 16
+ eval_batch_size: 16
+ shuffle_train: False
+ dataset_path: "~/code/data/micro_tinystories"
+ tokenizer_path: "EleutherAI/pythia-14m"
+ max_sequence_length: 128
+ pad_token_id: -100
+ pad_inputs: True
+ is_prepared: False
+
+logging:
+ use_wandb: True
+ use_file: False
+
+
+
diff --git a/media/rafale-logo.png b/media/rafale-logo.png
new file mode 100644
index 0000000..ca2d590
Binary files /dev/null and b/media/rafale-logo.png differ
diff --git a/media/rafale-logo.svg b/media/rafale-logo.svg
new file mode 100644
index 0000000..7dd4a4a
--- /dev/null
+++ b/media/rafale-logo.svg
@@ -0,0 +1,80 @@
+
+
+
+
diff --git a/pyproject.toml b/pyproject.toml
index 4999a7b..e034a72 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,23 +1,11 @@
-[build-system]
-requires = ["setuptools>=42", "wheel"]
-build-backend = "setuptools.build_meta"
-
[project]
name = "rafale"
version = "0.1"
-description = "simple transformer training lib"
+description = "opinionated transformer training cli"
readme = "README.md"
-requires-python = ">=3.6"
-license = {text = "MIT"}
-
-[project.urls]
-homepage = "https://github.com/maxrousseau/rafale"
+requires-python = "==3.12"
+packages = [{include = "rafale"}]
-[project.authors]
-name = "Maxime Rousseau"
-
-[tool.setuptools.packages.find]
-where = ["rafale"]
-
-[project.optional-dependencies]
-# Add any optional dependencies here, if applicable
\ No newline at end of file
+[build-system]
+requires = ["hatchling"]
+build-backend = "hatchling.build"
\ No newline at end of file
diff --git a/rafale/caches.py b/rafale/caches.py
new file mode 100644
index 0000000..daddac3
--- /dev/null
+++ b/rafale/caches.py
@@ -0,0 +1,5 @@
+import os
+
+DATA_CACHE_DIR = os.path.expanduser("~/.rafale_cache/data/")
+MODEL_CACHE_DIR = os.path.expanduser("~/.rafale_cache/models/")
+CHECKPOINT_CACHE_DIR = os.path.expanduser("~/.rafale_cache/checkpoints/")
diff --git a/rafale/datapipe.py b/rafale/datapipe.py
index 93a39e2..64d39a2 100644
--- a/rafale/datapipe.py
+++ b/rafale/datapipe.py
@@ -1,102 +1,362 @@
import os
+import warnings
+from abc import ABC, abstractmethod
from datasets import load_dataset, DatasetDict
-from transformers import DataCollatorForLanguageModeling
+from tokenizers import Tokenizer
+from tokenizers.processors import TemplateProcessing
+import torch
from torch.utils.data import DataLoader
+from torch.nn.utils.rnn import pad_sequence
+from rafale.caches import DATA_CACHE_DIR
-class DataPipeline:
+
+class DataPipeline(ABC):
"""
Base class
A data pipeline is initiated with a path and some configurations parameters:
- - path : local path to the dataset
- - collator_fn : function to be applied at parse
+ Args:
+ - dataset_path : local path to the dataset
+ - collator : function to be applied when sending a batch through the dataloader
+
Datasets should be saved using the following format:
/.json
+
+ Returns:
"""
- def __init__(self):
- self.path = None
- self.collator_fn = None
- self.padding = None
- self.truncation = None
- self.max_sequence_length = None
- self.shuffle_train = False
- self.batch_size = 4
- self.tokenizer = None
+ def __init__(self, **kwargs):
+ self.name: str = kwargs["name"]
+ self.tokenizer_name: str = kwargs["tokenizer_name"]
+ self.is_prepared: bool = kwargs["is_prepared"]
+ self.dataset_path: str = os.path.expanduser(kwargs["dataset_path"])
+
+ self.max_sequence_length: int = kwargs["max_sequence_length"]
+ self.train_batch_size: int = kwargs["train_batch_size"]
+ self.eval_batch_size: int = kwargs["eval_batch_size"]
+ self.pad_inputs: bool = kwargs["pad_inputs"]
+ self.pad_token_id: int = kwargs["pad_token_id"]
+ self.input_id_key: str = kwargs["input_id_key"]
+ self.shuffle_train: bool = kwargs["shuffle_train"]
+
+ self.num_processes: int = kwargs["num_processes"]
+ self.tokenizer = Tokenizer.from_pretrained(kwargs["tokenizer_path"])
+
+ self.data_collator = None
+
+ self.use_cached = False
def _load(self):
+ # either load directly from disk or from an hf dataset repo
try:
- self.dataset_pre = DatasetDict.load_from_disk(self.path)
+ self.dataset = DatasetDict.load_from_disk(self.dataset_path)
except:
- raise OSError(f"Wrong dataset file and/or path configuration!{self.path}")
+ try:
+ self.dataset = load_dataset(self.dataset_path)
+ pass
+ except:
+ raise OSError(
+ f"Wrong dataset file and/or path configuration! path: {self.dataset_path}"
+ )
+
+ @abstractmethod
+ def _prepare(self):
+ """Perform all data preprocessing here: tokenization, chunking, truncation, etc. (EXCEPT padding!). Padding will be performed by
+ the datacollator.
+ """
+ pass
+
+ def _check_path(self):
+ """make sure that the dataset has not already been parsed at location"""
+ output_path = f"{self.name}_{self.tokenizer_name}_bs{self.train_batch_size}_len{self.max_sequence_length}"
+
+ assert DATA_CACHE_DIR[-1] == "/"
- def _post_tokenize(self):
- None
+ save_path = os.path.abspath(os.path.join(DATA_CACHE_DIR, output_path))
+
+ if os.path.isdir(DATA_CACHE_DIR):
+ pass
+ else:
+ os.makedirs(DATA_CACHE_DIR)
+
+ if os.path.isdir(save_path):
+ warnings.warn(
+ f"Dataset already exists at location:\n\t {save_path} \n ABORTING PREPARATION, USING CACHED DATASET!"
+ )
+
+ self.is_prepared = True
+ self.use_cached = True
+
+ return save_path
def __call__(self):
# returns a or multiple dataloaders
self._load()
- dataloaders = {}
+ self.dataloaders = {}
- if type(self.dataset_pre) == DatasetDict:
- for subset in self.dataset_pre:
+ if not self.is_prepared:
+ cache_dataset_path = self._check_path()
+
+ if type(self.dataset) == DatasetDict:
+ for subset in self.dataset:
if subset == "train":
shuffle = self.shuffle_train
+ batch_size = self.train_batch_size
else:
shuffle = False
+ batch_size = self.eval_batch_size
- # @DEBUG
- self.dataset_pre[subset] = self.dataset_pre[subset].select(range(10))
- _set = self.dataset_pre[subset].map(
- lambda example: self._tokenize(
- example,
- )
- )
- _set = self._post_tokenize(_set)
- print(len(_set["input_ids"][0]))
- print(_set)
- dataloaders[subset] = DataLoader(
- _set,
- collate_fn=self.collator_fn, # DEBUG
- batch_size=self.batch_size,
- shuffle=shuffle, # DEBUG
+ # if the data is not ready to be passed to the dataloader
+ if not self.is_prepared:
+ print(f"preparing subset {subset}")
+ self.dataset[subset] = self._prepare(self.dataset[subset])
+
+ if self.use_cached:
+ self.dataset = DatasetDict.load_from_disk(cache_dataset_path)
+
+ self.dataloaders[subset] = DataLoader(
+ self.dataset[subset],
+ collate_fn=self.data_collator,
+ batch_size=batch_size,
)
+ print(f"✅ Dataloader ready for subset {subset}.")
+
+ if not self.is_prepared:
+ self.dataset.save_to_disk(cache_dataset_path)
+ print(f"✅ Saved prepared dataset at {cache_dataset_path}.")
else:
- # process a single set
- None
+ raise TypeError(
+ f"self.dataset is type {type(self.dataset)}, but should be DatasetDict."
+ )
- return dataloaders
+ return self.dataloaders
+
+
+class InferenceDatapipeline:
+ def __init__(self, tokenizer_path):
+ self.tokenizer = Tokenizer.from_pretrained(tokenizer_path)
+
+ def _tokenizer_templating(self, tokenizer, add_eos=True):
+ if add_eos:
+ tokenizer.post_processor = TemplateProcessing(
+ single="<|endoftext|> $A",
+ special_tokens=[
+ ("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")),
+ ],
+ )
+
+ return tokenizer
+
+ def __call__(self, input_str, use_template: bool = True):
+ """
+ tokenize input_str
+ convert to torch tensor (batch_size=1)
+ add the endoftext token
+ """
+ if use_template:
+ self.tokenizer = self._tokenizer_templating(self.tokenizer)
+
+ tokenized_inputs = {
+ "input_ids": torch.LongTensor(
+ self.tokenizer.encode(input_str).ids
+ ).unsqueeze(dim=0)
+ }
+
+ return tokenized_inputs
+
+ def ids_to_str(self, tensor):
+ return ifdp.tokenizer.decode(tensor.squeeze().detach().numpy())
+
+
+class CausalCollator:
+ def __init__(self, pad_token_id: int, input_id_key: str, pad_inputs: bool):
+ self.pad_token_id = pad_token_id
+ self.input_id_key = input_id_key
+ self.pad_inputs = pad_inputs
+ def __call__(self, features):
+ # Extract the input IDs from the batch
+ input_ids = [torch.tensor(example[self.input_id_key]) for example in features]
+
+ # Pad the inputs if required
+ if self.pad_inputs:
+ input_ids = pad_sequence(
+ input_ids, batch_first=True, padding_value=self.pad_token_id
+ )
+
+ # Set the last token of each label sequence to pad_token_id to ignore loss for the last prediction
+ # shift left the ids
+ labels = [
+ torch.cat([ids[1:], torch.tensor([self.pad_token_id])]) for ids in input_ids
+ ]
+ labels = torch.stack(labels, dim=0)
+
+ return {
+ "input_ids": input_ids,
+ "labels": labels,
+ }
+
+
+class MLMCollator:
+ def __init__(
+ self,
+ mask_p: float = 0.15,
+ whole_word_mask: bool = False,
+ mask_span: bool = False,
+ pad_token_id: int = -100,
+ input_id_key: str = "input_ids",
+ pad_inputs: bool = True,
+ ):
+ """masks some % of tokens for MLM objective"""
+ pass
+
+
+class DefaultCollator:
+ def __init__(
+ self,
+ pad_token_id: int = -100,
+ input_id_key: str = "input_ids",
+ pad_inputs: bool = True,
+ ):
+ """for task data where labels are already set"""
+ pass
-# TODO:
-# download 10k wiki random subset (or "simple") - shuffle then download, do that in a colab notebook...
-# then apply MLM loading to it for testing...
+class TinyStoriesCausalNeoX(DataPipeline):
+ """This is sample datapipelin for the TinyStories dataset.
-# datapipe is simple, you call the function and get the dataloader(s) you need for running
-# a debug_mode flag can be included (single batch)
-class WikiMLMPipe(DataPipeline):
- """a first testing pipeline for wikipedia for MLM"""
+
+ This dataset is prepared for causal language modelling using the gpt neox tokenizer (eleutherai). We
+
+ Usage:
+ ts_dict = {
+ "name": "tinystories_testing",
+ "tokenizer_name": "neox",
+ "is_prepared": False,
+ "input_id_key": "input_ids",
+ "batch_size": 16,
+ "shuffle_train": False,
+ "dataset_path": "~/code/data/micro_tinystories",
+ "tokenizer_path": "EleutherAI/pythia-14m",
+ "max_sequence_length": 128,
+ "pad_token_id": -100,
+ "pad_inputs": True,
+ "is_prepared": False,
+ }
+ ts_dpipe = TinyStoriesCausalNeoX(**ts_dict)
+ dataloaders = ts_causal()
+
+ Args:
+
+ Returns:
+
+ """
+
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.data_collator = CausalCollator(
+ pad_token_id=self.pad_token_id,
+ input_id_key=self.input_id_key,
+ pad_inputs=self.pad_inputs,
+ )
+
+ # TODO: figure out if they really only use endoftext for everything...
+ def _tokenizer_templating(self, tokenizer, add_eos=True):
+ if add_eos:
+ tokenizer.post_processor = TemplateProcessing(
+ single="$A <|endoftext|>",
+ special_tokens=[
+ ("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")),
+ ],
+ )
+
+ return tokenizer
+
+ def _tokenize(self, example, tokenizer, key="text"):
+ return {self.input_id_key: tokenizer.encode(example[key]).ids}
+
+ def _group_and_chunk(self, examples, key="input_ids", block_size=None, pad=False):
+ concatenated_tokens = sum(examples[key], [])
+ total_length = len(concatenated_tokens)
+
+ if total_length >= block_size:
+ total_length = (total_length // block_size) * block_size
+
+ result = {
+ "input_ids": [
+ concatenated_tokens[i : i + block_size]
+ for i in range(0, total_length, block_size)
+ ]
+ }
+
+ return result
+
+ def _prepare(self, data):
+ """"""
+
+ # apply functions above to dataset
+ self.tokenizer = self._tokenizer_templating(self.tokenizer)
+
+ data = data.map(
+ lambda example: self._tokenize(example, self.tokenizer),
+ remove_columns=data.column_names,
+ num_proc=self.num_processes,
+ )
+
+ data = data.map(
+ lambda example: self._group_and_chunk(
+ example, block_size=self.max_sequence_length
+ ),
+ batched=True,
+ num_proc=self.num_processes,
+ )
+
+ return data
+
+
+class TinyStoriesMLM(DataPipeline):
+ """ """
+
+ pass
+
+
+class ImdbCLS(DataPipeline):
+ pass
+
+
+'''
+class ImdbClsPipe(DataPipeline):
+ """A pipeline for the imdb dataset for """
def __init__(self, **kwargs):
self.path = kwargs["path"]
+ self.name = kwargs["name"] # name
+ self.is_tokenized = kwargs["is_tokenized"]
+
self.padding = kwargs["padding"] # "max_length"
- self.truncation = kwargs["truncation"] # "max_length"
self.max_sequence_length = kwargs["max_sequence_length"] # 512
- self.shuffle_train = kwargs["shuffle_train"] # False
+
+ self.shuffle_train = kwargs["shuffle_train"] # False
self.batch_size = kwargs["batch_size"]
self.tokenizer = kwargs["tokenizer"]
- self.collator_fn = DataCollatorForLanguageModeling(tokenizer=self.tokenizer)
+ self.collator_fn = DataCollatorWithPadding(
+ tokenizer=self.tokenizer,
+ padding=self.padding,
+ max_length=self.max_sequence_length,
+ return_tensors='pt'
+ )
+
+ self.data = datasets.DatasetDict.load_from_disk(self.path)
def _post_tokenize(self, dataset):
- return dataset.remove_columns(["url", "title", "text", "id"])
+ return dataset.remove_columns(["text"])
def _tokenize(
self,
@@ -104,26 +364,60 @@ def _tokenize(
):
source_tokenized = self.tokenizer(
examples["text"],
- padding=self.padding,
+ truncation=True,
max_length=self.max_sequence_length,
- truncation=self.truncation,
return_token_type_ids=True,
+ return_tensors='pt'
)
batch = {k: v for k, v in source_tokenized.items()}
return batch
+ def _map_tokenize(self, subsets=None):
+ # tokenize
+ print("tokenizing training")
+ self.data["train"] = self.data["train"].map(self._tokenize, batched=True)
+ self.data["train"] = self.data["train"].remove_columns("text")
-"""
-args = {"path" : "~/code/data/enwiki1m", "truncation": True, "max_sequence_length": 128, "shuffle_train" : False,
-"batch_size":4, "padding": "max_length", "tokenizer" : tokenizer}
+ print("tokenizing test")
+ self.data["test"] = self.data["test"].map(self._tokenize, batched=True)
+ self.data["test"] = self.data["test"].remove_columns("text")
-wikipipe = WikiMLMPipe(**args)
-dloaderz = wikipipe()
-next(iter(dloaderz["train"]))
+ def _save_tokenized(self):
+ # preprocess
+ self.path += "_tokenized"
+ print(f"saving tokenized data to disk at location : {self.path}")
+ assert os.path.isdir(self.path) == False
+ self.data.save_to_disk(self.path)
-"""
+ def string_to_tokens(self, input_str):
+ # tokenize
+ tensor = self._tokenize({"text": input_str})
+
+ return self.collator_fn(tensor)
+
+ def __call__(self, subsets = ["train", "test"]):
+ dataloaders = {}
-# class WikiSLaMPipe(DataPipeline):
+ if self.is_tokenized:
+ print("data tokenized")
+ else:
+ self._map_tokenize()
+ self._save_tokenized()
+
+ for _set in subsets:
+ if _set == "train":
+ shuffle = self.shuffle_train
+ else:
+ shuffle = False
+
+ dataloaders[_set] = DataLoader(
+ self.data[_set],
+ collate_fn=self.collator_fn, # DEBUG
+ batch_size=self.batch_size,
+ shuffle=shuffle, # DEBUG
+ )
+ return dataloaders
+'''
diff --git a/rafale/main.py b/rafale/main.py
index 54fe9be..aa00691 100644
--- a/rafale/main.py
+++ b/rafale/main.py
@@ -1,38 +1,211 @@
+import os
import argparse
import yaml
+import time
+from datetime import datetime
+import torch
+import torch.utils.data
-# generate datasetmodule/model given a ".py" file containing the setup code
-from datasets.data_utils import DatasetWrapper
-from models.model_utils import ModelWrapper
+import composer
+from composer import Trainer, Time
+from composer.loggers import InMemoryLogger, WandBLogger, FileLogger
+from composer.algorithms import GradientClipping
+from composer.optim.scheduler import (
+ CosineAnnealingWithWarmupScheduler,
+ CosineAnnealingScheduler,
+ LinearScheduler,
+ LinearWithWarmupScheduler,
+)
+
+from rafale.models.decoder import ComposerLM
+from rafale.models.encoder import EncoderWrapper
+from rafale.models.configurations import (
+ load_safetensors,
+ Pythia14MConfig,
+ BertConfig,
+ RobertaConfig,
+)
+
+from rafale.caches import CHECKPOINT_CACHE_DIR
+from rafale.datapipe import TinyStoriesCausalNeoX
+
+
+ENV_VARS = {key: value for key, value in os.environ.items()}
parser = argparse.ArgumentParser(description="launch a training run")
+
parser.add_argument(
- "-c", "--config", type=str, help="path to yaml configuration file", required=True
+ "-c",
+ "--training_config",
+ type=str,
+ help="path to yaml run configuration file",
+ required=True,
)
args = parser.parse_args()
+model_config_dict = {
+ "pythia14m": Pythia14MConfig,
+ "bert": BertConfig,
+ "roberta": RobertaConfig,
+}
+
+data_pipeline_dict = {
+ "tinystories_neox": TinyStoriesCausalNeoX,
+}
+
def main():
- # load/parse yaml config
- with open(args.config, "r") as f:
+ # CONFIG ##################################################################
+ with open(args.training_config, "r") as f:
config = yaml.safe_load(f)
- print(config["run"]["name"])
- print(config["model"])
+ run_name = config["run"]["name"]
+ run_n_epochs = config["run"]["n_epochs"]
+ run_seed = config["run"]["seed"]
+ run_save_interval = config["run"]["save_interval"]
- # build & load the model
- # @HERE
+ run_clip_type = config["run"]["clip_type"]
+ run_clip_value = float(config["run"]["clip_value"])
- # build & load the dataloader
+ device_bs = config["run"]["device_bs"] # int or "auto"
- # setup logging?
+ # schedule
+ run_schedule_type = config["run"]["schedule"]
+ run_max_lr = float(config["run"]["max_lr"]) # learning rate
+ run_warmup_pct = float(config["run"]["warmup_pct"])
+ if run_schedule_type == "cosine-warmup":
+ run_scheduler = CosineAnnealingWithWarmupScheduler(
+ t_warmup=Time(run_warmup_pct, "dur"), alpha_f=0.1
+ )
+ else:
+ raise TypeError(
+ f"Model type {model_type} is not valid! Supports: cosine-warmup.\nlinear, cosine, and linear-warmup planned"
+ )
- # build trainer
+ model_config_key = config["model"]["config"]
+ model_type = config["model"]["type"]
+ model_use_pretrained = config["model"]["use_pretrained"]
- # launch
+ data_pipeline_key = config["data"]["pipeline"]
+ dataset_config = config["data"]["config"]
+ print(dataset_config)
+
+ # DATALOADERS #############################################################
+ data_pipeline = data_pipeline_dict[data_pipeline_key](**dataset_config)
+ dataloaders = data_pipeline()
+ if "DATA" in ENV_VARS.keys() and ENV_VARS["DATA"] == "1":
+ print("Data processing complete, exiting...")
+ return 0 # just do data preprocessing if we pass DATA=1
+
+ # MODEL #######################################################################
+ model_config = model_config_dict[model_config_key]
+
+ if model_type == "decoder":
+ rafale_model = ComposerLM(model_config)
+ elif model_type == "encoder":
+ rafale_model = EncoderWrapper(model_config)
+ else:
+ raise TypeError(
+ f"Model type {model_type} is not valid! Supports: encoder, decoder."
+ )
- return None
+ if model_use_pretrained:
+ rafale_model.model = load_safetensors(rafale_model.model, model_config)
+
+ # LOGGING #################################################################
+ # mem_logger = InMemoryLogger()
+ # @TODO :: add some logging options in the yaml
+ wandb_logger = WandBLogger(project="rafale", name=run_name)
+ # file_logger = FileLogger(filename=f"{run_name}-{time}".txt)
+
+ # GRADIENT CLIPPING #######################################################
+ clipping_type = "norm" # can also be 'adaptive' or 'value'
+ gradient_clip = GradientClipping(
+ clipping_type=clipping_type, clipping_threshold=0.1
+ )
+
+ # DEVICES #################################################################
+ device = "gpu" if torch.cuda.is_available() else "cpu" # select the device
+ if device == "gpu":
+ run_precision = "amp_fp16"
+ else:
+ run_precision = "fp32"
+
+ # DEBUG RUN ###############################################################
+ if "DEBUG" in ENV_VARS.keys() and ENV_VARS["DEBUG"] == "1":
+ from torch.utils.data import Subset, DataLoader, default_collate
+ from datasets import Dataset
+
+ # single batch, same for train and test 10 epochs
+ bs = 4
+ debug_batch = next(iter(dataloaders["train"]))
+ debug_batch = Dataset.from_dict(
+ {k: v[:bs] for k, v in debug_batch.items()}
+ ).with_format("torch")
+ debug_batch = DataLoader(
+ debug_batch,
+ batch_size=bs,
+ shuffle=False,
+ collate_fn=default_collate,
+ )
+
+ trainer = Trainer(
+ model=rafale_model,
+ seed=run_seed,
+ train_dataloader=debug_batch,
+ eval_dataloader=debug_batch,
+ optimizers=torch.optim.AdamW(rafale_model.parameters(), lr=1e-4),
+ max_duration=10, # num epochs
+ device=device,
+ loggers=None,
+ precision=run_precision,
+ progress_bar=True,
+ )
+
+ return 0
+
+ # TRAIN ###################################################################
+ # training subset must have key "train" then whatever is called the validation subset (i.e. test, val, validation,
+ # eval, etc) as long as there is only 1 other subset, we call it
+ dl_keys = list(dataloaders.keys())
+ assert "train" in dl_keys
+ dl_keys.remove("train")
+ assert len(dl_keys) == 1
+ eval_subset_key = dl_keys[0]
+
+ # get datetime for checkpoint, directories are created by composer
+ now = datetime.now()
+ formatted_date = now.strftime(
+ "%d" + "d" + "%m" + "m" + "%Y" + "y" + "_%H" + "h" + "%M" + "m"
+ ) # Format it as DDdMMmYYYYy_HHhMMm
+ checkpoint_folder = os.path.abspath(
+ os.path.join(CHECKPOINT_CACHE_DIR, f"{run_name}-{formatted_date}/")
+ )
+
+ trainer = Trainer(
+ model=rafale_model,
+ seed=run_seed,
+ train_dataloader=dataloaders["train"],
+ eval_dataloader=dataloaders[eval_subset_key],
+ optimizers=torch.optim.AdamW(rafale_model.parameters(), lr=run_max_lr),
+ max_duration=run_n_epochs, # num epochs
+ eval_interval="50ba", # default is 1ep !
+ device_train_microbatch_size=device_bs, # will handle gradient accumulation automatically
+ device=device,
+ loggers=[wandb_logger],
+ precision=run_precision,
+ progress_bar=True,
+ schedulers=run_scheduler,
+ algorithms=[gradient_clip],
+ save_folder=checkpoint_folder,
+ save_latest_filename="latest",
+ save_interval=run_save_interval,
+ )
+
+ # launch
+ trainer.fit()
+ print(f"🍻 TRAINING COMPLETE\n💾 CHECKPOINTS SAVED AT LOCATION: {checkpoint_folder}")
if __name__ == "__main__":
diff --git a/rafale/models/configurations.py b/rafale/models/configurations.py
new file mode 100644
index 0000000..d11177e
--- /dev/null
+++ b/rafale/models/configurations.py
@@ -0,0 +1,189 @@
+import requests
+import os
+
+from dataclasses import dataclass
+
+from safetensors import safe_open
+
+from ..caches import MODEL_CACHE_DIR
+
+"""
+to simplify model loading add a configuration for the pre-trained weight loading using safetensors instead of loading
+the full model.
+> then save to a folder named ".pretrained/" in this directory
+"""
+
+
+def download_file(url, save_path):
+ """
+ Downloads a file from the specified URL to the given save path.
+
+ :param url: The URL of the file to download.
+ :param save_path: The local path where the file will be saved.
+ """
+ try:
+ # Make sure the directory exists
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+ with requests.get(url, stream=True) as response:
+ response.raise_for_status() # Check for HTTP errors
+ total_size = int(response.headers.get("content-length", 0))
+ block_size = 1024 # 1 Kilobyte
+ progress = 0
+
+ with open(save_path, "wb") as file:
+ for data in response.iter_content(block_size):
+ file.write(data)
+ progress += len(data)
+ print(
+ f"Downloaded {progress} of {total_size} bytes ({(progress/total_size)*100:.2f}%)",
+ end="\r",
+ )
+
+ print(f"\nDownload completed successfully! File saved to: {save_path}")
+
+ except requests.exceptions.HTTPError as http_err:
+ print(f"HTTP error occurred: {http_err}") # HTTP error
+ except Exception as err:
+ print(f"An error occurred: {err}") # Other errors
+
+
+def load_safetensors(rafale_model, model_config):
+ """Transfer the pretrained safetensors to rafale model"""
+ tensors = {}
+
+ safetensors_path = os.path.join(MODEL_CACHE_DIR, model_config.name + ".safetensors")
+
+ if os.path.isfile(safetensors_path):
+ pass
+ else:
+ download_file(model_config.safetensors_url, safetensors_path)
+
+ with safe_open(safetensors_path, framework="pt") as f:
+ for k in f.keys():
+ tensors[k] = f.get_tensor(k)
+
+ rafale_model = model_config.convert_params_dict(rafale_model, tensors)
+
+ return rafale_model
+
+
+@dataclass
+class BertConfig:
+ embed_dim: int = 768
+ vocab_size: int = 30522
+ attention_dropout: float = 0.1
+ hidden_dropout: float = 0.1
+ num_heads: int = 12
+ ff_dim: int = 3072
+ max_pos_embedding: int = 512
+ layer_norm_eps: float = 1e-12
+ num_blocks: int = 12
+ pad_token_id: int = 0
+ num_token_type: int = 2
+ fast_attention: bool = (
+ False # use xformers (todo: add FlashAttention2), NOT IMPLEMENTED*
+ )
+
+
+@dataclass
+class BertTinyConfig:
+ embed_dim: int = 128
+ vocab_size: int = 30522 # could usage would be to 30522 + num_extra_tokens
+ attention_dropout: float = 0.1
+ hidden_dropout: float = 0.1
+ num_heads: int = 2
+ ff_dim: int = 512
+ max_pos_embedding: int = 512
+ layer_norm_eps: float = 1e-12
+ num_blocks: int = 2
+ pad_token_id: int = 0
+ num_token_type: int = 2
+ fast_attention: bool = False # use xformers (todo: add FlashAttention2)
+
+
+@dataclass
+class RobertaConfig:
+ embed_dim: int = 768
+ vocab_size: int = 50265
+ attention_dropout: float = 0.1
+ hidden_dropout: float = 0.1
+ num_heads: int = 12
+ ff_dim: int = 3072
+ max_pos_embedding: int = 514
+ layer_norm_eps: float = 1e-05
+ num_blocks: int = 12
+ pad_token_id: int = 1
+ num_token_type: int = 1
+ bos_token_id: int = 0
+ eos_token_id: int = 2
+ fast_attention: bool = False
+
+
+@dataclass
+class Pythia14MConfig:
+ name: str = "pythia14m"
+ safetensors_url: str = (
+ "https://huggingface.co/EleutherAI/pythia-14m/resolve/main/model.safetensors"
+ )
+
+ embed_dim: int = 128
+ num_heads: int = 4
+ ff_dim: int = 512
+ hidden_act: str = "gelu"
+ max_pos_embedding: int = 2048
+ vocab_size: int = 50304
+ use_cache: bool = True
+
+ parallel_residual: bool = True
+
+ attention_dropout: float = 0.1
+ hidden_dropout: float = 0.1
+
+ layer_norm_eps: float = 1e-05
+ num_blocks: int = 6
+
+ bos_token_id: int = 0
+ eos_token_id: int = 0
+ fast_attention: bool = False
+
+ rotary_emb_base: int = 10000
+ rotary_pct: float = 0.25
+
+ tie_word_embeddings: bool = False
+
+ @classmethod
+ def convert_params_dict(cls, target, source):
+ """
+ Source safetensors dict to our rafale model class.
+ """
+ # not needed for our implementation
+ unused = ["rotary_emb.inv_freq", "masked_bias", "attention.bias"]
+ for k, v in list(source.items()):
+ if True in [x in k for x in unused]:
+ del source[k]
+
+ conversion = [
+ ("gpt_neox.embed_in", "token_embeddings.input_embeddings"),
+ ("gpt_neox.layers", "layers"),
+ ("input_layernorm", "attention_norm"),
+ ("post_attention_layernorm", "ffn_norm"),
+ ("mlp", "feed_forward"),
+ ("dense_4h_to_h", "ff_out"),
+ ("dense_h_to_4h", "ff_in"),
+ ("embed_out", "output"),
+ ("gpt_neox.final_layer_norm", "final_norm"),
+ ]
+
+ updated_parameters = {}
+ for k, v in source.items():
+ for hf_term, my_term in conversion:
+ if hf_term in k:
+ k = k.replace(hf_term, my_term)
+
+ updated_parameters[k] = v
+
+ # here we transfer weights for all layers
+ target.load_state_dict(updated_parameters, strict=True)
+
+ return target
diff --git a/rafale/models/convert_hf_weights.py b/rafale/models/convert_hf_weights.py
index a6329a4..b187425 100644
--- a/rafale/models/convert_hf_weights.py
+++ b/rafale/models/convert_hf_weights.py
@@ -55,6 +55,43 @@ def convert_bert_params_dict(target, source):
return target
+def convert_pythia_params_dict(target, source):
+ """
+ Source safetensors dict to our rafale model class.
+ """
+
+ # not needed for our implementation
+ unused = ["rotary_emb.inv_freq", "masked_bias", "attention.bias"]
+ for k, v in list(source.items()):
+ if True in [x in k for x in unused]:
+ del source[k]
+
+ conversion = [
+ ("gpt_neox.embed_in", "token_embeddings.input_embeddings"),
+ ("gpt_neox.layers", "layers"),
+ ("input_layernorm", "attention_norm"),
+ ("post_attention_layernorm", "ffn_norm"),
+ ("mlp", "feed_forward"),
+ ("dense_4h_to_h", "ff_out"),
+ ("dense_h_to_4h", "ff_in"),
+ ("embed_out", "output"),
+ ("gpt_neox.final_layer_norm", "final_norm"),
+ ]
+
+ updated_parameters = {}
+ for k, v in source.items():
+ for hf_term, my_term in conversion:
+ if hf_term in k:
+ k = k.replace(hf_term, my_term)
+
+ updated_parameters[k] = v
+
+ # here we transfer weights for all layers
+ target.load_state_dict(updated_parameters, strict=True)
+
+ return target
+
+
def convert_roberta_params_dict(target, source):
conversion = [
("roberta.embeddings", "embedding_layer"),
diff --git a/rafale/models/decoder.py b/rafale/models/decoder.py
new file mode 100644
index 0000000..3f2dea2
--- /dev/null
+++ b/rafale/models/decoder.py
@@ -0,0 +1,581 @@
+#!/usr/bin/env python
+from typing import Optional
+import torch
+
+from torch import nn
+from torch import Tensor
+
+import torch.nn.functional as F
+
+from torch.nn.functional import scaled_dot_product_attention
+
+from torchmetrics import Metric
+from torchmetrics.collections import MetricCollection
+
+from composer.metrics import LossMetric, LanguagePerplexity
+from composer.models import ComposerModel
+
+
+###############################################################################
+# simple implementation of GPT building
+class NeoXRoPE(nn.Module):
+ @classmethod
+ def precompute_sin_cos_cache(cls, dim=None, seq_len=None, base=10000, device=None):
+ """Computes the cos and sin angles to be applied to the token vectors.
+
+ We begin by computing thetas (freqs) across each dimension pair (P=D/2) for the whole sequence length (L).
+ Then we convert this matrix of shape LP into complex numbers of the same shape.
+ Finally the real and imaginary parts of these complex numbers are stored in a stacked matrix and returned.
+
+ Args:
+ dim (int): number of features dimension per token to apply rotations to (d*rotary_pct)
+ seq_len (int): sequence length of the input (use the maximum sequence length)
+ base (int): default 10000
+
+ Returns:
+ Tensor # of shape [1,1,L,R]
+
+ Tensor dimension names:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+ - R rotary dimensions (d*rotary_pct)
+ """
+
+ inv_freq = 1.0 / (
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim)
+ )
+ t = torch.arange(seq_len, dtype=torch.int64).type_as(inv_freq)
+ freqs = torch.einsum("i,j->ij", t, inv_freq)
+ emb = torch.cat((freqs, freqs), dim=-1)
+ cos_cached = emb.cos()[None, None, :, :] # shape is [1, 1, L, R]
+ sin_cached = emb.sin()[None, None, :, :]
+
+ return cos_cached, sin_cached
+
+ @staticmethod
+ def rotate_half(x):
+ x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
+ return torch.cat((-x2, x1), dim=x1.ndim - 1)
+
+ @classmethod
+ def apply_rotary_pos_emb(cls, q_BHLR, k_BHLR, cos, sin):
+ """Applies the rotation to the input queries and key features."""
+
+ tensor_device = q_BHLR.get_device()
+
+ cos = cos.to(device=tensor_device)
+ sin = sin.to(device=tensor_device)
+
+ return (q_BHLR * cos) + (cls.rotate_half(q_BHLR) * sin), (k_BHLR * cos) + (
+ cls.rotate_half(k_BHLR) * sin
+ )
+
+ @classmethod
+ def apply_rotary_pos_emb_offset(cls, q, k, cos, sin, offset: int = 0):
+ """
+ q and k are shape: BHLR
+ cos, sin are shape: 11LR
+ """
+ cos, sin = (
+ cos[:, :, offset : q.shape[2] + offset, :],
+ sin[:, :, offset : q.shape[2] + offset, :],
+ )
+
+ tensor_device = q.get_device()
+
+ cos = cos.to(device=tensor_device)
+ sin = sin.to(device=tensor_device)
+
+ return (q * cos) + (cls.rotate_half(q) * sin), (k * cos) + (
+ cls.rotate_half(k) * sin
+ )
+
+
+class DecoderEmbedding(nn.Module):
+ """Simply an input projection of the tokens here, since rotary position encodings are used.
+
+ Tensor dimension names:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.input_embeddings = nn.Embedding(
+ num_embeddings=config.vocab_size,
+ embedding_dim=config.embed_dim,
+ )
+ self.dropout = nn.Dropout(config.hidden_dropout)
+
+ def forward(self, x_BL):
+ x_BLD = self.input_embeddings(x_BL)
+ return self.dropout(x_BLD)
+
+
+class DecoderAttentionRotary(nn.Module):
+ """
+ Attention with rotary position embedding.
+
+ Tensor dimension names:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+
+ self.head_dim = config.embed_dim // config.num_heads
+ self.num_heads = config.num_heads
+ self.embed_dim = config.embed_dim
+ self.rotary_ndims = int(self.head_dim * config.rotary_pct)
+
+ self.attention_bias = True # @TODO: set bias to True or False from config.
+ self.query_key_value = nn.Linear(config.embed_dim, 3 * config.embed_dim)
+ self.dense = nn.Linear(config.embed_dim, config.embed_dim)
+ self.dropout = nn.Dropout(p=config.attention_dropout)
+ self.dropout_p = config.attention_dropout
+ self.norm_factor = self.head_dim**-0.5
+
+ def _split_heads(self, tensor: Tensor):
+ """
+ Splits hidden dim into attn_head_size and num_attention_heads
+
+ # input tensor: [bs, seq_len, hidden_size]
+ # returns: [bs, num_attention_heads, seq_len, attn_head_size]
+ """
+ batch_size = tensor.size(0)
+
+ return (
+ tensor.view(batch_size, -1, self.num_heads, self.head_dim)
+ .transpose(1, 2)
+ .contiguous()
+ )
+
+ def _merge_heads(self, tensor: Tensor):
+ """
+ input tensor: [bs. num_attention_heads, seq_len, attn_head_size]
+ returns: [bs, seq_len, hidden_size]
+ """
+ # tensor [bs, num_attention_heads, seq_len, attn_head_size]
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ # -> [bs, seq_len, num_attention_heads, attn_head_size]
+ tensor = tensor.view(
+ tensor.size(0),
+ tensor.size(1),
+ self.num_heads * self.head_dim,
+ ).contiguous()
+ # -> [bs, seq_len, hidden_size]
+ return tensor
+
+ def forward(self, x_BLD, freqs_cis):
+ if not self.training:
+ self.dropout_p = 0
+
+ bsz, seq_len, _ = x_BLD.size()
+
+ assert freqs_cis is not None
+
+ # projections
+ qkv = self.query_key_value(x_BLD)
+
+ # [batch, seq_len, (num_heads * 3 * head_size)]
+ # --> [batch, seq_len, num_heads, 3 * head_size]
+ new_qkv_shape = qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
+ qkv = qkv.view(*new_qkv_shape)
+
+ # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
+ q_BHLd = qkv[..., : self.head_dim].permute(0, 2, 1, 3)
+ k_BHLd = qkv[..., self.head_dim : 2 * self.head_dim].permute(0, 2, 1, 3)
+ v_BHLd = qkv[..., 2 * self.head_dim :].permute(0, 2, 1, 3)
+
+ # Slice the precomputed freqs_cis based on actual seq_len --> [1, 1, seq_len, R]
+ cos = freqs_cis[0][:, :, :seq_len, :]
+ sin = freqs_cis[1][:, :, :seq_len, :]
+
+ q_rot = q_BHLd[..., : self.rotary_ndims]
+ q_pass = q_BHLd[..., self.rotary_ndims :]
+ k_rot = k_BHLd[..., : self.rotary_ndims]
+ k_pass = k_BHLd[..., self.rotary_ndims :]
+
+ q_rot, k_rot = NeoXRoPE.apply_rotary_pos_emb(q_rot, k_rot, cos, sin)
+
+ q_BHLd = torch.cat((q_rot, q_pass), dim=-1)
+ k_BHLd = torch.cat((k_rot, k_pass), dim=-1)
+
+ # compute attention
+ attn_out_BHLd = scaled_dot_product_attention(
+ q_BHLd,
+ k_BHLd,
+ v_BHLd,
+ is_causal=True,
+ scale=self.norm_factor,
+ dropout_p=self.dropout_p,
+ )
+
+ attn_out_BLD = self._merge_heads(attn_out_BHLd)
+
+ attn_out_BLD = self.dense(attn_out_BLD)
+
+ return attn_out_BLD
+
+
+class DecoderAttentionRotaryKVCache(DecoderAttentionRotary):
+ """implements the KV cache mechanism"""
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ def forward(self, x_BLD, freqs_cis, causal_mask=None, past_kv=None):
+ # A) figure out if we passed a cached KV
+ assert freqs_cis is not None
+ has_past_kv = past_kv is not None and past_kv[0].numel() > 0
+
+ if not self.training:
+ self.dropout_p = 0
+
+ bsz, seq_len, _ = x_BLD.size()
+ # B) if we have a cached KV, apply the offset to the sequence length for RoPE
+ if has_past_kv:
+ offset = past_kv[0].shape[
+ 2
+ ] # we want the lenght here, our kv shape is BHLd
+ seq_len += offset
+ else:
+ offset = 0
+
+ # projections
+ qkv = self.query_key_value(x_BLD)
+
+ # [batch, seq_len, (num_heads * 3 * head_size)]
+ # --> [batch, seq_len, num_heads, 3 * head_size]
+ new_qkv_shape = qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim)
+ qkv = qkv.view(*new_qkv_shape)
+
+ # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
+ q_BHLd = qkv[..., : self.head_dim].permute(0, 2, 1, 3)
+ k_BHLd = qkv[..., self.head_dim : 2 * self.head_dim].permute(0, 2, 1, 3)
+ v_BHLd = qkv[..., 2 * self.head_dim :].permute(0, 2, 1, 3)
+
+ # Slice the precomputed freqs_cis based on actual seq_len --> [1, 1, seq_len, R]
+ cos = freqs_cis[0][:, :, :seq_len, :]
+ sin = freqs_cis[1][:, :, :seq_len, :]
+
+ q_rot = q_BHLd[..., : self.rotary_ndims]
+ q_pass = q_BHLd[..., self.rotary_ndims :]
+ k_rot = k_BHLd[..., : self.rotary_ndims]
+ k_pass = k_BHLd[..., self.rotary_ndims :]
+
+ q_rot, k_rot = NeoXRoPE.apply_rotary_pos_emb_offset(
+ q_rot, k_rot, cos, sin, offset=offset
+ )
+
+ q_BHLd = torch.cat((q_rot, q_pass), dim=-1)
+ k_BHLd = torch.cat((k_rot, k_pass), dim=-1)
+
+ # C) before scaled_dot_product_attention we are going to
+ # Cache QKV values
+ if has_past_kv:
+ past_key, past_value = past_kv
+
+ k_BHLd = torch.cat((past_key.type_as(k_BHLd), k_BHLd), dim=2)
+ v_BHLd = torch.cat((past_value.type_as(v_BHLd), v_BHLd), dim=2)
+
+ # kv_cache = torch.stack((k_BHLd, v_BHLd))
+ kv_cache = (k_BHLd, v_BHLd) # let's keep them as a list of tuples
+ # ####################################################################
+
+ # print(f"key device {k_BHLd.get_device()}")
+ # print(f"query device {q_BHLd.get_device()}")
+ # print(f"value device {v_BHLd.get_device()}")
+ # print(f"causal mask device {causal_mask.get_device()}")
+
+ tensor_device = q_BHLd.get_device()
+ if causal_mask.get_device() != tensor_device:
+ causal_mask = causal_mask.to(device=tensor_device)
+
+ # compute attention here
+ attn_out_BHLd = scaled_dot_product_attention(
+ q_BHLd,
+ k_BHLd,
+ v_BHLd,
+ is_causal=False,
+ attn_mask=causal_mask,
+ scale=self.norm_factor,
+ dropout_p=self.dropout_p,
+ ) # even with rectangular matrices scaled_dot_product_attention will handle the causal mask by apply left bias
+ # causal mask which is exactly what we need.
+
+ attn_out_BLD = self._merge_heads(attn_out_BHLd)
+
+ attn_out_BLD = self.dense(attn_out_BLD)
+
+ return attn_out_BLD, kv_cache
+
+
+# ^^^^^ #######################################################################
+
+
+class DecoderFeedForward(nn.Module):
+ """
+ Tensor dimension names:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.ff_in = nn.Linear(config.embed_dim, config.ff_dim)
+ self.gelu = nn.GELU()
+ self.ff_out = nn.Linear(config.ff_dim, config.embed_dim)
+
+ def forward(self, x_BLD):
+ x_BLF = self.ff_in(x_BLD)
+ x_BLF = self.gelu(x_BLF)
+ out_BLD = self.ff_out(x_BLF)
+ return out_BLD
+
+
+class DecoderBlock(nn.Module):
+ """A single trasnformer decoder block/layer.
+
+ Tensor dimension names:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+ """
+
+ def __init__(self, config):
+ super().__init__()
+ self.attention = DecoderAttentionRotary(config)
+ self.feed_forward = DecoderFeedForward(config)
+ self.ffn_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
+ self.attention_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
+
+ def forward(self, x_BLD, freqs_cis, parallel_residual=True, use_cache=True):
+ assert freqs_cis is not None
+
+ if parallel_residual:
+ out_BLD = (
+ x_BLD
+ + self.attention(self.attention_norm(x_BLD), freqs_cis)
+ + self.feed_forward(self.ffn_norm(x_BLD))
+ )
+ else:
+ h_BLD = x_BLD + self.attention(self.attention_norm(x_BLD), freqs_cis)
+ out_BLD = h_BLD + self.feed_forward(self.ffn_norm(h_BLD))
+
+ return out_BLD
+
+
+# handle KV cache state
+class DecoderBlockKVcache(DecoderBlock):
+ """A single trasnformer decoder block/layer.
+
+ Tensor dimension names:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+ """
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.attention = DecoderAttentionRotaryKVCache(config)
+ self.feed_forward = DecoderFeedForward(config)
+ self.ffn_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
+ self.attention_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
+
+ def forward(
+ self, x_BLD, freqs_cis, causal_mask, layer_kv_cache, parallel_residual=True
+ ):
+ assert freqs_cis is not None
+
+ if parallel_residual:
+ attn_out_BLD, layer_kv_cache = self.attention(
+ self.attention_norm(x_BLD), freqs_cis, causal_mask, layer_kv_cache
+ )
+ out_BLD = x_BLD + attn_out_BLD + self.feed_forward(self.ffn_norm(x_BLD))
+
+ else:
+ attn_out_BLD, layer_kv_cache = self.attention(
+ self.attention_norm(x_BLD), freqs_cis, layer_kv_cache
+ )
+ h_BLD = x_BLD + attn_out_BLD
+ out_BLD = h_BLD + self.feed_forward(self.ffn_norm(h_BLD))
+
+ return out_BLD, layer_kv_cache
+
+
+class DecoderWrapper(nn.Module):
+ """Full model wrapper for causal language modelling."""
+
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.token_embeddings = DecoderEmbedding(config)
+
+ if self.config.use_cache:
+ self.layers = nn.ModuleList(
+ DecoderBlockKVcache(config) for _ in range(config.num_blocks)
+ )
+ else:
+ self.layers = nn.ModuleList(
+ DecoderBlock(config) for _ in range(config.num_blocks)
+ )
+
+ self.final_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps)
+ self.output = nn.Linear(config.embed_dim, config.vocab_size, bias=False)
+
+ self.freqs_cis: Optional[Tensor] = None
+
+ self.max_batch_size = -1
+ self.max_seq_length = config.max_pos_embedding
+
+ self.rotary_pct = 0.25
+
+ self.vocab_size = config.vocab_size
+
+ def setup_caches(self):
+ head_dim = self.config.embed_dim // self.config.num_heads
+ dtype = self.output.weight.dtype
+
+ head_size = self.config.embed_dim // self.config.num_heads
+ rotary_ndims = int(head_size * self.rotary_pct)
+ self.cos, self.sin = NeoXRoPE.precompute_sin_cos_cache(
+ dim=rotary_ndims,
+ seq_len=self.config.max_pos_embedding,
+ )
+
+ self.freqs_cis = (self.cos, self.sin)
+
+ def _generate_causal_mask(self, cache_length: int, new_length: int) -> torch.Tensor:
+ """
+ Creates a causal mask for autoregressive attention with KV caching.
+
+ Args:
+ cache_length (int): Number of cached tokens (K).
+ new_length (int): Number of new tokens to generate (T).
+ device (torch.device): The device on which to create the mask.
+ dtype (torch.dtype): The data type of the mask tensor.
+
+ Returns:
+ torch.Tensor: A (K + T) x (K + T) mask tensor where:
+ - Cached tokens can attend to all cached and new tokens.
+ - New tokens can attend to all cached tokens and up to their current position in new tokens.
+ - Future new tokens are masked to prevent attention.
+ """
+
+ causal_mask = torch.tril(torch.ones(new_length, new_length)).bool()
+
+ if cache_length > 0:
+ cache_tokens_mask = torch.ones((new_length, cache_length)).bool()
+ causal_mask = torch.cat((cache_tokens_mask, causal_mask), dim=1)
+
+ return causal_mask
+
+ def forward(self, batch: Tensor, past_kv_cache=None):
+ # if self.freqs_cis is None or self.causal_mask is None:
+ if self.freqs_cis is None:
+ self.setup_caches() # Caches must be initialized first
+
+ freqs_cis = self.freqs_cis
+
+ idx = batch["input_ids"]
+ num_new_tokens = idx.size(1)
+
+ x = self.token_embeddings(idx)
+
+ if self.config.use_cache and past_kv_cache is None:
+ past_kv_cache = [None] * self.config.num_blocks
+
+ if past_kv_cache[0] is not None:
+ num_cache_tokens = past_kv_cache[0][0].size(2)
+ else:
+ num_cache_tokens = 0
+ causal_mask = self._generate_causal_mask(num_cache_tokens, num_new_tokens)
+
+ kv_cache_list = []
+
+ for i, layer in enumerate(self.layers):
+ if self.config.use_cache:
+ x, layer_kv_cache = layer(x, freqs_cis, causal_mask, past_kv_cache[i])
+ kv_cache_list.append(layer_kv_cache)
+ else:
+ x = layer(x, freqs_cis)
+ x = self.final_norm(x)
+ logits = self.output(x)
+
+ return logits, kv_cache_list
+
+
+class ComposerLM(ComposerModel):
+ """wrapper with nice properties for simple training and evaluation"""
+
+ def __init__(self, config):
+ "docstring"
+ super().__init__()
+ self.model = DecoderWrapper(config)
+ self.ce_loss = nn.CrossEntropyLoss()
+ self.train_metrics = MetricCollection(
+ [LossMetric(self.ce_loss), LanguagePerplexity()]
+ )
+ self.eval_metrics = MetricCollection(
+ [LossMetric(self.ce_loss), LanguagePerplexity()]
+ )
+
+ def forward(self, batch): # batch is the output of the dataloader
+ """batch is a dict with "input_ids" key, model also takes past_kv"""
+ # specify how batches are passed through the model
+ return self.model(batch)
+
+ def eval_forward(self, batch, outputs=False):
+ if outputs:
+ if type(outputs) is tuple:
+ outputs, _ = outputs
+ return outputs
+
+ outputs = self.model(batch)
+ if type(outputs) is tuple:
+ outputs, _ = outputs
+
+ return outputs
+
+ def update_metric(self, batch, outputs, metric) -> None:
+ targets = batch["labels"]
+ metric.update(outputs.view(-1, self.model.vocab_size), targets.view(-1))
+
+ def get_metrics(self, is_train=False) -> dict[str, Metric]:
+ # defines which metrics to use in each phase of training
+ return self.train_metrics if is_train else self.eval_metrics
+
+ def loss(self, outputs, batch):
+ targets = batch["labels"]
+
+ if type(outputs) is tuple:
+ outputs, _ = outputs
+
+ return self.ce_loss(outputs.view(-1, self.model.vocab_size), targets.view(-1))
diff --git a/rafale/models/decoding_strategies.py b/rafale/models/decoding_strategies.py
new file mode 100644
index 0000000..8e11eb0
--- /dev/null
+++ b/rafale/models/decoding_strategies.py
@@ -0,0 +1,91 @@
+import torch
+import torch.nn.functional as F
+
+
+def repeat_ngram(input_ids, ngram_max_length=4):
+ """
+ Clean up output by checking for repeated n-grams of length less than ngram_max_length.
+ If it finds a repeated n-gram, it returns the n-gram; otherwise, it returns 0.
+
+ Args:
+ input_ids: Tensor of shape (seq_length,), the sequence of input token IDs.
+ ngram_max_length: The maximum length of the n-gram to check for repetition.
+
+ Returns:
+ A list of tokens representing the repeated n-gram if found, otherwise 0.
+ """
+ ngram_found = False
+ ngram_tokens = None
+
+ # Create a set to store seen n-grams
+ seen_ngrams = set()
+
+ print(input_ids)
+ # Iterate through possible n-gram lengths
+ for n in range(1, ngram_max_length + 1):
+ for i in range(len(input_ids) - n + 1):
+ ngram = tuple(input_ids[i : i + n].tolist())
+ if ngram in seen_ngrams:
+ ngram_found = True
+ ngram_tokens = list(ngram)
+ break
+ seen_ngrams.add(ngram)
+ if ngram_found:
+ break
+
+ if ngram_found:
+ return ngram_tokens
+ else:
+ return 0
+
+
+def greedy_decode(model, batch, max_length, eos_token_id, check_repeat_ngrams=True):
+ """
+ Implements greedy decoding for the rafale transformer model.
+
+ Args:
+ model: The decoder model (e.g., DecoderWrapper).
+ batch: Dictionary containing input_ids of shape (batch_size, seq_length), the input prompt.
+ max_length: The maximum length of the generated sequence.
+ eos_token_id: The ID of the end-of-sequence token.
+
+ Returns:
+ Dictionary containing input_ids of shape (batch_size, max_length) with the generated tokens.
+ """
+ batch_size = batch["input_ids"].size(0)
+ if batch_size != 1:
+ raise ValueError(
+ "greedy_decode currently only supports batch_size=1. Provided batch_size: {}".format(
+ batch_size
+ )
+ )
+
+ input_seq_len = batch["input_ids"].size(1)
+ kv_cache_list = None
+
+ # Generate tokens until max_length or eos_token is generated
+ for _ in range(max_length - input_seq_len):
+ # Forward pass through the model
+ outputs, kv_cache_list = model(batch, kv_cache_list)
+ logits = outputs[:, -1, :] # Get the logits for the last generated token
+
+ # Greedily select the token with the highest probability
+ next_token = torch.argmax(logits, dim=-1).unsqueeze(-1) # Shape: (1, 1)
+
+ # Append the predicted token to the generated sequence
+ batch["input_ids"] = torch.cat((batch["input_ids"], next_token), dim=1)
+
+ # Check for repeated n-grams and stop if detected
+ if check_repeat_ngrams:
+ repeated_ngram = repeat_ngram(
+ batch["input_ids"].squeeze(), ngram_max_length=4
+ )
+ if repeated_ngram != 0:
+ print(repeated_ngram)
+ break
+
+ # Check if the sequence has generated the eos_token_id
+ if next_token.item() == eos_token_id:
+ break
+
+ return batch
diff --git a/rafale/models/encoder.py b/rafale/models/encoder.py
index cdabd85..4c7ee04 100644
--- a/rafale/models/encoder.py
+++ b/rafale/models/encoder.py
@@ -1,15 +1,34 @@
from dataclasses import dataclass
+
import torch
import torch.nn.functional as F
-# import xformers.ops as xops
-
from torch import nn
from torch.nn.functional import scaled_dot_product_attention
+###############################################################################
+# SIMPLE BERT-like BUILDING BLOCKS #
+###############################################################################
+
+
+# @TODO :: Refactor, improve documentation and add tensor dimension keys for the names
+
+
class Embedding(nn.Module):
- """embeddings from word, absolute/fixe position (and token_type embedding?)"""
+ """Embeddings
+
+ In addition to the word embedding, BERT uses learned absolute position embeddings. We also have token type embedding for one the BERT pretraining
+ objectives.
+
+ Tensor dimension keys:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
+ """
def __init__(
self,
@@ -67,98 +86,19 @@ def forward(self, input_ids, token_type_ids):
return E
-class MultiHeadAttention(nn.Module):
- def __init__(self, n_heads, embed_dim, dropout_p=0.1, fast_attn=False):
- "uses xformers memory effeicient attention"
- super().__init__()
- self.dropout_p = dropout_p
- self.fast_attn = fast_attn
- assert embed_dim % n_heads == 0
- # We assume d_v always equals d_k
- self.head_dim = embed_dim // n_heads
- self.n_heads = n_heads
- self.embed_dim = embed_dim
- self.all_head_size = n_heads * self.head_dim
-
- # @TODO get linear projections
- self.query = nn.Linear(embed_dim, embed_dim)
- self.key = nn.Linear(embed_dim, embed_dim)
- self.value = nn.Linear(embed_dim, embed_dim)
-
- # output
- self.out = nn.Linear(embed_dim, embed_dim)
-
- def forward(self, q, k, v):
- # @HERE TODO
- """
- so input q, k, v are essentially the same tensor since we're doing self attention but the idea here is that this
- same implementation could be used for cross attention.
-
- in self-attention they have dimension [batch_size, sequence_length, embedding_dimension] so for bert that would
- be like [4, 512, 768] for example.
-
- each attention head will handle part of the embedding dimensions (wow I didn't know that and don't fully
- understand why...). So this is why we want to have embed_dim % n_head == 0.
-
- (1) we use view to reshape the tensor into shape [batch_size, seq_length, n_head, head_embed] --> .view(batch_size,
- -1, self.num_heads, self.head_dim)
- (2) then we transpose seq_length and n_head to parrellalize the computations during the attention computations
- --> .transpose(1, 2)
-
-
- ## Summary of Shape Changes
- Input: [batch_size, seq_length, embed_dim]
- Post Linear Layer: [batch_size, seq_length, embed_dim] (same shape, but transformed)
- View for Heads: [batch_size, seq_length, num_heads, head_dim]
- Transpose for Heads: [batch_size, num_heads, seq_length, head_dim]
-
- ## after having applied attention
- We receive a tensor of shape [batch_size, num_heads, seq_length, head_dim] (same as before)
- Now we want to get back to our original embedding and sequence shape so first we swap back num_head and
- seq_length with --> .transpose(1,2)
- Then we want to aggregate our head_dim to have our full embedding space back up together again with -->
- .view(batch_size, -1, self.embed_dim)
- and we get shape [batch_size, seq_length, embed_dim] at the end
-
- """
-
-
-# @TODO: separate the attention modules to more easily inspect tensors
class EncoderSelfAttention(nn.Module):
- """just MHA with the various implementations
- so input q, k, v are essentially the same tensor since we're doing self attention but the idea here is that this
- same implementation could be used for cross attention.
-
- in self-attention they have dimension [batch_size, sequence_length, embedding_dimension] so for bert that would
- be like [4, 512, 768] for example.
-
- each attention head will handle part of the embedding dimensions (wow I didn't know that and don't fully
- understand why...). So this is why we want to have embed_dim % n_head == 0.
-
- (1) we use view to reshape the tensor into shape [batch_size, seq_length, n_head, head_embed] --> .view(batch_size,
- -1, self.num_heads, self.head_dim)
- (2) then we transpose seq_length and n_head to parrellalize the computations during the attention computations
- --> .transpose(1, 2)
-
-
- ## Summary of Shape Changes
- Input: [batch_size, seq_length, embed_dim]
- Post Linear Layer: [batch_size, seq_length, embed_dim] (same shape, but transformed)
- View for Heads: [batch_size, seq_length, num_heads, head_dim]
- Transpose for Heads: [batch_size, num_heads, seq_length, head_dim]
-
- ## after having applied attention
- We receive a tensor of shape [batch_size, num_heads, seq_length, head_dim] (same as before)
- Now we want to get back to our original embedding and sequence shape so first we swap back num_head and
- seq_length with --> .transpose(1,2)
- Then we want to aggregate our head_dim to have our full embedding space back up together again with -->
- .view(batch_size, -1, self.embed_dim)
- and we get shape [batch_size, seq_length, embed_dim] at the end
-
+ """Bidirectional multi-head self attention.
+
+ Tensor dimension keys:
+ - B batch size
+ - L sequence length
+ - H number of attention heads
+ - D embedding dimension
+ - d attention head dimension D//H
+ - F feedforward dimension
"""
def __init__(self, n_heads, embed_dim, dropout_p=0.1, fast_attn=False):
- "uses xformers memory effeicient attention"
super().__init__()
self.dropout_p = dropout_p
self.fast_attn = fast_attn
@@ -176,12 +116,10 @@ def __init__(self, n_heads, embed_dim, dropout_p=0.1, fast_attn=False):
self.value = nn.Linear(embed_dim, embed_dim)
def forward(self, q, k, v):
- # @HERE TODO
""""""
batch_size = q.size(0)
if not self.training:
self.dropout_p = 0
- print("model not training, attention dropout is 0")
# check transformation again here....
q = (
@@ -200,26 +138,14 @@ def forward(self, q, k, v):
.transpose(1, 2)
)
- # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention
- # check for flash-attn-2? optional
- if self.fast_attn:
- raise Exception("fast attention not implemented yet")
- # attn_output = xops.memory_efficient_attention(
- # q,
- # k,
- # v,
- # p=self.dropout_p,
- # )
-
- else:
- attn_output = scaled_dot_product_attention(
- q,
- k,
- v,
- dropout_p=self.dropout_p,
- )
+ attn_output = scaled_dot_product_attention(
+ q,
+ k,
+ v,
+ dropout_p=self.dropout_p,
+ )
- # Concatenate heads and put through final linear layer
+ # concatenate heads and put through final linear layer
attn_output = (
attn_output.transpose(1, 2)
.contiguous()
@@ -267,7 +193,7 @@ def __init__(self, embed_dim, eps=None, dropout_p=None):
self.dropout = nn.Dropout(dropout_p)
def forward(self, x, residual):
- x = self.dropout(x)
+ x = self.dropout(x) # @TODO :: make sure this should be here...
x = self.ln(x + residual)
return x
@@ -277,10 +203,6 @@ class EncoderBlock(nn.Module):
def __init__(self, embed_dim, n_heads, ff_dim, eps=None, dropout_p=None):
super().__init__()
- # self.mha = MultiHeadAttention(
- # n_heads=n_heads, embed_dim=embed_dim, dropout_p=dropout_p
- # )
-
self.attention = AttentionModule(
n_heads=n_heads, embed_dim=embed_dim, dropout_p=dropout_p
)
@@ -372,82 +294,6 @@ def forward(self, **kwargs):
return x
- def compute_decoupled_label_loss(self, logits, labels, permutations, tokenizer):
- """ """
- batch_size = logits.size(0)
- F.softmax(logits, dim=-1)
- letters = list(map(chr, range(97, 123)))
-
- # Define the loss functions
- # same used as adapet
-
- nn.BCELoss(reduction="none")
- nn.BCEWithLogitsLoss(reduction="none")
-
- losses = []
-
- for i in range(batch_size):
- num_labels = permutations[i]
- # get all the relevant ids (choices) from the sample
- relevant_ids = tokenizer.convert_tokens_to_ids(letters[:num_labels])
-
- # the token id of the positive label of the example
- example_label = labels[i]
- example_label_id = example_label[example_label != -100].item()
-
- if len(relevant_ids) == 1:
- # caveat if example only has one choice...
- example_bad_ids = []
- else:
- relevant_ids.remove(example_label_id)
- example_bad_ids = relevant_ids
-
- # get the logit predictions from the mask token
- l = logits[i]
- # l = probabilities[i]
- mask_token_logits = l[labels[i] != -100]
- mask_token_logits = torch.flatten(mask_token_logits)
-
- indices = [x for x in range(l.size(1)) if x not in relevant_ids]
- non_choice_logits = torch.index_select(
- mask_token_logits, 0, torch.tensor(indices)
- )
-
- # probability logits for the positive label
- positive_prediction = mask_token_logits[example_label_id]
-
- negative_losses = []
- negative_label = torch.zeros([])
- # probability logits for the negative labels
- for idx in example_bad_ids:
- # negative_predictions.append(mask_token_logits[idx])
- negative_losses.append(
- self.bcel_loss(mask_token_logits[idx], negative_label)
- )
-
- nulltoken_labels = torch.zeros(len(indices)) # device="cuda:0"
- positive_labels = torch.ones([])
-
- # mean of bad labels bcel
- nulltoken_label_loss = torch.mean(
- self.bcel_loss(non_choice_logits, nulltoken_labels)
- )
- negative_label_loss = torch.sum(torch.stack(negative_losses))
- positive_label_loss = self.bcel_loss(positive_prediction, positive_labels)
-
- losses.append(
- torch.sum(
- torch.stack(
- [
- positive_label_loss.view(1),
- nulltoken_label_loss.view(1),
- negative_label_loss.view(1),
- ]
- )
- )
- )
- return torch.mean(torch.stack(losses))
-
def compute_loss(self, logits, labels):
""" """
ce_loss = nn.CrossEntropyLoss(ignore_index=-100)
@@ -460,37 +306,3 @@ def compute_loss(self, logits, labels):
# Compute and return the loss
return ce_loss(logits, labels)
-
-
-# DELETE THIS
-def get_tokens_from_logits(logits, tokenizer=None):
- """
- return the prediced tokens for all of the inputs
- """
- # Apply softmax to convert logits to probabilities
- probabilities = F.softmax(logits, dim=-1)
-
- # Get the predicted token IDs
- predicted_token_ids = torch.argmax(probabilities, dim=-1)
-
- predicted_tokens = [
- tokenizer.convert_ids_to_tokens(seq.numpy())
- for seq in torch.unbind(predicted_token_ids, dim=0)
- ]
- return predicted_tokens
-
-
-@dataclass
-class BertConfig:
- embed_dim: int = 768
- vocab_size: int = 30522 # could usage would be to 30522 + num_extra_tokens
- attention_dropout: float = 0.1
- hidden_dropout: float = 0.1
- num_heads: int = 12
- ff_dim: int = 3072
- max_pos_embedding: int = 512
- layer_norm_eps: float = 1e-12
- num_blocks: int = 12
- pad_token_id: int = 0
- num_token_type: int = 2
- fast_attention: bool = False # use xformers (todo: add FlashAttention2)
diff --git a/rafale/models/model_utils.py b/rafale/models/model_utils.py
new file mode 100644
index 0000000..04caed1
--- /dev/null
+++ b/rafale/models/model_utils.py
@@ -0,0 +1,15 @@
+def get_tokens_from_logits(logits, tokenizer=None):
+ """
+ return the prediced tokens for all of the inputs
+ """
+ # Apply softmax to convert logits to probabilities
+ probabilities = F.softmax(logits, dim=-1)
+
+ # Get the predicted token IDs
+ predicted_token_ids = torch.argmax(probabilities, dim=-1)
+
+ predicted_tokens = [
+ tokenizer.convert_ids_to_tokens(seq.numpy())
+ for seq in torch.unbind(predicted_token_ids, dim=0)
+ ]
+ return predicted_tokens
diff --git a/rafale/test.yaml b/rafale/test.yaml
deleted file mode 100644
index 690e14e..0000000
--- a/rafale/test.yaml
+++ /dev/null
@@ -1,9 +0,0 @@
-run:
- name: "test"
-
-model:
- name: "test-model"
- bin: "./"
-
-data:
- training_path: "training_path"
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 49f3772..0000000
--- a/setup.py
+++ /dev/null
@@ -1,11 +0,0 @@
-from setuptools import setup
-
-setup(
- name="rafale",
- version="0.1",
- description="simple transformer training lib",
- url="https://github.com/maxrousseau/rafale",
- author="maxime rousseau",
- packages=["rafale"],
- zip_safe=False,
-)
diff --git a/test/pythia_tinystories.yaml b/test/pythia_tinystories.yaml
new file mode 100644
index 0000000..b9d4c04
--- /dev/null
+++ b/test/pythia_tinystories.yaml
@@ -0,0 +1,43 @@
+run:
+ name: "pythia14m-tinystories" # name of your experiment, used for checkpointing
+ seed: 42
+ n_epochs: 1
+ max_lr: 6e-04
+ warmup_pct: 0.01
+ schedule: "cosine-warmup" # linear, linear-warmup, cosine, cosine-warmup
+ optimizer: "AdamW"
+ eval_interval: "50ba"
+ clip_type: "norm"
+ clip_value: 1.0
+ device_bs: "auto"
+ save_interval: "200ba"
+
+model:
+ config: "pythia14m" # config key
+ type: "decoder"
+ use_pretrained: True
+
+data:
+ pipeline: "tinystories_neox" # the preprocessing/tokenization pipeline
+ config:
+ name: "tinystories"
+ num_processes: 8
+ tokenizer_name: "neox"
+ is_prepared: False
+ input_id_key: "input_ids"
+ train_batch_size: 1024
+ eval_batch_size: 16
+ shuffle_train: False
+ dataset_path: "~/code/data/TinyStories"
+ tokenizer_path: "EleutherAI/pythia-14m"
+ max_sequence_length: 512
+ pad_token_id: -100
+ pad_inputs: True
+ is_prepared: False
+
+logging: # @TODO :: not implemented
+ use_wandb: True
+ use_file: False
+ eval_interval: "10ba"
+ log_dir: "./run_logs"
+ checkpoint_dir: "./checkpoints"
diff --git a/test/test.yaml b/test/test.yaml
new file mode 100644
index 0000000..856f851
--- /dev/null
+++ b/test/test.yaml
@@ -0,0 +1,42 @@
+# we want data and model configurations to be in files rather than in yaml, leave training hyperparams to yaml config only
+
+run:
+ name: "test-ministories" # name of your experiment, used for checkpointing
+ seed: 42
+ n_epochs: 1
+ max_lr: 3e-04
+ warmup_pct: 0.01
+ schedule: "cosine-warmup" # linear, linear-warmup, cosine, cosine-warmup
+ optimizer: "AdamW"
+ eval_interval: "50ba"
+ clip_type: "norm"
+ clip_value: 1.0
+ device_bs: "auto"
+ save_interval: "1ep"
+
+model:
+ config: "pythia14m" # config key
+ type: "decoder"
+ use_pretrained: True
+
+data:
+ pipeline: "tinystories_neox" # the preprocessing/tokenization pipeline
+ config:
+ name: "tinystories_testing"
+ num_processes: 1
+ tokenizer_name: "neox"
+ is_prepared: False
+ input_id_key: "input_ids"
+ train_batch_size: 16
+ eval_batch_size: 16
+ shuffle_train: False
+ dataset_path: "~/code/data/micro_tinystories"
+ tokenizer_path: "EleutherAI/pythia-14m"
+ max_sequence_length: 128
+ pad_token_id: -100
+ pad_inputs: True
+ is_prepared: False
+
+logging:
+ use_wandb: True
+ use_file: False
diff --git a/rafale/models/debug_encoders.py b/test/test_bert.py
similarity index 98%
rename from rafale/models/debug_encoders.py
rename to test/test_bert.py
index ce5b11b..d36b984 100644
--- a/rafale/models/debug_encoders.py
+++ b/test/test_bert.py
@@ -1,3 +1,4 @@
+#!/usr/bin/env python
import os
from datapipe import WikiMLMPipe
@@ -122,3 +123,7 @@ def test_bert():
atol=5e-4,
rtol=5e-4,
)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/test_pythia.py b/test/test_pythia.py
new file mode 100644
index 0000000..cd1b422
--- /dev/null
+++ b/test/test_pythia.py
@@ -0,0 +1,255 @@
+import torch
+import numpy as np
+
+from safetensors import safe_open
+
+from rafale.models.decoder import DecoderWrapper
+from rafale.models.configurations import Pythia14MConfig, load_safetensors
+
+from transformers import AutoTokenizer, GPTNeoXForCausalLM
+
+
+def test_layer_and_outputs(rafale_model, hf_model, tokenizer, layer=0, tol=1e-05):
+ """
+ # @NOTE :: this currently on evaluates with KV cache enabled, write the test to run this function without the KV cache
+ """
+ hf_activation = {}
+ hf_input_activation = {}
+ rafale_activation = {}
+ rafale_input_activation = {}
+
+ # tuple of shape num_layers, 2 (keys, values), tensor BHLd
+ # make a fake kv-cache of length 4
+ kv_cache = []
+ n_layers = 6
+ cache_len = 4
+ n_heads = 4
+ head_dim = 32
+ for i in range(n_layers):
+ k = torch.randn(1, n_heads, cache_len, 32)
+ v = torch.randn(1, n_heads, cache_len, 32)
+ kv_cache.append((k, v))
+
+ def get_hf_activation(name):
+ def hook(model, input, output):
+ hf_activation[name] = output.detach()
+
+ return hook
+
+ def get_hf_input_activation(name):
+ def hook(model, _input, output):
+ hf_input_activation[name] = _input[0].detach()
+
+ return hook
+
+ def get_rafale_activation(name):
+ def hook(model, input, output):
+ rafale_activation[name] = output.detach()
+
+ return hook
+
+ def get_rafale_input_activation(name):
+ def hook(model, _input, output):
+ rafale_input_activation[name] = _input[0].detach()
+
+ return hook
+
+ # embeddings
+ rafale_model.token_embeddings.register_forward_hook(
+ get_rafale_activation("input_embeddings")
+ )
+ hf_model.gpt_neox.embed_in.register_forward_hook(
+ get_hf_activation("input_embeddings")
+ )
+
+ # input layernorm
+ rafale_model.layers[layer].attention_norm.register_forward_hook(
+ get_rafale_activation("attn_norm")
+ )
+ hf_model.gpt_neox.layers[layer].input_layernorm.register_forward_hook(
+ get_hf_activation("attn_norm")
+ )
+
+ rafale_model.layers[layer].attention_norm.register_forward_hook(
+ get_rafale_input_activation("input_attn_norm")
+ )
+ hf_model.gpt_neox.layers[layer].input_layernorm.register_forward_hook(
+ get_hf_input_activation("input_attn_norm")
+ )
+
+ # attention projection query_key_values (pre RoPE)
+ rafale_model.layers[layer].attention.query_key_value.register_forward_hook(
+ get_rafale_activation("attn_inproj")
+ )
+
+ hf_model.gpt_neox.layers[layer].attention.query_key_value.register_forward_hook(
+ get_hf_activation("attn_inproj")
+ )
+
+ # out proj
+ rafale_model.layers[layer].attention.dense.register_forward_hook(
+ get_rafale_activation("attn_dense")
+ )
+
+ hf_model.gpt_neox.layers[layer].attention.dense.register_forward_hook(
+ get_hf_activation("attn_dense")
+ )
+
+ # INPUT check before attention dense* (if this fails then RoPE is probably the problem...)
+ rafale_model.layers[layer].attention.dense.register_forward_hook(
+ get_rafale_input_activation("attn_dense")
+ )
+ hf_model.gpt_neox.layers[layer].attention.dense.register_forward_hook(
+ get_hf_input_activation("attn_dense")
+ )
+
+ # feed forward out
+ rafale_model.layers[layer].feed_forward.ff_out.register_forward_hook(
+ get_rafale_activation("ffout")
+ )
+ hf_model.gpt_neox.layers[layer].mlp.dense_4h_to_h.register_forward_hook(
+ get_hf_activation("ffout")
+ )
+
+ # hf_model.gpt_neox.layers[0].attention(tensor)
+
+ input_str = "Hello World from pythia!"
+ tokens = tokenizer(input_str, return_tensors="pt")
+
+ hf_model.eval()
+ rafale_model.eval()
+
+ with torch.no_grad():
+ # hf_out = hf_model(tokens["input_ids"])["logits"].detach().numpy()
+
+ hf_out = hf_model(tokens["input_ids"], use_cache=True, past_key_values=kv_cache)
+ hf_out = hf_out["logits"].detach().numpy()
+
+ # rafale_out = rafale_model(tokens)[0].detach().numpy()
+
+ rafale_out = rafale_model(tokens, past_kv_cache=kv_cache)[0].detach().numpy()
+
+ print(f"Dropout p should be 0: {rafale_model.layers[layer].attention.dropout_p}")
+ print(f"Testing layer {layer}")
+ # EMBEDDING ###################################################################
+ try:
+ np.testing.assert_allclose(
+ rafale_activation["input_embeddings"].numpy(),
+ hf_activation["input_embeddings"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+
+ print(f"✅ embeddings OK!")
+ except:
+ print("⚠️ Embedding difference!")
+
+ # PRE-ATTENTION NORM ######################################################
+ try:
+ np.testing.assert_allclose(
+ rafale_input_activation["input_attn_norm"].numpy(),
+ hf_input_activation["input_attn_norm"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+ print(f"✅ INPUTS of pre-attention norm OK!")
+ except:
+ print("⚠️ INPUTS pre-attention norm difference")
+
+ try:
+ np.testing.assert_allclose(
+ rafale_activation["attn_norm"].numpy(),
+ hf_activation["attn_norm"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+ print(f"✅ pre-attention norm OK!")
+ except:
+ print("⚠️ pre-attention norm difference")
+
+ # LINEAR PROJECTION FOR ATTENTION ######################################################
+
+ try:
+ np.testing.assert_allclose(
+ rafale_activation["attn_inproj"].numpy(),
+ hf_activation["attn_inproj"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+ print(f"✅ attention in-projection OK!")
+
+ except:
+ print(f"⚠️ attention in-projection difference")
+
+ # INPUTS OF ATTENTION DENSE LAYER ########################################
+ try:
+ np.testing.assert_allclose(
+ rafale_input_activation["attn_dense"].numpy(),
+ hf_input_activation["attn_dense"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+ print(f"✅ inputs of attention dense OK")
+
+ except:
+ r = rafale_input_activation["attn_dense"].numpy()
+ h = hf_input_activation["attn_dense"].numpy()
+
+ print(r.shape)
+ print(h.shape)
+ print("⚠️ inputs of attention dense difference")
+
+ np.testing.assert_allclose(
+ rafale_input_activation["attn_dense"].numpy(),
+ hf_input_activation["attn_dense"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+
+ # OUTPUT OF ATTENTION DENSE LAYER ########################################
+ try:
+ np.testing.assert_allclose(
+ rafale_activation["attn_dense"].numpy(),
+ hf_activation["attn_dense"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+ print(f"✅ attention out dense OK!")
+ except:
+ print("⚠️ attention out dense difference")
+
+ try:
+ np.testing.assert_allclose(
+ rafale_activation["ffout"].numpy(),
+ hf_activation["ffout"].numpy(),
+ rtol=tol,
+ atol=tol,
+ )
+ print(f"✅ feedforward out dense OK!")
+ except:
+ print("⚠️ ff out dense difference")
+
+ # Final model output ########################################
+ try:
+ np.testing.assert_allclose(rafale_out, hf_out, rtol=tol, atol=tol)
+ print(f"🎉 Model outputs match reference implementation!")
+ except:
+ print("❌ Model outputs do not match")
+
+
+def main():
+ """ """
+ torch.manual_seed(0)
+ np.random.seed(0)
+
+ hf_pythia = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-14m")
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m")
+
+ rafale_pythia = DecoderWrapper(Pythia14MConfig) # check forward pass OK
+ rafale_pythia = load_safetensors(rafale_pythia, Pythia14MConfig)
+
+ test_layer_and_outputs(rafale_pythia, hf_pythia, tokenizer)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/test/test_pythia_generation.py b/test/test_pythia_generation.py
new file mode 100644
index 0000000..572b218
--- /dev/null
+++ b/test/test_pythia_generation.py
@@ -0,0 +1,37 @@
+import torch
+import numpy as np
+
+# from safetensors import safe_open
+# from tokenizers import Tokenizer
+
+# from rafale.datapipe import InferenceDatapipeline
+# from rafale.models.decoding_strategies import greedy_decode
+# from rafale.models.decoder import DecoderWrapper
+# from rafale.models.configurations import Pythia14MConfig, load_safetensors
+
+
+# Example usage
+# Initialize the rafale_pythia model
+rafale_pythia = DecoderWrapper(Pythia14MConfig)
+rafale_pythia = load_safetensors(rafale_pythia, Pythia14MConfig)
+ifdp = InferenceDatapipeline("EleutherAI/pythia-14m")
+test_str = "Once upon a time,"
+
+# Define input_ids (e.g., starting with a token)
+# input_ids = torch.tensor([[Pythia14MConfig.bos_token_id]]) # Shape: (1, 1)
+
+# Define maximum sequence length for generation
+max_length = 32
+
+# Generate sequence using greedy decoding
+generated_sequence = greedy_decode(
+ rafale_pythia,
+ ifdp(test_str),
+ max_length,
+ Pythia14MConfig.eos_token_id,
+ check_repeat_ngrams=True,
+)
+
+generated_str = ifdp.ids_to_str(generated_sequence["input_ids"])
+
+print("Generated sequence:", generated_str)