Skip to content

Commit

Permalink
Add LoRA to GPT2
Browse files Browse the repository at this point in the history
  • Loading branch information
lakshith committed Jul 31, 2024
1 parent 0f2a9be commit 77d00f0
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 84 deletions.
20 changes: 11 additions & 9 deletions labml_nn/transformers/LoRA/GPT2.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import torch.nn as nn
from transformers import AutoTokenizer
from labml_nn.transformers.LoRA import Linear, Embedding

tokenizer = AutoTokenizer.from_pretrained("gpt2")

Expand All @@ -10,15 +11,16 @@
"n_head": 12,
"n_layer": 12,
"n_positions": 1024,
"vocab_size": 50257
"vocab_size": 50257,
"device": "cuda"
}


class FFN(nn.Module):
def __init__(self, dim):
super().__init__()
self.c_fc = nn.Linear(config['n_embd'], dim)
self.c_proj = nn.Linear(dim, config['n_embd'])
self.c_fc = Linear(config['n_embd'], dim, r=32, bias=True)
self.c_proj = Linear(dim, config['n_embd'], r=32, bias=True)
self.act = nn.functional.gelu

def forward(self, hidden_states):
Expand All @@ -36,8 +38,8 @@ def __init__(self):
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim

self.c_att = nn.Linear(config['n_embd'], config['n_embd'] * 3)
self.c_proj = nn.Linear(config['n_embd'], config['n_embd'])
self.c_att = Linear(config['n_embd'], config['n_embd'] * 3, r=32, bias=True)
self.c_proj = Linear(config['n_embd'], config['n_embd'], r=32, bias=True)

def _split_heads(self, tensor, num_heads, attn_head_size):
"""
Expand Down Expand Up @@ -100,20 +102,20 @@ class GPTModel(nn.Module):
def __init__(self):
super().__init__()

self.token_embedding = nn.Embedding(config['vocab_size'], config['n_embd'])
self.position_embedding = nn.Embedding(config['n_positions'], config['n_embd'])
self.token_embedding = Embedding(config['vocab_size'], config['n_embd'], r=32)
self.position_embedding = Embedding(config['n_positions'], config['n_embd'], r=32)

self.blocks = nn.ModuleList([Block() for _ in range(config['n_layer'])])

self.final_norm = nn.LayerNorm(config['n_embd'], eps=config['layer_norm_epsilon'])

self.lm_head = nn.Linear(config['n_embd'], config['vocab_size'], bias=False)
self.lm_head = Linear(config['n_embd'], config['vocab_size'], r=32, bias=False)

def forward(self, input_ids):
batch_size, input_shape = input_ids.size()

token_embeddings = self.token_embedding(input_ids) # B T C
position_ids = torch.arange(input_shape) # T C
position_ids = torch.arange(input_shape, device=config['device']) # T C
position_embeddings = self.position_embedding(position_ids) # B T C

hidden_states = token_embeddings + position_embeddings
Expand Down
2 changes: 1 addition & 1 deletion labml_nn/transformers/LoRA/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(
self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim)))
self.weight.requires_grad = False

self.scaling = alpha / self.r
self.scaling = alpha / r
self.lora_a = nn.Parameter(torch.empty((num_embeddings, r)))
self.lora_b = nn.Parameter(torch.empty((r, embedding_dim)))

Expand Down
55 changes: 39 additions & 16 deletions labml_nn/transformers/LoRA/experiment.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-29T07:14:27.781097Z",
"start_time": "2024-07-29T07:14:24.819976Z"
"end_time": "2024-07-31T12:22:57.496965Z",
"start_time": "2024-07-31T12:22:55.151730Z"
}
},
"cell_type": "code",
Expand All @@ -19,8 +19,8 @@
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-29T07:14:28.183960Z",
"start_time": "2024-07-29T07:14:27.782683Z"
"end_time": "2024-07-31T12:22:57.986397Z",
"start_time": "2024-07-31T12:22:57.498305Z"
}
},
"cell_type": "code",
Expand All @@ -39,8 +39,8 @@
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-07-29T07:14:29.840925Z",
"start_time": "2024-07-29T07:14:28.185080Z"
"end_time": "2024-07-31T12:22:58.562136Z",
"start_time": "2024-07-31T12:22:57.987296Z"
}
},
"source": [
Expand All @@ -54,20 +54,38 @@
"if unexpected_keys:\n",
" print(f\"Unexpected keys: {unexpected_keys}\")"
],
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_7130/2581223434.py:3: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
" state_dict = torch.load('transformed.pth')\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Missing keys: ['token_embedding.lora_a', 'token_embedding.lora_b', 'position_embedding.lora_a', 'position_embedding.lora_b', 'blocks.0.attn.c_att.lora_a', 'blocks.0.attn.c_att.lora_b', 'blocks.0.attn.c_proj.lora_a', 'blocks.0.attn.c_proj.lora_b', 'blocks.0.ffn.c_fc.lora_a', 'blocks.0.ffn.c_fc.lora_b', 'blocks.0.ffn.c_proj.lora_a', 'blocks.0.ffn.c_proj.lora_b', 'blocks.1.attn.c_att.lora_a', 'blocks.1.attn.c_att.lora_b', 'blocks.1.attn.c_proj.lora_a', 'blocks.1.attn.c_proj.lora_b', 'blocks.1.ffn.c_fc.lora_a', 'blocks.1.ffn.c_fc.lora_b', 'blocks.1.ffn.c_proj.lora_a', 'blocks.1.ffn.c_proj.lora_b', 'blocks.2.attn.c_att.lora_a', 'blocks.2.attn.c_att.lora_b', 'blocks.2.attn.c_proj.lora_a', 'blocks.2.attn.c_proj.lora_b', 'blocks.2.ffn.c_fc.lora_a', 'blocks.2.ffn.c_fc.lora_b', 'blocks.2.ffn.c_proj.lora_a', 'blocks.2.ffn.c_proj.lora_b', 'blocks.3.attn.c_att.lora_a', 'blocks.3.attn.c_att.lora_b', 'blocks.3.attn.c_proj.lora_a', 'blocks.3.attn.c_proj.lora_b', 'blocks.3.ffn.c_fc.lora_a', 'blocks.3.ffn.c_fc.lora_b', 'blocks.3.ffn.c_proj.lora_a', 'blocks.3.ffn.c_proj.lora_b', 'blocks.4.attn.c_att.lora_a', 'blocks.4.attn.c_att.lora_b', 'blocks.4.attn.c_proj.lora_a', 'blocks.4.attn.c_proj.lora_b', 'blocks.4.ffn.c_fc.lora_a', 'blocks.4.ffn.c_fc.lora_b', 'blocks.4.ffn.c_proj.lora_a', 'blocks.4.ffn.c_proj.lora_b', 'blocks.5.attn.c_att.lora_a', 'blocks.5.attn.c_att.lora_b', 'blocks.5.attn.c_proj.lora_a', 'blocks.5.attn.c_proj.lora_b', 'blocks.5.ffn.c_fc.lora_a', 'blocks.5.ffn.c_fc.lora_b', 'blocks.5.ffn.c_proj.lora_a', 'blocks.5.ffn.c_proj.lora_b', 'blocks.6.attn.c_att.lora_a', 'blocks.6.attn.c_att.lora_b', 'blocks.6.attn.c_proj.lora_a', 'blocks.6.attn.c_proj.lora_b', 'blocks.6.ffn.c_fc.lora_a', 'blocks.6.ffn.c_fc.lora_b', 'blocks.6.ffn.c_proj.lora_a', 'blocks.6.ffn.c_proj.lora_b', 'blocks.7.attn.c_att.lora_a', 'blocks.7.attn.c_att.lora_b', 'blocks.7.attn.c_proj.lora_a', 'blocks.7.attn.c_proj.lora_b', 'blocks.7.ffn.c_fc.lora_a', 'blocks.7.ffn.c_fc.lora_b', 'blocks.7.ffn.c_proj.lora_a', 'blocks.7.ffn.c_proj.lora_b', 'blocks.8.attn.c_att.lora_a', 'blocks.8.attn.c_att.lora_b', 'blocks.8.attn.c_proj.lora_a', 'blocks.8.attn.c_proj.lora_b', 'blocks.8.ffn.c_fc.lora_a', 'blocks.8.ffn.c_fc.lora_b', 'blocks.8.ffn.c_proj.lora_a', 'blocks.8.ffn.c_proj.lora_b', 'blocks.9.attn.c_att.lora_a', 'blocks.9.attn.c_att.lora_b', 'blocks.9.attn.c_proj.lora_a', 'blocks.9.attn.c_proj.lora_b', 'blocks.9.ffn.c_fc.lora_a', 'blocks.9.ffn.c_fc.lora_b', 'blocks.9.ffn.c_proj.lora_a', 'blocks.9.ffn.c_proj.lora_b', 'blocks.10.attn.c_att.lora_a', 'blocks.10.attn.c_att.lora_b', 'blocks.10.attn.c_proj.lora_a', 'blocks.10.attn.c_proj.lora_b', 'blocks.10.ffn.c_fc.lora_a', 'blocks.10.ffn.c_fc.lora_b', 'blocks.10.ffn.c_proj.lora_a', 'blocks.10.ffn.c_proj.lora_b', 'blocks.11.attn.c_att.lora_a', 'blocks.11.attn.c_att.lora_b', 'blocks.11.attn.c_proj.lora_a', 'blocks.11.attn.c_proj.lora_b', 'blocks.11.ffn.c_fc.lora_a', 'blocks.11.ffn.c_fc.lora_b', 'blocks.11.ffn.c_proj.lora_a', 'blocks.11.ffn.c_proj.lora_b', 'lm_head.lora_a', 'lm_head.lora_b']\n"
]
}
],
"execution_count": 3
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-29T07:22:30.408855Z",
"start_time": "2024-07-29T07:22:30.168376Z"
"end_time": "2024-07-31T12:23:00.447976Z",
"start_time": "2024-07-31T12:22:58.566527Z"
}
},
"cell_type": "code",
"source": [
"prompt = \"hello how are you\"\n",
"tokenized = tokenizer(prompt, return_tensors=\"pt\")\n",
"tokenized['input_ids'] = tokenized['input_ids'].to('cuda')\n",
"model = model.to('cuda')\n",
"\n",
"with torch.no_grad():\n",
" model.eval()\n",
Expand All @@ -90,22 +108,27 @@
]
}
],
"execution_count": 17
"execution_count": 4
},
{
"metadata": {},
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-31T12:23:00.452060Z",
"start_time": "2024-07-31T12:23:00.448904Z"
}
},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "",
"id": "c12776360008a974"
"id": "c12776360008a974",
"outputs": [],
"execution_count": 4
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (ml)",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "ml"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down
Loading

0 comments on commit 77d00f0

Please sign in to comment.