Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/pytorch tpu container #14

Merged
merged 22 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d20a488
PyTorch TPU Dockerfile
shub-kris Feb 15, 2024
84ca2b2
PyTorch TPU Dockerfile
shub-kris Feb 15, 2024
42bbd21
Add Notebook
shub-kris Feb 16, 2024
e679d99
remove: unused dockerfile, sentence-transformers
shub-kris Feb 16, 2024
5447403
Add example with PyTorch_XLA TPU DLC
shub-kris Feb 21, 2024
9b2fb84
Update README to add Single-Host
shub-kris Feb 21, 2024
de6c186
Add more info about different hosts
shub-kris Feb 21, 2024
4680bcd
Replace manual training script with trainer
shub-kris Feb 22, 2024
0901c59
Add Dolly, use TRL and OPT-350M
shub-kris Feb 23, 2024
c2e8a3b
Push local changes
shub-kris Feb 23, 2024
815b0df
Merge branch 'feature/pytorch-tpu-transformers-example' into feature/…
shub-kris Feb 26, 2024
e935b17
Change name, add transformers trainer example
shub-kris Feb 26, 2024
6747611
Update dockerfile and example according to nightly version
shub-kris Feb 28, 2024
68f94b4
Add info to add token
shub-kris Feb 28, 2024
79872fd
checkpointing is faster with this base image
shub-kris Feb 29, 2024
2a80fe2
Build transformers and trl from main for checkpointing and sfttrainer…
shub-kris Mar 14, 2024
b53b200
Update README
shub-kris Mar 14, 2024
e75a3dd
Change parameters
shub-kris Mar 14, 2024
e3041ad
Change paramters, add example for LLama
shub-kris Mar 15, 2024
7f2e378
Rename folder
shub-kris Mar 15, 2024
6bb2197
llama-example: name of function from train_gemma to train_llama
shub-kris Mar 15, 2024
b119483
update(tpu-examples): change model to gemma-7b, add inference at the …
shub-kris Mar 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla@sha256:8f1dcd5b03f993e4da5c20d17c77aff6a5f22d5455f8eb042d2e4b16ac460526
# The nightly image with PyTorch = 2.3, Python=3.10
# Read more about it here: https://github.com/pytorch/xla?tab=readme-ov-file#docker

LABEL maintainer="Hugging Face"
ARG DEBIAN_FRONTEND=noninteractive

# Versions
ARG TRANSFORMERS='4.38.1'
ARG DIFFUSERS='0.26.3'
ARG PEFT='0.8.2'
ARG TRL='0.7.11'
ARG DATASETS='2.17.1'
ARG ACCELERATE='0.27.2'
ARG EVALUATE='0.4.1'
ARG NOTEBOOK='7.1.1'

RUN apt-get update \
&& apt-get install -y \
bzip2 \
curl \
git \
git-lfs \
tar \
gcc \
g++ \
libaio-dev \
# audio
libsndfile1-dev \
ffmpeg \
apt-transport-https \
gnupg \
ca-certificates \
&& apt-get clean autoremove --yes

# Update pip
RUN pip install --upgrade pip

# Install Hugging Face Libraries
RUN pip install --upgrade --no-cache-dir \
transformers[sklearn,sentencepiece,vision]==${TRANSFORMERS} \
diffusers==${DIFFUSERS} \
datasets==${DATASETS} \
accelerate==${ACCELERATE} \
evaluate==${EVALUATE} \
peft==${PEFT} \
trl==${TRL} \
notebook==${NOTEBOOK}

#Install Google Cloud Dependencies
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" \
| tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg \
| apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - && \
apt-get update -y && \
apt-get install google-cloud-sdk -y

RUN pip install --upgrade --no-cache-dir \
google-cloud-storage \
google-cloud-bigquery \
google-cloud-aiplatform \
google-cloud-pubsub \
google-cloud-logging

# Check if correct versions are installed
RUN python -c "import transformers, diffusers, datasets, accelerate, evaluate, peft, trl, torch; \
assert all([mod.__version__ == version for mod, version in [(transformers, '${TRANSFORMERS}'), (diffusers, '${DIFFUSERS}'), \
(datasets, '${DATASETS}'), (accelerate, '${ACCELERATE}'), (evaluate, '${EVALUATE}'), (peft, '${PEFT}'), (trl, '${TRL}'), \
(torch, '2.3.0')]])"
104 changes: 104 additions & 0 deletions examples/google-cloud-tpu-vm/causal-language-modeling/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Finetune Gemma-2B using Hugging Face PyTorch TPU DLC on Google Cloud TPU(v5e)

This example demonstrates how to finetune [gemma-2b](https://huggingface.co/google/gemma-2b) using Hugging Face's DLCs on Google Cloud single-host TPU(v5e) VM. We use the [transformers](https://huggingface.co/docs/transformers/), [TRL](https://huggingface.co/docs/trl/en/index), and [PEFT](https://huggingface.co/docs/peft/index) library to fine-tune. The dataset used for this example is the [Dolly-15k](databricks/databricks-dolly-15k) dataset which can be easily accessed from Hugging Face's [Datasets](https://huggingface.co/datasets) Hub.


## What are TPUs?

Google Cloud TPUs are custom-designed AI accelerators, which are optimized for training and inference of large AI models. They are ideal for a variety of use cases, such as chatbots, code generation, media content generation, synthetic speech, vision services, recommendation engines, personalization models, among others.

Advantages of using TPUs include:

- Designed to scale cost-efficiently for a wide range of AI workloads, spanning training, fine-tuning, and inference.
- Optimized for TensorFlow, PyTorch, and JAX, and are available in a variety of form factors, including edge devices, workstations, and cloud-based infrastructure.
- TPUs are available in [Google Cloud](https://cloud.google.com/tpu/docs/intro-to-tpu), and have been integrated with [Vertex AI](https://cloud.google.com/vertex-ai/docs/training/training-with-tpu-vm), and [Google Kubernetes Engine (GKE)](https://cloud.google.com/tpu?hl=en#cloud-tpu-in-gke).
-

## Before you begin

Make sure you have the following:
- A Google Cloud project with billing enabled.
<!-- - Access to Hugging Face's PyTorch TPU DLC. -->
- [Google Cloud CLI](https://cloud.google.com/sdk/docs/install#linux) installed on your local machine.

For installing Google Cloud CLI, you can use the following commands:

```bash
curl https://sdk.cloud.google.com | bash
exec zsh -l
gcloud init
```

You can configure your Google Cloud project using the following command:

```bash
gcloud auth login
gcloud config set project <PROJECT_ID>
gcloud auth application-default login
```

Enable the Compute Engine and Cloud TPU APIs using the following commands:

```bash
gcloud services enable compute.googleapis.com
gcloud services enable tpu.googleapis.com
```


## Spin up a TPU VM on Google Cloud

We will be using [Cloud TPU v5e](https://cloud.google.com/tpu/docs/v5e-training), Google Cloud's latest generation AI accelerator. We will setup a single-host TPU(v5e) VM to train the model.

You can read more about Single-host(8 chips) and Multi-host(> 8 chips) TPU VMs on [Google Cloud TPU configurations](https://cloud.google.com/tpu/docs/supported-tpu-configurations).

Note: Steps to run the example would differ for multi-host TPU VMs. One would need to use [SAX](https://github.com/google/saxml) for multi-host training and multi-host inference.

To [set up a TPU VM](https://cloud.google.com/tpu/docs/setup-gcp-account#set-up-env), follow the steps below:

<!-- TODO: Update this script to directly use the Hugging Face PyTorch TPU DLC -->

```bash
gcloud alpha compute tpus tpu-vm create dev-tpu-vm \
--zone=us-west4-a \
--accelerator-type=v5litepod-8 \
--version v2-alpha-tpuv5-lite
```

After some time, the TPU VM will be created. You can see the list of TPU VMs in [Google Cloud console](https://console.cloud.google.com/compute/tpus).


## Set up the environment

Once, the Cloud TPU VM is up and running, you can SSH into the VM using the following command:

```bash
gcloud alpha compute tpus tpu-vm ssh dev-tpu-vm --zone=us-west4-a
```

<!-- TODO: Update the link to the Dockerfile and remove the part where docker image needs to be build once DLCs are released-->
You now need to build the environment using Hugging Face's PyTorch TPU DLC [Dockerfile](https://github.com/huggingface/Google-Cloud-Containers/blob/feature/pytorch-tpu-container/containers/pytorch/training/tpu/2.1/transformers/4.37.2/py310/Dockerfile). You can use the following commands to build the environment:

```bash
git clone https://github.com/huggingface/Google-Cloud-Containers.git
cd Google-Cloud-Containers
sudo docker build -t huggingface-pytorch-training-tpu-2.3.transformers.4.38.1.py310:latest -f containers/pytorch/training/tpu/2.3/transformers/4.38.1/py310/Dockerfile .
```

## Train the model
Once, the docker image is built, we need to run the docker container in order to activate the enviroment. You can use the following commands to run the docker container:

```bash
sudo docker run -it -v $(pwd):/workspace --privileged huggingface-pytorch-training-tpu-2.3.transformers.4.38.1.py310:latest bash
```

Now, you can run the following commands to train the model:

```bash
export PJRT_DEVICE=TPU XLA_USE_BF16=1 XLA_USE_SPMD=1
export HF_TOKEN=<YOUR_HF_TOKEN>
cd /workspace
python examples/google-cloud-tpu-vm/causal-language-modeling/finetune-gemma-lora-dolly.py \
--num_epochs 3 \
--train_batch_size 16 \
--lr 3e-4
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import argparse

import torch
import torch_xla
import torch_xla.core.xla_model as xm
from datasets import load_dataset
from peft import LoraConfig, TaskType
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
TrainingArguments,
)
from trl import SFTTrainer


def train_gemma(args):
raw_dataset = load_dataset("databricks/databricks-dolly-15k", split="train")
model_id = "google/gemma-2b"

def format_dolly(sample):
instruction = f"### Instruction\n{sample['instruction']}"
context = (
f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
)
response = f"### Answer\n{sample['response']}"
# join all the parts together
prompt = "\n\n".join(
[i for i in [instruction, context, response] if i is not None]
)
sample["text"] = prompt
return sample

# apply prompt template
format_dataset = raw_dataset.map(
format_dolly, remove_columns=list(raw_dataset.features)
)

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

device = xm.xla_device()

# Load model
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16)
lora_config = LoraConfig(
r=16,
target_modules=["q_proj", "v_proj"],
task_type=TaskType.CAUSAL_LM,
lora_alpha=32,
lora_dropout=0.05,
)

# Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True.
fsdp_config = {
"fsdp_transformer_layer_cls_to_wrap": [
"GemmaDecoderLayer" # Specify the layer to wrap according to the model's config
],
"xla": True,
"xla_fsdp_v2": True,
"xla_fsdp_grad_ckpt": True,
}

# Define training arguments
training_args = TrainingArguments(
output_dir="output",
per_device_train_batch_size=args.train_batch_size,
learning_rate=args.lr,
num_train_epochs=args.num_epochs,
logging_strategy="steps",
logging_steps=20,
save_steps=args.save_steps,
dataloader_drop_last=True, # Required for SPMD.
fsdp="full_shard",
fsdp_config=fsdp_config,
)

# Initialize our Trainer
trainer = SFTTrainer(
model=model,
peft_config=lora_config,
args=training_args,
dataset_text_field="text",
packing=True,
train_dataset=format_dataset,
tokenizer=tokenizer,
data_collator=DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False),
)
# Train the model
trainer.train()


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--num_epochs", default=3, type=int)
parser.add_argument("--train_batch_size", default=16, type=int)
parser.add_argument("--lr", default=3e-4, type=float)
parser.add_argument("--save_steps", default=100, type=int)
args = parser.parse_args()
return args


if __name__ == "__main__":
args = parse_args()
train_gemma(args)