Skip to content

Commit

Permalink
training loop
Browse files Browse the repository at this point in the history
  • Loading branch information
lakshith-403 committed Jul 29, 2024
1 parent 23b7e2e commit 0f2a9be
Showing 1 changed file with 162 additions and 0 deletions.
162 changes: 162 additions & 0 deletions labml_nn/transformers/LoRA/train.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
{
"cells": [
{
"cell_type": "code",
"id": "initial_id",
"metadata": {
"collapsed": true
},
"source": "# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"with open('input.txt', 'r', encoding='utf-8') as f:\n",
" text = f.read()"
],
"id": "3b1e507015ba6b81",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"gpt2\")\n",
"\n",
"tokens = tokenizer.encode(text, add_special_tokens=False)"
],
"id": "ac8e51ae5bbfcae7",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"context_length = 10\n",
"batch_size = 64"
],
"id": "aeefcdf813e427e",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"num_batches = len(tokens) // (batch_size * context_length)\n",
"tokens = tokens[:num_batches * batch_size * context_length]"
],
"id": "a384b42274f008a2",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"import torch\n",
"\n",
"input_ids = torch.tensor(tokens).view(-1, context_length)"
],
"id": "5c4cc78ac1a02c1d",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"from torch.utils.data import DataLoader, TensorDataset\n",
"from torch.optim import Adam\n",
"print(input_ids.shape)\n",
"dataset = TensorDataset(input_ids)\n",
"dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)"
],
"id": "7037fd75e2161382",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"from labml_nn.transformers.LoRA.GPT2 import GPTModel\n",
"\n",
"model = GPTModel()"
],
"id": "a98b7baa064b8494",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": [
"optimizer = Adam(model.parameters(), lr=5e-5)\n",
"criterion = torch.nn.CrossEntropyLoss()\n",
"\n",
"model.eval()\n",
"epochs = 3\n",
"for epoch in range(epochs):\n",
" for batch in dataloader:\n",
" inputs = batch[0]\n",
" labels = inputs.clone()\n",
" \n",
" outputs = model(inputs)\n",
" \n",
" shift_logits = outputs[..., :-1, :]\n",
" shift_labels = labels[..., 1:]\n",
" \n",
" loss = criterion(shift_logits.reshape(-1, shift_logits.size(-1)), shift_labels.reshape(-1))\n",
" \n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" print(f'Epoch: {epoch + 1}, Loss: {loss.item()}')\n",
" break\n",
"\n",
"print(\"Training complete.\")"
],
"id": "e2f5076894770740",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "code",
"source": "",
"id": "da2d4023002648dc",
"outputs": [],
"execution_count": null
}
],
"metadata": {
"kernelspec": {
"display_name": "Python (ml)",
"language": "python",
"name": "ml"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 0f2a9be

Please sign in to comment.