diff --git a/.gitignore b/.gitignore index c852262..c12b2d2 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ __pycache__/ *.py[cod] *$py.class +wandb/ *.DS_Store @@ -160,3 +161,6 @@ cython_debug/ # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ + +# uv +uv.lock \ No newline at end of file diff --git a/README.md b/README.md index 35b1c34..3b7ed79 100644 --- a/README.md +++ b/README.md @@ -1,61 +1,170 @@ +
+ # rafale -Rafale is (for now) a simple and opinionated transformer encoder training CLI. + -## Dependencies +
-Attempting to balance ergonomics and simplicity. This is meant to be easily hackable for research purposes. +## 💡Purpose -``` -torch, lightning-fabric (or) accelerate, datasets, rich (eyecandy) ~~tokenizers will be removed~~ -``` +Rafale provides an opinionated scaffolding for training transformers. It is solely built to be an efficient +learning/research tool. It is **not** a fully fledged library for large scale training. -@TODO :: (check out this stream on HF accelerate)[https://www.youtube.com/watch?v=X-Jx5-YskKY] +It should be thought of as a starting point for research projects to bootstrap experiments on small LMs. The best way to +use rafale is to simply fork it and build on top of it for your specific purposes. +### Core dependencies -## Purpose +Attempting to balance ergonomics and simplicity. This is meant to be easily hackable for research purposes. -This package is solely built to be an efficient research tool. It will not support data preprocessing/handling -pipelines. It should be thought of as a starting point for research projects to bootstrap experiments on small LMs. +``` +torch, composer, datasets, tokenizers +``` -Should be used pip installable via git and setup to be easily hackable to build on top of it. +## 🚀 Installation & Usage -Datasets should be preshuffled and pretokenized, only load it from disk and feed it to the dataloader with the collator -function. +Setup with ```uv``` ([install uv](https://github.com/astral-sh/uv)). +```sh +$ git clone +$ cd rafale +$ uv venv +$ . .venv/bin/activate +$ uv pip install -r cuda-requirements.txt (or cpu-requirements.txt) +$ uv pip install -e . +``` + +Launch a run with a configuration. -## Usage +```sh +$ python rafale/main -c test/pythia_tinystories.yaml +``` -Mostly optimized for SLURM clusters. +What if I just want to prepare my dataset? ```DATA=1``` will run the data preparation and caching pipeline without +launching the training run. ```sh +$ DATA=1 python rafale/main -c test/pythia_tinystories.yaml +``` -rafale run -c config.yaml # set DEBUG=1 for a sanity check +What if I want to test my model to make sure that its learning? ```DEBUG=1``` will run 10 epochs on a single training +batch (same for train/eval), the model should fit quickly if there are no bugs in the implementation. +```sh +$ DEBUG=1 python rafale/main -c test/pythia_tinystories.yaml ``` -## Roadmap - -v0.1 -- [ ] Local model weight loading -- [ ] load weights from safetensors and include it in the config (BERT and GPT2) -- [ ] integration with lighteval (?) -- [ ] Logging/Progress/Debug outputs with Rich library -- ~~RoBERTa BPE tokenizer with TikToken (compare w/ HF), on the fly tokenization to be handled by dataloader's - collator (for MLM)~~ - - ~~model will be tied to the tokenizer, so dataloader will be defined after the model and use it's tokenizer~~ - - We don't want anything to do with preprocessing, all data should be split/augmented/shuffled/tokenized/etc. All we - do with this tool is load it from disk, turn it to a tensor and send it to the model -- [ ] Local dataloader -- [ ] ```debug``` single batch debug -- [ ] ```main.py``` handles both training and evaluation (together or separately) -- [-] BERT/RoBERTa support (MLM objective) - + [ ] move the testing in the notebook to a debug file in the modeling folder - + **layerwise decay** for fine-tuning (https://kozodoi.me/blog/20220329/discriminative-lr) - + optimizations : flash attn2, xformers layer_norm (triton) or RMSNorm, xformers fused_linear_layer - + RMSNorm -- [ ] simple trainer (see lightning-fabric simple trainer example and start from there) - + bf16/fp16, gradient clipping, and gradient accumulation - -v0.2 -- DeepSpeed ZeRO - - Streaming dataloader + +### 🔧 Under the hood + +The goal of rafale is to provide a single entry point for data preparation and training. You configure the model and +dataset. Then call the training job. + +When calling a run, first we run the datapipepline. If the dataset has already been processed (tokenized, padded, +chunked, etc.), it will be loaded from the cache (default location is ```~/.rafale_cache```. + +> [!NOTE] +> #### Adding a new model +> To add a new model, you need write a new configuration to ```rafale/models/configurations.py```, and add it's key to +> ```model_config_dict``` in ```rafale/main.py```. +> +> Look at the ```ComposerLM``` wrapper class in ```rafale/models/decoder.py``` to check if all your building blocks are +> there. Otherwise you may need to modify/write a new wrapper. +> +> #### Adding a new datapipeline +> +> If the dataset is hosted on huggingface, simply use git lfs to clone the repo locally or use the repo name as the +> dataset path. Same goes for tokenizers since we use their tokenizer implementation. +> +> You will need to add a new datapipeline class in ```rafale/datapipes.py``` where the ```_prepare``` method all data +> preprocessing (tokenization, chunking, truncation, etc.) **EXCEPT** padding. Padding will be performed by the datacollator. + +### 📕 Docs + +Append this file ```llm-docprompt.txt``` to your favorite LLM and ask away. + +### 🦾 Supported models + + +| Name | Implemented | Inference test | Training test | +|:------------|:------------|:---------------|:--------------| +| BERT | ✅ | | | +| RoBERTa | ✅ | | | +| Pythia | ✅ | ✅ | ✅ | +| CLIP/SigLIP | ⏳ | | | + + +## 🔮 Roadmap + +
+ v0.1 + + +### v0.1 - initial release +- [x] single entrypoint CLI +- [ ] simple deploy/build + - [x] CPU macos build - Ok, uv run works with this + - [x] local linux machine - for now uv for venv + requirements.txt + - [ ] SLURM compute-canada - TBD + - NOTE: because uv still does not fully play well with pytorch recommend semi-manual setup* +- [ ] load weights from safetensors and include it in the config (BERT/RoBERTa and Pythia) + - [x] pythia + - [ ] BERT/RoBERTa (need to move from HF to safetensors) + - [ ] MLM + - [ ] Classification +- [x] Pythia KV-cache implementation +- [x] greedy generation +- [ ] datapipes for CLM and MLM + - local dataloader for now + - [x] CLM tinystories + - [ ] MLM tinystories + - [ ] Imdb classification +- [x] ```main.py``` handles both training and evaluation (together or separately) +- [x] Mosaic Composer/Trainer + + [x] fp16 + + [x] gradient clipping + + [x] gradient accumulation (automatically handled by composer) + + [x] building blocks are nn.Modules, specific models are ComposerModel classes with methods to load safetensor weights + automatically (keep in a single separate file for each model) + + [x] set DEBUG=1 for 1 batch sanity check before launching a run + +Datapipelines +1. [x] tokenize +2. [x] concat and split w/ block size (pad w/ collator) +3. [x] save to disk {source}_{tokname}_bs{int}_len{int} +4. [x] data_collator: *next* pad (if desired), label shift right and return torch tensor # HF: does this in the model... +5. [x] test with model training +6. [ ] tiny stories but for MLM also +
+ +
+ v1.0 + +### path to v1.0 +cleanup and additional features +- [ ] clean up ```tests``` for pythia and bert models on tinystories +- [ ] move the testing in the notebook to a debug file in the modeling folder +- [ ] optimizations : flash attn2, xformers layer_norm (triton) or RMSNorm, xformers fused_linear_layer +- [ ] try out schedulefree, SOAP, and other optimizers +- [ ] **layerwise decay** for fine-tuning (https://kozodoi.me/blog/20220329/discriminative-lr) +- [ ] multimodality CLIP +- [ ] integration with lm-eval-harness (guide)[https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage] + +
+ +## I am GPU-rich what do I use? + +For large scale experiments other frameworks/libraries exist: +- lingua (Facebookresearch) +- torchtitan (Pytorch) +- torchtune (Pytorch) +- litGPT (LightningAI) +- GPT-NeoX (EleutherAI) +- nanotron (Huggingface) +- llm-foundry (MosaicML) diff --git a/cpu-requirements.txt b/cpu-requirements.txt new file mode 100644 index 0000000..e69de29 diff --git a/cuda-requirements.txt b/cuda-requirements.txt new file mode 100644 index 0000000..3338534 --- /dev/null +++ b/cuda-requirements.txt @@ -0,0 +1,98 @@ +aiohappyeyeballs==2.4.3 +aiohttp==3.10.10 +aiosignal==1.3.1 +anyio==4.6.2.post1 +argcomplete==3.5.1 +arrow==1.3.0 +attrs==24.2.0 +backoff==2.2.1 +certifi==2024.8.30 +charset-normalizer==3.4.0 +click==8.1.7 +composer==0.26.0 +coolname==2.2.0 +datasets==3.0.2 +dill==0.3.8 +docker-pycreds==0.4.0 +filelock==3.16.1 +frozenlist==1.5.0 +fsspec==2024.9.0 +gitdb==4.0.11 +gitpython==3.1.43 +gql==3.5.0 +graphql-core==3.2.5 +huggingface-hub==0.26.2 +idna==3.10 +importlib-metadata==8.5.0 +jinja2==3.1.4 +lightning-utilities==0.11.8 +markdown-it-py==3.0.0 +markupsafe==3.0.2 +mdurl==0.1.2 +mosaicml-cli==0.6.42 +mpmath==1.3.0 +multidict==6.1.0 +multiprocess==0.70.16 +networkx==3.4.2 +numpy==2.1.2 +nvidia-cublas-cu12==12.1.3.1 +nvidia-cuda-cupti-cu12==12.1.105 +nvidia-cuda-nvrtc-cu12==12.1.105 +nvidia-cuda-runtime-cu12==12.1.105 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.0.2.54 +nvidia-curand-cu12==10.3.2.106 +nvidia-cusolver-cu12==11.4.5.107 +nvidia-cusparse-cu12==12.1.0.106 +nvidia-nccl-cu12==2.20.5 +nvidia-nvjitlink-cu12==12.6.77 +nvidia-nvtx-cu12==12.1.105 +packaging==24.1 +pandas==2.2.3 +pillow==10.4.0 +platformdirs==4.3.6 +prompt-toolkit==3.0.48 +propcache==0.2.0 +protobuf==5.28.3 +psutil==6.1.0 +py-cpuinfo==9.0.0 +pyarrow==18.0.0 +pygments==2.18.0 +python-dateutil==2.9.0.post0 +pytorch-ranger==0.1.1 +pytz==2024.2 +pyyaml==6.0.2 +questionary==1.10.0 +-e file:///home/max/code/rafale +requests==2.32.3 +rich==13.9.3 +ruamel-yaml==0.18.6 +ruamel-yaml-clib==0.2.12 +safetensors==0.4.5 +sentry-sdk==2.17.0 +setproctitle==1.3.3 +setuptools==75.3.0 +six==1.16.0 +smmap==5.0.1 +sniffio==1.3.1 +sympy==1.13.3 +tabulate==0.9.0 +termcolor==2.5.0 +tokenizers==0.20.1 +torch==2.4.0 +torch-optimizer==0.3.0 +torchmetrics==1.4.0.post0 +torchvision==0.19.0 +tqdm==4.66.6 +triton==3.0.0 +types-python-dateutil==2.9.0.20241003 +typing-extensions==4.12.2 +tzdata==2024.2 +urllib3==2.2.3 +validators==0.34.0 +wandb==0.18.5 +wcwidth==0.2.13 +websockets==11.0.3 +xxhash==3.5.0 +yarl==1.17.1 +zipp==3.20.2 diff --git a/generate_llm_docprompt.sh b/generate_llm_docprompt.sh new file mode 100755 index 0000000..c07b362 --- /dev/null +++ b/generate_llm_docprompt.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# Output file +OUTPUT_FILE="llm-docprompt.txt" + +# Clear the output file if it exists +> "$OUTPUT_FILE" + +# Add repo structure, excluding unwanted directories like .venv +echo "### Repository Structure ###" >> "$OUTPUT_FILE" +tree . -I 'wandb|*__pycache__|media|*-requirements.txt|.venv' >> "$OUTPUT_FILE" +echo -e "\n\n" >> "$OUTPUT_FILE" + +# Include README.md content +if [[ -f "README.md" ]]; then + echo "### README.md ###" >> "$OUTPUT_FILE" + cat README.md >> "$OUTPUT_FILE" + echo -e "\n\n" >> "$OUTPUT_FILE" +fi + +# Function to include content of a given file type, excluding .venv directory +include_files() { + local pattern="$1" + local header="$2" + + find . -type f -name "$pattern" ! -path "./.venv/*" | while read -r file; do + echo "### $file ###" >> "$OUTPUT_FILE" + cat "$file" >> "$OUTPUT_FILE" + echo -e "\n\n" >> "$OUTPUT_FILE" + done +} + +# Include Python files, excluding those in .venv +include_files "*.py" "Python File" + +# Include YAML files only from the 'test' folder +find ./test -type f -name "*.yaml" | while read -r yaml_file; do + echo "### $yaml_file ###" >> "$OUTPUT_FILE" + cat "$yaml_file" >> "$OUTPUT_FILE" + echo -e "\n\n" >> "$OUTPUT_FILE" +done + +echo "Documentation prompt has been generated in $OUTPUT_FILE" diff --git a/llm-docprompt.txt b/llm-docprompt.txt new file mode 100644 index 0000000..b631d96 --- /dev/null +++ b/llm-docprompt.txt @@ -0,0 +1,2926 @@ +### Repository Structure ### +. +├── generate_llm_docprompt.sh +├── LICENSE +├── llm-docprompt.txt +├── pyproject.toml +├── rafale +│   ├── caches.py +│   ├── datapipe.py +│   ├── __init__.py +│   ├── main.py +│   └── models +│   ├── configurations.py +│   ├── convert_hf_weights.py +│   ├── decoder.py +│   ├── decoding_strategies.py +│   ├── encoder.py +│   ├── __init__.py +│   ├── model_utils.py +│   ├── roberta_config +│   │   └── tokenizer.json +│   └── roberta.py +├── README.md +└── test + ├── pythia_tinystories.yaml + ├── test_bert.py + ├── test_pythia_generation.py + ├── test_pythia.py + └── test.yaml + +5 directories, 23 files + + + +### README.md ### +
+ +# rafale + + + +
+ +## 💡Purpose + +Rafale provides an opinionated scaffolding for training transformers. It is solely built to be an efficient +learning/research tool. It is **not** a fully fledged library for large scale training. + +It should be thought of as a starting point for research projects to bootstrap experiments on small LMs. The best way to +use rafale is to simply fork it and build on top of it for your specific purposes. + +### Core dependencies + +Attempting to balance ergonomics and simplicity. This is meant to be easily hackable for research purposes. + +``` +torch, composer, datasets, tokenizers +``` + +## 🚀 Installation & Usage + +Setup with ```uv``` ([install uv](https://github.com/astral-sh/uv)). +```sh +$ git clone +$ cd rafale +$ uv venv +$ . .venv/bin/activate +$ uv pip install -r cuda-requirements.txt (or cpu-requirements.txt) +$ uv pip install -e . +``` + +Launch a run with a configuration. + +```sh +$ python rafale/main -c test/pythia_tinystories.yaml +``` + +What if I just want to prepare my dataset? ```DATA=1``` will run the data preparation and caching pipeline without +launching the training run. + +```sh +$ DATA=1 python rafale/main -c test/pythia_tinystories.yaml +``` + +What if I want to test my model to make sure that its learning? ```DEBUG=1``` will run 10 epochs on a single training +batch (same for train/eval), the model should fit quickly if there are no bugs in the implementation. + +```sh +$ DEBUG=1 python rafale/main -c test/pythia_tinystories.yaml +``` + + +### 🔧 Under the hood + +The goal of rafale is to provide a single entry point for data preparation and training. You configure the model and +dataset. Then call the training job. + +When calling a run, first we run the datapipepline. If the dataset has already been processed (tokenized, padded, +chunked, etc.), it will be loaded from the cache (default location is ```~/.rafale_cache```. + +> [!NOTE] +> #### Adding a new model +> To add a new model, you need write a new configuration to ```rafale/models/configurations.py```, and add it's key to +> ```model_config_dict``` in ```rafale/main.py```. +> +> Look at the ```ComposerLM``` wrapper class in ```rafale/models/decoder.py``` to check if all your building blocks are +> there. Otherwise you may need to modify/write a new wrapper. +> +> #### Adding a new datapipeline +> +> If the dataset is hosted on huggingface, simply use git lfs to clone the repo locally or use the repo name as the +> dataset path. Same goes for tokenizers since we use their tokenizer implementation. +> +> You will need to add a new datapipeline class in ```rafale/datapipes.py``` where the ```_prepare``` method all data +> preprocessing (tokenization, chunking, truncation, etc.) **EXCEPT** padding. Padding will be performed by the datacollator. + +### 📕 Docs + +Append this file ```llm-docprompt.txt``` to your favorite LLM and ask away. + +### 🦾 Supported models + + +| Name | Implemented | Inference test | Training test | +|:------------|:------------|:---------------|:--------------| +| BERT | ✅ | | | +| RoBERTa | ✅ | | | +| Pythia | ✅ | ✅ | ✅ | +| CLIP/SigLIP | ⏳ | | | + + +## 🔮 Roadmap + +
+ v0.1 + + +### v0.1 - initial release +- [x] single entrypoint CLI +- [ ] simple deploy/build + - [x] CPU macos build - Ok, uv run works with this + - [x] local linux machine - for now uv for venv + requirements.txt + - [ ] SLURM compute-canada - TBD + - NOTE: because uv still does not fully play well with pytorch recommend semi-manual setup* +- [ ] load weights from safetensors and include it in the config (BERT/RoBERTa and Pythia) + - [x] pythia + - [ ] BERT/RoBERTa (need to move from HF to safetensors) + - [ ] MLM + - [ ] Classification +- [x] Pythia KV-cache implementation +- [x] greedy generation +- [ ] datapipes for CLM and MLM + - local dataloader for now + - [x] CLM tinystories + - [ ] MLM tinystories + - [ ] Imdb classification +- [x] ```main.py``` handles both training and evaluation (together or separately) +- [x] Mosaic Composer/Trainer + + [x] fp16 + + [x] gradient clipping + + [x] gradient accumulation (automatically handled by composer) + + [x] building blocks are nn.Modules, specific models are ComposerModel classes with methods to load safetensor weights + automatically (keep in a single separate file for each model) + + [x] set DEBUG=1 for 1 batch sanity check before launching a run + +Datapipelines +1. [x] tokenize +2. [x] concat and split w/ block size (pad w/ collator) +3. [x] save to disk {source}_{tokname}_bs{int}_len{int} +4. [x] data_collator: *next* pad (if desired), label shift right and return torch tensor # HF: does this in the model... +5. [x] test with model training +6. [ ] tiny stories but for MLM also +
+ +
+ v1.0 + +### path to v1.0 +cleanup and additional features +- [ ] clean up ```tests``` for pythia and bert models on tinystories +- [ ] move the testing in the notebook to a debug file in the modeling folder +- [ ] optimizations : flash attn2, xformers layer_norm (triton) or RMSNorm, xformers fused_linear_layer +- [ ] try out schedulefree, SOAP, and other optimizers +- [ ] **layerwise decay** for fine-tuning (https://kozodoi.me/blog/20220329/discriminative-lr) +- [ ] multimodality CLIP +- [ ] integration with lm-eval-harness (guide)[https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage] + +
+ +## I am GPU-rich what do I use? + +For large scale experiments other frameworks/libraries exist: +- lingua (Facebookresearch) +- torchtitan (Pytorch) +- torchtune (Pytorch) +- litGPT (LightningAI) +- GPT-NeoX (EleutherAI) +- nanotron (Huggingface) +- llm-foundry (MosaicML) + + + +### ./rafale/models/decoder.py ### +#!/usr/bin/env python +from typing import Optional +import torch + +from torch import nn +from torch import Tensor + +import torch.nn.functional as F + +from torch.nn.functional import scaled_dot_product_attention + +from torchmetrics import Metric +from torchmetrics.collections import MetricCollection + +from composer.metrics import LossMetric, LanguagePerplexity +from composer.models import ComposerModel + + +############################################################################### +# simple implementation of GPT building +class NeoXRoPE(nn.Module): + @classmethod + def precompute_sin_cos_cache(cls, dim=None, seq_len=None, base=10000, device=None): + """Computes the cos and sin angles to be applied to the token vectors. + + We begin by computing thetas (freqs) across each dimension pair (P=D/2) for the whole sequence length (L). + Then we convert this matrix of shape LP into complex numbers of the same shape. + Finally the real and imaginary parts of these complex numbers are stored in a stacked matrix and returned. + + Args: + dim (int): number of features dimension per token to apply rotations to (d*rotary_pct) + seq_len (int): sequence length of the input (use the maximum sequence length) + base (int): default 10000 + + Returns: + Tensor # of shape [1,1,L,R] + + Tensor dimension names: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + - R rotary dimensions (d*rotary_pct) + """ + + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim) + ) + t = torch.arange(seq_len, dtype=torch.int64).type_as(inv_freq) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos_cached = emb.cos()[None, None, :, :] # shape is [1, 1, L, R] + sin_cached = emb.sin()[None, None, :, :] + + return cos_cached, sin_cached + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=x1.ndim - 1) + + @classmethod + def apply_rotary_pos_emb(cls, q_BHLR, k_BHLR, cos, sin): + """Applies the rotation to the input queries and key features.""" + + tensor_device = q_BHLR.get_device() + + cos = cos.to(device=tensor_device) + sin = sin.to(device=tensor_device) + + return (q_BHLR * cos) + (cls.rotate_half(q_BHLR) * sin), (k_BHLR * cos) + ( + cls.rotate_half(k_BHLR) * sin + ) + + @classmethod + def apply_rotary_pos_emb_offset(cls, q, k, cos, sin, offset: int = 0): + """ + q and k are shape: BHLR + cos, sin are shape: 11LR + """ + cos, sin = ( + cos[:, :, offset : q.shape[2] + offset, :], + sin[:, :, offset : q.shape[2] + offset, :], + ) + + tensor_device = q.get_device() + + cos = cos.to(device=tensor_device) + sin = sin.to(device=tensor_device) + + return (q * cos) + (cls.rotate_half(q) * sin), (k * cos) + ( + cls.rotate_half(k) * sin + ) + + +class DecoderEmbedding(nn.Module): + """Simply an input projection of the tokens here, since rotary position encodings are used. + + Tensor dimension names: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + """ + + def __init__(self, config): + super().__init__() + + self.input_embeddings = nn.Embedding( + num_embeddings=config.vocab_size, + embedding_dim=config.embed_dim, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, x_BL): + x_BLD = self.input_embeddings(x_BL) + return self.dropout(x_BLD) + + +class DecoderAttentionRotary(nn.Module): + """ + Attention with rotary position embedding. + + Tensor dimension names: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + + """ + + def __init__(self, config): + super().__init__() + + self.head_dim = config.embed_dim // config.num_heads + self.num_heads = config.num_heads + self.embed_dim = config.embed_dim + self.rotary_ndims = int(self.head_dim * config.rotary_pct) + + self.attention_bias = True # @TODO: set bias to True or False from config. + self.query_key_value = nn.Linear(config.embed_dim, 3 * config.embed_dim) + self.dense = nn.Linear(config.embed_dim, config.embed_dim) + self.dropout = nn.Dropout(p=config.attention_dropout) + self.dropout_p = config.attention_dropout + self.norm_factor = self.head_dim**-0.5 + + def _split_heads(self, tensor: Tensor): + """ + Splits hidden dim into attn_head_size and num_attention_heads + + # input tensor: [bs, seq_len, hidden_size] + # returns: [bs, num_attention_heads, seq_len, attn_head_size] + """ + batch_size = tensor.size(0) + + return ( + tensor.view(batch_size, -1, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def _merge_heads(self, tensor: Tensor): + """ + input tensor: [bs. num_attention_heads, seq_len, attn_head_size] + returns: [bs, seq_len, hidden_size] + """ + # tensor [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3).contiguous() + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view( + tensor.size(0), + tensor.size(1), + self.num_heads * self.head_dim, + ).contiguous() + # -> [bs, seq_len, hidden_size] + return tensor + + def forward(self, x_BLD, freqs_cis): + if not self.training: + self.dropout_p = 0 + + bsz, seq_len, _ = x_BLD.size() + + assert freqs_cis is not None + + # projections + qkv = self.query_key_value(x_BLD) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + q_BHLd = qkv[..., : self.head_dim].permute(0, 2, 1, 3) + k_BHLd = qkv[..., self.head_dim : 2 * self.head_dim].permute(0, 2, 1, 3) + v_BHLd = qkv[..., 2 * self.head_dim :].permute(0, 2, 1, 3) + + # Slice the precomputed freqs_cis based on actual seq_len --> [1, 1, seq_len, R] + cos = freqs_cis[0][:, :, :seq_len, :] + sin = freqs_cis[1][:, :, :seq_len, :] + + q_rot = q_BHLd[..., : self.rotary_ndims] + q_pass = q_BHLd[..., self.rotary_ndims :] + k_rot = k_BHLd[..., : self.rotary_ndims] + k_pass = k_BHLd[..., self.rotary_ndims :] + + q_rot, k_rot = NeoXRoPE.apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + + q_BHLd = torch.cat((q_rot, q_pass), dim=-1) + k_BHLd = torch.cat((k_rot, k_pass), dim=-1) + + # compute attention + attn_out_BHLd = scaled_dot_product_attention( + q_BHLd, + k_BHLd, + v_BHLd, + is_causal=True, + scale=self.norm_factor, + dropout_p=self.dropout_p, + ) + + attn_out_BLD = self._merge_heads(attn_out_BHLd) + + attn_out_BLD = self.dense(attn_out_BLD) + + return attn_out_BLD + + +class DecoderAttentionRotaryKVCache(DecoderAttentionRotary): + """implements the KV cache mechanism""" + + def __init__(self, config): + super().__init__(config) + + def forward(self, x_BLD, freqs_cis, causal_mask=None, past_kv=None): + # A) figure out if we passed a cached KV + assert freqs_cis is not None + has_past_kv = past_kv is not None and past_kv[0].numel() > 0 + + if not self.training: + self.dropout_p = 0 + + bsz, seq_len, _ = x_BLD.size() + # B) if we have a cached KV, apply the offset to the sequence length for RoPE + if has_past_kv: + offset = past_kv[0].shape[ + 2 + ] # we want the lenght here, our kv shape is BHLd + seq_len += offset + else: + offset = 0 + + # projections + qkv = self.query_key_value(x_BLD) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + q_BHLd = qkv[..., : self.head_dim].permute(0, 2, 1, 3) + k_BHLd = qkv[..., self.head_dim : 2 * self.head_dim].permute(0, 2, 1, 3) + v_BHLd = qkv[..., 2 * self.head_dim :].permute(0, 2, 1, 3) + + # Slice the precomputed freqs_cis based on actual seq_len --> [1, 1, seq_len, R] + cos = freqs_cis[0][:, :, :seq_len, :] + sin = freqs_cis[1][:, :, :seq_len, :] + + q_rot = q_BHLd[..., : self.rotary_ndims] + q_pass = q_BHLd[..., self.rotary_ndims :] + k_rot = k_BHLd[..., : self.rotary_ndims] + k_pass = k_BHLd[..., self.rotary_ndims :] + + q_rot, k_rot = NeoXRoPE.apply_rotary_pos_emb_offset( + q_rot, k_rot, cos, sin, offset=offset + ) + + q_BHLd = torch.cat((q_rot, q_pass), dim=-1) + k_BHLd = torch.cat((k_rot, k_pass), dim=-1) + + # C) before scaled_dot_product_attention we are going to + # Cache QKV values + if has_past_kv: + past_key, past_value = past_kv + + k_BHLd = torch.cat((past_key.type_as(k_BHLd), k_BHLd), dim=2) + v_BHLd = torch.cat((past_value.type_as(v_BHLd), v_BHLd), dim=2) + + # kv_cache = torch.stack((k_BHLd, v_BHLd)) + kv_cache = (k_BHLd, v_BHLd) # let's keep them as a list of tuples + # #################################################################### + + # print(f"key device {k_BHLd.get_device()}") + # print(f"query device {q_BHLd.get_device()}") + # print(f"value device {v_BHLd.get_device()}") + # print(f"causal mask device {causal_mask.get_device()}") + + tensor_device = q_BHLd.get_device() + if causal_mask.get_device() != tensor_device: + causal_mask = causal_mask.to(device=tensor_device) + + # compute attention here + attn_out_BHLd = scaled_dot_product_attention( + q_BHLd, + k_BHLd, + v_BHLd, + is_causal=False, + attn_mask=causal_mask, + scale=self.norm_factor, + dropout_p=self.dropout_p, + ) # even with rectangular matrices scaled_dot_product_attention will handle the causal mask by apply left bias + # causal mask which is exactly what we need. + + attn_out_BLD = self._merge_heads(attn_out_BHLd) + + attn_out_BLD = self.dense(attn_out_BLD) + + return attn_out_BLD, kv_cache + + +# ^^^^^ ####################################################################### + + +class DecoderFeedForward(nn.Module): + """ + Tensor dimension names: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + + """ + + def __init__(self, config): + super().__init__() + self.ff_in = nn.Linear(config.embed_dim, config.ff_dim) + self.gelu = nn.GELU() + self.ff_out = nn.Linear(config.ff_dim, config.embed_dim) + + def forward(self, x_BLD): + x_BLF = self.ff_in(x_BLD) + x_BLF = self.gelu(x_BLF) + out_BLD = self.ff_out(x_BLF) + return out_BLD + + +class DecoderBlock(nn.Module): + """A single trasnformer decoder block/layer. + + Tensor dimension names: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + """ + + def __init__(self, config): + super().__init__() + self.attention = DecoderAttentionRotary(config) + self.feed_forward = DecoderFeedForward(config) + self.ffn_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) + self.attention_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) + + def forward(self, x_BLD, freqs_cis, parallel_residual=True, use_cache=True): + assert freqs_cis is not None + + if parallel_residual: + out_BLD = ( + x_BLD + + self.attention(self.attention_norm(x_BLD), freqs_cis) + + self.feed_forward(self.ffn_norm(x_BLD)) + ) + else: + h_BLD = x_BLD + self.attention(self.attention_norm(x_BLD), freqs_cis) + out_BLD = h_BLD + self.feed_forward(self.ffn_norm(h_BLD)) + + return out_BLD + + +# handle KV cache state +class DecoderBlockKVcache(DecoderBlock): + """A single trasnformer decoder block/layer. + + Tensor dimension names: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + """ + + def __init__(self, config): + super().__init__(config) + self.attention = DecoderAttentionRotaryKVCache(config) + self.feed_forward = DecoderFeedForward(config) + self.ffn_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) + self.attention_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) + + def forward( + self, x_BLD, freqs_cis, causal_mask, layer_kv_cache, parallel_residual=True + ): + assert freqs_cis is not None + + if parallel_residual: + attn_out_BLD, layer_kv_cache = self.attention( + self.attention_norm(x_BLD), freqs_cis, causal_mask, layer_kv_cache + ) + out_BLD = x_BLD + attn_out_BLD + self.feed_forward(self.ffn_norm(x_BLD)) + + else: + attn_out_BLD, layer_kv_cache = self.attention( + self.attention_norm(x_BLD), freqs_cis, layer_kv_cache + ) + h_BLD = x_BLD + attn_out_BLD + out_BLD = h_BLD + self.feed_forward(self.ffn_norm(h_BLD)) + + return out_BLD, layer_kv_cache + + +class DecoderWrapper(nn.Module): + """Full model wrapper for causal language modelling.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.token_embeddings = DecoderEmbedding(config) + + if self.config.use_cache: + self.layers = nn.ModuleList( + DecoderBlockKVcache(config) for _ in range(config.num_blocks) + ) + else: + self.layers = nn.ModuleList( + DecoderBlock(config) for _ in range(config.num_blocks) + ) + + self.final_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) + self.output = nn.Linear(config.embed_dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + + self.max_batch_size = -1 + self.max_seq_length = config.max_pos_embedding + + self.rotary_pct = 0.25 + + self.vocab_size = config.vocab_size + + def setup_caches(self): + head_dim = self.config.embed_dim // self.config.num_heads + dtype = self.output.weight.dtype + + head_size = self.config.embed_dim // self.config.num_heads + rotary_ndims = int(head_size * self.rotary_pct) + self.cos, self.sin = NeoXRoPE.precompute_sin_cos_cache( + dim=rotary_ndims, + seq_len=self.config.max_pos_embedding, + ) + + self.freqs_cis = (self.cos, self.sin) + + def _generate_causal_mask(self, cache_length: int, new_length: int) -> torch.Tensor: + """ + Creates a causal mask for autoregressive attention with KV caching. + + Args: + cache_length (int): Number of cached tokens (K). + new_length (int): Number of new tokens to generate (T). + device (torch.device): The device on which to create the mask. + dtype (torch.dtype): The data type of the mask tensor. + + Returns: + torch.Tensor: A (K + T) x (K + T) mask tensor where: + - Cached tokens can attend to all cached and new tokens. + - New tokens can attend to all cached tokens and up to their current position in new tokens. + - Future new tokens are masked to prevent attention. + """ + + causal_mask = torch.tril(torch.ones(new_length, new_length)).bool() + + if cache_length > 0: + cache_tokens_mask = torch.ones((new_length, cache_length)).bool() + causal_mask = torch.cat((cache_tokens_mask, causal_mask), dim=1) + + return causal_mask + + def forward(self, batch: Tensor, past_kv_cache=None): + # if self.freqs_cis is None or self.causal_mask is None: + if self.freqs_cis is None: + self.setup_caches() # Caches must be initialized first + + freqs_cis = self.freqs_cis + + idx = batch["input_ids"] + num_new_tokens = idx.size(1) + + x = self.token_embeddings(idx) + + if self.config.use_cache and past_kv_cache is None: + past_kv_cache = [None] * self.config.num_blocks + + if past_kv_cache[0] is not None: + num_cache_tokens = past_kv_cache[0][0].size(2) + else: + num_cache_tokens = 0 + causal_mask = self._generate_causal_mask(num_cache_tokens, num_new_tokens) + + kv_cache_list = [] + + for i, layer in enumerate(self.layers): + if self.config.use_cache: + x, layer_kv_cache = layer(x, freqs_cis, causal_mask, past_kv_cache[i]) + kv_cache_list.append(layer_kv_cache) + else: + x = layer(x, freqs_cis) + x = self.final_norm(x) + logits = self.output(x) + + return logits, kv_cache_list + + +class ComposerLM(ComposerModel): + """wrapper with nice properties for simple training and evaluation""" + + def __init__(self, config): + "docstring" + super().__init__() + self.model = DecoderWrapper(config) + self.ce_loss = nn.CrossEntropyLoss() + self.train_metrics = MetricCollection( + [LossMetric(self.ce_loss), LanguagePerplexity()] + ) + self.eval_metrics = MetricCollection( + [LossMetric(self.ce_loss), LanguagePerplexity()] + ) + + def forward(self, batch): # batch is the output of the dataloader + """batch is a dict with "input_ids" key, model also takes past_kv""" + # specify how batches are passed through the model + return self.model(batch) + + def eval_forward(self, batch, outputs=False): + if outputs: + if type(outputs) is tuple: + outputs, _ = outputs + return outputs + + outputs = self.model(batch) + if type(outputs) is tuple: + outputs, _ = outputs + + return outputs + + def update_metric(self, batch, outputs, metric) -> None: + targets = batch["labels"] + metric.update(outputs.view(-1, self.model.vocab_size), targets.view(-1)) + + def get_metrics(self, is_train=False) -> dict[str, Metric]: + # defines which metrics to use in each phase of training + return self.train_metrics if is_train else self.eval_metrics + + def loss(self, outputs, batch): + targets = batch["labels"] + + if type(outputs) is tuple: + outputs, _ = outputs + + return self.ce_loss(outputs.view(-1, self.model.vocab_size), targets.view(-1)) + + + +### ./rafale/models/convert_hf_weights.py ### +import argparse +import torch +import torch.nn.functional as F + +# from encoder import EncoderWrapper +from transformers import AutoTokenizer, AutoModelForMaskedLM + + +# helpers +def list_params(model): + for k, v in model.items(): + print(k) + + +def get_name_params(model): + all_params = {} + for name, param in model.named_parameters(): + all_params[name] = param + + return all_params + + +def convert_bert_params_dict(target, source): + conversion = [ + ("bert.embeddings", "embedding_layer"), + ("bert.encoder.layer", "blocks"), + ("attention.self", "attention.self_attn"), + ("attention.output.dense", "attention.out"), + ("attention.output.LayerNorm", "add_norm_1.ln"), + ("intermediate.dense", "ff.ff_in"), + ("output.dense", "ff.ff_out"), + ("output.LayerNorm", "add_norm_2.ln"), + ("cls.predictions.bias", "mlm_head.bias"), + ("cls.predictions.transform.dense", "mlm_head.dense"), + ("cls.predictions.transform.LayerNorm", "mlm_head.ln"), + ("cls.predictions.decoder", "mlm_head.decoder"), + ] + + source_parameters = source.state_dict() + + updated_parameters = {} + for k, v in source_parameters.items(): + for hf_term, my_term in conversion: + if hf_term in k: + k = k.replace(hf_term, my_term) + + updated_parameters[k] = v + + # return updated_parameters + # assert new_dict.keys() == target.keys(), was Ok but different for the state dict + + # here we transfer weights for all layers + target.load_state_dict(updated_parameters, strict=True) + + return target + + +def convert_pythia_params_dict(target, source): + """ + Source safetensors dict to our rafale model class. + """ + + # not needed for our implementation + unused = ["rotary_emb.inv_freq", "masked_bias", "attention.bias"] + for k, v in list(source.items()): + if True in [x in k for x in unused]: + del source[k] + + conversion = [ + ("gpt_neox.embed_in", "token_embeddings.input_embeddings"), + ("gpt_neox.layers", "layers"), + ("input_layernorm", "attention_norm"), + ("post_attention_layernorm", "ffn_norm"), + ("mlp", "feed_forward"), + ("dense_4h_to_h", "ff_out"), + ("dense_h_to_4h", "ff_in"), + ("embed_out", "output"), + ("gpt_neox.final_layer_norm", "final_norm"), + ] + + updated_parameters = {} + for k, v in source.items(): + for hf_term, my_term in conversion: + if hf_term in k: + k = k.replace(hf_term, my_term) + + updated_parameters[k] = v + + # here we transfer weights for all layers + target.load_state_dict(updated_parameters, strict=True) + + return target + + +def convert_roberta_params_dict(target, source): + conversion = [ + ("roberta.embeddings", "embedding_layer"), + ("roberta.encoder.layer", "blocks"), + ("attention.self", "attention.self_attn"), + ("attention.output.dense", "attention.out"), + ("attention.output.LayerNorm", "add_norm_1.ln"), + ("intermediate.dense", "ff.ff_in"), + ("output.dense", "ff.ff_out"), + ("output.LayerNorm", "add_norm_2.ln"), + ("lm_head.bias", "mlm_head.bias"), + ("lm_head.dense", "mlm_head.dense"), + ("lm_head.layer_norm", "mlm_head.ln"), + ("lm_head.decoder", "mlm_head.decoder"), + ] + + source_parameters = source.state_dict() + + updated_parameters = {} + for k, v in source_parameters.items(): + for hf_term, my_term in conversion: + if hf_term in k: + k = k.replace(hf_term, my_term) + + updated_parameters[k] = v + + # return updated_parameters + # assert new_dict.keys() == target.keys(), was Ok but different for the state dict + + # here we transfer weights for all layers + target.load_state_dict(updated_parameters, strict=True) + + return target + + +def test_conversion(): + """Convert the weights modify the dictionary for the blank and choice tokens, export the resulting tokenizer and + model checkpoints into assets + """ + torch.manual_seed(0) + + tokenizer = AutoTokenizer.from_pretrained( + "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" + ) + input_tokens = tokenizer("paris is the [MASK] of France.", return_tensors="pt") + hf_model = AutoModelForMaskedLM.from_pretrained( + "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" + ) + extra_tokens = [ + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + "", + ] + + # check if the tokens are already in the vocabulary + extra_tokens = set(extra_tokens) - set(tokenizer.vocab.keys()) + + # add the tokens to the tokenizer vocabulary + tokenizer.add_tokens(list(extra_tokens)) + + # add new, random embeddings for the new tokens + hf_model.resize_token_embeddings(len(tokenizer)) + + cfg = Config + slam = SlamEncoder(cfg) + slam = convert_bert_params_dict(slam, hf_model) + + def get_test_preds(model, sample, tokenizer, hf=False): + model.eval() + output = model(**sample) + + if hf: + output = output.logits + + probs = F.softmax(output, dim=-1) + tokens = torch.argmax(probs, dim=-1) + sequence = tokenizer.batch_decode(tokens) + print(tokens) + + return ( + output, + tokens, + probs, + sequence, + ) + + _, _, _, hf_tokens = get_test_preds(hf_model, input_tokens, tokenizer, hf=True) + + _, _, _, my_tokens = get_test_preds(slam, input_tokens, tokenizer) + # the output logits are different, however the output tokens predicted seem to be almost always the same + print(f"my implementation: {my_tokens}\n hf implementation: {hf_tokens}") + + +def add_new_tokens(model, tokenizer, token_list): + token_list = set(token_list) - set(tokenizer.vocab.keys()) + + # add the tokens to the tokenizer vocabulary + tokenizer.add_tokens(list(token_list)) + + # add new, random embeddings for the new tokens + model.resize_token_embeddings(len(tokenizer)) + + return model, tokenizer + + +def convert_and_save(dump_dir): + """Convert the weights modify the dictionary for the blank and choice tokens, export the resulting tokenizer and + model checkpoints into assets + """ + tokenizer = AutoTokenizer.from_pretrained( + "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" + ) + hf_model = AutoModelForMaskedLM.from_pretrained( + "microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext" + ) + + cfg = Config + slam = SlamEncoder(cfg) + slam = convert_bert_params_dict(slam, hf_model) + + # save model + tokenizer.save_pretrained(f"{dump_dir}/tokenizer/") + torch.save(slam.state_dict(), f"{dump_dir}/model.pt") + + +def main(): + parser = argparse.ArgumentParser( + description="Convert bert weights and tokenizer for Slamminnnnn" + ) + + # Adding arguments + parser.add_argument( + "--dump_dir", + required=True, + type=str, + help="Where to put the files", + ) + + # Parse arguments + args = parser.parse_args() + convert_and_save(args.dump_dir) + + +if __name__ == "__main__": + main() + + + +### ./rafale/models/decoding_strategies.py ### +import torch +import torch.nn.functional as F + + +def repeat_ngram(input_ids, ngram_max_length=4): + """ + Clean up output by checking for repeated n-grams of length less than ngram_max_length. + If it finds a repeated n-gram, it returns the n-gram; otherwise, it returns 0. + + Args: + input_ids: Tensor of shape (seq_length,), the sequence of input token IDs. + ngram_max_length: The maximum length of the n-gram to check for repetition. + + Returns: + A list of tokens representing the repeated n-gram if found, otherwise 0. + """ + ngram_found = False + ngram_tokens = None + + # Create a set to store seen n-grams + seen_ngrams = set() + + print(input_ids) + # Iterate through possible n-gram lengths + for n in range(1, ngram_max_length + 1): + for i in range(len(input_ids) - n + 1): + ngram = tuple(input_ids[i : i + n].tolist()) + if ngram in seen_ngrams: + ngram_found = True + ngram_tokens = list(ngram) + break + seen_ngrams.add(ngram) + if ngram_found: + break + + if ngram_found: + return ngram_tokens + else: + return 0 + + +def greedy_decode(model, batch, max_length, eos_token_id, check_repeat_ngrams=True): + """ + Implements greedy decoding for the rafale transformer model. + + Args: + model: The decoder model (e.g., DecoderWrapper). + batch: Dictionary containing input_ids of shape (batch_size, seq_length), the input prompt. + max_length: The maximum length of the generated sequence. + eos_token_id: The ID of the end-of-sequence token. + + Returns: + Dictionary containing input_ids of shape (batch_size, max_length) with the generated tokens. + """ + batch_size = batch["input_ids"].size(0) + if batch_size != 1: + raise ValueError( + "greedy_decode currently only supports batch_size=1. Provided batch_size: {}".format( + batch_size + ) + ) + + input_seq_len = batch["input_ids"].size(1) + kv_cache_list = None + + # Generate tokens until max_length or eos_token is generated + for _ in range(max_length - input_seq_len): + # Forward pass through the model + outputs, kv_cache_list = model(batch, kv_cache_list) + logits = outputs[:, -1, :] # Get the logits for the last generated token + + # Greedily select the token with the highest probability + next_token = torch.argmax(logits, dim=-1).unsqueeze(-1) # Shape: (1, 1) + + # Append the predicted token to the generated sequence + batch["input_ids"] = torch.cat((batch["input_ids"], next_token), dim=1) + + # Check for repeated n-grams and stop if detected + if check_repeat_ngrams: + repeated_ngram = repeat_ngram( + batch["input_ids"].squeeze(), ngram_max_length=4 + ) + if repeated_ngram != 0: + print(repeated_ngram) + break + + # Check if the sequence has generated the eos_token_id + if next_token.item() == eos_token_id: + break + + return batch + + + +### ./rafale/models/model_utils.py ### +def get_tokens_from_logits(logits, tokenizer=None): + """ + return the prediced tokens for all of the inputs + """ + # Apply softmax to convert logits to probabilities + probabilities = F.softmax(logits, dim=-1) + + # Get the predicted token IDs + predicted_token_ids = torch.argmax(probabilities, dim=-1) + + predicted_tokens = [ + tokenizer.convert_ids_to_tokens(seq.numpy()) + for seq in torch.unbind(predicted_token_ids, dim=0) + ] + return predicted_tokens + + + +### ./rafale/models/encoder.py ### +from dataclasses import dataclass + +import torch +import torch.nn.functional as F + +from torch import nn +from torch.nn.functional import scaled_dot_product_attention + + +############################################################################### +# SIMPLE BERT-like BUILDING BLOCKS # +############################################################################### + + +# @TODO :: Refactor, improve documentation and add tensor dimension keys for the names + + +class Embedding(nn.Module): + """Embeddings + + In addition to the word embedding, BERT uses learned absolute position embeddings. We also have token type embedding for one the BERT pretraining + objectives. + + Tensor dimension keys: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + """ + + def __init__( + self, + vocab_size=None, + hidden_size=None, + pad_token_id=None, + max_sequence_length=512, + num_token_type=None, # technically should be 2, in HF they use type_vocab_size + layer_norm_eps=None, + hidden_dropout_prob=None, + ): + super().__init__() + # nn.Embedding is just a lookup table, + self.word_embeddings = nn.Embedding( + num_embeddings=vocab_size, + embedding_dim=hidden_size, + padding_idx=pad_token_id, + ) + self.position_embeddings = nn.Embedding( + num_embeddings=max_sequence_length, + embedding_dim=hidden_size, + padding_idx=pad_token_id, # ROBERTA only? + ) + self.token_type_embeddings = nn.Embedding( + num_token_type, hidden_size + ) # NOTE :: these are actually the segment embeddings + # from the original BERT paper... maybe rename? + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.dropout = nn.Dropout(hidden_dropout_prob) + + # not considered as model parameter + self.register_buffer( + "position_ids", + torch.arange(max_sequence_length).expand((1, -1)), + persistent=False, + ) + + def forward(self, input_ids, token_type_ids): + seq_length = input_ids.size(1) + + position_ids = torch.index_select( + self.position_ids, 1, torch.arange(seq_length) + ) + position_ids = position_ids.expand_as(input_ids) + + # we assume absolute positional encoding here like in the original BERT and sum everything up + W = self.word_embeddings(input_ids) + P = self.position_embeddings(position_ids) + T = self.token_type_embeddings(token_type_ids) + + E = W + P + T + E = self.LayerNorm(E) + E = self.dropout(E) + + return E + + +class EncoderSelfAttention(nn.Module): + """Bidirectional multi-head self attention. + + Tensor dimension keys: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + """ + + def __init__(self, n_heads, embed_dim, dropout_p=0.1, fast_attn=False): + super().__init__() + self.dropout_p = dropout_p + self.fast_attn = fast_attn + assert embed_dim % n_heads == 0 + + # We assume d_v always equals d_k + self.head_dim = embed_dim // n_heads + self.n_heads = n_heads + self.embed_dim = embed_dim + self.all_head_size = n_heads * self.head_dim + + # get linear projections + self.query = nn.Linear(embed_dim, embed_dim) + self.key = nn.Linear(embed_dim, embed_dim) + self.value = nn.Linear(embed_dim, embed_dim) + + def forward(self, q, k, v): + """""" + batch_size = q.size(0) + if not self.training: + self.dropout_p = 0 + + # check transformation again here.... + q = ( + self.query(q) + .view(batch_size, -1, self.n_heads, self.head_dim) + .transpose(1, 2) + ) + k = ( + self.key(k) + .view(batch_size, -1, self.n_heads, self.head_dim) + .transpose(1, 2) + ) + v = ( + self.value(v) + .view(batch_size, -1, self.n_heads, self.head_dim) + .transpose(1, 2) + ) + + attn_output = scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.dropout_p, + ) + + # concatenate heads and put through final linear layer + attn_output = ( + attn_output.transpose(1, 2) + .contiguous() + .view(batch_size, -1, self.embed_dim) + ) + + return attn_output + + +class AttentionModule(nn.Module): + """the actual block with the output projections""" + + # output + def __init__(self, n_heads, embed_dim, dropout_p=None, fast_attn=False): + super().__init__() + self.self_attn = EncoderSelfAttention( + n_heads, embed_dim, dropout_p=dropout_p, fast_attn=fast_attn + ) + self.out = nn.Linear(embed_dim, embed_dim) + + def forward(self, x): + attn_output = self.self_attn(x, x, x) + out = self.out(attn_output) + return out + + +class FeedForward(nn.Module): + def __init__(self, embed_dim, ff_dim): + super().__init__() + self.ff_in = nn.Linear(embed_dim, ff_dim) + self.gelu = nn.GELU() + self.ff_out = nn.Linear(ff_dim, embed_dim) + + def forward(self, x): + x = self.ff_in(x) + x = self.gelu(x) + x = self.ff_out(x) + return x + + +class AddNorm(nn.Module): + def __init__(self, embed_dim, eps=None, dropout_p=None): + super().__init__() + self.ln = nn.LayerNorm(embed_dim, eps=eps) + self.dropout = nn.Dropout(dropout_p) + + def forward(self, x, residual): + x = self.dropout(x) # @TODO :: make sure this should be here... + x = self.ln(x + residual) + + return x + + +class EncoderBlock(nn.Module): + def __init__(self, embed_dim, n_heads, ff_dim, eps=None, dropout_p=None): + super().__init__() + + self.attention = AttentionModule( + n_heads=n_heads, embed_dim=embed_dim, dropout_p=dropout_p + ) + self.add_norm_1 = AddNorm(embed_dim, eps=eps, dropout_p=dropout_p) + self.ff = FeedForward(embed_dim, ff_dim=ff_dim) + self.add_norm_2 = AddNorm(embed_dim, eps=eps, dropout_p=dropout_p) + + def forward(self, x): + residual_1 = x + x = self.attention(x) + x = self.add_norm_1(x, residual_1) + + residual_2 = x + x = self.ff(x) + x = self.add_norm_2(x, residual_2) + + return x + + +class MLMHead(nn.Module): + def __init__(self, embed_dim, vocab_size, eps=None): + super().__init__() + self.dense = nn.Linear(embed_dim, embed_dim) + self.gelu = nn.GELU() + self.ln = nn.LayerNorm(embed_dim, eps=eps) + + self.decoder = nn.Linear(embed_dim, vocab_size, bias=False) + self.bias = nn.Parameter(torch.zeros(vocab_size)) + + # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` + self.decoder.bias = self.bias + + def forward(self, x): + x = self.dense(x) + x = self.gelu(x) + x = self.ln(x) + x = self.decoder(x) + + return x + + +class EncoderWrapper(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.embedding_layer = Embedding( + vocab_size=config.vocab_size, + hidden_size=config.embed_dim, + pad_token_id=config.pad_token_id, + max_sequence_length=config.max_pos_embedding, + num_token_type=config.num_token_type, # technically should be 2, in HF they use type_vocab_size + layer_norm_eps=config.layer_norm_eps, + hidden_dropout_prob=config.hidden_dropout, + ) + + self.blocks = nn.ModuleList() + for i in range(config.num_blocks): + self.blocks.append( + EncoderBlock( + config.embed_dim, + config.num_heads, + config.ff_dim, + eps=config.layer_norm_eps, + dropout_p=config.hidden_dropout, + ) + ) + + self.mlm_head = MLMHead( + embed_dim=config.embed_dim, + eps=config.layer_norm_eps, + vocab_size=config.vocab_size, + ) + + # Tie the weights + self.mlm_head.decoder.weight = self.embedding_layer.word_embeddings.weight + # @NOTE :: bias are tied too with the HF model + + # no bias for MLM head (?), let's keep it since the HF implementation keeps it as well + # self.mlm_head.mlm[-1].bias = None + def forward(self, **kwargs): + input_ids = kwargs["input_ids"] + token_type_ids = kwargs["token_type_ids"] + + x = self.embedding_layer(input_ids, token_type_ids) + # x = self.encoder_blocks(x) + for block in self.blocks: + x = block(x) + x = self.mlm_head(x) + + return x + + def compute_loss(self, logits, labels): + """ """ + ce_loss = nn.CrossEntropyLoss(ignore_index=-100) + + # Flatten the logits and labels + logits = logits.view( + -1, self.config.vocab_size + ) # Adjust vocab_size as per your config + labels = labels.view(-1) + + # Compute and return the loss + return ce_loss(logits, labels) + + + +### ./rafale/models/__init__.py ### + + + +### ./rafale/models/configurations.py ### +import requests +import os + +from dataclasses import dataclass + +from safetensors import safe_open + +from ..caches import MODEL_CACHE_DIR + +""" +to simplify model loading add a configuration for the pre-trained weight loading using safetensors instead of loading +the full model. +> then save to a folder named ".pretrained/" in this directory +""" + + +def download_file(url, save_path): + """ + Downloads a file from the specified URL to the given save path. + + :param url: The URL of the file to download. + :param save_path: The local path where the file will be saved. + """ + try: + # Make sure the directory exists + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + with requests.get(url, stream=True) as response: + response.raise_for_status() # Check for HTTP errors + total_size = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1 Kilobyte + progress = 0 + + with open(save_path, "wb") as file: + for data in response.iter_content(block_size): + file.write(data) + progress += len(data) + print( + f"Downloaded {progress} of {total_size} bytes ({(progress/total_size)*100:.2f}%)", + end="\r", + ) + + print(f"\nDownload completed successfully! File saved to: {save_path}") + + except requests.exceptions.HTTPError as http_err: + print(f"HTTP error occurred: {http_err}") # HTTP error + except Exception as err: + print(f"An error occurred: {err}") # Other errors + + +def load_safetensors(rafale_model, model_config): + """Transfer the pretrained safetensors to rafale model""" + tensors = {} + + safetensors_path = os.path.join(MODEL_CACHE_DIR, model_config.name + ".safetensors") + + if os.path.isfile(safetensors_path): + pass + else: + download_file(model_config.safetensors_url, safetensors_path) + + with safe_open(safetensors_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = f.get_tensor(k) + + rafale_model = model_config.convert_params_dict(rafale_model, tensors) + + return rafale_model + + +@dataclass +class BertConfig: + embed_dim: int = 768 + vocab_size: int = 30522 + attention_dropout: float = 0.1 + hidden_dropout: float = 0.1 + num_heads: int = 12 + ff_dim: int = 3072 + max_pos_embedding: int = 512 + layer_norm_eps: float = 1e-12 + num_blocks: int = 12 + pad_token_id: int = 0 + num_token_type: int = 2 + fast_attention: bool = ( + False # use xformers (todo: add FlashAttention2), NOT IMPLEMENTED* + ) + + +@dataclass +class BertTinyConfig: + embed_dim: int = 128 + vocab_size: int = 30522 # could usage would be to 30522 + num_extra_tokens + attention_dropout: float = 0.1 + hidden_dropout: float = 0.1 + num_heads: int = 2 + ff_dim: int = 512 + max_pos_embedding: int = 512 + layer_norm_eps: float = 1e-12 + num_blocks: int = 2 + pad_token_id: int = 0 + num_token_type: int = 2 + fast_attention: bool = False # use xformers (todo: add FlashAttention2) + + +@dataclass +class RobertaConfig: + embed_dim: int = 768 + vocab_size: int = 50265 + attention_dropout: float = 0.1 + hidden_dropout: float = 0.1 + num_heads: int = 12 + ff_dim: int = 3072 + max_pos_embedding: int = 514 + layer_norm_eps: float = 1e-05 + num_blocks: int = 12 + pad_token_id: int = 1 + num_token_type: int = 1 + bos_token_id: int = 0 + eos_token_id: int = 2 + fast_attention: bool = False + + +@dataclass +class Pythia14MConfig: + name: str = "pythia14m" + safetensors_url: str = ( + "https://huggingface.co/EleutherAI/pythia-14m/resolve/main/model.safetensors" + ) + + embed_dim: int = 128 + num_heads: int = 4 + ff_dim: int = 512 + hidden_act: str = "gelu" + max_pos_embedding: int = 2048 + vocab_size: int = 50304 + use_cache: bool = True + + parallel_residual: bool = True + + attention_dropout: float = 0.1 + hidden_dropout: float = 0.1 + + layer_norm_eps: float = 1e-05 + num_blocks: int = 6 + + bos_token_id: int = 0 + eos_token_id: int = 0 + fast_attention: bool = False + + rotary_emb_base: int = 10000 + rotary_pct: float = 0.25 + + tie_word_embeddings: bool = False + + @classmethod + def convert_params_dict(cls, target, source): + """ + Source safetensors dict to our rafale model class. + """ + # not needed for our implementation + unused = ["rotary_emb.inv_freq", "masked_bias", "attention.bias"] + for k, v in list(source.items()): + if True in [x in k for x in unused]: + del source[k] + + conversion = [ + ("gpt_neox.embed_in", "token_embeddings.input_embeddings"), + ("gpt_neox.layers", "layers"), + ("input_layernorm", "attention_norm"), + ("post_attention_layernorm", "ffn_norm"), + ("mlp", "feed_forward"), + ("dense_4h_to_h", "ff_out"), + ("dense_h_to_4h", "ff_in"), + ("embed_out", "output"), + ("gpt_neox.final_layer_norm", "final_norm"), + ] + + updated_parameters = {} + for k, v in source.items(): + for hf_term, my_term in conversion: + if hf_term in k: + k = k.replace(hf_term, my_term) + + updated_parameters[k] = v + + # here we transfer weights for all layers + target.load_state_dict(updated_parameters, strict=True) + + return target + + + +### ./rafale/models/roberta.py ### +import torch + +from encoder import EncoderWrapper +from dataclasses import dataclass + +# from composer.models import ComposerModel + + +@dataclass +class RobertaConfig: + embed_dim: int = 768 + vocab_size: int = 50265 + attention_dropout: float = 0.1 + hidden_dropout: float = 0.1 + num_heads: int = 12 + ff_dim: int = 3072 + max_pos_embedding: int = 514 + layer_norm_eps: float = 1e-05 + num_blocks: int = 12 + pad_token_id: int = 1 + num_token_type: int = 1 + bos_token_id: int = 0 + eos_token_id: int = 2 + fast_attention: bool = False # use xformers (todo: add FlashAttention2) + + +class RobertaMLM(EncoderWrapper): + def __init__(self, config): + super().__init__(config) + self.embedding_layer.forward = self.roberta_embedding_forward + # monkey patched forward method to fix the position_id embedding without changing the original encoder embedding class. + + def mlm_hook(self): + """TBD""" + None + + def roberta_embedding_forward(self, input_ids, token_type_ids): + position_ids = self.create_position_ids_from_input_ids(input_ids, 1) + + # we assume absolute positional encoding here like in the original BERT and sum everything up + W = self.embedding_layer.word_embeddings(input_ids) + P = self.embedding_layer.position_embeddings(position_ids) + T = self.embedding_layer.token_type_embeddings(token_type_ids) + + E = W + P + T + E = self.embedding_layer.LayerNorm(E) + E = self.embedding_layer.dropout(E) + + return E + + def create_position_ids_from_input_ids( + self, input_ids, padding_idx, past_key_values_length=0 + ): + """ + MAX NOTE: from the huggingface implementation, they use a different method to create the positon_ids in roberta than + bert. whithout this the model breaks... simply modifies the method used to cast the array. + + Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + incremental_indices = ( + torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length + ) * mask + return incremental_indices.long() + padding_idx + + + +### ./rafale/datapipe.py ### +import os +import warnings +from abc import ABC, abstractmethod + +from datasets import load_dataset, DatasetDict + +from tokenizers import Tokenizer +from tokenizers.processors import TemplateProcessing + +import torch +from torch.utils.data import DataLoader +from torch.nn.utils.rnn import pad_sequence + +from rafale.caches import DATA_CACHE_DIR + + +class DataPipeline(ABC): + """ + Base class + + A data pipeline is initiated with a path and some configurations parameters: + + Args: + - dataset_path : local path to the dataset + - collator : function to be applied when sending a batch through the dataloader + + + Datasets should be saved using the following format: + /.json + + Returns: + """ + + def __init__(self, **kwargs): + self.name: str = kwargs["name"] + self.tokenizer_name: str = kwargs["tokenizer_name"] + self.is_prepared: bool = kwargs["is_prepared"] + self.dataset_path: str = os.path.expanduser(kwargs["dataset_path"]) + + self.max_sequence_length: int = kwargs["max_sequence_length"] + self.train_batch_size: int = kwargs["train_batch_size"] + self.eval_batch_size: int = kwargs["eval_batch_size"] + self.pad_inputs: bool = kwargs["pad_inputs"] + self.pad_token_id: int = kwargs["pad_token_id"] + self.input_id_key: str = kwargs["input_id_key"] + self.shuffle_train: bool = kwargs["shuffle_train"] + + self.num_processes: int = kwargs["num_processes"] + self.tokenizer = Tokenizer.from_pretrained(kwargs["tokenizer_path"]) + + self.data_collator = None + + self.use_cached = False + + def _load(self): + # either load directly from disk or from an hf dataset repo + try: + self.dataset = DatasetDict.load_from_disk(self.dataset_path) + except: + try: + self.dataset = load_dataset(self.dataset_path) + pass + except: + raise OSError( + f"Wrong dataset file and/or path configuration! path: {self.dataset_path}" + ) + + @abstractmethod + def _prepare(self): + """Perform all data preprocessing here: tokenization, chunking, truncation, etc. (EXCEPT padding!). Padding will be performed by + the datacollator. + """ + pass + + def _check_path(self): + """make sure that the dataset has not already been parsed at location""" + output_path = f"{self.name}_{self.tokenizer_name}_bs{self.train_batch_size}_len{self.max_sequence_length}" + + assert DATA_CACHE_DIR[-1] == "/" + + save_path = os.path.abspath(os.path.join(DATA_CACHE_DIR, output_path)) + + if os.path.isdir(DATA_CACHE_DIR): + pass + else: + os.makedirs(DATA_CACHE_DIR) + + if os.path.isdir(save_path): + warnings.warn( + f"Dataset already exists at location:\n\t {save_path} \n ABORTING PREPARATION, USING CACHED DATASET!" + ) + + self.is_prepared = True + self.use_cached = True + + return save_path + + def __call__(self): + # returns a or multiple dataloaders + self._load() + + self.dataloaders = {} + + if not self.is_prepared: + cache_dataset_path = self._check_path() + + if type(self.dataset) == DatasetDict: + for subset in self.dataset: + if subset == "train": + shuffle = self.shuffle_train + batch_size = self.train_batch_size + else: + shuffle = False + batch_size = self.eval_batch_size + + # if the data is not ready to be passed to the dataloader + if not self.is_prepared: + print(f"preparing subset {subset}") + self.dataset[subset] = self._prepare(self.dataset[subset]) + + if self.use_cached: + self.dataset = DatasetDict.load_from_disk(cache_dataset_path) + + self.dataloaders[subset] = DataLoader( + self.dataset[subset], + collate_fn=self.data_collator, + batch_size=batch_size, + ) + print(f"✅ Dataloader ready for subset {subset}.") + + if not self.is_prepared: + self.dataset.save_to_disk(cache_dataset_path) + print(f"✅ Saved prepared dataset at {cache_dataset_path}.") + else: + raise TypeError( + f"self.dataset is type {type(self.dataset)}, but should be DatasetDict." + ) + + return self.dataloaders + + +class InferenceDatapipeline: + def __init__(self, tokenizer_path): + self.tokenizer = Tokenizer.from_pretrained(tokenizer_path) + + def _tokenizer_templating(self, tokenizer, add_eos=True): + if add_eos: + tokenizer.post_processor = TemplateProcessing( + single="<|endoftext|> $A", + special_tokens=[ + ("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")), + ], + ) + + return tokenizer + + def __call__(self, input_str, use_template: bool = True): + """ + tokenize input_str + convert to torch tensor (batch_size=1) + add the endoftext token + """ + if use_template: + self.tokenizer = self._tokenizer_templating(self.tokenizer) + + tokenized_inputs = { + "input_ids": torch.LongTensor( + self.tokenizer.encode(input_str).ids + ).unsqueeze(dim=0) + } + + return tokenized_inputs + + def ids_to_str(self, tensor): + return ifdp.tokenizer.decode(tensor.squeeze().detach().numpy()) + + +class CausalCollator: + def __init__(self, pad_token_id: int, input_id_key: str, pad_inputs: bool): + self.pad_token_id = pad_token_id + self.input_id_key = input_id_key + self.pad_inputs = pad_inputs + + def __call__(self, features): + # Extract the input IDs from the batch + input_ids = [torch.tensor(example[self.input_id_key]) for example in features] + + # Pad the inputs if required + if self.pad_inputs: + input_ids = pad_sequence( + input_ids, batch_first=True, padding_value=self.pad_token_id + ) + + # Set the last token of each label sequence to pad_token_id to ignore loss for the last prediction + # shift left the ids + labels = [ + torch.cat([ids[1:], torch.tensor([self.pad_token_id])]) for ids in input_ids + ] + labels = torch.stack(labels, dim=0) + + return { + "input_ids": input_ids, + "labels": labels, + } + + +class MLMCollator: + def __init__( + self, + mask_p: float = 0.15, + whole_word_mask: bool = False, + mask_span: bool = False, + pad_token_id: int = -100, + input_id_key: str = "input_ids", + pad_inputs: bool = True, + ): + """masks some % of tokens for MLM objective""" + pass + + +class DefaultCollator: + def __init__( + self, + pad_token_id: int = -100, + input_id_key: str = "input_ids", + pad_inputs: bool = True, + ): + """for task data where labels are already set""" + pass + + +class TinyStoriesCausalNeoX(DataPipeline): + """This is sample datapipelin for the TinyStories dataset. + + + This dataset is prepared for causal language modelling using the gpt neox tokenizer (eleutherai). We + + Usage: + ts_dict = { + "name": "tinystories_testing", + "tokenizer_name": "neox", + "is_prepared": False, + "input_id_key": "input_ids", + "batch_size": 16, + "shuffle_train": False, + "dataset_path": "~/code/data/micro_tinystories", + "tokenizer_path": "EleutherAI/pythia-14m", + "max_sequence_length": 128, + "pad_token_id": -100, + "pad_inputs": True, + "is_prepared": False, + } + ts_dpipe = TinyStoriesCausalNeoX(**ts_dict) + dataloaders = ts_causal() + + Args: + + Returns: + + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.data_collator = CausalCollator( + pad_token_id=self.pad_token_id, + input_id_key=self.input_id_key, + pad_inputs=self.pad_inputs, + ) + + # TODO: figure out if they really only use endoftext for everything... + def _tokenizer_templating(self, tokenizer, add_eos=True): + if add_eos: + tokenizer.post_processor = TemplateProcessing( + single="$A <|endoftext|>", + special_tokens=[ + ("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")), + ], + ) + + return tokenizer + + def _tokenize(self, example, tokenizer, key="text"): + return {self.input_id_key: tokenizer.encode(example[key]).ids} + + def _group_and_chunk(self, examples, key="input_ids", block_size=None, pad=False): + concatenated_tokens = sum(examples[key], []) + total_length = len(concatenated_tokens) + + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + + result = { + "input_ids": [ + concatenated_tokens[i : i + block_size] + for i in range(0, total_length, block_size) + ] + } + + return result + + def _prepare(self, data): + """""" + + # apply functions above to dataset + self.tokenizer = self._tokenizer_templating(self.tokenizer) + + data = data.map( + lambda example: self._tokenize(example, self.tokenizer), + remove_columns=data.column_names, + num_proc=self.num_processes, + ) + + data = data.map( + lambda example: self._group_and_chunk( + example, block_size=self.max_sequence_length + ), + batched=True, + num_proc=self.num_processes, + ) + + return data + + +class TinyStoriesMLM(DataPipeline): + """ """ + + pass + + +class ImdbCLS(DataPipeline): + pass + + +''' +class ImdbClsPipe(DataPipeline): + """A pipeline for the imdb dataset for """ + + def __init__(self, **kwargs): + self.path = kwargs["path"] + self.name = kwargs["name"] # name + self.is_tokenized = kwargs["is_tokenized"] + + self.padding = kwargs["padding"] # "max_length" + self.max_sequence_length = kwargs["max_sequence_length"] # 512 + + self.shuffle_train = kwargs["shuffle_train"] # False + self.batch_size = kwargs["batch_size"] + self.tokenizer = kwargs["tokenizer"] + self.collator_fn = DataCollatorWithPadding( + tokenizer=self.tokenizer, + padding=self.padding, + max_length=self.max_sequence_length, + return_tensors='pt' + ) + + self.data = datasets.DatasetDict.load_from_disk(self.path) + + def _post_tokenize(self, dataset): + return dataset.remove_columns(["text"]) + + def _tokenize( + self, + examples, + ): + source_tokenized = self.tokenizer( + examples["text"], + truncation=True, + max_length=self.max_sequence_length, + return_token_type_ids=True, + return_tensors='pt' + ) + + batch = {k: v for k, v in source_tokenized.items()} + + return batch + + def _map_tokenize(self, subsets=None): + # tokenize + print("tokenizing training") + self.data["train"] = self.data["train"].map(self._tokenize, batched=True) + self.data["train"] = self.data["train"].remove_columns("text") + + print("tokenizing test") + self.data["test"] = self.data["test"].map(self._tokenize, batched=True) + self.data["test"] = self.data["test"].remove_columns("text") + + + def _save_tokenized(self): + # preprocess + self.path += "_tokenized" + print(f"saving tokenized data to disk at location : {self.path}") + assert os.path.isdir(self.path) == False + self.data.save_to_disk(self.path) + + def string_to_tokens(self, input_str): + # tokenize + tensor = self._tokenize({"text": input_str}) + + return self.collator_fn(tensor) + + def __call__(self, subsets = ["train", "test"]): + dataloaders = {} + + if self.is_tokenized: + print("data tokenized") + else: + self._map_tokenize() + self._save_tokenized() + + for _set in subsets: + if _set == "train": + shuffle = self.shuffle_train + else: + shuffle = False + + dataloaders[_set] = DataLoader( + self.data[_set], + collate_fn=self.collator_fn, # DEBUG + batch_size=self.batch_size, + shuffle=shuffle, # DEBUG + ) + return dataloaders +''' + + + +### ./rafale/main.py ### +import os +import argparse +import yaml +import time +from datetime import datetime + +import torch +import torch.utils.data + +import composer +from composer import Trainer, Time +from composer.loggers import InMemoryLogger, WandBLogger, FileLogger +from composer.algorithms import GradientClipping +from composer.optim.scheduler import ( + CosineAnnealingWithWarmupScheduler, + CosineAnnealingScheduler, + LinearScheduler, + LinearWithWarmupScheduler, +) + +from rafale.models.decoder import ComposerLM +from rafale.models.encoder import EncoderWrapper +from rafale.models.configurations import ( + load_safetensors, + Pythia14MConfig, + BertConfig, + RobertaConfig, +) + +from rafale.caches import CHECKPOINT_CACHE_DIR +from rafale.datapipe import TinyStoriesCausalNeoX + + +ENV_VARS = {key: value for key, value in os.environ.items()} + +parser = argparse.ArgumentParser(description="launch a training run") + +parser.add_argument( + "-c", + "--training_config", + type=str, + help="path to yaml run configuration file", + required=True, +) +args = parser.parse_args() + +model_config_dict = { + "pythia14m": Pythia14MConfig, + "bert": BertConfig, + "roberta": RobertaConfig, +} + +data_pipeline_dict = { + "tinystories_neox": TinyStoriesCausalNeoX, +} + + +def main(): + # CONFIG ################################################################## + with open(args.training_config, "r") as f: + config = yaml.safe_load(f) + + run_name = config["run"]["name"] + run_n_epochs = config["run"]["n_epochs"] + run_seed = config["run"]["seed"] + run_save_interval = config["run"]["save_interval"] + + run_clip_type = config["run"]["clip_type"] + run_clip_value = float(config["run"]["clip_value"]) + + device_bs = config["run"]["device_bs"] # int or "auto" + + # schedule + run_schedule_type = config["run"]["schedule"] + run_max_lr = float(config["run"]["max_lr"]) # learning rate + run_warmup_pct = float(config["run"]["warmup_pct"]) + if run_schedule_type == "cosine-warmup": + run_scheduler = CosineAnnealingWithWarmupScheduler( + t_warmup=Time(run_warmup_pct, "dur"), alpha_f=0.1 + ) + else: + raise TypeError( + f"Model type {model_type} is not valid! Supports: cosine-warmup.\nlinear, cosine, and linear-warmup planned" + ) + + model_config_key = config["model"]["config"] + model_type = config["model"]["type"] + model_use_pretrained = config["model"]["use_pretrained"] + + data_pipeline_key = config["data"]["pipeline"] + dataset_config = config["data"]["config"] + print(dataset_config) + + # DATALOADERS ############################################################# + data_pipeline = data_pipeline_dict[data_pipeline_key](**dataset_config) + dataloaders = data_pipeline() + if "DATA" in ENV_VARS.keys() and ENV_VARS["DATA"] == "1": + print("Data processing complete, exiting...") + return 0 # just do data preprocessing if we pass DATA=1 + + # MODEL ####################################################################### + model_config = model_config_dict[model_config_key] + + if model_type == "decoder": + rafale_model = ComposerLM(model_config) + elif model_type == "encoder": + rafale_model = EncoderWrapper(model_config) + else: + raise TypeError( + f"Model type {model_type} is not valid! Supports: encoder, decoder." + ) + + if model_use_pretrained: + rafale_model.model = load_safetensors(rafale_model.model, model_config) + + # LOGGING ################################################################# + # mem_logger = InMemoryLogger() + # @TODO :: add some logging options in the yaml + wandb_logger = WandBLogger(project="rafale", name=run_name) + # file_logger = FileLogger(filename=f"{run_name}-{time}".txt) + + # GRADIENT CLIPPING ####################################################### + clipping_type = "norm" # can also be 'adaptive' or 'value' + gradient_clip = GradientClipping( + clipping_type=clipping_type, clipping_threshold=0.1 + ) + + # DEVICES ################################################################# + device = "gpu" if torch.cuda.is_available() else "cpu" # select the device + if device == "gpu": + run_precision = "amp_fp16" + else: + run_precision = "fp32" + + # DEBUG RUN ############################################################### + if "DEBUG" in ENV_VARS.keys() and ENV_VARS["DEBUG"] == "1": + from torch.utils.data import Subset, DataLoader, default_collate + from datasets import Dataset + + # single batch, same for train and test 10 epochs + bs = 4 + debug_batch = next(iter(dataloaders["train"])) + debug_batch = Dataset.from_dict( + {k: v[:bs] for k, v in debug_batch.items()} + ).with_format("torch") + debug_batch = DataLoader( + debug_batch, + batch_size=bs, + shuffle=False, + collate_fn=default_collate, + ) + + trainer = Trainer( + model=rafale_model, + seed=run_seed, + train_dataloader=debug_batch, + eval_dataloader=debug_batch, + optimizers=torch.optim.AdamW(rafale_model.parameters(), lr=1e-4), + max_duration=10, # num epochs + device=device, + loggers=None, + precision=run_precision, + progress_bar=True, + ) + + return 0 + + # TRAIN ################################################################### + # training subset must have key "train" then whatever is called the validation subset (i.e. test, val, validation, + # eval, etc) as long as there is only 1 other subset, we call it + dl_keys = list(dataloaders.keys()) + assert "train" in dl_keys + dl_keys.remove("train") + assert len(dl_keys) == 1 + eval_subset_key = dl_keys[0] + + # get datetime for checkpoint, directories are created by composer + now = datetime.now() + formatted_date = now.strftime( + "%d" + "d" + "%m" + "m" + "%Y" + "y" + "_%H" + "h" + "%M" + "m" + ) # Format it as DDdMMmYYYYy_HHhMMm + checkpoint_folder = os.path.abspath( + os.path.join(CHECKPOINT_CACHE_DIR, f"{run_name}-{formatted_date}/") + ) + + trainer = Trainer( + model=rafale_model, + seed=run_seed, + train_dataloader=dataloaders["train"], + eval_dataloader=dataloaders[eval_subset_key], + optimizers=torch.optim.AdamW(rafale_model.parameters(), lr=run_max_lr), + max_duration=run_n_epochs, # num epochs + eval_interval="50ba", # default is 1ep ! + device_train_microbatch_size=device_bs, # will handle gradient accumulation automatically + device=device, + loggers=[wandb_logger], + precision=run_precision, + progress_bar=True, + schedulers=run_scheduler, + algorithms=[gradient_clip], + save_folder=checkpoint_folder, + save_latest_filename="latest", + save_interval=run_save_interval, + ) + + # launch + trainer.fit() + print(f"🍻 TRAINING COMPLETE\n💾 CHECKPOINTS SAVED AT LOCATION: {checkpoint_folder}") + + +if __name__ == "__main__": + main() + + + +### ./rafale/caches.py ### +import os + +DATA_CACHE_DIR = os.path.expanduser("~/.rafale_cache/data/") +MODEL_CACHE_DIR = os.path.expanduser("~/.rafale_cache/models/") +CHECKPOINT_CACHE_DIR = os.path.expanduser("~/.rafale_cache/checkpoints/") + + + +### ./rafale/__init__.py ### + + + +### ./test/test_pythia.py ### +import torch +import numpy as np + +from safetensors import safe_open + +from rafale.models.decoder import DecoderWrapper +from rafale.models.configurations import Pythia14MConfig, load_safetensors + +from transformers import AutoTokenizer, GPTNeoXForCausalLM + + +def test_layer_and_outputs(rafale_model, hf_model, tokenizer, layer=0, tol=1e-05): + """ + # @NOTE :: this currently on evaluates with KV cache enabled, write the test to run this function without the KV cache + """ + hf_activation = {} + hf_input_activation = {} + rafale_activation = {} + rafale_input_activation = {} + + # tuple of shape num_layers, 2 (keys, values), tensor BHLd + # make a fake kv-cache of length 4 + kv_cache = [] + n_layers = 6 + cache_len = 4 + n_heads = 4 + head_dim = 32 + for i in range(n_layers): + k = torch.randn(1, n_heads, cache_len, 32) + v = torch.randn(1, n_heads, cache_len, 32) + kv_cache.append((k, v)) + + def get_hf_activation(name): + def hook(model, input, output): + hf_activation[name] = output.detach() + + return hook + + def get_hf_input_activation(name): + def hook(model, _input, output): + hf_input_activation[name] = _input[0].detach() + + return hook + + def get_rafale_activation(name): + def hook(model, input, output): + rafale_activation[name] = output.detach() + + return hook + + def get_rafale_input_activation(name): + def hook(model, _input, output): + rafale_input_activation[name] = _input[0].detach() + + return hook + + # embeddings + rafale_model.token_embeddings.register_forward_hook( + get_rafale_activation("input_embeddings") + ) + hf_model.gpt_neox.embed_in.register_forward_hook( + get_hf_activation("input_embeddings") + ) + + # input layernorm + rafale_model.layers[layer].attention_norm.register_forward_hook( + get_rafale_activation("attn_norm") + ) + hf_model.gpt_neox.layers[layer].input_layernorm.register_forward_hook( + get_hf_activation("attn_norm") + ) + + rafale_model.layers[layer].attention_norm.register_forward_hook( + get_rafale_input_activation("input_attn_norm") + ) + hf_model.gpt_neox.layers[layer].input_layernorm.register_forward_hook( + get_hf_input_activation("input_attn_norm") + ) + + # attention projection query_key_values (pre RoPE) + rafale_model.layers[layer].attention.query_key_value.register_forward_hook( + get_rafale_activation("attn_inproj") + ) + + hf_model.gpt_neox.layers[layer].attention.query_key_value.register_forward_hook( + get_hf_activation("attn_inproj") + ) + + # out proj + rafale_model.layers[layer].attention.dense.register_forward_hook( + get_rafale_activation("attn_dense") + ) + + hf_model.gpt_neox.layers[layer].attention.dense.register_forward_hook( + get_hf_activation("attn_dense") + ) + + # INPUT check before attention dense* (if this fails then RoPE is probably the problem...) + rafale_model.layers[layer].attention.dense.register_forward_hook( + get_rafale_input_activation("attn_dense") + ) + hf_model.gpt_neox.layers[layer].attention.dense.register_forward_hook( + get_hf_input_activation("attn_dense") + ) + + # feed forward out + rafale_model.layers[layer].feed_forward.ff_out.register_forward_hook( + get_rafale_activation("ffout") + ) + hf_model.gpt_neox.layers[layer].mlp.dense_4h_to_h.register_forward_hook( + get_hf_activation("ffout") + ) + + # hf_model.gpt_neox.layers[0].attention(tensor) + + input_str = "Hello World from pythia!" + tokens = tokenizer(input_str, return_tensors="pt") + + hf_model.eval() + rafale_model.eval() + + with torch.no_grad(): + # hf_out = hf_model(tokens["input_ids"])["logits"].detach().numpy() + + hf_out = hf_model(tokens["input_ids"], use_cache=True, past_key_values=kv_cache) + hf_out = hf_out["logits"].detach().numpy() + + # rafale_out = rafale_model(tokens)[0].detach().numpy() + + rafale_out = rafale_model(tokens, past_kv_cache=kv_cache)[0].detach().numpy() + + print(f"Dropout p should be 0: {rafale_model.layers[layer].attention.dropout_p}") + print(f"Testing layer {layer}") + # EMBEDDING ################################################################### + try: + np.testing.assert_allclose( + rafale_activation["input_embeddings"].numpy(), + hf_activation["input_embeddings"].numpy(), + rtol=tol, + atol=tol, + ) + + print(f"✅ embeddings OK!") + except: + print("⚠️ Embedding difference!") + + # PRE-ATTENTION NORM ###################################################### + try: + np.testing.assert_allclose( + rafale_input_activation["input_attn_norm"].numpy(), + hf_input_activation["input_attn_norm"].numpy(), + rtol=tol, + atol=tol, + ) + print(f"✅ INPUTS of pre-attention norm OK!") + except: + print("⚠️ INPUTS pre-attention norm difference") + + try: + np.testing.assert_allclose( + rafale_activation["attn_norm"].numpy(), + hf_activation["attn_norm"].numpy(), + rtol=tol, + atol=tol, + ) + print(f"✅ pre-attention norm OK!") + except: + print("⚠️ pre-attention norm difference") + + # LINEAR PROJECTION FOR ATTENTION ###################################################### + + try: + np.testing.assert_allclose( + rafale_activation["attn_inproj"].numpy(), + hf_activation["attn_inproj"].numpy(), + rtol=tol, + atol=tol, + ) + print(f"✅ attention in-projection OK!") + + except: + print(f"⚠️ attention in-projection difference") + + # INPUTS OF ATTENTION DENSE LAYER ######################################## + try: + np.testing.assert_allclose( + rafale_input_activation["attn_dense"].numpy(), + hf_input_activation["attn_dense"].numpy(), + rtol=tol, + atol=tol, + ) + print(f"✅ inputs of attention dense OK") + + except: + r = rafale_input_activation["attn_dense"].numpy() + h = hf_input_activation["attn_dense"].numpy() + + print(r.shape) + print(h.shape) + print("⚠️ inputs of attention dense difference") + + np.testing.assert_allclose( + rafale_input_activation["attn_dense"].numpy(), + hf_input_activation["attn_dense"].numpy(), + rtol=tol, + atol=tol, + ) + + # OUTPUT OF ATTENTION DENSE LAYER ######################################## + try: + np.testing.assert_allclose( + rafale_activation["attn_dense"].numpy(), + hf_activation["attn_dense"].numpy(), + rtol=tol, + atol=tol, + ) + print(f"✅ attention out dense OK!") + except: + print("⚠️ attention out dense difference") + + try: + np.testing.assert_allclose( + rafale_activation["ffout"].numpy(), + hf_activation["ffout"].numpy(), + rtol=tol, + atol=tol, + ) + print(f"✅ feedforward out dense OK!") + except: + print("⚠️ ff out dense difference") + + # Final model output ######################################## + try: + np.testing.assert_allclose(rafale_out, hf_out, rtol=tol, atol=tol) + print(f"🎉 Model outputs match reference implementation!") + except: + print("❌ Model outputs do not match") + + +def main(): + """ """ + torch.manual_seed(0) + np.random.seed(0) + + hf_pythia = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-14m") + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m") + + rafale_pythia = DecoderWrapper(Pythia14MConfig) # check forward pass OK + rafale_pythia = load_safetensors(rafale_pythia, Pythia14MConfig) + + test_layer_and_outputs(rafale_pythia, hf_pythia, tokenizer) + + +if __name__ == "__main__": + main() + + + +### ./test/test_bert.py ### +#!/usr/bin/env python +import os + +from datapipe import WikiMLMPipe +from encoder import EncoderWrapper, BertConfig +from roberta import RobertaConfig, RobertaMLM +from convert_hf_weights import convert_bert_params_dict, convert_roberta_params_dict + +from transformers import AutoTokenizer, AutoModelForMaskedLM + +import torch + +import numpy as np +import random + +torch.set_deterministic_debug_mode(1) +torch.use_deterministic_algorithms(True) + +SEED = 42 +torch.manual_seed(SEED) +np.random.seed(SEED) +random.seed(SEED) + +# dump modeling debugging code here... +# @TODO :: call from cli, layer by layer check, add proper logging + + +# roberta +def test_roberta(): + # roberta base + roberta_cfg = RobertaConfig() + r_roberta = RobertaMLM(roberta_cfg) + + roberta_tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base") + hf_roberta = AutoModelForMaskedLM.from_pretrained("FacebookAI/roberta-base") + + r_roberta = convert_roberta_params_dict(r_roberta, hf_roberta) + + args = { + "path": "~/code/data/enwiki1m", + "truncation": True, + "max_sequence_length": 128, + "shuffle_train": False, + "batch_size": 1, + "padding": "max_length", + "tokenizer": roberta_tokenizer, + } + + wikipipe = WikiMLMPipe(**args) + dl = wikipipe() + + batch = next(iter(dl["train"])) + + with torch.no_grad(): + r_roberta.eval() + hf_roberta.eval() + r_output = r_roberta(**batch) + r_output2 = r_roberta(**batch) + hf_output = hf_roberta(**batch) + hf_output2 = hf_roberta(**batch) + np.testing.assert_allclose( + hf_output.logits.detach().numpy(), + hf_output2.logits.detach().numpy(), + atol=5e-4, + rtol=5e-4, + ) + print("hf deterministic") + np.testing.assert_allclose( + r_output.detach().numpy(), r_output2.detach().numpy(), atol=5e-4, rtol=5e-4 + ) + print("rafale deterministic") + np.testing.assert_allclose( + r_output.detach().numpy(), + hf_output.logits.detach().numpy(), + atol=5e-4, + rtol=5e-4, + ) + + +def test_bert(): + bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") + hf_bert = AutoModelForMaskedLM.from_pretrained("bert-base-uncased") + bert_cfg = BertConfig() + r_bert = EncoderWrapper(bert_cfg) + r_bert = convert_bert_params_dict(r_bert, hf_bert) + + args = { + "path": "~/code/data/enwiki1m", + "truncation": True, + "max_sequence_length": 128, + "shuffle_train": False, + "batch_size": 1, + "padding": "max_length", + "tokenizer": bert_tokenizer, + } + + wikipipe = WikiMLMPipe(**args) + dl = wikipipe() + + batch = next(iter(dl["train"])) + + with torch.no_grad(): + r_bert.eval() + hf_bert.eval() + r_output = r_bert(**batch) + r_output2 = r_bert(**batch) + hf_output = hf_bert(**batch) + hf_output2 = hf_bert(**batch) + np.testing.assert_allclose( + hf_output.logits.detach().numpy(), + hf_output2.logits.detach().numpy(), + atol=5e-4, + rtol=5e-4, + ) + print("hf deterministic") + np.testing.assert_allclose( + r_output.detach().numpy(), r_output2.detach().numpy(), atol=5e-4, rtol=5e-4 + ) + print("rafale deterministic") + np.testing.assert_allclose( + r_output.detach().numpy(), + hf_output.logits.detach().numpy(), + atol=5e-4, + rtol=5e-4, + ) + + +if __name__ == "__main__": + main() + + + +### ./test/test_pythia_generation.py ### +import torch +import numpy as np + +# from safetensors import safe_open +# from tokenizers import Tokenizer + +# from rafale.datapipe import InferenceDatapipeline +# from rafale.models.decoding_strategies import greedy_decode +# from rafale.models.decoder import DecoderWrapper +# from rafale.models.configurations import Pythia14MConfig, load_safetensors + + +# Example usage +# Initialize the rafale_pythia model +rafale_pythia = DecoderWrapper(Pythia14MConfig) +rafale_pythia = load_safetensors(rafale_pythia, Pythia14MConfig) +ifdp = InferenceDatapipeline("EleutherAI/pythia-14m") +test_str = "Once upon a time," + +# Define input_ids (e.g., starting with a token) +# input_ids = torch.tensor([[Pythia14MConfig.bos_token_id]]) # Shape: (1, 1) + +# Define maximum sequence length for generation +max_length = 32 + +# Generate sequence using greedy decoding +generated_sequence = greedy_decode( + rafale_pythia, + ifdp(test_str), + max_length, + Pythia14MConfig.eos_token_id, + check_repeat_ngrams=True, +) + +generated_str = ifdp.ids_to_str(generated_sequence["input_ids"]) + +print("Generated sequence:", generated_str) + + + +### ./test/pythia_tinystories.yaml ### +run: + name: "pythia14m-tinystories" # name of your experiment, used for checkpointing + seed: 42 + n_epochs: 1 + max_lr: 6e-04 + warmup_pct: 0.01 + schedule: "cosine-warmup" # linear, linear-warmup, cosine, cosine-warmup + optimizer: "AdamW" + eval_interval: "50ba" + clip_type: "norm" + clip_value: 1.0 + device_bs: "auto" + save_interval: "200ba" + +model: + config: "pythia14m" # config key + type: "decoder" + use_pretrained: True + +data: + pipeline: "tinystories_neox" # the preprocessing/tokenization pipeline + config: + name: "tinystories" + num_processes: 8 + tokenizer_name: "neox" + is_prepared: False + input_id_key: "input_ids" + train_batch_size: 1024 + eval_batch_size: 16 + shuffle_train: False + dataset_path: "~/code/data/TinyStories" + tokenizer_path: "EleutherAI/pythia-14m" + max_sequence_length: 512 + pad_token_id: -100 + pad_inputs: True + is_prepared: False + +logging: # @TODO :: not implemented + use_wandb: True + use_file: False + eval_interval: "10ba" + log_dir: "./run_logs" + checkpoint_dir: "./checkpoints" + + + +### ./test/test.yaml ### +# we want data and model configurations to be in files rather than in yaml, leave training hyperparams to yaml config only + +run: + name: "test-ministories" # name of your experiment, used for checkpointing + seed: 42 + n_epochs: 1 + max_lr: 3e-04 + warmup_pct: 0.01 + schedule: "cosine-warmup" # linear, linear-warmup, cosine, cosine-warmup + optimizer: "AdamW" + eval_interval: "50ba" + clip_type: "norm" + clip_value: 1.0 + device_bs: "auto" + save_interval: "1ep" + +model: + config: "pythia14m" # config key + type: "decoder" + use_pretrained: True + +data: + pipeline: "tinystories_neox" # the preprocessing/tokenization pipeline + config: + name: "tinystories_testing" + num_processes: 1 + tokenizer_name: "neox" + is_prepared: False + input_id_key: "input_ids" + train_batch_size: 16 + eval_batch_size: 16 + shuffle_train: False + dataset_path: "~/code/data/micro_tinystories" + tokenizer_path: "EleutherAI/pythia-14m" + max_sequence_length: 128 + pad_token_id: -100 + pad_inputs: True + is_prepared: False + +logging: + use_wandb: True + use_file: False + + + diff --git a/media/rafale-logo.png b/media/rafale-logo.png new file mode 100644 index 0000000..ca2d590 Binary files /dev/null and b/media/rafale-logo.png differ diff --git a/media/rafale-logo.svg b/media/rafale-logo.svg new file mode 100644 index 0000000..7dd4a4a --- /dev/null +++ b/media/rafale-logo.svg @@ -0,0 +1,80 @@ + + + + diff --git a/pyproject.toml b/pyproject.toml index 4999a7b..e034a72 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,23 +1,11 @@ -[build-system] -requires = ["setuptools>=42", "wheel"] -build-backend = "setuptools.build_meta" - [project] name = "rafale" version = "0.1" -description = "simple transformer training lib" +description = "opinionated transformer training cli" readme = "README.md" -requires-python = ">=3.6" -license = {text = "MIT"} - -[project.urls] -homepage = "https://github.com/maxrousseau/rafale" +requires-python = "==3.12" +packages = [{include = "rafale"}] -[project.authors] -name = "Maxime Rousseau" - -[tool.setuptools.packages.find] -where = ["rafale"] - -[project.optional-dependencies] -# Add any optional dependencies here, if applicable \ No newline at end of file +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" \ No newline at end of file diff --git a/rafale/caches.py b/rafale/caches.py new file mode 100644 index 0000000..daddac3 --- /dev/null +++ b/rafale/caches.py @@ -0,0 +1,5 @@ +import os + +DATA_CACHE_DIR = os.path.expanduser("~/.rafale_cache/data/") +MODEL_CACHE_DIR = os.path.expanduser("~/.rafale_cache/models/") +CHECKPOINT_CACHE_DIR = os.path.expanduser("~/.rafale_cache/checkpoints/") diff --git a/rafale/datapipe.py b/rafale/datapipe.py index 93a39e2..64d39a2 100644 --- a/rafale/datapipe.py +++ b/rafale/datapipe.py @@ -1,102 +1,362 @@ import os +import warnings +from abc import ABC, abstractmethod from datasets import load_dataset, DatasetDict -from transformers import DataCollatorForLanguageModeling +from tokenizers import Tokenizer +from tokenizers.processors import TemplateProcessing +import torch from torch.utils.data import DataLoader +from torch.nn.utils.rnn import pad_sequence +from rafale.caches import DATA_CACHE_DIR -class DataPipeline: + +class DataPipeline(ABC): """ Base class A data pipeline is initiated with a path and some configurations parameters: - - path : local path to the dataset - - collator_fn : function to be applied at parse + Args: + - dataset_path : local path to the dataset + - collator : function to be applied when sending a batch through the dataloader + Datasets should be saved using the following format: /.json + + Returns: """ - def __init__(self): - self.path = None - self.collator_fn = None - self.padding = None - self.truncation = None - self.max_sequence_length = None - self.shuffle_train = False - self.batch_size = 4 - self.tokenizer = None + def __init__(self, **kwargs): + self.name: str = kwargs["name"] + self.tokenizer_name: str = kwargs["tokenizer_name"] + self.is_prepared: bool = kwargs["is_prepared"] + self.dataset_path: str = os.path.expanduser(kwargs["dataset_path"]) + + self.max_sequence_length: int = kwargs["max_sequence_length"] + self.train_batch_size: int = kwargs["train_batch_size"] + self.eval_batch_size: int = kwargs["eval_batch_size"] + self.pad_inputs: bool = kwargs["pad_inputs"] + self.pad_token_id: int = kwargs["pad_token_id"] + self.input_id_key: str = kwargs["input_id_key"] + self.shuffle_train: bool = kwargs["shuffle_train"] + + self.num_processes: int = kwargs["num_processes"] + self.tokenizer = Tokenizer.from_pretrained(kwargs["tokenizer_path"]) + + self.data_collator = None + + self.use_cached = False def _load(self): + # either load directly from disk or from an hf dataset repo try: - self.dataset_pre = DatasetDict.load_from_disk(self.path) + self.dataset = DatasetDict.load_from_disk(self.dataset_path) except: - raise OSError(f"Wrong dataset file and/or path configuration!{self.path}") + try: + self.dataset = load_dataset(self.dataset_path) + pass + except: + raise OSError( + f"Wrong dataset file and/or path configuration! path: {self.dataset_path}" + ) + + @abstractmethod + def _prepare(self): + """Perform all data preprocessing here: tokenization, chunking, truncation, etc. (EXCEPT padding!). Padding will be performed by + the datacollator. + """ + pass + + def _check_path(self): + """make sure that the dataset has not already been parsed at location""" + output_path = f"{self.name}_{self.tokenizer_name}_bs{self.train_batch_size}_len{self.max_sequence_length}" + + assert DATA_CACHE_DIR[-1] == "/" - def _post_tokenize(self): - None + save_path = os.path.abspath(os.path.join(DATA_CACHE_DIR, output_path)) + + if os.path.isdir(DATA_CACHE_DIR): + pass + else: + os.makedirs(DATA_CACHE_DIR) + + if os.path.isdir(save_path): + warnings.warn( + f"Dataset already exists at location:\n\t {save_path} \n ABORTING PREPARATION, USING CACHED DATASET!" + ) + + self.is_prepared = True + self.use_cached = True + + return save_path def __call__(self): # returns a or multiple dataloaders self._load() - dataloaders = {} + self.dataloaders = {} - if type(self.dataset_pre) == DatasetDict: - for subset in self.dataset_pre: + if not self.is_prepared: + cache_dataset_path = self._check_path() + + if type(self.dataset) == DatasetDict: + for subset in self.dataset: if subset == "train": shuffle = self.shuffle_train + batch_size = self.train_batch_size else: shuffle = False + batch_size = self.eval_batch_size - # @DEBUG - self.dataset_pre[subset] = self.dataset_pre[subset].select(range(10)) - _set = self.dataset_pre[subset].map( - lambda example: self._tokenize( - example, - ) - ) - _set = self._post_tokenize(_set) - print(len(_set["input_ids"][0])) - print(_set) - dataloaders[subset] = DataLoader( - _set, - collate_fn=self.collator_fn, # DEBUG - batch_size=self.batch_size, - shuffle=shuffle, # DEBUG + # if the data is not ready to be passed to the dataloader + if not self.is_prepared: + print(f"preparing subset {subset}") + self.dataset[subset] = self._prepare(self.dataset[subset]) + + if self.use_cached: + self.dataset = DatasetDict.load_from_disk(cache_dataset_path) + + self.dataloaders[subset] = DataLoader( + self.dataset[subset], + collate_fn=self.data_collator, + batch_size=batch_size, ) + print(f"✅ Dataloader ready for subset {subset}.") + + if not self.is_prepared: + self.dataset.save_to_disk(cache_dataset_path) + print(f"✅ Saved prepared dataset at {cache_dataset_path}.") else: - # process a single set - None + raise TypeError( + f"self.dataset is type {type(self.dataset)}, but should be DatasetDict." + ) - return dataloaders + return self.dataloaders + + +class InferenceDatapipeline: + def __init__(self, tokenizer_path): + self.tokenizer = Tokenizer.from_pretrained(tokenizer_path) + + def _tokenizer_templating(self, tokenizer, add_eos=True): + if add_eos: + tokenizer.post_processor = TemplateProcessing( + single="<|endoftext|> $A", + special_tokens=[ + ("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")), + ], + ) + + return tokenizer + + def __call__(self, input_str, use_template: bool = True): + """ + tokenize input_str + convert to torch tensor (batch_size=1) + add the endoftext token + """ + if use_template: + self.tokenizer = self._tokenizer_templating(self.tokenizer) + + tokenized_inputs = { + "input_ids": torch.LongTensor( + self.tokenizer.encode(input_str).ids + ).unsqueeze(dim=0) + } + + return tokenized_inputs + + def ids_to_str(self, tensor): + return ifdp.tokenizer.decode(tensor.squeeze().detach().numpy()) + + +class CausalCollator: + def __init__(self, pad_token_id: int, input_id_key: str, pad_inputs: bool): + self.pad_token_id = pad_token_id + self.input_id_key = input_id_key + self.pad_inputs = pad_inputs + def __call__(self, features): + # Extract the input IDs from the batch + input_ids = [torch.tensor(example[self.input_id_key]) for example in features] + + # Pad the inputs if required + if self.pad_inputs: + input_ids = pad_sequence( + input_ids, batch_first=True, padding_value=self.pad_token_id + ) + + # Set the last token of each label sequence to pad_token_id to ignore loss for the last prediction + # shift left the ids + labels = [ + torch.cat([ids[1:], torch.tensor([self.pad_token_id])]) for ids in input_ids + ] + labels = torch.stack(labels, dim=0) + + return { + "input_ids": input_ids, + "labels": labels, + } + + +class MLMCollator: + def __init__( + self, + mask_p: float = 0.15, + whole_word_mask: bool = False, + mask_span: bool = False, + pad_token_id: int = -100, + input_id_key: str = "input_ids", + pad_inputs: bool = True, + ): + """masks some % of tokens for MLM objective""" + pass + + +class DefaultCollator: + def __init__( + self, + pad_token_id: int = -100, + input_id_key: str = "input_ids", + pad_inputs: bool = True, + ): + """for task data where labels are already set""" + pass -# TODO: -# download 10k wiki random subset (or "simple") - shuffle then download, do that in a colab notebook... -# then apply MLM loading to it for testing... +class TinyStoriesCausalNeoX(DataPipeline): + """This is sample datapipelin for the TinyStories dataset. -# datapipe is simple, you call the function and get the dataloader(s) you need for running -# a debug_mode flag can be included (single batch) -class WikiMLMPipe(DataPipeline): - """a first testing pipeline for wikipedia for MLM""" + + This dataset is prepared for causal language modelling using the gpt neox tokenizer (eleutherai). We + + Usage: + ts_dict = { + "name": "tinystories_testing", + "tokenizer_name": "neox", + "is_prepared": False, + "input_id_key": "input_ids", + "batch_size": 16, + "shuffle_train": False, + "dataset_path": "~/code/data/micro_tinystories", + "tokenizer_path": "EleutherAI/pythia-14m", + "max_sequence_length": 128, + "pad_token_id": -100, + "pad_inputs": True, + "is_prepared": False, + } + ts_dpipe = TinyStoriesCausalNeoX(**ts_dict) + dataloaders = ts_causal() + + Args: + + Returns: + + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.data_collator = CausalCollator( + pad_token_id=self.pad_token_id, + input_id_key=self.input_id_key, + pad_inputs=self.pad_inputs, + ) + + # TODO: figure out if they really only use endoftext for everything... + def _tokenizer_templating(self, tokenizer, add_eos=True): + if add_eos: + tokenizer.post_processor = TemplateProcessing( + single="$A <|endoftext|>", + special_tokens=[ + ("<|endoftext|>", tokenizer.token_to_id("<|endoftext|>")), + ], + ) + + return tokenizer + + def _tokenize(self, example, tokenizer, key="text"): + return {self.input_id_key: tokenizer.encode(example[key]).ids} + + def _group_and_chunk(self, examples, key="input_ids", block_size=None, pad=False): + concatenated_tokens = sum(examples[key], []) + total_length = len(concatenated_tokens) + + if total_length >= block_size: + total_length = (total_length // block_size) * block_size + + result = { + "input_ids": [ + concatenated_tokens[i : i + block_size] + for i in range(0, total_length, block_size) + ] + } + + return result + + def _prepare(self, data): + """""" + + # apply functions above to dataset + self.tokenizer = self._tokenizer_templating(self.tokenizer) + + data = data.map( + lambda example: self._tokenize(example, self.tokenizer), + remove_columns=data.column_names, + num_proc=self.num_processes, + ) + + data = data.map( + lambda example: self._group_and_chunk( + example, block_size=self.max_sequence_length + ), + batched=True, + num_proc=self.num_processes, + ) + + return data + + +class TinyStoriesMLM(DataPipeline): + """ """ + + pass + + +class ImdbCLS(DataPipeline): + pass + + +''' +class ImdbClsPipe(DataPipeline): + """A pipeline for the imdb dataset for """ def __init__(self, **kwargs): self.path = kwargs["path"] + self.name = kwargs["name"] # name + self.is_tokenized = kwargs["is_tokenized"] + self.padding = kwargs["padding"] # "max_length" - self.truncation = kwargs["truncation"] # "max_length" self.max_sequence_length = kwargs["max_sequence_length"] # 512 - self.shuffle_train = kwargs["shuffle_train"] # False + + self.shuffle_train = kwargs["shuffle_train"] # False self.batch_size = kwargs["batch_size"] self.tokenizer = kwargs["tokenizer"] - self.collator_fn = DataCollatorForLanguageModeling(tokenizer=self.tokenizer) + self.collator_fn = DataCollatorWithPadding( + tokenizer=self.tokenizer, + padding=self.padding, + max_length=self.max_sequence_length, + return_tensors='pt' + ) + + self.data = datasets.DatasetDict.load_from_disk(self.path) def _post_tokenize(self, dataset): - return dataset.remove_columns(["url", "title", "text", "id"]) + return dataset.remove_columns(["text"]) def _tokenize( self, @@ -104,26 +364,60 @@ def _tokenize( ): source_tokenized = self.tokenizer( examples["text"], - padding=self.padding, + truncation=True, max_length=self.max_sequence_length, - truncation=self.truncation, return_token_type_ids=True, + return_tensors='pt' ) batch = {k: v for k, v in source_tokenized.items()} return batch + def _map_tokenize(self, subsets=None): + # tokenize + print("tokenizing training") + self.data["train"] = self.data["train"].map(self._tokenize, batched=True) + self.data["train"] = self.data["train"].remove_columns("text") -""" -args = {"path" : "~/code/data/enwiki1m", "truncation": True, "max_sequence_length": 128, "shuffle_train" : False, -"batch_size":4, "padding": "max_length", "tokenizer" : tokenizer} + print("tokenizing test") + self.data["test"] = self.data["test"].map(self._tokenize, batched=True) + self.data["test"] = self.data["test"].remove_columns("text") -wikipipe = WikiMLMPipe(**args) -dloaderz = wikipipe() -next(iter(dloaderz["train"])) + def _save_tokenized(self): + # preprocess + self.path += "_tokenized" + print(f"saving tokenized data to disk at location : {self.path}") + assert os.path.isdir(self.path) == False + self.data.save_to_disk(self.path) -""" + def string_to_tokens(self, input_str): + # tokenize + tensor = self._tokenize({"text": input_str}) + + return self.collator_fn(tensor) + + def __call__(self, subsets = ["train", "test"]): + dataloaders = {} -# class WikiSLaMPipe(DataPipeline): + if self.is_tokenized: + print("data tokenized") + else: + self._map_tokenize() + self._save_tokenized() + + for _set in subsets: + if _set == "train": + shuffle = self.shuffle_train + else: + shuffle = False + + dataloaders[_set] = DataLoader( + self.data[_set], + collate_fn=self.collator_fn, # DEBUG + batch_size=self.batch_size, + shuffle=shuffle, # DEBUG + ) + return dataloaders +''' diff --git a/rafale/main.py b/rafale/main.py index 54fe9be..aa00691 100644 --- a/rafale/main.py +++ b/rafale/main.py @@ -1,38 +1,211 @@ +import os import argparse import yaml +import time +from datetime import datetime +import torch +import torch.utils.data -# generate datasetmodule/model given a ".py" file containing the setup code -from datasets.data_utils import DatasetWrapper -from models.model_utils import ModelWrapper +import composer +from composer import Trainer, Time +from composer.loggers import InMemoryLogger, WandBLogger, FileLogger +from composer.algorithms import GradientClipping +from composer.optim.scheduler import ( + CosineAnnealingWithWarmupScheduler, + CosineAnnealingScheduler, + LinearScheduler, + LinearWithWarmupScheduler, +) + +from rafale.models.decoder import ComposerLM +from rafale.models.encoder import EncoderWrapper +from rafale.models.configurations import ( + load_safetensors, + Pythia14MConfig, + BertConfig, + RobertaConfig, +) + +from rafale.caches import CHECKPOINT_CACHE_DIR +from rafale.datapipe import TinyStoriesCausalNeoX + + +ENV_VARS = {key: value for key, value in os.environ.items()} parser = argparse.ArgumentParser(description="launch a training run") + parser.add_argument( - "-c", "--config", type=str, help="path to yaml configuration file", required=True + "-c", + "--training_config", + type=str, + help="path to yaml run configuration file", + required=True, ) args = parser.parse_args() +model_config_dict = { + "pythia14m": Pythia14MConfig, + "bert": BertConfig, + "roberta": RobertaConfig, +} + +data_pipeline_dict = { + "tinystories_neox": TinyStoriesCausalNeoX, +} + def main(): - # load/parse yaml config - with open(args.config, "r") as f: + # CONFIG ################################################################## + with open(args.training_config, "r") as f: config = yaml.safe_load(f) - print(config["run"]["name"]) - print(config["model"]) + run_name = config["run"]["name"] + run_n_epochs = config["run"]["n_epochs"] + run_seed = config["run"]["seed"] + run_save_interval = config["run"]["save_interval"] - # build & load the model - # @HERE + run_clip_type = config["run"]["clip_type"] + run_clip_value = float(config["run"]["clip_value"]) - # build & load the dataloader + device_bs = config["run"]["device_bs"] # int or "auto" - # setup logging? + # schedule + run_schedule_type = config["run"]["schedule"] + run_max_lr = float(config["run"]["max_lr"]) # learning rate + run_warmup_pct = float(config["run"]["warmup_pct"]) + if run_schedule_type == "cosine-warmup": + run_scheduler = CosineAnnealingWithWarmupScheduler( + t_warmup=Time(run_warmup_pct, "dur"), alpha_f=0.1 + ) + else: + raise TypeError( + f"Model type {model_type} is not valid! Supports: cosine-warmup.\nlinear, cosine, and linear-warmup planned" + ) - # build trainer + model_config_key = config["model"]["config"] + model_type = config["model"]["type"] + model_use_pretrained = config["model"]["use_pretrained"] - # launch + data_pipeline_key = config["data"]["pipeline"] + dataset_config = config["data"]["config"] + print(dataset_config) + + # DATALOADERS ############################################################# + data_pipeline = data_pipeline_dict[data_pipeline_key](**dataset_config) + dataloaders = data_pipeline() + if "DATA" in ENV_VARS.keys() and ENV_VARS["DATA"] == "1": + print("Data processing complete, exiting...") + return 0 # just do data preprocessing if we pass DATA=1 + + # MODEL ####################################################################### + model_config = model_config_dict[model_config_key] + + if model_type == "decoder": + rafale_model = ComposerLM(model_config) + elif model_type == "encoder": + rafale_model = EncoderWrapper(model_config) + else: + raise TypeError( + f"Model type {model_type} is not valid! Supports: encoder, decoder." + ) - return None + if model_use_pretrained: + rafale_model.model = load_safetensors(rafale_model.model, model_config) + + # LOGGING ################################################################# + # mem_logger = InMemoryLogger() + # @TODO :: add some logging options in the yaml + wandb_logger = WandBLogger(project="rafale", name=run_name) + # file_logger = FileLogger(filename=f"{run_name}-{time}".txt) + + # GRADIENT CLIPPING ####################################################### + clipping_type = "norm" # can also be 'adaptive' or 'value' + gradient_clip = GradientClipping( + clipping_type=clipping_type, clipping_threshold=0.1 + ) + + # DEVICES ################################################################# + device = "gpu" if torch.cuda.is_available() else "cpu" # select the device + if device == "gpu": + run_precision = "amp_fp16" + else: + run_precision = "fp32" + + # DEBUG RUN ############################################################### + if "DEBUG" in ENV_VARS.keys() and ENV_VARS["DEBUG"] == "1": + from torch.utils.data import Subset, DataLoader, default_collate + from datasets import Dataset + + # single batch, same for train and test 10 epochs + bs = 4 + debug_batch = next(iter(dataloaders["train"])) + debug_batch = Dataset.from_dict( + {k: v[:bs] for k, v in debug_batch.items()} + ).with_format("torch") + debug_batch = DataLoader( + debug_batch, + batch_size=bs, + shuffle=False, + collate_fn=default_collate, + ) + + trainer = Trainer( + model=rafale_model, + seed=run_seed, + train_dataloader=debug_batch, + eval_dataloader=debug_batch, + optimizers=torch.optim.AdamW(rafale_model.parameters(), lr=1e-4), + max_duration=10, # num epochs + device=device, + loggers=None, + precision=run_precision, + progress_bar=True, + ) + + return 0 + + # TRAIN ################################################################### + # training subset must have key "train" then whatever is called the validation subset (i.e. test, val, validation, + # eval, etc) as long as there is only 1 other subset, we call it + dl_keys = list(dataloaders.keys()) + assert "train" in dl_keys + dl_keys.remove("train") + assert len(dl_keys) == 1 + eval_subset_key = dl_keys[0] + + # get datetime for checkpoint, directories are created by composer + now = datetime.now() + formatted_date = now.strftime( + "%d" + "d" + "%m" + "m" + "%Y" + "y" + "_%H" + "h" + "%M" + "m" + ) # Format it as DDdMMmYYYYy_HHhMMm + checkpoint_folder = os.path.abspath( + os.path.join(CHECKPOINT_CACHE_DIR, f"{run_name}-{formatted_date}/") + ) + + trainer = Trainer( + model=rafale_model, + seed=run_seed, + train_dataloader=dataloaders["train"], + eval_dataloader=dataloaders[eval_subset_key], + optimizers=torch.optim.AdamW(rafale_model.parameters(), lr=run_max_lr), + max_duration=run_n_epochs, # num epochs + eval_interval="50ba", # default is 1ep ! + device_train_microbatch_size=device_bs, # will handle gradient accumulation automatically + device=device, + loggers=[wandb_logger], + precision=run_precision, + progress_bar=True, + schedulers=run_scheduler, + algorithms=[gradient_clip], + save_folder=checkpoint_folder, + save_latest_filename="latest", + save_interval=run_save_interval, + ) + + # launch + trainer.fit() + print(f"🍻 TRAINING COMPLETE\n💾 CHECKPOINTS SAVED AT LOCATION: {checkpoint_folder}") if __name__ == "__main__": diff --git a/rafale/models/configurations.py b/rafale/models/configurations.py new file mode 100644 index 0000000..d11177e --- /dev/null +++ b/rafale/models/configurations.py @@ -0,0 +1,189 @@ +import requests +import os + +from dataclasses import dataclass + +from safetensors import safe_open + +from ..caches import MODEL_CACHE_DIR + +""" +to simplify model loading add a configuration for the pre-trained weight loading using safetensors instead of loading +the full model. +> then save to a folder named ".pretrained/" in this directory +""" + + +def download_file(url, save_path): + """ + Downloads a file from the specified URL to the given save path. + + :param url: The URL of the file to download. + :param save_path: The local path where the file will be saved. + """ + try: + # Make sure the directory exists + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + with requests.get(url, stream=True) as response: + response.raise_for_status() # Check for HTTP errors + total_size = int(response.headers.get("content-length", 0)) + block_size = 1024 # 1 Kilobyte + progress = 0 + + with open(save_path, "wb") as file: + for data in response.iter_content(block_size): + file.write(data) + progress += len(data) + print( + f"Downloaded {progress} of {total_size} bytes ({(progress/total_size)*100:.2f}%)", + end="\r", + ) + + print(f"\nDownload completed successfully! File saved to: {save_path}") + + except requests.exceptions.HTTPError as http_err: + print(f"HTTP error occurred: {http_err}") # HTTP error + except Exception as err: + print(f"An error occurred: {err}") # Other errors + + +def load_safetensors(rafale_model, model_config): + """Transfer the pretrained safetensors to rafale model""" + tensors = {} + + safetensors_path = os.path.join(MODEL_CACHE_DIR, model_config.name + ".safetensors") + + if os.path.isfile(safetensors_path): + pass + else: + download_file(model_config.safetensors_url, safetensors_path) + + with safe_open(safetensors_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = f.get_tensor(k) + + rafale_model = model_config.convert_params_dict(rafale_model, tensors) + + return rafale_model + + +@dataclass +class BertConfig: + embed_dim: int = 768 + vocab_size: int = 30522 + attention_dropout: float = 0.1 + hidden_dropout: float = 0.1 + num_heads: int = 12 + ff_dim: int = 3072 + max_pos_embedding: int = 512 + layer_norm_eps: float = 1e-12 + num_blocks: int = 12 + pad_token_id: int = 0 + num_token_type: int = 2 + fast_attention: bool = ( + False # use xformers (todo: add FlashAttention2), NOT IMPLEMENTED* + ) + + +@dataclass +class BertTinyConfig: + embed_dim: int = 128 + vocab_size: int = 30522 # could usage would be to 30522 + num_extra_tokens + attention_dropout: float = 0.1 + hidden_dropout: float = 0.1 + num_heads: int = 2 + ff_dim: int = 512 + max_pos_embedding: int = 512 + layer_norm_eps: float = 1e-12 + num_blocks: int = 2 + pad_token_id: int = 0 + num_token_type: int = 2 + fast_attention: bool = False # use xformers (todo: add FlashAttention2) + + +@dataclass +class RobertaConfig: + embed_dim: int = 768 + vocab_size: int = 50265 + attention_dropout: float = 0.1 + hidden_dropout: float = 0.1 + num_heads: int = 12 + ff_dim: int = 3072 + max_pos_embedding: int = 514 + layer_norm_eps: float = 1e-05 + num_blocks: int = 12 + pad_token_id: int = 1 + num_token_type: int = 1 + bos_token_id: int = 0 + eos_token_id: int = 2 + fast_attention: bool = False + + +@dataclass +class Pythia14MConfig: + name: str = "pythia14m" + safetensors_url: str = ( + "https://huggingface.co/EleutherAI/pythia-14m/resolve/main/model.safetensors" + ) + + embed_dim: int = 128 + num_heads: int = 4 + ff_dim: int = 512 + hidden_act: str = "gelu" + max_pos_embedding: int = 2048 + vocab_size: int = 50304 + use_cache: bool = True + + parallel_residual: bool = True + + attention_dropout: float = 0.1 + hidden_dropout: float = 0.1 + + layer_norm_eps: float = 1e-05 + num_blocks: int = 6 + + bos_token_id: int = 0 + eos_token_id: int = 0 + fast_attention: bool = False + + rotary_emb_base: int = 10000 + rotary_pct: float = 0.25 + + tie_word_embeddings: bool = False + + @classmethod + def convert_params_dict(cls, target, source): + """ + Source safetensors dict to our rafale model class. + """ + # not needed for our implementation + unused = ["rotary_emb.inv_freq", "masked_bias", "attention.bias"] + for k, v in list(source.items()): + if True in [x in k for x in unused]: + del source[k] + + conversion = [ + ("gpt_neox.embed_in", "token_embeddings.input_embeddings"), + ("gpt_neox.layers", "layers"), + ("input_layernorm", "attention_norm"), + ("post_attention_layernorm", "ffn_norm"), + ("mlp", "feed_forward"), + ("dense_4h_to_h", "ff_out"), + ("dense_h_to_4h", "ff_in"), + ("embed_out", "output"), + ("gpt_neox.final_layer_norm", "final_norm"), + ] + + updated_parameters = {} + for k, v in source.items(): + for hf_term, my_term in conversion: + if hf_term in k: + k = k.replace(hf_term, my_term) + + updated_parameters[k] = v + + # here we transfer weights for all layers + target.load_state_dict(updated_parameters, strict=True) + + return target diff --git a/rafale/models/convert_hf_weights.py b/rafale/models/convert_hf_weights.py index a6329a4..b187425 100644 --- a/rafale/models/convert_hf_weights.py +++ b/rafale/models/convert_hf_weights.py @@ -55,6 +55,43 @@ def convert_bert_params_dict(target, source): return target +def convert_pythia_params_dict(target, source): + """ + Source safetensors dict to our rafale model class. + """ + + # not needed for our implementation + unused = ["rotary_emb.inv_freq", "masked_bias", "attention.bias"] + for k, v in list(source.items()): + if True in [x in k for x in unused]: + del source[k] + + conversion = [ + ("gpt_neox.embed_in", "token_embeddings.input_embeddings"), + ("gpt_neox.layers", "layers"), + ("input_layernorm", "attention_norm"), + ("post_attention_layernorm", "ffn_norm"), + ("mlp", "feed_forward"), + ("dense_4h_to_h", "ff_out"), + ("dense_h_to_4h", "ff_in"), + ("embed_out", "output"), + ("gpt_neox.final_layer_norm", "final_norm"), + ] + + updated_parameters = {} + for k, v in source.items(): + for hf_term, my_term in conversion: + if hf_term in k: + k = k.replace(hf_term, my_term) + + updated_parameters[k] = v + + # here we transfer weights for all layers + target.load_state_dict(updated_parameters, strict=True) + + return target + + def convert_roberta_params_dict(target, source): conversion = [ ("roberta.embeddings", "embedding_layer"), diff --git a/rafale/models/decoder.py b/rafale/models/decoder.py new file mode 100644 index 0000000..3f2dea2 --- /dev/null +++ b/rafale/models/decoder.py @@ -0,0 +1,581 @@ +#!/usr/bin/env python +from typing import Optional +import torch + +from torch import nn +from torch import Tensor + +import torch.nn.functional as F + +from torch.nn.functional import scaled_dot_product_attention + +from torchmetrics import Metric +from torchmetrics.collections import MetricCollection + +from composer.metrics import LossMetric, LanguagePerplexity +from composer.models import ComposerModel + + +############################################################################### +# simple implementation of GPT building +class NeoXRoPE(nn.Module): + @classmethod + def precompute_sin_cos_cache(cls, dim=None, seq_len=None, base=10000, device=None): + """Computes the cos and sin angles to be applied to the token vectors. + + We begin by computing thetas (freqs) across each dimension pair (P=D/2) for the whole sequence length (L). + Then we convert this matrix of shape LP into complex numbers of the same shape. + Finally the real and imaginary parts of these complex numbers are stored in a stacked matrix and returned. + + Args: + dim (int): number of features dimension per token to apply rotations to (d*rotary_pct) + seq_len (int): sequence length of the input (use the maximum sequence length) + base (int): default 10000 + + Returns: + Tensor # of shape [1,1,L,R] + + Tensor dimension names: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + - R rotary dimensions (d*rotary_pct) + """ + + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim) + ) + t = torch.arange(seq_len, dtype=torch.int64).type_as(inv_freq) + freqs = torch.einsum("i,j->ij", t, inv_freq) + emb = torch.cat((freqs, freqs), dim=-1) + cos_cached = emb.cos()[None, None, :, :] # shape is [1, 1, L, R] + sin_cached = emb.sin()[None, None, :, :] + + return cos_cached, sin_cached + + @staticmethod + def rotate_half(x): + x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=x1.ndim - 1) + + @classmethod + def apply_rotary_pos_emb(cls, q_BHLR, k_BHLR, cos, sin): + """Applies the rotation to the input queries and key features.""" + + tensor_device = q_BHLR.get_device() + + cos = cos.to(device=tensor_device) + sin = sin.to(device=tensor_device) + + return (q_BHLR * cos) + (cls.rotate_half(q_BHLR) * sin), (k_BHLR * cos) + ( + cls.rotate_half(k_BHLR) * sin + ) + + @classmethod + def apply_rotary_pos_emb_offset(cls, q, k, cos, sin, offset: int = 0): + """ + q and k are shape: BHLR + cos, sin are shape: 11LR + """ + cos, sin = ( + cos[:, :, offset : q.shape[2] + offset, :], + sin[:, :, offset : q.shape[2] + offset, :], + ) + + tensor_device = q.get_device() + + cos = cos.to(device=tensor_device) + sin = sin.to(device=tensor_device) + + return (q * cos) + (cls.rotate_half(q) * sin), (k * cos) + ( + cls.rotate_half(k) * sin + ) + + +class DecoderEmbedding(nn.Module): + """Simply an input projection of the tokens here, since rotary position encodings are used. + + Tensor dimension names: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + """ + + def __init__(self, config): + super().__init__() + + self.input_embeddings = nn.Embedding( + num_embeddings=config.vocab_size, + embedding_dim=config.embed_dim, + ) + self.dropout = nn.Dropout(config.hidden_dropout) + + def forward(self, x_BL): + x_BLD = self.input_embeddings(x_BL) + return self.dropout(x_BLD) + + +class DecoderAttentionRotary(nn.Module): + """ + Attention with rotary position embedding. + + Tensor dimension names: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + + """ + + def __init__(self, config): + super().__init__() + + self.head_dim = config.embed_dim // config.num_heads + self.num_heads = config.num_heads + self.embed_dim = config.embed_dim + self.rotary_ndims = int(self.head_dim * config.rotary_pct) + + self.attention_bias = True # @TODO: set bias to True or False from config. + self.query_key_value = nn.Linear(config.embed_dim, 3 * config.embed_dim) + self.dense = nn.Linear(config.embed_dim, config.embed_dim) + self.dropout = nn.Dropout(p=config.attention_dropout) + self.dropout_p = config.attention_dropout + self.norm_factor = self.head_dim**-0.5 + + def _split_heads(self, tensor: Tensor): + """ + Splits hidden dim into attn_head_size and num_attention_heads + + # input tensor: [bs, seq_len, hidden_size] + # returns: [bs, num_attention_heads, seq_len, attn_head_size] + """ + batch_size = tensor.size(0) + + return ( + tensor.view(batch_size, -1, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) + + def _merge_heads(self, tensor: Tensor): + """ + input tensor: [bs. num_attention_heads, seq_len, attn_head_size] + returns: [bs, seq_len, hidden_size] + """ + # tensor [bs, num_attention_heads, seq_len, attn_head_size] + tensor = tensor.permute(0, 2, 1, 3).contiguous() + # -> [bs, seq_len, num_attention_heads, attn_head_size] + tensor = tensor.view( + tensor.size(0), + tensor.size(1), + self.num_heads * self.head_dim, + ).contiguous() + # -> [bs, seq_len, hidden_size] + return tensor + + def forward(self, x_BLD, freqs_cis): + if not self.training: + self.dropout_p = 0 + + bsz, seq_len, _ = x_BLD.size() + + assert freqs_cis is not None + + # projections + qkv = self.query_key_value(x_BLD) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + q_BHLd = qkv[..., : self.head_dim].permute(0, 2, 1, 3) + k_BHLd = qkv[..., self.head_dim : 2 * self.head_dim].permute(0, 2, 1, 3) + v_BHLd = qkv[..., 2 * self.head_dim :].permute(0, 2, 1, 3) + + # Slice the precomputed freqs_cis based on actual seq_len --> [1, 1, seq_len, R] + cos = freqs_cis[0][:, :, :seq_len, :] + sin = freqs_cis[1][:, :, :seq_len, :] + + q_rot = q_BHLd[..., : self.rotary_ndims] + q_pass = q_BHLd[..., self.rotary_ndims :] + k_rot = k_BHLd[..., : self.rotary_ndims] + k_pass = k_BHLd[..., self.rotary_ndims :] + + q_rot, k_rot = NeoXRoPE.apply_rotary_pos_emb(q_rot, k_rot, cos, sin) + + q_BHLd = torch.cat((q_rot, q_pass), dim=-1) + k_BHLd = torch.cat((k_rot, k_pass), dim=-1) + + # compute attention + attn_out_BHLd = scaled_dot_product_attention( + q_BHLd, + k_BHLd, + v_BHLd, + is_causal=True, + scale=self.norm_factor, + dropout_p=self.dropout_p, + ) + + attn_out_BLD = self._merge_heads(attn_out_BHLd) + + attn_out_BLD = self.dense(attn_out_BLD) + + return attn_out_BLD + + +class DecoderAttentionRotaryKVCache(DecoderAttentionRotary): + """implements the KV cache mechanism""" + + def __init__(self, config): + super().__init__(config) + + def forward(self, x_BLD, freqs_cis, causal_mask=None, past_kv=None): + # A) figure out if we passed a cached KV + assert freqs_cis is not None + has_past_kv = past_kv is not None and past_kv[0].numel() > 0 + + if not self.training: + self.dropout_p = 0 + + bsz, seq_len, _ = x_BLD.size() + # B) if we have a cached KV, apply the offset to the sequence length for RoPE + if has_past_kv: + offset = past_kv[0].shape[ + 2 + ] # we want the lenght here, our kv shape is BHLd + seq_len += offset + else: + offset = 0 + + # projections + qkv = self.query_key_value(x_BLD) + + # [batch, seq_len, (num_heads * 3 * head_size)] + # --> [batch, seq_len, num_heads, 3 * head_size] + new_qkv_shape = qkv.size()[:-1] + (self.num_heads, 3 * self.head_dim) + qkv = qkv.view(*new_qkv_shape) + + # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] + q_BHLd = qkv[..., : self.head_dim].permute(0, 2, 1, 3) + k_BHLd = qkv[..., self.head_dim : 2 * self.head_dim].permute(0, 2, 1, 3) + v_BHLd = qkv[..., 2 * self.head_dim :].permute(0, 2, 1, 3) + + # Slice the precomputed freqs_cis based on actual seq_len --> [1, 1, seq_len, R] + cos = freqs_cis[0][:, :, :seq_len, :] + sin = freqs_cis[1][:, :, :seq_len, :] + + q_rot = q_BHLd[..., : self.rotary_ndims] + q_pass = q_BHLd[..., self.rotary_ndims :] + k_rot = k_BHLd[..., : self.rotary_ndims] + k_pass = k_BHLd[..., self.rotary_ndims :] + + q_rot, k_rot = NeoXRoPE.apply_rotary_pos_emb_offset( + q_rot, k_rot, cos, sin, offset=offset + ) + + q_BHLd = torch.cat((q_rot, q_pass), dim=-1) + k_BHLd = torch.cat((k_rot, k_pass), dim=-1) + + # C) before scaled_dot_product_attention we are going to + # Cache QKV values + if has_past_kv: + past_key, past_value = past_kv + + k_BHLd = torch.cat((past_key.type_as(k_BHLd), k_BHLd), dim=2) + v_BHLd = torch.cat((past_value.type_as(v_BHLd), v_BHLd), dim=2) + + # kv_cache = torch.stack((k_BHLd, v_BHLd)) + kv_cache = (k_BHLd, v_BHLd) # let's keep them as a list of tuples + # #################################################################### + + # print(f"key device {k_BHLd.get_device()}") + # print(f"query device {q_BHLd.get_device()}") + # print(f"value device {v_BHLd.get_device()}") + # print(f"causal mask device {causal_mask.get_device()}") + + tensor_device = q_BHLd.get_device() + if causal_mask.get_device() != tensor_device: + causal_mask = causal_mask.to(device=tensor_device) + + # compute attention here + attn_out_BHLd = scaled_dot_product_attention( + q_BHLd, + k_BHLd, + v_BHLd, + is_causal=False, + attn_mask=causal_mask, + scale=self.norm_factor, + dropout_p=self.dropout_p, + ) # even with rectangular matrices scaled_dot_product_attention will handle the causal mask by apply left bias + # causal mask which is exactly what we need. + + attn_out_BLD = self._merge_heads(attn_out_BHLd) + + attn_out_BLD = self.dense(attn_out_BLD) + + return attn_out_BLD, kv_cache + + +# ^^^^^ ####################################################################### + + +class DecoderFeedForward(nn.Module): + """ + Tensor dimension names: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + + """ + + def __init__(self, config): + super().__init__() + self.ff_in = nn.Linear(config.embed_dim, config.ff_dim) + self.gelu = nn.GELU() + self.ff_out = nn.Linear(config.ff_dim, config.embed_dim) + + def forward(self, x_BLD): + x_BLF = self.ff_in(x_BLD) + x_BLF = self.gelu(x_BLF) + out_BLD = self.ff_out(x_BLF) + return out_BLD + + +class DecoderBlock(nn.Module): + """A single trasnformer decoder block/layer. + + Tensor dimension names: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + """ + + def __init__(self, config): + super().__init__() + self.attention = DecoderAttentionRotary(config) + self.feed_forward = DecoderFeedForward(config) + self.ffn_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) + self.attention_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) + + def forward(self, x_BLD, freqs_cis, parallel_residual=True, use_cache=True): + assert freqs_cis is not None + + if parallel_residual: + out_BLD = ( + x_BLD + + self.attention(self.attention_norm(x_BLD), freqs_cis) + + self.feed_forward(self.ffn_norm(x_BLD)) + ) + else: + h_BLD = x_BLD + self.attention(self.attention_norm(x_BLD), freqs_cis) + out_BLD = h_BLD + self.feed_forward(self.ffn_norm(h_BLD)) + + return out_BLD + + +# handle KV cache state +class DecoderBlockKVcache(DecoderBlock): + """A single trasnformer decoder block/layer. + + Tensor dimension names: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + """ + + def __init__(self, config): + super().__init__(config) + self.attention = DecoderAttentionRotaryKVCache(config) + self.feed_forward = DecoderFeedForward(config) + self.ffn_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) + self.attention_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) + + def forward( + self, x_BLD, freqs_cis, causal_mask, layer_kv_cache, parallel_residual=True + ): + assert freqs_cis is not None + + if parallel_residual: + attn_out_BLD, layer_kv_cache = self.attention( + self.attention_norm(x_BLD), freqs_cis, causal_mask, layer_kv_cache + ) + out_BLD = x_BLD + attn_out_BLD + self.feed_forward(self.ffn_norm(x_BLD)) + + else: + attn_out_BLD, layer_kv_cache = self.attention( + self.attention_norm(x_BLD), freqs_cis, layer_kv_cache + ) + h_BLD = x_BLD + attn_out_BLD + out_BLD = h_BLD + self.feed_forward(self.ffn_norm(h_BLD)) + + return out_BLD, layer_kv_cache + + +class DecoderWrapper(nn.Module): + """Full model wrapper for causal language modelling.""" + + def __init__(self, config): + super().__init__() + self.config = config + self.token_embeddings = DecoderEmbedding(config) + + if self.config.use_cache: + self.layers = nn.ModuleList( + DecoderBlockKVcache(config) for _ in range(config.num_blocks) + ) + else: + self.layers = nn.ModuleList( + DecoderBlock(config) for _ in range(config.num_blocks) + ) + + self.final_norm = nn.LayerNorm(config.embed_dim, config.layer_norm_eps) + self.output = nn.Linear(config.embed_dim, config.vocab_size, bias=False) + + self.freqs_cis: Optional[Tensor] = None + + self.max_batch_size = -1 + self.max_seq_length = config.max_pos_embedding + + self.rotary_pct = 0.25 + + self.vocab_size = config.vocab_size + + def setup_caches(self): + head_dim = self.config.embed_dim // self.config.num_heads + dtype = self.output.weight.dtype + + head_size = self.config.embed_dim // self.config.num_heads + rotary_ndims = int(head_size * self.rotary_pct) + self.cos, self.sin = NeoXRoPE.precompute_sin_cos_cache( + dim=rotary_ndims, + seq_len=self.config.max_pos_embedding, + ) + + self.freqs_cis = (self.cos, self.sin) + + def _generate_causal_mask(self, cache_length: int, new_length: int) -> torch.Tensor: + """ + Creates a causal mask for autoregressive attention with KV caching. + + Args: + cache_length (int): Number of cached tokens (K). + new_length (int): Number of new tokens to generate (T). + device (torch.device): The device on which to create the mask. + dtype (torch.dtype): The data type of the mask tensor. + + Returns: + torch.Tensor: A (K + T) x (K + T) mask tensor where: + - Cached tokens can attend to all cached and new tokens. + - New tokens can attend to all cached tokens and up to their current position in new tokens. + - Future new tokens are masked to prevent attention. + """ + + causal_mask = torch.tril(torch.ones(new_length, new_length)).bool() + + if cache_length > 0: + cache_tokens_mask = torch.ones((new_length, cache_length)).bool() + causal_mask = torch.cat((cache_tokens_mask, causal_mask), dim=1) + + return causal_mask + + def forward(self, batch: Tensor, past_kv_cache=None): + # if self.freqs_cis is None or self.causal_mask is None: + if self.freqs_cis is None: + self.setup_caches() # Caches must be initialized first + + freqs_cis = self.freqs_cis + + idx = batch["input_ids"] + num_new_tokens = idx.size(1) + + x = self.token_embeddings(idx) + + if self.config.use_cache and past_kv_cache is None: + past_kv_cache = [None] * self.config.num_blocks + + if past_kv_cache[0] is not None: + num_cache_tokens = past_kv_cache[0][0].size(2) + else: + num_cache_tokens = 0 + causal_mask = self._generate_causal_mask(num_cache_tokens, num_new_tokens) + + kv_cache_list = [] + + for i, layer in enumerate(self.layers): + if self.config.use_cache: + x, layer_kv_cache = layer(x, freqs_cis, causal_mask, past_kv_cache[i]) + kv_cache_list.append(layer_kv_cache) + else: + x = layer(x, freqs_cis) + x = self.final_norm(x) + logits = self.output(x) + + return logits, kv_cache_list + + +class ComposerLM(ComposerModel): + """wrapper with nice properties for simple training and evaluation""" + + def __init__(self, config): + "docstring" + super().__init__() + self.model = DecoderWrapper(config) + self.ce_loss = nn.CrossEntropyLoss() + self.train_metrics = MetricCollection( + [LossMetric(self.ce_loss), LanguagePerplexity()] + ) + self.eval_metrics = MetricCollection( + [LossMetric(self.ce_loss), LanguagePerplexity()] + ) + + def forward(self, batch): # batch is the output of the dataloader + """batch is a dict with "input_ids" key, model also takes past_kv""" + # specify how batches are passed through the model + return self.model(batch) + + def eval_forward(self, batch, outputs=False): + if outputs: + if type(outputs) is tuple: + outputs, _ = outputs + return outputs + + outputs = self.model(batch) + if type(outputs) is tuple: + outputs, _ = outputs + + return outputs + + def update_metric(self, batch, outputs, metric) -> None: + targets = batch["labels"] + metric.update(outputs.view(-1, self.model.vocab_size), targets.view(-1)) + + def get_metrics(self, is_train=False) -> dict[str, Metric]: + # defines which metrics to use in each phase of training + return self.train_metrics if is_train else self.eval_metrics + + def loss(self, outputs, batch): + targets = batch["labels"] + + if type(outputs) is tuple: + outputs, _ = outputs + + return self.ce_loss(outputs.view(-1, self.model.vocab_size), targets.view(-1)) diff --git a/rafale/models/decoding_strategies.py b/rafale/models/decoding_strategies.py new file mode 100644 index 0000000..8e11eb0 --- /dev/null +++ b/rafale/models/decoding_strategies.py @@ -0,0 +1,91 @@ +import torch +import torch.nn.functional as F + + +def repeat_ngram(input_ids, ngram_max_length=4): + """ + Clean up output by checking for repeated n-grams of length less than ngram_max_length. + If it finds a repeated n-gram, it returns the n-gram; otherwise, it returns 0. + + Args: + input_ids: Tensor of shape (seq_length,), the sequence of input token IDs. + ngram_max_length: The maximum length of the n-gram to check for repetition. + + Returns: + A list of tokens representing the repeated n-gram if found, otherwise 0. + """ + ngram_found = False + ngram_tokens = None + + # Create a set to store seen n-grams + seen_ngrams = set() + + print(input_ids) + # Iterate through possible n-gram lengths + for n in range(1, ngram_max_length + 1): + for i in range(len(input_ids) - n + 1): + ngram = tuple(input_ids[i : i + n].tolist()) + if ngram in seen_ngrams: + ngram_found = True + ngram_tokens = list(ngram) + break + seen_ngrams.add(ngram) + if ngram_found: + break + + if ngram_found: + return ngram_tokens + else: + return 0 + + +def greedy_decode(model, batch, max_length, eos_token_id, check_repeat_ngrams=True): + """ + Implements greedy decoding for the rafale transformer model. + + Args: + model: The decoder model (e.g., DecoderWrapper). + batch: Dictionary containing input_ids of shape (batch_size, seq_length), the input prompt. + max_length: The maximum length of the generated sequence. + eos_token_id: The ID of the end-of-sequence token. + + Returns: + Dictionary containing input_ids of shape (batch_size, max_length) with the generated tokens. + """ + batch_size = batch["input_ids"].size(0) + if batch_size != 1: + raise ValueError( + "greedy_decode currently only supports batch_size=1. Provided batch_size: {}".format( + batch_size + ) + ) + + input_seq_len = batch["input_ids"].size(1) + kv_cache_list = None + + # Generate tokens until max_length or eos_token is generated + for _ in range(max_length - input_seq_len): + # Forward pass through the model + outputs, kv_cache_list = model(batch, kv_cache_list) + logits = outputs[:, -1, :] # Get the logits for the last generated token + + # Greedily select the token with the highest probability + next_token = torch.argmax(logits, dim=-1).unsqueeze(-1) # Shape: (1, 1) + + # Append the predicted token to the generated sequence + batch["input_ids"] = torch.cat((batch["input_ids"], next_token), dim=1) + + # Check for repeated n-grams and stop if detected + if check_repeat_ngrams: + repeated_ngram = repeat_ngram( + batch["input_ids"].squeeze(), ngram_max_length=4 + ) + if repeated_ngram != 0: + print(repeated_ngram) + break + + # Check if the sequence has generated the eos_token_id + if next_token.item() == eos_token_id: + break + + return batch diff --git a/rafale/models/encoder.py b/rafale/models/encoder.py index cdabd85..4c7ee04 100644 --- a/rafale/models/encoder.py +++ b/rafale/models/encoder.py @@ -1,15 +1,34 @@ from dataclasses import dataclass + import torch import torch.nn.functional as F -# import xformers.ops as xops - from torch import nn from torch.nn.functional import scaled_dot_product_attention +############################################################################### +# SIMPLE BERT-like BUILDING BLOCKS # +############################################################################### + + +# @TODO :: Refactor, improve documentation and add tensor dimension keys for the names + + class Embedding(nn.Module): - """embeddings from word, absolute/fixe position (and token_type embedding?)""" + """Embeddings + + In addition to the word embedding, BERT uses learned absolute position embeddings. We also have token type embedding for one the BERT pretraining + objectives. + + Tensor dimension keys: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension + """ def __init__( self, @@ -67,98 +86,19 @@ def forward(self, input_ids, token_type_ids): return E -class MultiHeadAttention(nn.Module): - def __init__(self, n_heads, embed_dim, dropout_p=0.1, fast_attn=False): - "uses xformers memory effeicient attention" - super().__init__() - self.dropout_p = dropout_p - self.fast_attn = fast_attn - assert embed_dim % n_heads == 0 - # We assume d_v always equals d_k - self.head_dim = embed_dim // n_heads - self.n_heads = n_heads - self.embed_dim = embed_dim - self.all_head_size = n_heads * self.head_dim - - # @TODO get linear projections - self.query = nn.Linear(embed_dim, embed_dim) - self.key = nn.Linear(embed_dim, embed_dim) - self.value = nn.Linear(embed_dim, embed_dim) - - # output - self.out = nn.Linear(embed_dim, embed_dim) - - def forward(self, q, k, v): - # @HERE TODO - """ - so input q, k, v are essentially the same tensor since we're doing self attention but the idea here is that this - same implementation could be used for cross attention. - - in self-attention they have dimension [batch_size, sequence_length, embedding_dimension] so for bert that would - be like [4, 512, 768] for example. - - each attention head will handle part of the embedding dimensions (wow I didn't know that and don't fully - understand why...). So this is why we want to have embed_dim % n_head == 0. - - (1) we use view to reshape the tensor into shape [batch_size, seq_length, n_head, head_embed] --> .view(batch_size, - -1, self.num_heads, self.head_dim) - (2) then we transpose seq_length and n_head to parrellalize the computations during the attention computations - --> .transpose(1, 2) - - - ## Summary of Shape Changes - Input: [batch_size, seq_length, embed_dim] - Post Linear Layer: [batch_size, seq_length, embed_dim] (same shape, but transformed) - View for Heads: [batch_size, seq_length, num_heads, head_dim] - Transpose for Heads: [batch_size, num_heads, seq_length, head_dim] - - ## after having applied attention - We receive a tensor of shape [batch_size, num_heads, seq_length, head_dim] (same as before) - Now we want to get back to our original embedding and sequence shape so first we swap back num_head and - seq_length with --> .transpose(1,2) - Then we want to aggregate our head_dim to have our full embedding space back up together again with --> - .view(batch_size, -1, self.embed_dim) - and we get shape [batch_size, seq_length, embed_dim] at the end - - """ - - -# @TODO: separate the attention modules to more easily inspect tensors class EncoderSelfAttention(nn.Module): - """just MHA with the various implementations - so input q, k, v are essentially the same tensor since we're doing self attention but the idea here is that this - same implementation could be used for cross attention. - - in self-attention they have dimension [batch_size, sequence_length, embedding_dimension] so for bert that would - be like [4, 512, 768] for example. - - each attention head will handle part of the embedding dimensions (wow I didn't know that and don't fully - understand why...). So this is why we want to have embed_dim % n_head == 0. - - (1) we use view to reshape the tensor into shape [batch_size, seq_length, n_head, head_embed] --> .view(batch_size, - -1, self.num_heads, self.head_dim) - (2) then we transpose seq_length and n_head to parrellalize the computations during the attention computations - --> .transpose(1, 2) - - - ## Summary of Shape Changes - Input: [batch_size, seq_length, embed_dim] - Post Linear Layer: [batch_size, seq_length, embed_dim] (same shape, but transformed) - View for Heads: [batch_size, seq_length, num_heads, head_dim] - Transpose for Heads: [batch_size, num_heads, seq_length, head_dim] - - ## after having applied attention - We receive a tensor of shape [batch_size, num_heads, seq_length, head_dim] (same as before) - Now we want to get back to our original embedding and sequence shape so first we swap back num_head and - seq_length with --> .transpose(1,2) - Then we want to aggregate our head_dim to have our full embedding space back up together again with --> - .view(batch_size, -1, self.embed_dim) - and we get shape [batch_size, seq_length, embed_dim] at the end - + """Bidirectional multi-head self attention. + + Tensor dimension keys: + - B batch size + - L sequence length + - H number of attention heads + - D embedding dimension + - d attention head dimension D//H + - F feedforward dimension """ def __init__(self, n_heads, embed_dim, dropout_p=0.1, fast_attn=False): - "uses xformers memory effeicient attention" super().__init__() self.dropout_p = dropout_p self.fast_attn = fast_attn @@ -176,12 +116,10 @@ def __init__(self, n_heads, embed_dim, dropout_p=0.1, fast_attn=False): self.value = nn.Linear(embed_dim, embed_dim) def forward(self, q, k, v): - # @HERE TODO """""" batch_size = q.size(0) if not self.training: self.dropout_p = 0 - print("model not training, attention dropout is 0") # check transformation again here.... q = ( @@ -200,26 +138,14 @@ def forward(self, q, k, v): .transpose(1, 2) ) - # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html#torch.nn.functional.scaled_dot_product_attention - # check for flash-attn-2? optional - if self.fast_attn: - raise Exception("fast attention not implemented yet") - # attn_output = xops.memory_efficient_attention( - # q, - # k, - # v, - # p=self.dropout_p, - # ) - - else: - attn_output = scaled_dot_product_attention( - q, - k, - v, - dropout_p=self.dropout_p, - ) + attn_output = scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.dropout_p, + ) - # Concatenate heads and put through final linear layer + # concatenate heads and put through final linear layer attn_output = ( attn_output.transpose(1, 2) .contiguous() @@ -267,7 +193,7 @@ def __init__(self, embed_dim, eps=None, dropout_p=None): self.dropout = nn.Dropout(dropout_p) def forward(self, x, residual): - x = self.dropout(x) + x = self.dropout(x) # @TODO :: make sure this should be here... x = self.ln(x + residual) return x @@ -277,10 +203,6 @@ class EncoderBlock(nn.Module): def __init__(self, embed_dim, n_heads, ff_dim, eps=None, dropout_p=None): super().__init__() - # self.mha = MultiHeadAttention( - # n_heads=n_heads, embed_dim=embed_dim, dropout_p=dropout_p - # ) - self.attention = AttentionModule( n_heads=n_heads, embed_dim=embed_dim, dropout_p=dropout_p ) @@ -372,82 +294,6 @@ def forward(self, **kwargs): return x - def compute_decoupled_label_loss(self, logits, labels, permutations, tokenizer): - """ """ - batch_size = logits.size(0) - F.softmax(logits, dim=-1) - letters = list(map(chr, range(97, 123))) - - # Define the loss functions - # same used as adapet - - nn.BCELoss(reduction="none") - nn.BCEWithLogitsLoss(reduction="none") - - losses = [] - - for i in range(batch_size): - num_labels = permutations[i] - # get all the relevant ids (choices) from the sample - relevant_ids = tokenizer.convert_tokens_to_ids(letters[:num_labels]) - - # the token id of the positive label of the example - example_label = labels[i] - example_label_id = example_label[example_label != -100].item() - - if len(relevant_ids) == 1: - # caveat if example only has one choice... - example_bad_ids = [] - else: - relevant_ids.remove(example_label_id) - example_bad_ids = relevant_ids - - # get the logit predictions from the mask token - l = logits[i] - # l = probabilities[i] - mask_token_logits = l[labels[i] != -100] - mask_token_logits = torch.flatten(mask_token_logits) - - indices = [x for x in range(l.size(1)) if x not in relevant_ids] - non_choice_logits = torch.index_select( - mask_token_logits, 0, torch.tensor(indices) - ) - - # probability logits for the positive label - positive_prediction = mask_token_logits[example_label_id] - - negative_losses = [] - negative_label = torch.zeros([]) - # probability logits for the negative labels - for idx in example_bad_ids: - # negative_predictions.append(mask_token_logits[idx]) - negative_losses.append( - self.bcel_loss(mask_token_logits[idx], negative_label) - ) - - nulltoken_labels = torch.zeros(len(indices)) # device="cuda:0" - positive_labels = torch.ones([]) - - # mean of bad labels bcel - nulltoken_label_loss = torch.mean( - self.bcel_loss(non_choice_logits, nulltoken_labels) - ) - negative_label_loss = torch.sum(torch.stack(negative_losses)) - positive_label_loss = self.bcel_loss(positive_prediction, positive_labels) - - losses.append( - torch.sum( - torch.stack( - [ - positive_label_loss.view(1), - nulltoken_label_loss.view(1), - negative_label_loss.view(1), - ] - ) - ) - ) - return torch.mean(torch.stack(losses)) - def compute_loss(self, logits, labels): """ """ ce_loss = nn.CrossEntropyLoss(ignore_index=-100) @@ -460,37 +306,3 @@ def compute_loss(self, logits, labels): # Compute and return the loss return ce_loss(logits, labels) - - -# DELETE THIS -def get_tokens_from_logits(logits, tokenizer=None): - """ - return the prediced tokens for all of the inputs - """ - # Apply softmax to convert logits to probabilities - probabilities = F.softmax(logits, dim=-1) - - # Get the predicted token IDs - predicted_token_ids = torch.argmax(probabilities, dim=-1) - - predicted_tokens = [ - tokenizer.convert_ids_to_tokens(seq.numpy()) - for seq in torch.unbind(predicted_token_ids, dim=0) - ] - return predicted_tokens - - -@dataclass -class BertConfig: - embed_dim: int = 768 - vocab_size: int = 30522 # could usage would be to 30522 + num_extra_tokens - attention_dropout: float = 0.1 - hidden_dropout: float = 0.1 - num_heads: int = 12 - ff_dim: int = 3072 - max_pos_embedding: int = 512 - layer_norm_eps: float = 1e-12 - num_blocks: int = 12 - pad_token_id: int = 0 - num_token_type: int = 2 - fast_attention: bool = False # use xformers (todo: add FlashAttention2) diff --git a/rafale/models/model_utils.py b/rafale/models/model_utils.py new file mode 100644 index 0000000..04caed1 --- /dev/null +++ b/rafale/models/model_utils.py @@ -0,0 +1,15 @@ +def get_tokens_from_logits(logits, tokenizer=None): + """ + return the prediced tokens for all of the inputs + """ + # Apply softmax to convert logits to probabilities + probabilities = F.softmax(logits, dim=-1) + + # Get the predicted token IDs + predicted_token_ids = torch.argmax(probabilities, dim=-1) + + predicted_tokens = [ + tokenizer.convert_ids_to_tokens(seq.numpy()) + for seq in torch.unbind(predicted_token_ids, dim=0) + ] + return predicted_tokens diff --git a/rafale/test.yaml b/rafale/test.yaml deleted file mode 100644 index 690e14e..0000000 --- a/rafale/test.yaml +++ /dev/null @@ -1,9 +0,0 @@ -run: - name: "test" - -model: - name: "test-model" - bin: "./" - -data: - training_path: "training_path" diff --git a/setup.py b/setup.py deleted file mode 100644 index 49f3772..0000000 --- a/setup.py +++ /dev/null @@ -1,11 +0,0 @@ -from setuptools import setup - -setup( - name="rafale", - version="0.1", - description="simple transformer training lib", - url="https://github.com/maxrousseau/rafale", - author="maxime rousseau", - packages=["rafale"], - zip_safe=False, -) diff --git a/test/pythia_tinystories.yaml b/test/pythia_tinystories.yaml new file mode 100644 index 0000000..b9d4c04 --- /dev/null +++ b/test/pythia_tinystories.yaml @@ -0,0 +1,43 @@ +run: + name: "pythia14m-tinystories" # name of your experiment, used for checkpointing + seed: 42 + n_epochs: 1 + max_lr: 6e-04 + warmup_pct: 0.01 + schedule: "cosine-warmup" # linear, linear-warmup, cosine, cosine-warmup + optimizer: "AdamW" + eval_interval: "50ba" + clip_type: "norm" + clip_value: 1.0 + device_bs: "auto" + save_interval: "200ba" + +model: + config: "pythia14m" # config key + type: "decoder" + use_pretrained: True + +data: + pipeline: "tinystories_neox" # the preprocessing/tokenization pipeline + config: + name: "tinystories" + num_processes: 8 + tokenizer_name: "neox" + is_prepared: False + input_id_key: "input_ids" + train_batch_size: 1024 + eval_batch_size: 16 + shuffle_train: False + dataset_path: "~/code/data/TinyStories" + tokenizer_path: "EleutherAI/pythia-14m" + max_sequence_length: 512 + pad_token_id: -100 + pad_inputs: True + is_prepared: False + +logging: # @TODO :: not implemented + use_wandb: True + use_file: False + eval_interval: "10ba" + log_dir: "./run_logs" + checkpoint_dir: "./checkpoints" diff --git a/test/test.yaml b/test/test.yaml new file mode 100644 index 0000000..856f851 --- /dev/null +++ b/test/test.yaml @@ -0,0 +1,42 @@ +# we want data and model configurations to be in files rather than in yaml, leave training hyperparams to yaml config only + +run: + name: "test-ministories" # name of your experiment, used for checkpointing + seed: 42 + n_epochs: 1 + max_lr: 3e-04 + warmup_pct: 0.01 + schedule: "cosine-warmup" # linear, linear-warmup, cosine, cosine-warmup + optimizer: "AdamW" + eval_interval: "50ba" + clip_type: "norm" + clip_value: 1.0 + device_bs: "auto" + save_interval: "1ep" + +model: + config: "pythia14m" # config key + type: "decoder" + use_pretrained: True + +data: + pipeline: "tinystories_neox" # the preprocessing/tokenization pipeline + config: + name: "tinystories_testing" + num_processes: 1 + tokenizer_name: "neox" + is_prepared: False + input_id_key: "input_ids" + train_batch_size: 16 + eval_batch_size: 16 + shuffle_train: False + dataset_path: "~/code/data/micro_tinystories" + tokenizer_path: "EleutherAI/pythia-14m" + max_sequence_length: 128 + pad_token_id: -100 + pad_inputs: True + is_prepared: False + +logging: + use_wandb: True + use_file: False diff --git a/rafale/models/debug_encoders.py b/test/test_bert.py similarity index 98% rename from rafale/models/debug_encoders.py rename to test/test_bert.py index ce5b11b..d36b984 100644 --- a/rafale/models/debug_encoders.py +++ b/test/test_bert.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python import os from datapipe import WikiMLMPipe @@ -122,3 +123,7 @@ def test_bert(): atol=5e-4, rtol=5e-4, ) + + +if __name__ == "__main__": + main() diff --git a/test/test_pythia.py b/test/test_pythia.py new file mode 100644 index 0000000..cd1b422 --- /dev/null +++ b/test/test_pythia.py @@ -0,0 +1,255 @@ +import torch +import numpy as np + +from safetensors import safe_open + +from rafale.models.decoder import DecoderWrapper +from rafale.models.configurations import Pythia14MConfig, load_safetensors + +from transformers import AutoTokenizer, GPTNeoXForCausalLM + + +def test_layer_and_outputs(rafale_model, hf_model, tokenizer, layer=0, tol=1e-05): + """ + # @NOTE :: this currently on evaluates with KV cache enabled, write the test to run this function without the KV cache + """ + hf_activation = {} + hf_input_activation = {} + rafale_activation = {} + rafale_input_activation = {} + + # tuple of shape num_layers, 2 (keys, values), tensor BHLd + # make a fake kv-cache of length 4 + kv_cache = [] + n_layers = 6 + cache_len = 4 + n_heads = 4 + head_dim = 32 + for i in range(n_layers): + k = torch.randn(1, n_heads, cache_len, 32) + v = torch.randn(1, n_heads, cache_len, 32) + kv_cache.append((k, v)) + + def get_hf_activation(name): + def hook(model, input, output): + hf_activation[name] = output.detach() + + return hook + + def get_hf_input_activation(name): + def hook(model, _input, output): + hf_input_activation[name] = _input[0].detach() + + return hook + + def get_rafale_activation(name): + def hook(model, input, output): + rafale_activation[name] = output.detach() + + return hook + + def get_rafale_input_activation(name): + def hook(model, _input, output): + rafale_input_activation[name] = _input[0].detach() + + return hook + + # embeddings + rafale_model.token_embeddings.register_forward_hook( + get_rafale_activation("input_embeddings") + ) + hf_model.gpt_neox.embed_in.register_forward_hook( + get_hf_activation("input_embeddings") + ) + + # input layernorm + rafale_model.layers[layer].attention_norm.register_forward_hook( + get_rafale_activation("attn_norm") + ) + hf_model.gpt_neox.layers[layer].input_layernorm.register_forward_hook( + get_hf_activation("attn_norm") + ) + + rafale_model.layers[layer].attention_norm.register_forward_hook( + get_rafale_input_activation("input_attn_norm") + ) + hf_model.gpt_neox.layers[layer].input_layernorm.register_forward_hook( + get_hf_input_activation("input_attn_norm") + ) + + # attention projection query_key_values (pre RoPE) + rafale_model.layers[layer].attention.query_key_value.register_forward_hook( + get_rafale_activation("attn_inproj") + ) + + hf_model.gpt_neox.layers[layer].attention.query_key_value.register_forward_hook( + get_hf_activation("attn_inproj") + ) + + # out proj + rafale_model.layers[layer].attention.dense.register_forward_hook( + get_rafale_activation("attn_dense") + ) + + hf_model.gpt_neox.layers[layer].attention.dense.register_forward_hook( + get_hf_activation("attn_dense") + ) + + # INPUT check before attention dense* (if this fails then RoPE is probably the problem...) + rafale_model.layers[layer].attention.dense.register_forward_hook( + get_rafale_input_activation("attn_dense") + ) + hf_model.gpt_neox.layers[layer].attention.dense.register_forward_hook( + get_hf_input_activation("attn_dense") + ) + + # feed forward out + rafale_model.layers[layer].feed_forward.ff_out.register_forward_hook( + get_rafale_activation("ffout") + ) + hf_model.gpt_neox.layers[layer].mlp.dense_4h_to_h.register_forward_hook( + get_hf_activation("ffout") + ) + + # hf_model.gpt_neox.layers[0].attention(tensor) + + input_str = "Hello World from pythia!" + tokens = tokenizer(input_str, return_tensors="pt") + + hf_model.eval() + rafale_model.eval() + + with torch.no_grad(): + # hf_out = hf_model(tokens["input_ids"])["logits"].detach().numpy() + + hf_out = hf_model(tokens["input_ids"], use_cache=True, past_key_values=kv_cache) + hf_out = hf_out["logits"].detach().numpy() + + # rafale_out = rafale_model(tokens)[0].detach().numpy() + + rafale_out = rafale_model(tokens, past_kv_cache=kv_cache)[0].detach().numpy() + + print(f"Dropout p should be 0: {rafale_model.layers[layer].attention.dropout_p}") + print(f"Testing layer {layer}") + # EMBEDDING ################################################################### + try: + np.testing.assert_allclose( + rafale_activation["input_embeddings"].numpy(), + hf_activation["input_embeddings"].numpy(), + rtol=tol, + atol=tol, + ) + + print(f"✅ embeddings OK!") + except: + print("⚠️ Embedding difference!") + + # PRE-ATTENTION NORM ###################################################### + try: + np.testing.assert_allclose( + rafale_input_activation["input_attn_norm"].numpy(), + hf_input_activation["input_attn_norm"].numpy(), + rtol=tol, + atol=tol, + ) + print(f"✅ INPUTS of pre-attention norm OK!") + except: + print("⚠️ INPUTS pre-attention norm difference") + + try: + np.testing.assert_allclose( + rafale_activation["attn_norm"].numpy(), + hf_activation["attn_norm"].numpy(), + rtol=tol, + atol=tol, + ) + print(f"✅ pre-attention norm OK!") + except: + print("⚠️ pre-attention norm difference") + + # LINEAR PROJECTION FOR ATTENTION ###################################################### + + try: + np.testing.assert_allclose( + rafale_activation["attn_inproj"].numpy(), + hf_activation["attn_inproj"].numpy(), + rtol=tol, + atol=tol, + ) + print(f"✅ attention in-projection OK!") + + except: + print(f"⚠️ attention in-projection difference") + + # INPUTS OF ATTENTION DENSE LAYER ######################################## + try: + np.testing.assert_allclose( + rafale_input_activation["attn_dense"].numpy(), + hf_input_activation["attn_dense"].numpy(), + rtol=tol, + atol=tol, + ) + print(f"✅ inputs of attention dense OK") + + except: + r = rafale_input_activation["attn_dense"].numpy() + h = hf_input_activation["attn_dense"].numpy() + + print(r.shape) + print(h.shape) + print("⚠️ inputs of attention dense difference") + + np.testing.assert_allclose( + rafale_input_activation["attn_dense"].numpy(), + hf_input_activation["attn_dense"].numpy(), + rtol=tol, + atol=tol, + ) + + # OUTPUT OF ATTENTION DENSE LAYER ######################################## + try: + np.testing.assert_allclose( + rafale_activation["attn_dense"].numpy(), + hf_activation["attn_dense"].numpy(), + rtol=tol, + atol=tol, + ) + print(f"✅ attention out dense OK!") + except: + print("⚠️ attention out dense difference") + + try: + np.testing.assert_allclose( + rafale_activation["ffout"].numpy(), + hf_activation["ffout"].numpy(), + rtol=tol, + atol=tol, + ) + print(f"✅ feedforward out dense OK!") + except: + print("⚠️ ff out dense difference") + + # Final model output ######################################## + try: + np.testing.assert_allclose(rafale_out, hf_out, rtol=tol, atol=tol) + print(f"🎉 Model outputs match reference implementation!") + except: + print("❌ Model outputs do not match") + + +def main(): + """ """ + torch.manual_seed(0) + np.random.seed(0) + + hf_pythia = GPTNeoXForCausalLM.from_pretrained("EleutherAI/pythia-14m") + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-14m") + + rafale_pythia = DecoderWrapper(Pythia14MConfig) # check forward pass OK + rafale_pythia = load_safetensors(rafale_pythia, Pythia14MConfig) + + test_layer_and_outputs(rafale_pythia, hf_pythia, tokenizer) + + +if __name__ == "__main__": + main() diff --git a/test/test_pythia_generation.py b/test/test_pythia_generation.py new file mode 100644 index 0000000..572b218 --- /dev/null +++ b/test/test_pythia_generation.py @@ -0,0 +1,37 @@ +import torch +import numpy as np + +# from safetensors import safe_open +# from tokenizers import Tokenizer + +# from rafale.datapipe import InferenceDatapipeline +# from rafale.models.decoding_strategies import greedy_decode +# from rafale.models.decoder import DecoderWrapper +# from rafale.models.configurations import Pythia14MConfig, load_safetensors + + +# Example usage +# Initialize the rafale_pythia model +rafale_pythia = DecoderWrapper(Pythia14MConfig) +rafale_pythia = load_safetensors(rafale_pythia, Pythia14MConfig) +ifdp = InferenceDatapipeline("EleutherAI/pythia-14m") +test_str = "Once upon a time," + +# Define input_ids (e.g., starting with a token) +# input_ids = torch.tensor([[Pythia14MConfig.bos_token_id]]) # Shape: (1, 1) + +# Define maximum sequence length for generation +max_length = 32 + +# Generate sequence using greedy decoding +generated_sequence = greedy_decode( + rafale_pythia, + ifdp(test_str), + max_length, + Pythia14MConfig.eos_token_id, + check_repeat_ngrams=True, +) + +generated_str = ifdp.ids_to_str(generated_sequence["input_ids"]) + +print("Generated sequence:", generated_str)