diff --git a/mindnlp/peft_lora_mindnlp.ipynb b/mindnlp/peft_lora_mindnlp.ipynb deleted file mode 100644 index 51f133c95..000000000 --- a/mindnlp/peft_lora_mindnlp.ipynb +++ /dev/null @@ -1,2572 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "48608ac7-71cd-4859-9d27-aac6b162d2b0", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for type is zero.\n", - " setattr(self, word, getattr(machar, word).flat[0])\n", - "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", - " return self._float_to_str(self.smallest_subnormal)\n", - "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:499: UserWarning: The value of the smallest subnormal for type is zero.\n", - " setattr(self, word, getattr(machar, word).flat[0])\n", - "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/numpy/core/getlimits.py:89: UserWarning: The value of the smallest subnormal for type is zero.\n", - " return self._float_to_str(self.smallest_subnormal)\n", - "/home/ma-user/anaconda3/envs/MindSpore/lib/python3.9/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n", - "Building prefix dict from the default dictionary ...\n", - "Loading model from cache /tmp/jieba.cache\n", - "Loading model cost 1.289 seconds.\n", - "Prefix dict has been built successfully.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dataset column: ['image', 'label']\n", - "dataset size: 5000\n", - "dataset batch size: 1\n" - ] - } - ], - "source": [ - "import mindspore\n", - "import mindnlp\n", - "import numpy as np\n", - "from mindspore import context, Tensor\n", - "from mindnlp.dataset import load_dataset\n", - "dataset = load_dataset(\"food101\", split=\"train[:5000]\")\n", - "\n", - "def show_dataset_info(dataset):\n", - " print(\"dataset column: {}\".format(dataset.get_col_names()))\n", - " print(\"dataset size: {}\".format(dataset.get_dataset_size()))\n", - " print(\"dataset batch size: {}\".format(dataset.get_batch_size()))\n", - "show_dataset_info(dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "e3f69b69-c1f4-4c60-b07a-d4a63784e711", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dataset column: ['image', 'label']\n", - "dataset size: 4500\n", - "dataset batch size: 1\n", - "dataset column: ['image', 'label']\n", - "dataset size: 500\n", - "dataset batch size: 1\n" - ] - } - ], - "source": [ - "train_ds, val_ds = dataset.split([0.9, 0.1])\n", - "show_dataset_info(train_ds)\n", - "show_dataset_info(val_ds)\n", - "\n", - "from mindnlp.transformers import ViTImageProcessor\n", - "image_processor = ViTImageProcessor.from_pretrained(\"google/vit-base-patch16-224-in21k\")\n", - "\n", - "def transform(image, label):\n", - " # 使用图像处理器处理\n", - " processed_output = image_processor(image, return_tensors='np')\n", - "\n", - " # 获取 'pixel_values',移除多余的批次维度\n", - " pixel_values = processed_output['pixel_values']\n", - " if len(pixel_values.shape) == 4 and pixel_values.shape[0] == 1:\n", - " pixel_values = np.squeeze(pixel_values, axis=0) # 移除第一个批次维度,变成 (3, 224, 224)\n", - " \n", - " labels = np.array([label], dtype=np.int32)\n", - " return pixel_values, labels" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "5af33ab7-520d-4911-a36e-20bd56da693f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "dataset column: ['pixel_values', 'labels']\n", - "dataset size: 281\n", - "dataset batch size: 16\n", - "dataset column: ['pixel_values', 'labels']\n", - "dataset size: 31\n", - "dataset batch size: 16\n" - ] - } - ], - "source": [ - "# 处理训练集\n", - "train_ds = train_ds.map(operations=transform, input_columns=[\"image\", \"label\"], output_columns=[\"pixel_values\", \"labels\"])\n", - "train_ds = train_ds.batch(batch_size=16, drop_remainder=True)\n", - "\n", - "# 处理验证集\n", - "val_ds = val_ds.map(operations=transform, input_columns=[\"image\", \"label\"], output_columns=[\"pixel_values\", \"labels\"])\n", - "val_ds = val_ds.batch(batch_size=16, drop_remainder=True)\n", - "\n", - "show_dataset_info(train_ds)\n", - "show_dataset_info(val_ds)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "760a10e1-567c-488f-b77a-a1a5b5d47394", - "metadata": {}, - "outputs": [], - "source": [ - "# 定义 Food101 的类别名称列表(需完整填写101个类别)\n", - "class_names = [\n", - " 'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',\n", - " 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',\n", - " 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',\n", - " 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla',\n", - " 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', 'clam_chowder',\n", - " 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', 'cup_cakes',\n", - " 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict',\n", - " 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras',\n", - " 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', 'fried_rice',\n", - " 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', 'grilled_cheese_sandwich',\n", - " 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', 'hot_and_sour_soup',\n", - " 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', 'lasagna',\n", - " 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', 'macarons', 'miso_soup',\n", - " 'mussels', 'nachos', 'omelette', 'onion_rings', 'oysters',\n", - " 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck',\n", - " 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib',\n", - " 'pulled_pork_sandwich', 'ramen', 'ravioli', 'red_velvet_cake', 'risotto',\n", - " 'samosa', 'sashimi', 'scallops', 'seaweed_salad', 'shrimp_and_grits',\n", - " 'spaghetti_bolognese', 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake',\n", - " 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare',\n", - " 'waffles'\n", - "]\n", - "\n", - "# 创建 label2id 和 id2label 字典\n", - "label2id = {name: idx for idx, name in enumerate(class_names)}\n", - "id2label = {idx: name for idx, name in enumerate(class_names)}" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "381475e6-7060-4c75-8ff9-8ea09d4a8c42", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[MS_ALLOC_CONF]Runtime config: enable_vmm:True vmm_align_size:2MB\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']\n", - "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "trainable params: 85876325 || all params: 85876325 || trainable%: 100.00\n" - ] - } - ], - "source": [ - "def print_trainable_parameters(model):\n", - " \"\"\"\n", - " Prints the number of trainable parameters in the model.\n", - " \"\"\"\n", - " trainable_params = 0\n", - " all_param = 0\n", - " for _, param in model.named_parameters():\n", - " all_param += param.numel()\n", - " if param.requires_grad:\n", - " trainable_params += param.numel()\n", - " print(\n", - " f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}\"\n", - " )\n", - "from mindnlp.transformers import ViTForImageClassification\n", - "model = ViTForImageClassification.from_pretrained(\n", - " \"google/vit-base-patch16-224-in21k\",\n", - " num_labels=101,\n", - " label2id=label2id,\n", - " id2label=id2label,\n", - " ignore_mismatched_sizes=True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint\n", - ")\n", - "print_trainable_parameters(model)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "68fed2fb-0060-434a-9e51-316b232ae52f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "trainable params: 667493 || all params: 86543818 || trainable%: 0.77\n" - ] - } - ], - "source": [ - "from mindnlp.peft import LoraConfig, get_peft_model\n", - "config = LoraConfig(\n", - " r=16,\n", - " lora_alpha=16,\n", - " target_modules=[\"query\", \"value\"],\n", - " lora_dropout=0.1,\n", - " bias=\"none\",\n", - " modules_to_save=[\"classifier\"],\n", - ")\n", - "lora_model = get_peft_model(model, config)\n", - "print_trainable_parameters(lora_model)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "f2dd558d-d4ca-4ebf-ba9c-6342766d2a4e", - "metadata": {}, - "outputs": [], - "source": [ - "from mindnlp.engine import Trainer, TrainingArguments\n", - "\n", - "training_args = TrainingArguments(\n", - " output_dir=\"./vit-base-food101\",\n", - " per_device_train_batch_size=128,\n", - " evaluation_strategy=\"epoch\",\n", - " num_train_epochs=5,\n", - " fp16=True,\n", - " save_steps=100,\n", - " eval_steps=100,\n", - " logging_steps=10,\n", - " learning_rate=5e-3,\n", - " save_total_limit=2,\n", - " remove_unused_columns=False,\n", - " load_best_model_at_end=False,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "70c986d1-fb52-464f-a1ed-5497dd94f8b9", - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import evaluate\n", - "\n", - "metric = evaluate.load(\"accuracy\")\n", - "# the compute_metrics function takes a Named Tuple as input:\n", - "# predictions, which are the logits of the model as Numpy arrays,\n", - "# and label_ids, which are the ground-truth labels as Numpy arrays.\n", - "def compute_metrics(eval_pred):\n", - " \"\"\"Computes accuracy on a batch of predictions\"\"\"\n", - " predictions = np.argmax(eval_pred.predictions, axis=1)\n", - " return metric.compute(predictions=predictions, references=eval_pred.label_ids)\n", - "\n", - "trainer = Trainer(\n", - " model=lora_model,\n", - " args=training_args,\n", - " compute_metrics=compute_metrics,\n", - " train_dataset=train_ds,\n", - " eval_dataset=val_ds,\n", - " tokenizer=image_processor,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "fb0114ce-dabc-4704-bf88-ad9b9971f166", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 1%| | 10/1405 [00:08<07:54, 2.94it/s] " - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 2.4978, 'learning_rate': 0.004964412811387901, 'epoch': 0.04}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 1%|▏ | 20/1405 [00:10<05:57, 3.88it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.4403, 'learning_rate': 0.004928825622775801, 'epoch': 0.07}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 2%|▏ | 30/1405 [00:13<05:59, 3.82it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.3283, 'learning_rate': 0.004893238434163701, 'epoch': 0.11}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 3%|▎ | 40/1405 [00:16<06:02, 3.77it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.2904, 'learning_rate': 0.004857651245551602, 'epoch': 0.14}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 4%|▎ | 50/1405 [00:18<05:58, 3.78it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.3089, 'learning_rate': 0.004822064056939502, 'epoch': 0.18}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 4%|▍ | 60/1405 [00:21<05:57, 3.76it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.4635, 'learning_rate': 0.004786476868327403, 'epoch': 0.21}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 5%|▍ | 70/1405 [00:24<05:54, 3.76it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.3133, 'learning_rate': 0.004750889679715303, 'epoch': 0.25}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 6%|▌ | 80/1405 [00:26<05:57, 3.70it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.5148, 'learning_rate': 0.004715302491103203, 'epoch': 0.28}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 6%|▋ | 90/1405 [00:29<05:18, 4.13it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.3734, 'learning_rate': 0.0046797153024911034, 'epoch': 0.32}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 7%|▋ | 100/1405 [00:31<05:15, 4.13it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.3448, 'learning_rate': 0.004644128113879003, 'epoch': 0.36}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 8%|▊ | 110/1405 [00:35<05:30, 3.92it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.4002, 'learning_rate': 0.004608540925266904, 'epoch': 0.39}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 9%|▊ | 120/1405 [00:37<05:26, 3.94it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.473, 'learning_rate': 0.004572953736654804, 'epoch': 0.43}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 9%|▉ | 130/1405 [00:40<05:21, 3.96it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.2682, 'learning_rate': 0.004537366548042704, 'epoch': 0.46}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 10%|▉ | 140/1405 [00:42<05:22, 3.93it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.3435, 'learning_rate': 0.004501779359430605, 'epoch': 0.5}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 11%|█ | 150/1405 [00:45<05:21, 3.90it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.3911, 'learning_rate': 0.004466192170818505, 'epoch': 0.53}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 11%|█▏ | 160/1405 [00:47<05:03, 4.11it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.2626, 'learning_rate': 0.004430604982206406, 'epoch': 0.57}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 12%|█▏ | 170/1405 [00:50<04:45, 4.32it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.4019, 'learning_rate': 0.004395017793594306, 'epoch': 0.6}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 13%|█▎ | 180/1405 [00:52<04:49, 4.24it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.3315, 'learning_rate': 0.004359430604982207, 'epoch': 0.64}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 14%|█▎ | 190/1405 [00:54<04:55, 4.11it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.3978, 'learning_rate': 0.004323843416370107, 'epoch': 0.68}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 14%|█▍ | 200/1405 [00:57<05:03, 3.97it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.3234, 'learning_rate': 0.004288256227758008, 'epoch': 0.71}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 15%|█▍ | 210/1405 [01:00<05:08, 3.87it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.4354, 'learning_rate': 0.004252669039145908, 'epoch': 0.75}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 16%|█▌ | 220/1405 [01:03<04:59, 3.96it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.3388, 'learning_rate': 0.004217081850533808, 'epoch': 0.78}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 16%|█▋ | 231/1405 [01:05<04:47, 4.09it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.3434, 'learning_rate': 0.004181494661921708, 'epoch': 0.82}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 17%|█▋ | 241/1405 [01:08<04:14, 4.57it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.2394, 'learning_rate': 0.004145907473309608, 'epoch': 0.85}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 18%|█▊ | 251/1405 [01:10<04:10, 4.60it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.2386, 'learning_rate': 0.004110320284697509, 'epoch': 0.89}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 19%|█▊ | 261/1405 [01:12<04:03, 4.69it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.4592, 'learning_rate': 0.004074733096085409, 'epoch': 0.93}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 19%|█▉ | 271/1405 [01:14<04:08, 4.57it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.5556, 'learning_rate': 0.004039145907473309, 'epoch': 0.96}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " 20%|██ | 281/1405 [01:16<03:59, 4.70it/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'loss': 0.3026, 'learning_rate': 0.00400355871886121, 'epoch': 1.0}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n", - " 0%| | 0/31 [00:00 is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.\n" - ] - }, - { - "data": { - "text/plain": [ - "ViTImageProcessor {\n", - " \"do_normalize\": true,\n", - " \"do_rescale\": true,\n", - " \"do_resize\": true,\n", - " \"image_mean\": [\n", - " 0.5,\n", - " 0.5,\n", - " 0.5\n", - " ],\n", - " \"image_processor_type\": \"ViTImageProcessor\",\n", - " \"image_std\": [\n", - " 0.5,\n", - " 0.5,\n", - " 0.5\n", - " ],\n", - " \"resample\": 2,\n", - " \"rescale_factor\": 0.00392156862745098,\n", - " \"size\": {\n", - " \"height\": 224,\n", - " \"width\": 224\n", - " }\n", - "}" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from transformers import AutoImageProcessor\n", - "\n", - "image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)\n", - "image_processor" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "from torchvision.transforms import (\n", - " CenterCrop,\n", - " Compose,\n", - " Normalize,\n", - " RandomHorizontalFlip,\n", - " RandomResizedCrop,\n", - " Resize,\n", - " ToTensor,\n", - ")\n", - "\n", - "normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)\n", - "train_transforms = Compose(\n", - " [\n", - " RandomResizedCrop(image_processor.size[\"height\"]),\n", - " RandomHorizontalFlip(),\n", - " ToTensor(),\n", - " normalize,\n", - " ]\n", - ")\n", - "\n", - "val_transforms = Compose(\n", - " [\n", - " Resize(image_processor.size[\"height\"]),\n", - " CenterCrop(image_processor.size[\"height\"]),\n", - " ToTensor(),\n", - " normalize,\n", - " ]\n", - ")\n", - "\n", - "\n", - "def preprocess_train(example_batch):\n", - " \"\"\"Apply train_transforms across a batch.\"\"\"\n", - " example_batch[\"pixel_values\"] = [train_transforms(image.convert(\"RGB\")) for image in example_batch[\"image\"]]\n", - " return example_batch\n", - "\n", - "\n", - "def preprocess_val(example_batch):\n", - " \"\"\"Apply val_transforms across a batch.\"\"\"\n", - " example_batch[\"pixel_values\"] = [val_transforms(image.convert(\"RGB\")) for image in example_batch[\"image\"]]\n", - " return example_batch" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "# split up training into training + validation\n", - "splits = dataset.train_test_split(test_size=0.1)\n", - "train_ds = splits[\"train\"]\n", - "val_ds = splits[\"test\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "train_ds.set_transform(preprocess_train)\n", - "val_ds.set_transform(preprocess_val)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "def print_trainable_parameters(model):\n", - " \"\"\"\n", - " Prints the number of trainable parameters in the model.\n", - " \"\"\"\n", - " trainable_params = 0\n", - " all_param = 0\n", - " for _, param in model.named_parameters():\n", - " all_param += param.numel()\n", - " if param.requires_grad:\n", - " trainable_params += param.numel()\n", - " print(\n", - " f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param:.2f}\"\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']\n", - "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" - ] - }, - { - "data": { - "text/plain": [ - "ViTForImageClassification(\n", - " (vit): ViTModel(\n", - " (embeddings): ViTEmbeddings(\n", - " (patch_embeddings): ViTPatchEmbeddings(\n", - " (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))\n", - " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (encoder): ViTEncoder(\n", - " (layer): ModuleList(\n", - " (0-11): 12 x ViTLayer(\n", - " (attention): ViTSdpaAttention(\n", - " (attention): ViTSdpaSelfAttention(\n", - " (query): Linear(in_features=768, out_features=768, bias=True)\n", - " (key): Linear(in_features=768, out_features=768, bias=True)\n", - " (value): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (output): ViTSelfOutput(\n", - " (dense): Linear(in_features=768, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " )\n", - " (intermediate): ViTIntermediate(\n", - " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", - " (intermediate_act_fn): GELUActivation()\n", - " )\n", - " (output): ViTOutput(\n", - " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " )\n", - " (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " )\n", - " )\n", - " )\n", - " (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", - " )\n", - " (classifier): Linear(in_features=768, out_features=101, bias=True)\n", - ")" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from transformers import AutoModelForImageClassification, TrainingArguments, Trainer\n", - "\n", - "model = AutoModelForImageClassification.from_pretrained(\n", - " model_checkpoint,\n", - " label2id=label2id,\n", - " id2label=id2label,\n", - " ignore_mismatched_sizes=True, # provide this in case you're planning to fine-tune an already fine-tuned checkpoint\n", - ")\n", - "# print_trainable_parameters(model)\n", - "model" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "trainable params: 667493 || all params: 86543818 || trainable%: 0.77\n" - ] - } - ], - "source": [ - "from peft import LoraConfig, get_peft_model\n", - "\n", - "config = LoraConfig(\n", - " r=16,\n", - " lora_alpha=16,\n", - " target_modules=[\"query\", \"value\"],\n", - " lora_dropout=0.1,\n", - " bias=\"none\",\n", - " modules_to_save=[\"classifier\"],\n", - ")\n", - "lora_model = get_peft_model(model, config)\n", - "print_trainable_parameters(lora_model)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "from transformers import TrainingArguments, Trainer\n", - "\n", - "\n", - "model_name = model_checkpoint.split(\"/\")[-1]\n", - "batch_size = 128\n", - "\n", - "args = TrainingArguments(\n", - " f\"{model_name}-finetuned-lora-food101\",\n", - " remove_unused_columns=False,\n", - " eval_strategy=\"epoch\",\n", - " save_strategy=\"epoch\",\n", - " learning_rate=5e-3,\n", - " per_device_train_batch_size=batch_size,\n", - " gradient_accumulation_steps=4,\n", - " per_device_eval_batch_size=batch_size,\n", - " fp16=True,\n", - " num_train_epochs=5,\n", - " logging_steps=10,\n", - " load_best_model_at_end=True,\n", - " metric_for_best_model=\"accuracy\",\n", - " push_to_hub=True,\n", - " label_names=[\"labels\"],\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import evaluate\n", - "\n", - "metric = evaluate.load(\"accuracy\")\n", - "\n", - "\n", - "# the compute_metrics function takes a Named Tuple as input:\n", - "# predictions, which are the logits of the model as Numpy arrays,\n", - "# and label_ids, which are the ground-truth labels as Numpy arrays.\n", - "def compute_metrics(eval_pred):\n", - " \"\"\"Computes accuracy on a batch of predictions\"\"\"\n", - " predictions = np.argmax(eval_pred.predictions, axis=1)\n", - " return metric.compute(predictions=predictions, references=eval_pred.label_ids)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "\n", - "\n", - "def collate_fn(examples):\n", - " pixel_values = torch.stack([example[\"pixel_values\"] for example in examples])\n", - " labels = torch.tensor([example[\"label\"] for example in examples])\n", - " return {\"pixel_values\": pixel_values, \"labels\": labels}" - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "c:\\Users\\19895\\AppData\\Local\\conda\\conda\\envs\\tg\\lib\\site-packages\\accelerate\\accelerator.py:494: FutureWarning: `torch.cuda.amp.GradScaler(args...)` is deprecated. Please use `torch.amp.GradScaler('cuda', args...)` instead.\n", - " self.scaler = torch.cuda.amp.GradScaler(**kwargs)\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5b7c23601ba246aebd325ef408f8f0b4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - " 0%| | 0/45 [00:00