Skip to content

Commit

Permalink
feat(training): added Gemma training example notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
tengomucho committed May 28, 2024
1 parent 0f74cf3 commit 6ab8651
Showing 1 changed file with 371 additions and 0 deletions.
371 changes: 371 additions & 0 deletions examples/language-modeling/gemma_tuning.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,371 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "c7227736-2685-4971-9402-d6015b5319b5",
"metadata": {},
"source": [
"# Fine-Tune Gemma on Google TPU\n",
"\n",
"This tutorial will teach how to fine-tune open LLMs like [Google Gemma](https://huggingface.co/google/gemma-7b) on Google Cloud's TPUs. In our example, we are going to leverage Hugging Face Optimum TPU, [🤗 Transformers](https://huggingface.co/docs/transformers/index) and datasets.\n",
"\n",
"### Google's TPU\n",
"\n",
"Google Cloud TPUs are custom-designed AI accelerators, which are optimized for training and inference of large AI models. They are ideal for a variety of use cases, such as chatbots, code generation, media content generation, synthetic speech, vision services, recommendation engines, personalization models, among others.\n",
"\n",
"Advantages of using TPUs include:\n",
"\n",
"* Designed to scale cost-efficiently for a wide range of AI workloads, spanning training, fine-tuning, and inference.\n",
"* Optimized for TensorFlow, PyTorch, and JAX, and are available in a variety of form factors, including edge devices, workstations, and cloud-based infrastructure.\n",
"* TPUs are available in Google Cloud, and have been integrated with Vertex AI, and Google Kubernetes Engine (GKE).\n",
"\n",
"### Environment Setup\n",
"\n",
"For this example, a single-host `v5elitepod8` TPU will be enough. To set up a TPU environment with Pytorch XLA, you can check this [Google Cloud guide](https://cloud.google.com/tpu/docs/run-calculation-pytorch) that shows how to do that.\n",
"\n",
"You can use `ssh` or `gcloud` commands to log in to the remote TPU. Enable port-forwarding for the port `8888`, e.g.:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "31227770",
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"gcloud compute tpus tpu-vm ssh $TPU_NAME \\\n",
" --zone=$ZONE \\\n",
" -- -L 8888:localhost:8888"
]
},
{
"cell_type": "markdown",
"id": "79fda238",
"metadata": {},
"source": [
"Once you have access to the TPU VM, you can clone the `optimum-tpu` repository containing the related notebook. Then you can install few packages used in this tutorial and launch the notebook:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "522b5a36",
"metadata": {
"vscode": {
"languageId": "shellscript"
}
},
"outputs": [],
"source": [
"git clone https://github.com/huggingface/optimum-tpu.git\n",
"# Install Optimum tpu\n",
"pip install -e . -f https://storage.googleapis.com/libtpu-releases/index.html\n",
"# Install TRL and PEFT for training (see later how they are used)\n",
"pip install trl peft\n",
"# Install Jupyter notebook\n",
"pip install -U jupyterlab notebook\n",
"# Optionally, install widgets extensions for better rendering\n",
"pip install ipywidgets widgetsnbextension\n",
"# Change directory and launch Jupyter notebook\n",
"cd optimum-tpu/examples/language-modeling\n",
"jupyter notebook --port 8888"
]
},
{
"cell_type": "markdown",
"id": "0927763f",
"metadata": {},
"source": [
"You should see the familiar Jupyter output where you can access from your browser:\n",
"\n",
"```\n",
"http://localhost:8888/tree?token=3ceb24619d0a2f99acf5fba41c51b475b1ddce7cadb2a133\n",
"```\n",
"\n",
"Since we are going to use the gated `gemma` model, you will need to log in to your [Hugging Face token](https://huggingface.co/settings/tokens):"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "37bccce7-1ce4-4470-9e81-c15b120ef294",
"metadata": {},
"outputs": [],
"source": [
"!huggingface-cli login --token YOUR_HF_TOKEN"
]
},
{
"cell_type": "markdown",
"id": "4cc6659c",
"metadata": {},
"source": [
"### Enable FSDPv2\n",
"\n",
"To fine-tune an LLM, it might be necessary to shard the model across the TPUs to prevent memory issues and enhance tuning performances. Fully Sharded Data Parallel is an algorithm that has been implemented on Pytorch and that allows to wrap modules to distribute them.\n",
"When using Pytorch/XLA on TPUs, [FSDPv2](https://pytorch.org/xla/master/#fully-sharded-data-parallel-via-spmd) is an utility that re-expresses the famous FSDP algorithm using SPMD (Single Program Multiple Data). In `optimum-tpu` it is possible to use dedicated helpers to use FSPDv2. To enable it, you can use the dedicated function, that should be called at the beginning of the execution:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6d3c7bc2",
"metadata": {},
"outputs": [],
"source": [
"from optimum.tpu import fsdp_v2\n",
"fsdp_v2.use_fsdp_v2()"
]
},
{
"cell_type": "markdown",
"id": "bf78b865",
"metadata": {},
"source": [
"### Load and Prepare Dataset\n",
"\n",
"We will use [Dolly](https://huggingface.co/datasets/databricks/databricks-dolly-15k), an open source dataset of instruction-following records on categories outlined in the [InstructGPT](https://arxiv.org/abs/2203.02155) paper, including brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization.\n",
"\n",
"We will load the dataset from the hub:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f0196b5d",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"dataset = load_dataset(\"databricks/databricks-dolly-15k\", split=\"train\")"
]
},
{
"cell_type": "markdown",
"id": "4a2d599e",
"metadata": {},
"source": [
"We can take a look to a sample:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12409299",
"metadata": {},
"outputs": [],
"source": [
"dataset[321]"
]
},
{
"cell_type": "markdown",
"id": "9dc05649",
"metadata": {},
"source": [
"You will obtain a result similar to this:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9c24e0b1",
"metadata": {},
"outputs": [],
"source": [
"{'instruction': 'When was the 8088 processor released?',\n",
" 'context': 'The 8086 (also called iAPX 86) is a 16-bit microprocessor chip designed by Intel between early 1976 and June 8, 1978, when it was released. The Intel 8088, released July 1, 1979, is a slightly modified chip with an external 8-bit data bus (allowing the use of cheaper and fewer supporting ICs),[note 1] and is notable as the processor used in the original IBM PC design.',\n",
" 'response': 'The Intel 8088 processor was released July 1, 1979.',\n",
" 'category': 'information_extraction'}"
]
},
{
"cell_type": "markdown",
"id": "842badf0",
"metadata": {},
"source": [
"We will define a formatting function that combines `instruction`, `context` and `response` fields, and tokenizes them in a complete prompt. We will use a tokenizer compatible with the model we intend to use."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1497e0f",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"model_id = \"google/gemma-7b\"\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id)\n",
"\n",
"def preprocess_function(sample):\n",
" instruction = f\"### Instruction\\n{sample['instruction']}\"\n",
" context = f\"### Context\\n{sample['context']}\" if len(sample[\"context\"]) > 0 else None\n",
" response = f\"### Answer\\n{sample['response']}\"\n",
" # join all the parts together\n",
" prompt = \"\\n\\n\".join([i for i in [instruction, context, response] if i is not None])\n",
" prompt += tokenizer.eos_token\n",
" sample[\"prompt\"] = prompt\n",
" return sample"
]
},
{
"cell_type": "markdown",
"id": "4d1bde72",
"metadata": {},
"source": [
"It is now possible to use this function to map the dataset, where original columns can now be removed:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "16b44a9b",
"metadata": {},
"outputs": [],
"source": [
"data = dataset.map(preprocess_function, remove_columns=list(dataset.features))"
]
},
{
"cell_type": "markdown",
"id": "90c26a40",
"metadata": {},
"source": [
"### Preparing the Model for Tuning\n",
"\n",
"We can now load the model that will be used for tuning. The dataset is now ready to be used for fine-tuning:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f18472ce",
"metadata": {},
"outputs": [],
"source": [
"from optimum.tpu import AutoModelForCausalLM\n",
"model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False)"
]
},
{
"cell_type": "markdown",
"id": "a6795dc6",
"metadata": {},
"source": [
"We're now going to use [Parameter Efficient FineTuning PEFT](https://huggingface.co/blog/peft) and [Low-Rank Adaptation (LoRA)](https://huggingface.co/papers/2106.09685) to efficiently fine tune the model on the prepared dataset. In the `LoraConfig` instance we will define the `nn.Linear` operations that will be fine tuned."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a01f651",
"metadata": {},
"outputs": [],
"source": [
"from peft import LoraConfig\n",
"\n",
"# Set up PEFT LoRA for fine-tuning.\n",
"lora_config = LoraConfig(\n",
" r=8,\n",
" target_modules=[\"k_proj\", \"v_proj\"],\n",
" task_type=\"CAUSAL_LM\",\n",
")"
]
},
{
"cell_type": "markdown",
"id": "71243244",
"metadata": {},
"source": [
"The `optimum-tpu` dedicated function will help you obtain arguments so you can create your trainer instance."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "780f1033",
"metadata": {},
"outputs": [],
"source": [
"from trl import SFTTrainer\n",
"from transformers import TrainingArguments\n",
"\n",
"# Set up the FSDP arguments\n",
"fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)\n",
"\n",
"# Set up the trainer\n",
"trainer = SFTTrainer(\n",
" model=model,\n",
" train_dataset=data,\n",
" args=TrainingArguments(\n",
" per_device_train_batch_size=64,\n",
" num_train_epochs=32,\n",
" max_steps=-1,\n",
" output_dir=\"./output\",\n",
" optim=\"adafactor\",\n",
" logging_steps=1,\n",
" dataloader_drop_last = True, # Required for FSDPv2.\n",
" **fsdp_training_args,\n",
" ),\n",
" peft_config=lora_config,\n",
" dataset_text_field=\"prompt\",\n",
" max_seq_length=1024,\n",
" packing=True,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "6ea98f82",
"metadata": {},
"source": [
"Once everything is ready it tuning your model is as simple as calling a function!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4c437a81",
"metadata": {},
"outputs": [],
"source": [
"trainer.train()"
]
},
{
"cell_type": "markdown",
"id": "1f92e754-9763-4df9-b532-8da93a44dfb2",
"metadata": {},
"source": [
"After this, you have successfully fine-tuned the model on the Dolly dataset."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit 6ab8651

Please sign in to comment.