diff --git a/containers/pytorch/training/tpu/2.3/transformers/4.39.0.dev0/py310/Dockerfile b/containers/pytorch/training/tpu/2.3/transformers/4.39.0.dev0/py310/Dockerfile new file mode 100644 index 00000000..81f92c3a --- /dev/null +++ b/containers/pytorch/training/tpu/2.3/transformers/4.39.0.dev0/py310/Dockerfile @@ -0,0 +1,73 @@ +FROM us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_20240229 +# The nightly image with PyTorch = 2.3, Python=3.10 +# Read more about it here: https://github.com/pytorch/xla?tab=readme-ov-file#docker + +LABEL maintainer="Hugging Face" +ARG DEBIAN_FRONTEND=noninteractive + +# Versions +ARG TRANSFORMERS='4.39.0' +ARG DIFFUSERS='0.27.1' +ARG PEFT='0.9.0' +ARG TRL='0.8.1' +ARG DATASETS='2.18.0' +ARG ACCELERATE='0.27.2' +ARG EVALUATE='0.4.1' +ARG NOTEBOOK='7.1.1' + +RUN apt-get update \ + && apt-get install -y \ + bzip2 \ + curl \ + git \ + git-lfs \ + tar \ + gcc \ + g++ \ + libaio-dev \ + # audio + libsndfile1-dev \ + ffmpeg \ + apt-transport-https \ + gnupg \ + ca-certificates \ + && apt-get clean autoremove --yes + +# Update pip +RUN pip install --upgrade pip + +# Install Hugging Face Libraries +RUN pip install --upgrade --no-cache-dir \ + transformers[sklearn,sentencepiece,vision]==${TRANSFORMERS} \ + diffusers==${DIFFUSERS} \ + datasets==${DATASETS} \ + accelerate==${ACCELERATE} \ + evaluate==${EVALUATE} \ + peft==${PEFT} \ + trl==${TRL} \ + notebook==${NOTEBOOK} + +# fix about saving checkpoints on TPU when using FSDPv2 +# can be removed once https://github.com/huggingface/transformers/pull/29780 is merged and new release is made +RUN pip install --upgrade git+https://github.com/shub-kris/transformers.git@fix/checkpointing-on-tpu-with-fsdpv2 + +#Install Google Cloud Dependencies +RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" \ + | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list && \ + curl https://packages.cloud.google.com/apt/doc/apt-key.gpg \ + | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - && \ + apt-get update -y && \ + apt-get install google-cloud-sdk -y + +RUN pip install --upgrade --no-cache-dir \ + google-cloud-storage \ + google-cloud-bigquery \ + google-cloud-aiplatform \ + google-cloud-pubsub \ + google-cloud-logging + +# Check if correct versions are installed +RUN python -c "import transformers, diffusers, datasets, accelerate, evaluate, peft, trl, torch; \ + assert all([mod.__version__ == version for mod, version in [(diffusers, '${DIFFUSERS}'), \ + (datasets, '${DATASETS}'), (accelerate, '${ACCELERATE}'), (evaluate, '${EVALUATE}'), (peft, '${PEFT}'), (trl, '${TRL}'), \ + (torch, '2.3.0')]])" \ No newline at end of file diff --git a/examples/tpu-examples/causal-language-modeling/README.md b/examples/tpu-examples/causal-language-modeling/README.md new file mode 100644 index 00000000..003c2fe0 --- /dev/null +++ b/examples/tpu-examples/causal-language-modeling/README.md @@ -0,0 +1,104 @@ +# Finetune Gemma-2B using Hugging Face PyTorch TPU DLC on Google Cloud TPU(v3-8) + +This example demonstrates how to finetune [gemma-7b](https://huggingface.co/google/gemma-7b) using Hugging Face's DLCs on Google Cloud single-host TPU(v3-8) VM. We use the [transformers](https://huggingface.co/docs/transformers/), [TRL](https://huggingface.co/docs/trl/en/index), and [PEFT](https://huggingface.co/docs/peft/index) library to fine-tune. The dataset used for this example is the [Dolly-15k](databricks/databricks-dolly-15k) dataset which can be easily accessed from Hugging Face's [Datasets](https://huggingface.co/datasets) Hub. + + +## What are TPUs? + +Google Cloud TPUs are custom-designed AI accelerators, which are optimized for training and inference of large AI models. They are ideal for a variety of use cases, such as chatbots, code generation, media content generation, synthetic speech, vision services, recommendation engines, personalization models, among others. + +Advantages of using TPUs include: + +- Designed to scale cost-efficiently for a wide range of AI workloads, spanning training, fine-tuning, and inference. +- Optimized for TensorFlow, PyTorch, and JAX, and are available in a variety of form factors, including edge devices, workstations, and cloud-based infrastructure. +- TPUs are available in [Google Cloud](https://cloud.google.com/tpu/docs/intro-to-tpu), and have been integrated with [Vertex AI](https://cloud.google.com/vertex-ai/docs/training/training-with-tpu-vm), and [Google Kubernetes Engine (GKE)](https://cloud.google.com/tpu?hl=en#cloud-tpu-in-gke). +- + +## Before you begin + +Make sure you have the following: +- A Google Cloud project with billing enabled. + +- [Google Cloud CLI](https://cloud.google.com/sdk/docs/install#linux) installed on your local machine. + +For installing Google Cloud CLI, you can use the following commands: + +```bash +curl https://sdk.cloud.google.com | bash +exec zsh -l +gcloud init +``` + +You can configure your Google Cloud project using the following command: + +```bash +gcloud auth login +gcloud config set project +gcloud auth application-default login +``` + +Enable the Compute Engine and Cloud TPU APIs using the following commands: + +```bash +gcloud services enable compute.googleapis.com +gcloud services enable tpu.googleapis.com +``` + + +## Spin up a TPU VM on Google Cloud + +We will be using [Cloud TPU v3-8](https://cloud.google.com/tpu/docs/v3), Google Cloud's generation AI accelerator. We will setup a single-host TPU(v3-8) VM to train the model. + +You can read more about Single-host(8 chips) and Multi-host(> 8 chips) TPU VMs on [Google Cloud TPU configurations](https://cloud.google.com/tpu/docs/supported-tpu-configurations). + +Note: Steps to run the example would differ for multi-host TPU VMs. One would need to use [SAX](https://github.com/google/saxml) for multi-host training and multi-host inference. + +To [set up a TPU VM](https://cloud.google.com/tpu/docs/setup-gcp-account#set-up-env), follow the steps below: + + + +```bash +gcloud compute tpus tpu-vm create dev-tpu-vm \ +--zone=us-west4-a \ +--accelerator-type=v3-8 \ +--version tpu-ubuntu2204-base +``` + +After some time, the TPU VM will be created. You can see the list of TPU VMs in [Google Cloud console](https://console.cloud.google.com/compute/tpus). + + +## Set up the environment + +Once, the Cloud TPU VM is up and running, you can SSH into the VM using the following command: + +```bash +gcloud compute tpus tpu-vm ssh dev-tpu-vm --zone=us-west4-a +``` + + +You now need to build the environment using Hugging Face's PyTorch TPU DLC [Dockerfile](https://github.com/huggingface/Google-Cloud-Containers/blob/main/containers/pytorch/training/tpu/2.3/transformers/4.39.0.dev0/py310/Dockerfile). You can use the following commands to build the environment: + +```bash +git clone https://github.com/huggingface/Google-Cloud-Containers.git +cd Google-Cloud-Containers +sudo docker build -t huggingface-pytorch-training-tpu-2.3.transformers.4.39.0.dev0.py310:latest -f containers/pytorch/training/tpu/2.3/transformers/4.39.0.dev0/py310/Dockerfile . +``` + +## Train the model +Once, the docker image is built, we need to run the docker container in order to activate the enviroment. You can use the following commands to run the docker container: + +```bash +sudo docker run -it -v $(pwd):/workspace --privileged huggingface-pytorch-training-tpu-2.3.transformers.4.39.0.dev0.py310:latest bash +``` + +Now, you can run the following commands to train the model: + +```bash +export PJRT_DEVICE=TPU XLA_USE_BF16=1 XLA_USE_SPMD=1 +export HF_TOKEN= +cd /workspace +python examples/tpu-examples/causal-language-modeling/finetune-gemma-lora-dolly.py \ +--num_epochs 3 \ +--train_batch_size 32 \ +--lr 3e-4 +``` \ No newline at end of file diff --git a/examples/tpu-examples/causal-language-modeling/finetune-gemma-lora-dolly.py b/examples/tpu-examples/causal-language-modeling/finetune-gemma-lora-dolly.py new file mode 100644 index 00000000..f979deff --- /dev/null +++ b/examples/tpu-examples/causal-language-modeling/finetune-gemma-lora-dolly.py @@ -0,0 +1,124 @@ +import argparse + +import torch +import torch_xla +from datasets import load_dataset +from peft import LoraConfig, PeftModel, TaskType +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments +from trl import SFTTrainer + + +def inference(model, tokenizer): + prompts = [ + "Why can camels survive for long without water?", + "Are the following items candy bars or gum: trident, Twix, hubba bubba, snickers, three musketeers, and wrigleys.", + ] + + for prompt in prompts: + text = f"### Instruction\n {prompt}" + device = "cpu" + inputs = tokenizer(text, return_tensors="pt").to(device) + outputs = model.generate( + **inputs, max_new_tokens=50 + ) # model.generate only supported on GPU and CPU + print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + print("\n\n") + + +def train_gemma(args): + raw_dataset = load_dataset("databricks/databricks-dolly-15k", split="train") + model_id = args.model_id + + def format_dolly(sample): + instruction = f"### Instruction\n{sample['instruction']}" + context = ( + f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None + ) + response = f"### Answer\n{sample['response']}" + # join all the parts together + prompt = "\n\n".join( + [i for i in [instruction, context, response] if i is not None] + ) + sample["text"] = prompt + return sample + + # apply prompt template + format_dataset = raw_dataset.map( + format_dolly, remove_columns=list(raw_dataset.features) + ) + + # Load Tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.padding_side = "right" # to prevent warnings + + # Load model + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) + lora_config = LoraConfig( + r=8, + target_modules="all-linear", + task_type=TaskType.CAUSAL_LM, + lora_alpha=16, + lora_dropout=0.05, + ) + + # Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True. + fsdp_config = { + "fsdp_transformer_layer_cls_to_wrap": [ + "GemmaDecoderLayer" # Specify the layer to wrap according to the model's config + ], + "xla": True, + "xla_fsdp_v2": True, + "xla_fsdp_grad_ckpt": True, + } + + # Define training arguments + training_args = TrainingArguments( + output_dir="output", + per_device_train_batch_size=args.train_batch_size, + learning_rate=args.lr, + num_train_epochs=args.num_epochs, + logging_strategy="steps", + logging_steps=args.logging_steps, + save_steps=args.save_steps, + optim="adafactor", + bf16=True, + dataloader_drop_last=True, # Required for SPMD. + fsdp="full_shard", + fsdp_config=fsdp_config, + ) + + # Initialize our Trainer + trainer = SFTTrainer( + model=model, + peft_config=lora_config, + args=training_args, + dataset_text_field="text", + packing=True, + train_dataset=format_dataset, + tokenizer=tokenizer, + ) + # Train the model + trainer.train() + trainer.save_model() + + # Inference + model = AutoModelForCausalLM.from_pretrained(model_id) + trained_peft_model = PeftModel.from_pretrained(model, "output") + inference(trained_peft_model, tokenizer) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model_id", default="google/gemma-7b", type=str) + parser.add_argument("--num_epochs", default=3, type=int) + parser.add_argument("--train_batch_size", default=32, type=int) + parser.add_argument("--lr", default=3e-4, type=float) + parser.add_argument("--save_steps", default=500, type=int) + parser.add_argument("--logging_steps", default=1, type=int) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + train_gemma(args) diff --git a/examples/tpu-examples/causal-language-modeling/finetune-llama2-lora-guanaco.py b/examples/tpu-examples/causal-language-modeling/finetune-llama2-lora-guanaco.py new file mode 100644 index 00000000..b81ab03e --- /dev/null +++ b/examples/tpu-examples/causal-language-modeling/finetune-llama2-lora-guanaco.py @@ -0,0 +1,105 @@ +import argparse + +import torch +import torch_xla +from datasets import load_dataset +from peft import LoraConfig, PeftModel, TaskType +from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments +from trl import SFTTrainer + + +def inference(model, tokenizer): + prompts = [ + "### Human: From now on, you will act as a nutritionist. I will ask questions about nutrition and you will reply with an explanation on how I can apply it to my daily basis. My first request: What is the main benefit of doing intermittent fastening regularly?", + "### Human: Was kannst Du im Vergleich zu anderen Large Language Models?", + ] + + for prompt in prompts: + device = "cpu" + inputs = tokenizer(prompt, return_tensors="pt").to(device) + outputs = model.generate( + **inputs, max_new_tokens=50 + ) # model.generate only supported on GPU and CPU + print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + print("\n\n") + + +def train_llama(args): + raw_dataset = load_dataset("timdettmers/openassistant-guanaco", split="train") + model_id = args.model_id + + # Load Tokenizer + tokenizer = AutoTokenizer.from_pretrained(model_id) + tokenizer.pad_token = ( + tokenizer.eos_token + ) # Set the padding token to be the same as the end-of-sequence token + + # Load model + model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16) + lora_config = LoraConfig( + r=8, + task_type=TaskType.CAUSAL_LM, + lora_alpha=16, + lora_dropout=0.05, + ) + + # Set up the FSDP config. To enable FSDP via SPMD, set xla_fsdp_v2 to True. + fsdp_config = { + "fsdp_transformer_layer_cls_to_wrap": [ + "LlamaDecoderLayer" # Specify the layer to wrap according to the model's config + ], + "xla": True, + "xla_fsdp_v2": True, + "xla_fsdp_grad_ckpt": True, + } + + # Define training arguments + training_args = TrainingArguments( + output_dir="output", + per_device_train_batch_size=args.train_batch_size, + learning_rate=args.lr, + num_train_epochs=args.num_epochs, + logging_strategy="steps", + logging_steps=args.logging_steps, + save_steps=args.save_steps, + bf16=True, + dataloader_drop_last=True, # Required for SPMD. + fsdp="full_shard", + fsdp_config=fsdp_config, + ) + + # Initialize our Trainer + trainer = SFTTrainer( + model=model, + peft_config=lora_config, + args=training_args, + dataset_text_field="text", + packing=True, + train_dataset=raw_dataset, + tokenizer=tokenizer, + ) + # Train the model + trainer.train() + trainer.save_model() + + # Inference + model = AutoModelForCausalLM.from_pretrained(model_id) + trained_peft_model = PeftModel.from_pretrained(model, "output") + inference(trained_peft_model, tokenizer) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model-id", default="meta-llama/Llama-2-7b-hf", type=str) + parser.add_argument("--num_epochs", default=3, type=int) + parser.add_argument("--train_batch_size", default=32, type=int) + parser.add_argument("--lr", default=2e-4, type=float) + parser.add_argument("--save_steps", default=500, type=int) + parser.add_argument("--logging_steps", default=1, type=int) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + train_llama(args)