From 395e8cabd7b618f3a1a8f01fd50c7ecbad3022a1 Mon Sep 17 00:00:00 2001 From: Andrey Velichkevich Date: Tue, 12 Mar 2024 11:35:10 +0000 Subject: [PATCH] Add Fine-Tune BERT LLM Example (#2021) * Add Fine-Tune BERT LLM Example Signed-off-by: Andrey Velichkevich * Add 3 GPUs in Notebook requirements Signed-off-by: Andrey Velichkevich --------- Signed-off-by: Andrey Velichkevich --- .../Train CNN with FashionMNIST.ipynb} | 0 .../create-pytorchjob.ipynb | 0 .../language-modeling}/train_api.ipynb | 0 .../Fine Tune BERT LLM.ipynb | 683 ++++++++++++++++++ .../image-classification}/create-tfjob.ipynb | 0 sdk/python/README.md | 3 +- 6 files changed, 685 insertions(+), 1 deletion(-) rename examples/{sdk/create-pytorchjob-from-func.ipynb => pytorch/image-classification/Train CNN with FashionMNIST.ipynb} (100%) rename examples/{sdk => pytorch/image-classification}/create-pytorchjob.ipynb (100%) rename examples/{sdk => pytorch/language-modeling}/train_api.ipynb (100%) create mode 100644 examples/pytorch/text-classification/Fine Tune BERT LLM.ipynb rename examples/{sdk => tensorflow/image-classification}/create-tfjob.ipynb (100%) diff --git a/examples/sdk/create-pytorchjob-from-func.ipynb b/examples/pytorch/image-classification/Train CNN with FashionMNIST.ipynb similarity index 100% rename from examples/sdk/create-pytorchjob-from-func.ipynb rename to examples/pytorch/image-classification/Train CNN with FashionMNIST.ipynb diff --git a/examples/sdk/create-pytorchjob.ipynb b/examples/pytorch/image-classification/create-pytorchjob.ipynb similarity index 100% rename from examples/sdk/create-pytorchjob.ipynb rename to examples/pytorch/image-classification/create-pytorchjob.ipynb diff --git a/examples/sdk/train_api.ipynb b/examples/pytorch/language-modeling/train_api.ipynb similarity index 100% rename from examples/sdk/train_api.ipynb rename to examples/pytorch/language-modeling/train_api.ipynb diff --git a/examples/pytorch/text-classification/Fine Tune BERT LLM.ipynb b/examples/pytorch/text-classification/Fine Tune BERT LLM.ipynb new file mode 100644 index 0000000000..bf10215ad0 --- /dev/null +++ b/examples/pytorch/text-classification/Fine Tune BERT LLM.ipynb @@ -0,0 +1,683 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Fine-Tune BERT LLM for Sentiment Analysis with Kubeflow PyTorchJob" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This Notebook will fine-tune Bidirectional Encoder Representations from Transformers (BERT) model with Yelp dataset to analyze text sentiment using distributed training with [Kubeflow PyTorchJob](https://www.kubeflow.org/docs/components/training/overview/).\n", + "\n", + "Pretrained BERT model: https://huggingface.co/google-bert/bert-base-cased\n", + "\n", + "Yelp review full dataset: https://huggingface.co/datasets/yelp_review_full\n", + "\n", + "This Notebook requires:\n", + "\n", + "- At least **3 GPU** on your Kubernetes cluster to fine-tune BERT model on 3 workers.\n", + "- AWS S3 bucket to export fine-tuned model.\n", + "\n", + "This example is based on [the HuggingFace fine-tuning tutorial](https://huggingface.co/docs/transformers/en/training)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Install required packages\n", + "\n", + "We need to install HuggingFace packages to run this Notebook." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install transformers datasets boto3\n", + "\n", + "!pip install git+https://github.com/kubeflow/training-operator.git#subdirectory=sdk/python\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "## Get samples from Yelp reviews dataset\n", + "\n", + "The Yelp reviews full star dataset is constructed by randomly taking 130,000 training samples and 10,000 testing samples for each review star from 1 to 5.\n", + "\n", + "In total there are 650,000 training samples and 50,000 testing samples.\n", + "\n", + "We are going to use this dataset to fine-tune BERT model." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-10T00:45:19.125747Z", + "iopub.status.busy": "2024-03-10T00:45:19.125051Z", + "iopub.status.idle": "2024-03-10T00:45:21.775181Z", + "shell.execute_reply": "2024-03-10T00:45:21.774143Z", + "shell.execute_reply.started": "2024-03-10T00:45:19.125725Z" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'label': 4,\n", + " 'text': \"Top notch doctor in a top notch practice. Can't say I am surprised \"\n", + " 'when I was referred to him by another doctor who I think is '\n", + " 'wonderful and because he went to one of the best medical schools in '\n", + " 'the country. \\\\nIt is really easy to get an appointment. There is '\n", + " 'minimal wait to be seen and his bedside manner is great.'}\n", + "{'label': 1,\n", + " 'text': 'Average run of the mill store. Associates are young teens and they '\n", + " \"really don't know where anything is. Luckily I am able to get \"\n", + " 'around to find everything. Found my puppy treats and moved on.'}\n" + ] + } + ], + "source": [ + "from pprint import pprint\n", + "\n", + "from datasets import load_dataset\n", + "\n", + "# Test only 100 samples in the Notebook.\n", + "dataset = load_dataset(\"yelp_review_full\", split=\"train[:100]\")\n", + "\n", + "# Print some test data.\n", + "pprint(dataset[5])\n", + "pprint(dataset[30])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create script to fine-tune BERT model\n", + "\n", + "We need to wrap our fine-tuning script in a function to create Kubeflow PyTorchJob." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-10T00:37:51.012597Z", + "iopub.status.busy": "2024-03-10T00:37:51.012357Z", + "iopub.status.idle": "2024-03-10T00:37:51.021633Z", + "shell.execute_reply": "2024-03-10T00:37:51.020711Z", + "shell.execute_reply.started": "2024-03-10T00:37:51.012581Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "def train_func(parameters):\n", + " import os\n", + "\n", + " import boto3\n", + " import evaluate\n", + " import numpy as np\n", + " from datasets import load_dataset\n", + " from datasets.distributed import split_dataset_by_node\n", + " from transformers import (\n", + " AutoModelForSequenceClassification,\n", + " AutoTokenizer,\n", + " Trainer,\n", + " TrainingArguments,\n", + " )\n", + "\n", + " # [1] Download BERT model, tokenizer, and Yelp dataset.\n", + " print(\"-\" * 40)\n", + " print(\"Download BERT Model\")\n", + " model = AutoModelForSequenceClassification.from_pretrained(\n", + " \"bert-base-cased\",\n", + " num_labels=5,\n", + " )\n", + " tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n", + "\n", + " print(\"-\" * 40)\n", + " print(\"Download Yelp Review Dataset\")\n", + "\n", + " # Use only 4000 data samples to reduce tokenization and training time.\n", + " # Training samples - 3600, test samples - 400\n", + " # Remove split to take all samples: dataset = load_dataset(\"yelp_review_full\")\n", + " dataset = load_dataset(\"yelp_review_full\", split=\"train[:4000]\")\n", + " dataset = dataset.train_test_split(test_size=0.1, stratify_by_column=\"label\")\n", + "\n", + " # [2] Preprocess dataset.\n", + " def tokenize_function(examples):\n", + " return tokenizer(examples[\"text\"], padding=\"max_length\", truncation=True)\n", + "\n", + " # Map Yelp review dataset to BERT tokenizer.\n", + " print(\"-\" * 40)\n", + " print(\"Map Yelp review dataset to BERT Tokenizer\")\n", + " tokenized_ds = dataset.map(tokenize_function, batched=True)\n", + "\n", + " # Distribute train and test datasets between PyTorch workers.\n", + " # Every worker will process chunk of training data.\n", + " # RANK and WORLD_SIZE will be set by Kubeflow Training Operator.\n", + " RANK = int(os.environ[\"RANK\"])\n", + " WORLD_SIZE = int(os.environ[\"WORLD_SIZE\"])\n", + " distributed_ds_train = split_dataset_by_node(\n", + " tokenized_ds[\"train\"],\n", + " rank=RANK,\n", + " world_size=WORLD_SIZE,\n", + " )\n", + " distributed_ds_test = split_dataset_by_node(\n", + " tokenized_ds[\"test\"],\n", + " rank=RANK,\n", + " world_size=WORLD_SIZE,\n", + " )\n", + "\n", + " # Evaluate accuracy.\n", + " metric = evaluate.load(\"accuracy\")\n", + "\n", + " def compute_metrics(eval_pred):\n", + " logits, labels = eval_pred\n", + " predictions = np.argmax(logits, axis=-1)\n", + " return metric.compute(predictions=predictions, references=labels)\n", + "\n", + " # [3] Define Training args.\n", + " training_args = TrainingArguments(\n", + " output_dir=\"test_trainer\",\n", + " evaluation_strategy=\"epoch\",\n", + " disable_tqdm=True,\n", + " log_level=\"info\",\n", + " )\n", + "\n", + " # [4] Define Trainer.\n", + " trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=distributed_ds_train,\n", + " eval_dataset=distributed_ds_test,\n", + " compute_metrics=compute_metrics,\n", + " )\n", + "\n", + " # [5] Fine-tune model.\n", + " print(\"-\" * 40)\n", + " print(f\"Start Distributed Training. RANK: {RANK} WORLD_SIZE: {WORLD_SIZE}\")\n", + "\n", + " trainer.train()\n", + "\n", + " print(\"-\" * 40)\n", + " print(\"Training is complete\")\n", + "\n", + " # [6] Export trained model to S3 from the worker with RANK = 0.\n", + " if RANK == 0:\n", + " trainer.save_model(\"./bert\")\n", + " s3 = boto3.resource(\"s3\")\n", + " bucket = s3.Bucket(parameters[\"BUCKET\"])\n", + " bucket.upload_file(\"bert/config.json\", \"bert/config.json\")\n", + " bucket.upload_file(\"bert/model.safetensors\", \"bert/model.safetensors\")\n", + "\n", + " print(\"-\" * 40)\n", + " print(\"Model is exported to S3\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Create Kubeflow PyTorchJob to fine-tune BERT on GPUs\n", + "\n", + "Use `TrainingClient()` to create PyTorchJob which will fine-tune BERT on **3 workers** using **1 GPU** for each worker.\n", + "\n", + "Your Kubernetes cluster should have sufficient **GPU** resources available." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-10T00:37:52.743447Z", + "iopub.status.busy": "2024-03-10T00:37:52.743202Z", + "iopub.status.idle": "2024-03-10T00:37:52.749400Z", + "shell.execute_reply": "2024-03-10T00:37:52.747484Z", + "shell.execute_reply.started": "2024-03-10T00:37:52.743430Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import uuid\n", + "\n", + "# Make random name for PyTorchJob\n", + "job_name = \"fine-tune-bert-\" + str(uuid.uuid4())[:5]\n", + "\n", + "# Replace `BUCKET_NAME` with your AWS S3 bucket.\n", + "bucket = \"BUCKET_NAME\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-10T00:37:54.673961Z", + "iopub.status.busy": "2024-03-10T00:37:54.673715Z", + "iopub.status.idle": "2024-03-10T00:37:54.849353Z", + "shell.execute_reply": "2024-03-10T00:37:54.847915Z", + "shell.execute_reply.started": "2024-03-10T00:37:54.673944Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from kubeflow.training import TrainingClient\n", + "\n", + "# Create PyTorchJob\n", + "TrainingClient().create_job(\n", + " name=job_name,\n", + " train_func=train_func,\n", + " parameters={\"BUCKET\": bucket},\n", + " num_workers=3, # Number of PyTorch workers to use.\n", + " resources_per_worker={\n", + " \"cpu\": \"4\",\n", + " \"memory\": \"10G\",\n", + " \"gpu\": \"1\",\n", + " },\n", + " packages_to_install=[\n", + " \"boto3\",\n", + " \"transformers\",\n", + " \"datasets\",\n", + " \"evaluate\",\n", + " \"accelerate\",\n", + " \"scikit-learn\",\n", + " ], # PIP packages will be installed during PyTorchJob runtime.\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "### Check the PyTorchJob conditions\n", + "\n", + "Use `TrainingClient()` APIs to get information about created PyTorchJob." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-10T00:37:58.701682Z", + "iopub.status.busy": "2024-03-10T00:37:58.701338Z", + "iopub.status.idle": "2024-03-10T00:37:58.747460Z", + "shell.execute_reply": "2024-03-10T00:37:58.746536Z", + "shell.execute_reply.started": "2024-03-10T00:37:58.701664Z" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorchJob Conditions\n", + "[{'last_transition_time': datetime.datetime(2024, 3, 10, 0, 37, 54, tzinfo=tzlocal()),\n", + " 'last_update_time': datetime.datetime(2024, 3, 10, 0, 37, 54, tzinfo=tzlocal()),\n", + " 'message': 'PyTorchJob fine-tune-bert-1a883 is created.',\n", + " 'reason': 'PyTorchJobCreated',\n", + " 'status': 'True',\n", + " 'type': 'Created'}, {'last_transition_time': datetime.datetime(2024, 3, 10, 0, 37, 56, tzinfo=tzlocal()),\n", + " 'last_update_time': datetime.datetime(2024, 3, 10, 0, 37, 56, tzinfo=tzlocal()),\n", + " 'message': 'PyTorchJob fine-tune-bert-1a883 is running.',\n", + " 'reason': 'PyTorchJobRunning',\n", + " 'status': 'True',\n", + " 'type': 'Running'}]\n", + "----------------------------------------\n", + "PyTorchJob is running\n" + ] + } + ], + "source": [ + "print(\"PyTorchJob Conditions\")\n", + "print(TrainingClient().get_job_conditions(job_name))\n", + "print(\"-\" * 40)\n", + "\n", + "# Wait until PyTorchJob has Running condition.\n", + "job = TrainingClient().wait_for_job_conditions(\n", + " job_name,\n", + " expected_conditions={\"Running\"},\n", + ")\n", + "print(\"PyTorchJob is running\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Get the PyTorchJob pod names\n", + "\n", + "Since we set 3 workers, PyTorchJob will create 1 master pod and 2 worker pods to execute distributed training." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-10T00:38:02.257947Z", + "iopub.status.busy": "2024-03-10T00:38:02.257697Z", + "iopub.status.idle": "2024-03-10T00:38:02.307198Z", + "shell.execute_reply": "2024-03-10T00:38:02.306329Z", + "shell.execute_reply.started": "2024-03-10T00:38:02.257930Z" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['fine-tune-bert-1a883-master-0',\n", + " 'fine-tune-bert-1a883-worker-0',\n", + " 'fine-tune-bert-1a883-worker-1']" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "TrainingClient().get_job_pod_names(job_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": { + "iopub.status.busy": "2022-09-01T20:10:25.759950Z", + "iopub.status.idle": "2022-09-01T20:10:25.760581Z", + "shell.execute_reply": "2022-09-01T20:10:25.760353Z", + "shell.execute_reply.started": "2022-09-01T20:10:25.760328Z" + }, + "tags": [] + }, + "source": [ + "### Get the PyTorchJob training logs\n", + "\n", + "Every worker processes 1200 training samples on each epoch since we distribute 3600 training samples across 3 workers." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-10T00:38:05.788903Z", + "iopub.status.busy": "2024-03-10T00:38:05.788625Z", + "iopub.status.idle": "2024-03-10T00:40:25.904118Z", + "shell.execute_reply": "2024-03-10T00:40:25.903020Z", + "shell.execute_reply.started": "2024-03-10T00:38:05.788883Z" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[Pod fine-tune-bert-1a883-master-0]: WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\n", + "[Pod fine-tune-bert-1a883-master-0]: ----------------------------------------\n", + "[Pod fine-tune-bert-1a883-master-0]: Download BERT Model\n", + "[Pod fine-tune-bert-1a883-master-0]: Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']\n", + "[Pod fine-tune-bert-1a883-master-0]: You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n", + "[Pod fine-tune-bert-1a883-master-0]: ----------------------------------------\n", + "[Pod fine-tune-bert-1a883-master-0]: Download Yelp Review Dataset\n", + "Downloading readme: 100%|██████████| 6.72k/6.72k [00:00<00:00, 26.2MB/s]\n", + "Downloading data: 100%|██████████| 299M/299M [00:05<00:00, 57.4MB/s] \n", + "Downloading data: 100%|██████████| 23.5M/23.5M [00:00<00:00, 45.3MB/s]\n", + "Generating train split: 100%|██████████| 650000/650000 [00:01<00:00, 371416.73 examples/s]\n", + "Generating test split: 100%|██████████| 50000/50000 [00:00<00:00, 363106.11 examples/s]\n", + "[Pod fine-tune-bert-1a883-master-0]: ----------------------------------------\n", + "[Pod fine-tune-bert-1a883-master-0]: Map Yelp review dataset to BERT Tokenizer\n", + "Map: 100%|██████████| 3600/3600 [00:01<00:00, 2464.94 examples/s]\n", + "Map: 100%|██████████| 400/400 [00:00<00:00, 2553.52 examples/s]\n", + "Downloading builder script: 100%|██████████| 4.20k/4.20k [00:00<00:00, 16.6MB/s]\n", + "[Pod fine-tune-bert-1a883-master-0]: /opt/conda/lib/python3.10/site-packages/accelerate/state.py:306: UserWarning: OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at 16 to improve oob performance.\n", + "[Pod fine-tune-bert-1a883-master-0]: warnings.warn(\n", + "[Pod fine-tune-bert-1a883-master-0]: ----------------------------------------\n", + "[Pod fine-tune-bert-1a883-master-0]: Start Distributed Training. RANK: 0 WORLD_SIZE: 3\n", + "[Pod fine-tune-bert-1a883-master-0]: The following columns in the training set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`, you can safely ignore this message.\n", + "[Pod fine-tune-bert-1a883-master-0]: ***** Running training *****\n", + "[Pod fine-tune-bert-1a883-master-0]: Num examples = 1,200\n", + "[Pod fine-tune-bert-1a883-master-0]: Num Epochs = 3\n", + "[Pod fine-tune-bert-1a883-master-0]: Instantaneous batch size per device = 8\n", + "[Pod fine-tune-bert-1a883-master-0]: Total train batch size (w. parallel, distributed & accumulation) = 24\n", + "[Pod fine-tune-bert-1a883-master-0]: Gradient Accumulation steps = 1\n", + "[Pod fine-tune-bert-1a883-master-0]: Total optimization steps = 150\n", + "[Pod fine-tune-bert-1a883-master-0]: Number of trainable parameters = 108,314,117\n", + "[Pod fine-tune-bert-1a883-master-0]: [W reducer.cpp:1346] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration, which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())\n", + "[Pod fine-tune-bert-1a883-master-0]: The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`, you can safely ignore this message.\n", + "[Pod fine-tune-bert-1a883-master-0]: ***** Running Evaluation *****\n", + "[Pod fine-tune-bert-1a883-master-0]: Num examples = 134\n", + "[Pod fine-tune-bert-1a883-master-0]: Batch size = 8\n", + "[Pod fine-tune-bert-1a883-master-0]: {'eval_loss': 1.2028350830078125, 'eval_accuracy': 0.4925373134328358, 'eval_runtime': 0.5392, 'eval_samples_per_second': 248.532, 'eval_steps_per_second': 11.128, 'epoch': 1.0}\n", + "[Pod fine-tune-bert-1a883-master-0]: The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`, you can safely ignore this message.\n", + "[Pod fine-tune-bert-1a883-master-0]: ***** Running Evaluation *****\n", + "[Pod fine-tune-bert-1a883-master-0]: Num examples = 134\n", + "[Pod fine-tune-bert-1a883-master-0]: Batch size = 8\n", + "[Pod fine-tune-bert-1a883-master-0]: {'eval_loss': 0.9666597843170166, 'eval_accuracy': 0.5895522388059702, 'eval_runtime': 0.5656, 'eval_samples_per_second': 236.909, 'eval_steps_per_second': 10.608, 'epoch': 2.0}\n", + "[Pod fine-tune-bert-1a883-master-0]: The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`, you can safely ignore this message.\n", + "[Pod fine-tune-bert-1a883-master-0]: ***** Running Evaluation *****\n", + "[Pod fine-tune-bert-1a883-master-0]: Num examples = 134\n", + "[Pod fine-tune-bert-1a883-master-0]: Batch size = 8\n", + "[Pod fine-tune-bert-1a883-master-0]: {'eval_loss': 0.852095901966095, 'eval_accuracy': 0.6268656716417911, 'eval_runtime': 0.5951, 'eval_samples_per_second': 225.172, 'eval_steps_per_second': 10.082, 'epoch': 3.0}\n", + "[Pod fine-tune-bert-1a883-master-0]: Training completed. Do not forget to share your model on huggingface.co/models =)\n", + "[Pod fine-tune-bert-1a883-master-0]: {'train_runtime': 73.6766, 'train_samples_per_second': 48.862, 'train_steps_per_second': 2.036, 'train_loss': 1.166010030110677, 'epoch': 3.0}\n", + "[Pod fine-tune-bert-1a883-master-0]: ----------------------------------------\n", + "[Pod fine-tune-bert-1a883-master-0]: Training is complete\n", + "[Pod fine-tune-bert-1a883-master-0]: Saving model checkpoint to ./bert\n", + "[Pod fine-tune-bert-1a883-master-0]: Configuration saved in ./bert/config.json\n", + "[Pod fine-tune-bert-1a883-master-0]: Model weights saved in ./bert/model.safetensors\n", + "[Pod fine-tune-bert-1a883-master-0]: ----------------------------------------\n", + "[Pod fine-tune-bert-1a883-master-0]: Model is exported to S3\n" + ] + } + ], + "source": [ + "logs, _ = TrainingClient().get_job_logs(job_name, follow=True)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download the fine-tuned model\n", + "\n", + "We can download our fine-tuned BERT model from S3 to evaluate it." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-10T00:41:32.463113Z", + "iopub.status.busy": "2024-03-10T00:41:32.462861Z", + "iopub.status.idle": "2024-03-10T00:41:34.615767Z", + "shell.execute_reply": "2024-03-10T00:41:34.615101Z", + "shell.execute_reply.started": "2024-03-10T00:41:32.463095Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import boto3\n", + "\n", + "s3 = boto3.resource(\"s3\")\n", + "bucket = s3.Bucket(bucket)\n", + "\n", + "# config.json is the model metadata.\n", + "# model.safetensors is the model weights & biases.\n", + "if not os.path.exists(\"bert\"):\n", + " os.makedirs(\"bert\")\n", + "bucket.download_file(\"bert/config.json\", \"bert/config.json\")\n", + "bucket.download_file(\"bert/model.safetensors\", \"bert/model.safetensors\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "tags": [] + }, + "source": [ + "### Test the fine-tuned BERT model\n", + "\n", + "We are going to use HuggingFace pipeline to test our model.\n", + "\n", + "We will ask for sentiment analysis task for our fine-tuned LLM." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-10T00:43:29.026194Z", + "iopub.status.busy": "2024-03-10T00:43:29.025948Z", + "iopub.status.idle": "2024-03-10T00:43:29.651226Z", + "shell.execute_reply": "2024-03-10T00:43:29.650644Z", + "shell.execute_reply.started": "2024-03-10T00:43:29.026177Z" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "This is one of the best restaurants I've ever been to.\n", + "Star: 4\n", + "Score: 0.8029219508171082\n", + "---------------------------\n", + "\n", + "\n", + "I am upset by using this service. It is very expensive and quality is bad.\n", + "Star: 1\n", + "Score: 0.5185158848762512\n", + "---------------------------\n" + ] + } + ], + "source": [ + "from transformers import AutoTokenizer, pipeline\n", + "\n", + "# During fine-tuning BERT tokenizer is not changed.\n", + "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-cased\")\n", + "\n", + "# Use pipeline with sentiment-analysis task to evaluate our model.\n", + "nlp = pipeline(\"sentiment-analysis\", model=\"./bert\", tokenizer=tokenizer)\n", + "\n", + "good_review = \"This is one of the best restaurants I've ever been to.\"\n", + "bad_review = \"I am upset by using this service. It is very expensive and quality is bad.\"\n", + "\n", + "print(good_review)\n", + "res = nlp(good_review)\n", + "\n", + "print(\"Star: \", res[0][\"label\"][6])\n", + "print(\"Score: \", res[0][\"score\"])\n", + "print(\"---------------------------\\n\\n\")\n", + "\n", + "\n", + "print(bad_review)\n", + "res = nlp(bad_review)\n", + "\n", + "print(\"Star: \", res[0][\"label\"][6])\n", + "print(\"Score: \", res[0][\"score\"])\n", + "print(\"---------------------------\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-01T23:44:15.511173Z", + "iopub.status.busy": "2024-03-01T23:44:15.510932Z", + "iopub.status.idle": "2024-03-01T23:44:15.539921Z", + "shell.execute_reply": "2024-03-01T23:44:15.539352Z", + "shell.execute_reply.started": "2024-03-01T23:44:15.511155Z" + }, + "tags": [] + }, + "source": [ + "## Delete the PyTorchJob\n", + "\n", + "When PyTorchJob is finished, you can delete the resource." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-10T00:43:41.129972Z", + "iopub.status.busy": "2024-03-10T00:43:41.129720Z", + "iopub.status.idle": "2024-03-10T00:43:41.157373Z", + "shell.execute_reply": "2024-03-10T00:43:41.156125Z", + "shell.execute_reply.started": "2024-03-10T00:43:41.129955Z" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "TrainingClient().delete_job(name=job_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/examples/sdk/create-tfjob.ipynb b/examples/tensorflow/image-classification/create-tfjob.ipynb similarity index 100% rename from examples/sdk/create-tfjob.ipynb rename to examples/tensorflow/image-classification/create-tfjob.ipynb diff --git a/sdk/python/README.md b/sdk/python/README.md index ac5c75d64d..be4a521861 100644 --- a/sdk/python/README.md +++ b/sdk/python/README.md @@ -29,7 +29,8 @@ python setup.py install --user ## Getting Started -Please follow the [sample](examples/kubeflow-tfjob-sdk.ipynb) to create, update and delete TFJob. +Please follow the [Getting Started guide](https://www.kubeflow.org/docs/components/training/overview/#getting-started) +or check Training Operator [examples](../../examples). ## Documentation for API Endpoints