-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/pytorch tpu container #14
Changes from all commits
d20a488
84ca2b2
42bbd21
e679d99
5447403
9b2fb84
de6c186
4680bcd
0901c59
c2e8a3b
815b0df
e935b17
6747611
68f94b4
79872fd
2a80fe2
b53b200
e75a3dd
e3041ad
7f2e378
6bb2197
b119483
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20240229 | ||
# The nightly image with PyTorch = 2.3, Python=3.10 | ||
# Read more about it here: https://github.com/pytorch/xla?tab=readme-ov-file#docker | ||
|
||
LABEL maintainer="Hugging Face" | ||
ARG DEBIAN_FRONTEND=noninteractive | ||
|
||
# Versions | ||
ARG TRANSFORMERS='4.39.0' | ||
ARG DIFFUSERS='0.27.1' | ||
ARG PEFT='0.9.0' | ||
ARG TRL='0.8.1' | ||
ARG DATASETS='2.18.0' | ||
ARG ACCELERATE='0.27.2' | ||
ARG EVALUATE='0.4.1' | ||
ARG NOTEBOOK='7.1.1' | ||
|
||
RUN apt-get update \ | ||
&& apt-get install -y \ | ||
bzip2 \ | ||
curl \ | ||
git \ | ||
git-lfs \ | ||
tar \ | ||
gcc \ | ||
g++ \ | ||
libaio-dev \ | ||
# audio | ||
libsndfile1-dev \ | ||
ffmpeg \ | ||
apt-transport-https \ | ||
gnupg \ | ||
ca-certificates \ | ||
&& apt-get clean autoremove --yes | ||
|
||
# Update pip | ||
RUN pip install --upgrade pip | ||
|
||
# Install Hugging Face Libraries | ||
RUN pip install --upgrade --no-cache-dir \ | ||
transformers[sklearn,sentencepiece,vision]==${TRANSFORMERS} \ | ||
diffusers==${DIFFUSERS} \ | ||
datasets==${DATASETS} \ | ||
accelerate==${ACCELERATE} \ | ||
evaluate==${EVALUATE} \ | ||
peft==${PEFT} \ | ||
trl==${TRL} \ | ||
notebook==${NOTEBOOK} | ||
|
||
# fix about saving checkpoints on TPU when using FSDPv2 | ||
# can be removed once https://github.com/huggingface/transformers/pull/29780 is merged and new release is made | ||
RUN pip install --upgrade git+https://github.com/shub-kris/transformers.git@fix/checkpointing-on-tpu-with-fsdpv2 | ||
|
||
#Install Google Cloud Dependencies | ||
RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" \ | ||
| tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \ | ||
curl https://packages.cloud.google.com/apt/doc/apt-key.gpg \ | ||
| apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - && \ | ||
apt-get update -y && \ | ||
apt-get install google-cloud-sdk -y | ||
|
||
RUN pip install --upgrade --no-cache-dir \ | ||
google-cloud-storage \ | ||
google-cloud-bigquery \ | ||
google-cloud-aiplatform \ | ||
google-cloud-pubsub \ | ||
google-cloud-logging | ||
|
||
# Check if correct versions are installed | ||
RUN python -c "import transformers, diffusers, datasets, accelerate, evaluate, peft, trl, torch; \ | ||
assert all([mod.__version__ == version for mod, version in [(diffusers, '${DIFFUSERS}'), \ | ||
(datasets, '${DATASETS}'), (accelerate, '${ACCELERATE}'), (evaluate, '${EVALUATE}'), (peft, '${PEFT}'), (trl, '${TRL}'), \ | ||
(torch, '2.3.0')]])" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Finetune Gemma-2B using Hugging Face PyTorch TPU DLC on Google Cloud TPU(v3-8) | ||
|
||
This example demonstrates how to finetune [gemma-7b](https://huggingface.co/google/gemma-7b) using Hugging Face's DLCs on Google Cloud single-host TPU(v3-8) VM. We use the [transformers](https://huggingface.co/docs/transformers/), [TRL](https://huggingface.co/docs/trl/en/index), and [PEFT](https://huggingface.co/docs/peft/index) library to fine-tune. The dataset used for this example is the [Dolly-15k](databricks/databricks-dolly-15k) dataset which can be easily accessed from Hugging Face's [Datasets](https://huggingface.co/datasets) Hub. | ||
|
||
|
||
## What are TPUs? | ||
|
||
Google Cloud TPUs are custom-designed AI accelerators, which are optimized for training and inference of large AI models. They are ideal for a variety of use cases, such as chatbots, code generation, media content generation, synthetic speech, vision services, recommendation engines, personalization models, among others. | ||
|
||
Advantages of using TPUs include: | ||
|
||
- Designed to scale cost-efficiently for a wide range of AI workloads, spanning training, fine-tuning, and inference. | ||
- Optimized for TensorFlow, PyTorch, and JAX, and are available in a variety of form factors, including edge devices, workstations, and cloud-based infrastructure. | ||
- TPUs are available in [Google Cloud](https://cloud.google.com/tpu/docs/intro-to-tpu), and have been integrated with [Vertex AI](https://cloud.google.com/vertex-ai/docs/training/training-with-tpu-vm), and [Google Kubernetes Engine (GKE)](https://cloud.google.com/tpu?hl=en#cloud-tpu-in-gke). | ||
- | ||
|
||
## Before you begin | ||
|
||
Make sure you have the following: | ||
- A Google Cloud project with billing enabled. | ||
<!-- - Access to Hugging Face's PyTorch TPU DLC. --> | ||
- [Google Cloud CLI](https://cloud.google.com/sdk/docs/install#linux) installed on your local machine. | ||
|
||
For installing Google Cloud CLI, you can use the following commands: | ||
|
||
```bash | ||
curl https://sdk.cloud.google.com | bash | ||
exec zsh -l | ||
gcloud init | ||
``` | ||
|
||
You can configure your Google Cloud project using the following command: | ||
|
||
```bash | ||
gcloud auth login | ||
gcloud config set project <PROJECT_ID> | ||
gcloud auth application-default login | ||
``` | ||
|
||
Enable the Compute Engine and Cloud TPU APIs using the following commands: | ||
|
||
```bash | ||
gcloud services enable compute.googleapis.com | ||
gcloud services enable tpu.googleapis.com | ||
``` | ||
|
||
|
||
## Spin up a TPU VM on Google Cloud | ||
|
||
We will be using [Cloud TPU v3-8](https://cloud.google.com/tpu/docs/v3), Google Cloud's generation AI accelerator. We will setup a single-host TPU(v3-8) VM to train the model. | ||
|
||
You can read more about Single-host(8 chips) and Multi-host(> 8 chips) TPU VMs on [Google Cloud TPU configurations](https://cloud.google.com/tpu/docs/supported-tpu-configurations). | ||
|
||
Note: Steps to run the example would differ for multi-host TPU VMs. One would need to use [SAX](https://github.com/google/saxml) for multi-host training and multi-host inference. | ||
|
||
To [set up a TPU VM](https://cloud.google.com/tpu/docs/setup-gcp-account#set-up-env), follow the steps below: | ||
|
||
<!-- TODO: Update this script to directly use the Hugging Face PyTorch TPU DLC --> | ||
|
||
```bash | ||
gcloud compute tpus tpu-vm create dev-tpu-vm \ | ||
--zone=us-west4-a \ | ||
--accelerator-type=v3-8 \ | ||
--version tpu-ubuntu2204-base | ||
``` | ||
|
||
After some time, the TPU VM will be created. You can see the list of TPU VMs in [Google Cloud console](https://console.cloud.google.com/compute/tpus). | ||
|
||
|
||
## Set up the environment | ||
|
||
Once, the Cloud TPU VM is up and running, you can SSH into the VM using the following command: | ||
|
||
```bash | ||
gcloud compute tpus tpu-vm ssh dev-tpu-vm --zone=us-west4-a | ||
``` | ||
|
||
<!-- TODO: Update the link to the Dockerfile and remove the part where docker image needs to be build once DLCs are released--> | ||
You now need to build the environment using Hugging Face's PyTorch TPU DLC [Dockerfile](https://github.com/huggingface/Google-Cloud-Containers/blob/main/containers/pytorch/training/tpu/2.3/transformers/4.39.0.dev0/py310/Dockerfile). You can use the following commands to build the environment: | ||
|
||
```bash | ||
git clone https://github.com/huggingface/Google-Cloud-Containers.git | ||
cd Google-Cloud-Containers | ||
sudo docker build -t huggingface-pytorch-training-tpu-2.3.transformers.4.39.0.dev0.py310:latest -f containers/pytorch/training/tpu/2.3/transformers/4.39.0.dev0/py310/Dockerfile . | ||
``` | ||
|
||
## Train the model | ||
Once, the docker image is built, we need to run the docker container in order to activate the enviroment. You can use the following commands to run the docker container: | ||
|
||
```bash | ||
sudo docker run -it -v $(pwd):/workspace --privileged huggingface-pytorch-training-tpu-2.3.transformers.4.39.0.dev0.py310:latest bash | ||
``` | ||
|
||
Now, you can run the following commands to train the model: | ||
|
||
```bash | ||
export PJRT_DEVICE=TPU XLA_USE_BF16=1 XLA_USE_SPMD=1 | ||
export HF_TOKEN=<YOUR_HF_TOKEN> | ||
cd /workspace | ||
python examples/tpu-examples/causal-language-modeling/finetune-gemma-lora-dolly.py \ | ||
--num_epochs 3 \ | ||
--train_batch_size 32 \ | ||
--lr 3e-4 | ||
``` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we include some steps to test or evaluate the results? To make sure the model is really learning and not generating garbage? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I will do that. I and someone else encountered a similar error: which I am working on. Once that is fixed which has also to do with evaluation and saving, will add an evaluation too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added a small inference function to verify if the model after being trained is working as expected. Tested the trained model and the inference generated:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also did it for the
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, updated the README and changed TPU to |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import argparse | ||
|
||
import torch | ||
import torch_xla | ||
from datasets import load_dataset | ||
from peft import LoraConfig, PeftModel, TaskType | ||
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments | ||
from trl import SFTTrainer | ||
|
||
|
||
def inference(model, tokenizer): | ||
prompts = [ | ||
"Why can camels survive for long without water?", | ||
"Are the following items candy bars or gum: trident, Twix, hubba bubba, snickers, three musketeers, and wrigleys.", | ||
] | ||
|
||
for prompt in prompts: | ||
text = f"### Instruction\n {prompt}" | ||
device = "cpu" | ||
inputs = tokenizer(text, return_tensors="pt").to(device) | ||
outputs = model.generate( | ||
**inputs, max_new_tokens=50 | ||
) # model.generate only supported on GPU and CPU | ||
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) | ||
print("\n\n") | ||
|
||
|
||
def train_gemma(args): | ||
raw_dataset = load_dataset("databricks/databricks-dolly-15k", split="train") | ||
model_id = args.model_id | ||
|
||
def format_dolly(sample): | ||
instruction = f"### Instruction\n{sample['instruction']}" | ||
context = ( | ||
f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None | ||
) | ||
response = f"### Answer\n{sample['response']}" | ||
# join all the parts together | ||
prompt = "\n\n".join( | ||
[i for i in [instruction, context, response] if i is not None] | ||
) | ||
sample["text"] = prompt | ||
return sample | ||
|
||
# apply prompt template | ||
format_dataset = raw_dataset.map( | ||
format_dolly, remove_columns=list(raw_dataset.features) | ||
) | ||
|
||
# Load Tokenizer | ||
tokenizer = AutoTokenizer.from_pretrained(model_id) | ||
tokenizer.padding_side = "right" # to prevent warnings | ||
|
||
# Load model | ||
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) | ||
lora_config = LoraConfig( | ||
r=8, | ||
target_modules="all-linear", | ||
task_type=TaskType.CAUSAL_LM, | ||
lora_alpha=16, | ||
lora_dropout=0.05, | ||
) | ||
|
||
# Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True. | ||
fsdp_config = { | ||
"fsdp_transformer_layer_cls_to_wrap": [ | ||
"GemmaDecoderLayer" # Specify the layer to wrap according to the model's config | ||
], | ||
"xla": True, | ||
"xla_fsdp_v2": True, | ||
"xla_fsdp_grad_ckpt": True, | ||
} | ||
|
||
# Define training arguments | ||
training_args = TrainingArguments( | ||
output_dir="output", | ||
per_device_train_batch_size=args.train_batch_size, | ||
learning_rate=args.lr, | ||
num_train_epochs=args.num_epochs, | ||
logging_strategy="steps", | ||
logging_steps=args.logging_steps, | ||
save_steps=args.save_steps, | ||
optim="adafactor", | ||
bf16=True, | ||
dataloader_drop_last=True, # Required for SPMD. | ||
fsdp="full_shard", | ||
fsdp_config=fsdp_config, | ||
) | ||
|
||
# Initialize our Trainer | ||
trainer = SFTTrainer( | ||
model=model, | ||
peft_config=lora_config, | ||
args=training_args, | ||
dataset_text_field="text", | ||
packing=True, | ||
train_dataset=format_dataset, | ||
tokenizer=tokenizer, | ||
) | ||
# Train the model | ||
trainer.train() | ||
trainer.save_model() | ||
|
||
# Inference | ||
model = AutoModelForCausalLM.from_pretrained(model_id) | ||
trained_peft_model = PeftModel.from_pretrained(model, "output") | ||
inference(trained_peft_model, tokenizer) | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model_id", default="google/gemma-7b", type=str) | ||
parser.add_argument("--num_epochs", default=3, type=int) | ||
parser.add_argument("--train_batch_size", default=32, type=int) | ||
parser.add_argument("--lr", default=3e-4, type=float) | ||
parser.add_argument("--save_steps", default=500, type=int) | ||
parser.add_argument("--logging_steps", default=1, type=int) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
train_gemma(args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know when there is the next official XLA release?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Somewhere early next month.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will update the dockerfile once there is a new release, for now I think we can merge the PR.