diff --git a/docs/en_US/Compression/QuickStart.rst b/docs/en_US/Compression/QuickStart.rst index 0e4b33b692..2998f23507 100644 --- a/docs/en_US/Compression/QuickStart.rst +++ b/docs/en_US/Compression/QuickStart.rst @@ -5,10 +5,13 @@ Quick Start :hidden: Tutorial + Notebook Example Model compression usually consists of three stages: 1) pre-training a model, 2) compress the model, 3) fine-tuning the model. NNI mainly focuses on the second stage and provides very simple APIs for compressing a model. Follow this guide for a quick look at how easy it is to use NNI to compress a model. +A `compression pipeline example <./compression_pipeline_example.rst>`__ with Jupyter notebook is supported and refer the code :githublink:`here `. + Model Pruning ------------- @@ -31,7 +34,7 @@ The specification of configuration can be found `here <./Tutorial.rst#specify-th Step2. Choose a pruner and compress the model ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -First instantiate the chosen pruner with your model and configuration as arguments, then invoke ``compress()`` to compress your model. Note that, some algorithms may check gradients for compressing, so we may also define an optimizer and pass it to the pruner. +First instantiate the chosen pruner with your model and configuration as arguments, then invoke ``compress()`` to compress your model. Note that, some algorithms may check gradients for compressing, so we may also define a trainer, an optimizer, a criterion and pass them to the pruner. .. code-block:: python @@ -42,19 +45,16 @@ First instantiate the chosen pruner with your model and configuration as argumen Some pruners (e.g., L1FilterPruner, FPGMPruner) prune once, some pruners (e.g., AGPPruner) prune your model iteratively, the masks are adjusted epoch by epoch during training. -Note that, ``pruner.compress`` simply adds masks on model weights, it does not include fine-tuning logic. If users want to fine tune the compressed model, they need to write the fine tune logic by themselves after ``pruner.compress``. +So if the pruners prune your model iteratively or they need training or inference to get gradients, you need pass finetuning logic to pruner. For example: .. code-block:: python - for epoch in range(1, args.epochs + 1): - pruner.update_epoch(epoch) - train(args, model, device, train_loader, optimizer_finetune, epoch) - test(model, device, test_loader) - -More APIs to control the fine-tuning can be found `here <./Tutorial.rst#apis-to-control-the-fine-tuning>`__. + from nni.algorithms.compression.pytorch.pruning import AGPPruner + pruner = AGPPruner(model, config_list, optimizer, trainer, criterion, num_iterations=10, epochs_per_iteration=1, pruning_algorithm='level') + model = pruner.compress() Step3. Export compression result ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/en_US/Compression/Tutorial.rst b/docs/en_US/Compression/Tutorial.rst index 8c60d337fc..29efacd6af 100644 --- a/docs/en_US/Compression/Tutorial.rst +++ b/docs/en_US/Compression/Tutorial.rst @@ -185,13 +185,6 @@ Please refer to `here `__ for detailed description. The exampl Control the Fine-tuning process ------------------------------- -APIs to control the fine-tuning -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Some compression algorithms control the progress of compression during fine-tuning (e.g. `AGP <../Compression/Pruner.rst#agp-pruner>`__\ ), and some algorithms need to do something after every minibatch. Therefore, we provide another two APIs for users to invoke: ``pruner.update_epoch(epoch)`` and ``pruner.step()``. - -``update_epoch`` should be invoked in every epoch, while ``step`` should be invoked after each minibatch. Note that most algorithms do not require calling the two APIs. Please refer to each algorithm's document for details. For the algorithms that do not need them, calling them is allowed but has no effect. - Enhance the fine-tuning process ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/en_US/Compression/compression_pipeline_example.ipynb b/docs/en_US/Compression/compression_pipeline_example.ipynb new file mode 100644 index 0000000000..1493a3c030 --- /dev/null +++ b/docs/en_US/Compression/compression_pipeline_example.ipynb @@ -0,0 +1,1281 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# 1. Prepare model" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 1, + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "class NaiveModel(torch.nn.Module):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.conv1 = torch.nn.Conv2d(1, 20, 5, 1)\n", + " self.conv2 = torch.nn.Conv2d(20, 50, 5, 1)\n", + " self.fc1 = torch.nn.Linear(4 * 4 * 50, 500)\n", + " self.fc2 = torch.nn.Linear(500, 10)\n", + " self.relu1 = torch.nn.ReLU6()\n", + " self.relu2 = torch.nn.ReLU6()\n", + " self.relu3 = torch.nn.ReLU6()\n", + " self.max_pool1 = torch.nn.MaxPool2d(2, 2)\n", + " self.max_pool2 = torch.nn.MaxPool2d(2, 2)\n", + "\n", + " def forward(self, x):\n", + " x = self.relu1(self.conv1(x))\n", + " x = self.max_pool1(x)\n", + " x = self.relu2(self.conv2(x))\n", + " x = self.max_pool2(x)\n", + " x = x.view(-1, x.size()[1:].numel())\n", + " x = self.relu3(self.fc1(x))\n", + " x = self.fc2(x)\n", + " return F.log_softmax(x, dim=1)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 2, + "source": [ + "# define model, optimizer, criterion, data_loader, trainer, evaluator.\n", + "\n", + "import torch.optim as optim\n", + "from torchvision import datasets, transforms\n", + "from torch.optim.lr_scheduler import StepLR\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "\n", + "model = NaiveModel().to(device)\n", + "\n", + "optimizer = optim.Adadelta(model.parameters(), lr=1)\n", + "\n", + "criterion = torch.nn.NLLLoss()\n", + "\n", + "transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])\n", + "train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)\n", + "test_dataset = datasets.MNIST('./data', train=False, transform=transform)\n", + "train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64)\n", + "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000)\n", + "\n", + "def trainer(model, optimizer, criterion, epoch):\n", + " model.train()\n", + " for batch_idx, (data, target) in enumerate(train_loader):\n", + " data, target = data.to(device), target.to(device)\n", + " optimizer.zero_grad()\n", + " output = model(data)\n", + " loss = criterion(output, target)\n", + " loss.backward()\n", + " optimizer.step()\n", + " if batch_idx % 100 == 0:\n", + " print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n", + " epoch, batch_idx * len(data), len(train_loader.dataset),\n", + " 100. * batch_idx / len(train_loader), loss.item()))\n", + "\n", + "def evaluator(model):\n", + " model.eval()\n", + " test_loss = 0\n", + " correct = 0\n", + " with torch.no_grad():\n", + " for data, target in test_loader:\n", + " data, target = data.to(device), target.to(device)\n", + " output = model(data)\n", + " test_loss += F.nll_loss(output, target, reduction='sum').item()\n", + " pred = output.argmax(dim=1, keepdim=True)\n", + " correct += pred.eq(target.view_as(pred)).sum().item()\n", + "\n", + " test_loss /= len(test_loader.dataset)\n", + " acc = 100 * correct / len(test_loader.dataset)\n", + "\n", + " print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'.format(\n", + " test_loss, correct, len(test_loader.dataset), acc))\n", + "\n", + " return acc" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 3, + "source": [ + "# pre-train model for 3 epoches.\n", + "\n", + "scheduler = StepLR(optimizer, step_size=1, gamma=0.7)\n", + "\n", + "for epoch in range(0, 3):\n", + " trainer(model, optimizer, criterion, epoch)\n", + " evaluator(model)\n", + " scheduler.step()" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Train Epoch: 0 [0/60000 (0%)]\tLoss: 2.313423\n", + "Train Epoch: 0 [6400/60000 (11%)]\tLoss: 0.091786\n", + "Train Epoch: 0 [12800/60000 (21%)]\tLoss: 0.087317\n", + "Train Epoch: 0 [19200/60000 (32%)]\tLoss: 0.036397\n", + "Train Epoch: 0 [25600/60000 (43%)]\tLoss: 0.008173\n", + "Train Epoch: 0 [32000/60000 (53%)]\tLoss: 0.047565\n", + "Train Epoch: 0 [38400/60000 (64%)]\tLoss: 0.122448\n", + "Train Epoch: 0 [44800/60000 (75%)]\tLoss: 0.036732\n", + "Train Epoch: 0 [51200/60000 (85%)]\tLoss: 0.150135\n", + "Train Epoch: 0 [57600/60000 (96%)]\tLoss: 0.109684\n", + "\n", + "Test set: Average loss: 0.0457, Accuracy: 9857/10000 (99%)\n", + "\n", + "Train Epoch: 1 [0/60000 (0%)]\tLoss: 0.020650\n", + "Train Epoch: 1 [6400/60000 (11%)]\tLoss: 0.091525\n", + "Train Epoch: 1 [12800/60000 (21%)]\tLoss: 0.019602\n", + "Train Epoch: 1 [19200/60000 (32%)]\tLoss: 0.027827\n", + "Train Epoch: 1 [25600/60000 (43%)]\tLoss: 0.019414\n", + "Train Epoch: 1 [32000/60000 (53%)]\tLoss: 0.007640\n", + "Train Epoch: 1 [38400/60000 (64%)]\tLoss: 0.051296\n", + "Train Epoch: 1 [44800/60000 (75%)]\tLoss: 0.012038\n", + "Train Epoch: 1 [51200/60000 (85%)]\tLoss: 0.121057\n", + "Train Epoch: 1 [57600/60000 (96%)]\tLoss: 0.015796\n", + "\n", + "Test set: Average loss: 0.0302, Accuracy: 9902/10000 (99%)\n", + "\n", + "Train Epoch: 2 [0/60000 (0%)]\tLoss: 0.009903\n", + "Train Epoch: 2 [6400/60000 (11%)]\tLoss: 0.062256\n", + "Train Epoch: 2 [12800/60000 (21%)]\tLoss: 0.013844\n", + "Train Epoch: 2 [19200/60000 (32%)]\tLoss: 0.014133\n", + "Train Epoch: 2 [25600/60000 (43%)]\tLoss: 0.001051\n", + "Train Epoch: 2 [32000/60000 (53%)]\tLoss: 0.006128\n", + "Train Epoch: 2 [38400/60000 (64%)]\tLoss: 0.032162\n", + "Train Epoch: 2 [44800/60000 (75%)]\tLoss: 0.007687\n", + "Train Epoch: 2 [51200/60000 (85%)]\tLoss: 0.092295\n", + "Train Epoch: 2 [57600/60000 (96%)]\tLoss: 0.006266\n", + "\n", + "Test set: Average loss: 0.0259, Accuracy: 9920/10000 (99%)\n", + "\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 4, + "source": [ + "# show all op_name and op_type in the model.\n", + "\n", + "[print('op_name: {}\\nop_type: {}\\n'.format(name, type(module))) for name, module in model.named_modules()]" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "op_name: \n", + "op_type: \n", + "\n", + "op_name: conv1\n", + "op_type: \n", + "\n", + "op_name: conv2\n", + "op_type: \n", + "\n", + "op_name: fc1\n", + "op_type: \n", + "\n", + "op_name: fc2\n", + "op_type: \n", + "\n", + "op_name: relu1\n", + "op_type: \n", + "\n", + "op_name: relu2\n", + "op_type: \n", + "\n", + "op_name: relu3\n", + "op_type: \n", + "\n", + "op_name: max_pool1\n", + "op_type: \n", + "\n", + "op_name: max_pool2\n", + "op_type: \n", + "\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[None, None, None, None, None, None, None, None, None, None]" + ] + }, + "metadata": {}, + "execution_count": 4 + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 5, + "source": [ + "# show the weight size of `conv1`.\n", + "\n", + "print(model.conv1.weight.data.size())" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([20, 1, 5, 5])\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 6, + "source": [ + "# show the weight of `conv1`.\n", + "\n", + "print(model.conv1.weight.data)" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tensor([[[[ 1.5338e-01, -1.1766e-01, -2.6654e-01, -2.9445e-02, -1.4650e-01],\n", + " [-1.8796e-01, -2.9882e-01, 6.9725e-02, 2.1561e-01, 6.5688e-02],\n", + " [ 1.5274e-01, -9.8471e-03, 3.2303e-01, 1.3472e-03, 1.7235e-01],\n", + " [ 1.1804e-01, 2.2535e-01, -8.3370e-02, -3.4553e-02, -1.2529e-01],\n", + " [-6.6012e-02, -2.0272e-02, -1.8797e-01, -4.6882e-02, -8.3206e-02]]],\n", + "\n", + "\n", + " [[[-1.2112e-01, 7.0756e-02, 5.0446e-02, 1.5156e-01, -2.7929e-02],\n", + " [-1.9744e-01, -2.1336e-03, 7.2534e-02, 6.2336e-02, 1.6039e-01],\n", + " [-6.7510e-02, 1.4636e-01, 7.1972e-02, -8.9118e-02, -4.0895e-02],\n", + " [ 2.9499e-02, 2.0788e-01, -1.4989e-01, 1.1668e-01, -2.8503e-01],\n", + " [ 8.1894e-02, -1.4489e-01, -4.2038e-02, -1.2794e-01, -5.0379e-02]]],\n", + "\n", + "\n", + " [[[ 3.8332e-02, -1.4270e-01, -1.9585e-01, 2.2653e-01, 1.0104e-01],\n", + " [-2.7956e-03, -1.4108e-01, -1.4694e-01, -1.3525e-01, 2.6959e-01],\n", + " [ 1.9522e-01, -1.2281e-01, -1.9173e-01, -1.8910e-02, 3.1572e-03],\n", + " [-1.0580e-01, -2.5239e-02, -5.8266e-02, -6.5815e-02, 6.6433e-02],\n", + " [ 8.9601e-02, 7.1189e-02, -2.4255e-01, 1.5746e-01, -1.4708e-01]]],\n", + "\n", + "\n", + " [[[-1.1963e-01, -1.7243e-01, -3.5174e-02, 1.4651e-01, -1.1675e-01],\n", + " [-1.3518e-01, 1.2830e-02, 7.7188e-02, 2.1060e-01, 4.0924e-02],\n", + " [-4.3364e-02, -1.9579e-01, -3.6559e-02, -6.9803e-02, 1.2380e-01],\n", + " [ 7.7321e-02, 3.7590e-02, 8.2935e-02, 2.2878e-01, 2.7859e-03],\n", + " [-1.3601e-01, -2.1167e-01, -2.3195e-01, -1.2524e-01, 1.0073e-01]]],\n", + "\n", + "\n", + " [[[-2.7300e-01, 6.8470e-02, 2.8405e-02, -4.5879e-03, -1.3735e-01],\n", + " [-8.9789e-02, -2.0209e-03, 5.0950e-03, 2.1633e-01, 2.5554e-01],\n", + " [ 5.4389e-02, 1.2262e-01, -1.5514e-01, -1.0416e-01, 1.3606e-01],\n", + " [-1.6794e-01, -2.8876e-02, 2.5900e-02, -2.4261e-02, 1.0923e-01],\n", + " [ 5.2524e-03, -4.4625e-02, -2.1327e-01, -1.7211e-01, -4.4819e-04]]],\n", + "\n", + "\n", + " [[[ 7.2378e-02, 1.5122e-01, -1.2964e-01, 4.9105e-02, -2.1639e-01],\n", + " [ 3.6547e-02, -1.5518e-02, 3.2059e-02, -3.2820e-02, 6.1231e-02],\n", + " [ 1.2514e-01, 8.0623e-02, 1.2686e-02, -1.0074e-01, 2.2836e-02],\n", + " [-2.6842e-02, 2.5578e-02, -2.5877e-01, -1.7808e-01, 7.6966e-02],\n", + " [-4.2424e-02, 4.7006e-02, -1.5486e-02, -4.2686e-02, 4.8482e-02]]],\n", + "\n", + "\n", + " [[[ 1.3081e-01, 9.9530e-02, -1.4729e-01, -1.7665e-01, -1.9757e-01],\n", + " [ 9.6603e-02, 2.2783e-02, 7.8402e-02, -2.8679e-02, 8.5252e-02],\n", + " [-1.5310e-02, 1.1605e-01, -5.8300e-02, 2.4563e-02, 1.7488e-01],\n", + " [ 6.5576e-02, -1.6325e-01, -1.1318e-01, -2.9251e-02, 6.2352e-02],\n", + " [-1.9084e-03, -1.4005e-01, -1.2363e-01, -9.7985e-02, -2.0562e-01]]],\n", + "\n", + "\n", + " [[[ 4.0772e-02, -8.2086e-02, -2.7555e-01, -3.2547e-01, -1.2226e-01],\n", + " [-5.9877e-02, 9.8567e-02, 2.5186e-01, -1.0280e-01, -2.3416e-01],\n", + " [ 8.5760e-02, 1.0896e-01, 1.4898e-01, 2.1579e-01, 8.5297e-02],\n", + " [ 5.4720e-02, -1.7226e-01, -7.2518e-02, 6.7099e-03, -1.6011e-03],\n", + " [-8.9944e-02, 1.7404e-01, -3.6985e-02, 1.8602e-01, 7.2353e-02]]],\n", + "\n", + "\n", + " [[[ 1.6276e-02, -9.6439e-02, -9.6085e-02, -2.4267e-01, -1.8521e-01],\n", + " [ 6.3310e-02, 1.7866e-01, 1.1694e-01, -1.4464e-01, -2.7711e-01],\n", + " [-2.4514e-02, 2.2222e-01, 2.1053e-01, -1.4271e-01, 8.7045e-02],\n", + " [-1.9207e-01, -5.4719e-02, -5.7775e-03, -1.0034e-05, -1.0923e-01],\n", + " [-2.4006e-02, 2.3780e-02, 1.8988e-01, 2.4734e-01, 4.8097e-02]]],\n", + "\n", + "\n", + " [[[ 1.1335e-01, -5.8451e-02, 5.2440e-02, -1.3223e-01, -2.5534e-02],\n", + " [ 9.1323e-02, -6.0707e-02, 2.3524e-01, 2.4992e-01, 8.7842e-02],\n", + " [ 2.9002e-02, 3.5379e-02, -5.9689e-02, -2.8363e-03, 1.8618e-01],\n", + " [-2.9671e-01, 8.1830e-03, 1.1076e-01, -5.4118e-02, -6.1685e-02],\n", + " [-1.7580e-01, -3.4534e-01, -3.9250e-01, -2.7569e-01, -2.6131e-01]]],\n", + "\n", + "\n", + " [[[ 1.1586e-01, -7.5997e-02, -1.4614e-01, 4.8750e-02, 1.8097e-01],\n", + " [-6.7027e-02, -1.4901e-01, -1.5614e-02, -1.0379e-02, 9.5526e-02],\n", + " [-3.2333e-02, -1.5107e-01, -1.9498e-01, 1.0083e-01, 2.2328e-01],\n", + " [-2.0692e-01, -6.3798e-02, -1.2524e-01, 1.9549e-01, 1.9682e-01],\n", + " [-2.1494e-01, 1.0475e-01, -2.4858e-02, -9.7831e-02, 1.1551e-01]]],\n", + "\n", + "\n", + " [[[ 6.3785e-02, -1.8044e-01, -1.0190e-01, -1.3588e-01, 8.5433e-02],\n", + " [ 2.0675e-01, 3.3238e-02, 9.2437e-02, 1.1799e-01, 2.1111e-01],\n", + " [-5.2138e-02, 1.5790e-01, 1.8151e-01, 8.0470e-02, 1.0131e-01],\n", + " [-4.4786e-02, 1.1771e-01, 2.1706e-02, -1.2563e-01, -2.1142e-01],\n", + " [-2.3589e-01, -2.1154e-01, -1.7890e-01, -2.7769e-01, -1.2512e-01]]],\n", + "\n", + "\n", + " [[[ 1.9133e-01, 2.4711e-01, 1.0413e-01, -1.9187e-01, -3.0991e-01],\n", + " [-1.2382e-01, 8.3641e-03, -5.6734e-02, 5.8376e-02, 2.2880e-02],\n", + " [-3.1734e-01, -1.0637e-02, -5.5974e-02, 1.0676e-01, -1.1080e-02],\n", + " [-2.2980e-01, 2.0486e-01, 1.0147e-01, 1.4484e-01, 5.2265e-02],\n", + " [ 7.4410e-02, 2.2806e-02, 8.5137e-02, -2.1809e-01, 3.1704e-02]]],\n", + "\n", + "\n", + " [[[-1.1006e-01, -2.5311e-01, 1.8925e-02, 1.0399e-02, 1.1951e-01],\n", + " [-2.1116e-01, 1.8409e-01, 3.2172e-02, 1.5962e-01, -7.9457e-02],\n", + " [ 1.1059e-01, 9.1966e-02, 1.0777e-01, -9.9132e-02, -4.4586e-02],\n", + " [-8.7919e-02, -3.7283e-02, 9.1275e-02, -3.7412e-02, 3.8875e-02],\n", + " [-4.3558e-02, 1.6196e-01, -4.7944e-03, -1.7560e-02, -1.2593e-01]]],\n", + "\n", + "\n", + " [[[ 7.6976e-02, -3.8627e-02, 1.2610e-01, 1.1994e-01, 2.1706e-03],\n", + " [ 7.4357e-02, 6.7929e-02, 3.1386e-02, 1.4606e-01, 2.1429e-01],\n", + " [-2.6569e-01, -4.2631e-04, -3.6654e-02, -3.0967e-02, -9.4961e-02],\n", + " [-2.0192e-01, -3.5423e-01, -2.5246e-01, -3.5092e-01, -2.4159e-01],\n", + " [ 1.7636e-02, 1.3744e-01, -1.0306e-01, 8.8370e-02, 7.3258e-02]]],\n", + "\n", + "\n", + " [[[ 2.0016e-01, 1.0956e-01, -5.9223e-02, 6.4871e-03, -2.4165e-01],\n", + " [ 5.6283e-02, 1.7276e-01, -2.2316e-01, -1.6699e-01, -7.0742e-02],\n", + " [ 2.6179e-01, -2.5102e-01, -2.0774e-01, -9.6413e-02, 3.4367e-02],\n", + " [-9.1882e-02, -2.9195e-01, -8.7432e-02, 1.0144e-01, -2.0559e-02],\n", + " [-2.5668e-01, -9.8016e-02, 1.1103e-01, -3.0233e-02, 1.1076e-01]]],\n", + "\n", + "\n", + " [[[ 1.0027e-03, -5.7955e-02, -2.1339e-01, -1.6729e-01, -2.0870e-01],\n", + " [ 4.2464e-02, 2.3177e-01, -6.1459e-02, -1.0905e-01, 1.7613e-02],\n", + " [-1.2282e-01, 2.1762e-01, -1.3553e-02, 2.7476e-01, 1.6703e-01],\n", + " [-5.6282e-02, 1.2731e-02, 1.0944e-01, -1.7347e-01, 4.4497e-02],\n", + " [ 5.7346e-02, -5.4657e-02, 4.8718e-02, -2.6221e-02, -2.6933e-02]]],\n", + "\n", + "\n", + " [[[ 6.7697e-02, 1.5692e-01, 2.7050e-01, 1.5936e-02, 1.7659e-01],\n", + " [-2.8899e-02, -1.4866e-01, 3.1838e-02, 1.0903e-01, 1.2292e-01],\n", + " [-1.3608e-01, -4.3198e-03, -9.8925e-02, -4.5599e-02, 1.3452e-01],\n", + " [-5.1435e-02, -2.3815e-01, -2.4151e-01, -4.8556e-02, 1.3825e-01],\n", + " [-1.2823e-01, 8.9324e-03, -1.5313e-01, -2.2933e-01, -3.4081e-02]]],\n", + "\n", + "\n", + " [[[-1.8396e-01, -6.8774e-03, -1.6675e-01, 7.1980e-03, 1.9922e-02],\n", + " [ 1.3416e-01, -1.1450e-01, -1.5277e-01, -6.5713e-02, -9.5435e-02],\n", + " [ 1.5406e-01, -9.1235e-02, -1.0880e-01, -7.1603e-02, -9.5575e-02],\n", + " [ 2.1772e-01, 8.4073e-02, -2.5264e-01, -2.1428e-01, 1.9537e-01],\n", + " [ 1.3124e-01, 7.9532e-02, -2.4044e-01, -1.5717e-01, 1.6562e-01]]],\n", + "\n", + "\n", + " [[[ 1.1849e-01, -5.0517e-03, -1.8900e-01, 1.8093e-02, 6.4660e-02],\n", + " [-1.5309e-01, -2.0106e-01, -8.6551e-02, 5.2692e-03, 1.5448e-01],\n", + " [-3.0727e-01, 4.9703e-02, -4.7637e-02, 2.9111e-01, -1.3173e-01],\n", + " [-8.5167e-02, -1.3540e-01, 2.9235e-01, 3.7895e-03, -9.4651e-02],\n", + " [-6.0694e-02, 9.6936e-02, 1.0533e-01, -6.1769e-02, -1.8086e-01]]]],\n", + " device='cuda:0')\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "# 2. Prepare config_list for pruning" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 7, + "source": [ + "# we will prune 50% weights in `conv1`.\n", + "\n", + "config_list = [{\n", + " 'sparsity': 0.5,\n", + " 'op_types': ['Conv2d'],\n", + " 'op_names': ['conv1']\n", + "}]" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "# 3. Choose a pruner and pruning" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 8, + "source": [ + "# use l1filter pruner to prune the model\n", + "\n", + "from nni.algorithms.compression.pytorch.pruning import L1FilterPruner\n", + "\n", + "# Note that if you use a compressor that need you to pass a optimizer,\n", + "# you need a new optimizer instead of you have used above, because NNI might modify the optimizer.\n", + "# And of course this modified optimizer can not be used in finetuning.\n", + "pruner = L1FilterPruner(model, config_list)" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 9, + "source": [ + "# we can find the `conv1` has been wrapped, the origin `conv1` changes to `conv1.module`.\n", + "# the weight of conv1 will modify by `weight * mask` in `forward()`. The initial mask is a `ones_like(weight)` tensor.\n", + "\n", + "[print('op_name: {}\\nop_type: {}\\n'.format(name, type(module))) for name, module in model.named_modules()]" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "op_name: \n", + "op_type: \n", + "\n", + "op_name: conv1\n", + "op_type: \n", + "\n", + "op_name: conv1.module\n", + "op_type: \n", + "\n", + "op_name: conv2\n", + "op_type: \n", + "\n", + "op_name: fc1\n", + "op_type: \n", + "\n", + "op_name: fc2\n", + "op_type: \n", + "\n", + "op_name: relu1\n", + "op_type: \n", + "\n", + "op_name: relu2\n", + "op_type: \n", + "\n", + "op_name: relu3\n", + "op_type: \n", + "\n", + "op_name: max_pool1\n", + "op_type: \n", + "\n", + "op_name: max_pool2\n", + "op_type: \n", + "\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "[None, None, None, None, None, None, None, None, None, None, None]" + ] + }, + "metadata": {}, + "execution_count": 9 + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 10, + "source": [ + "# compress the model, the mask will be updated.\n", + "\n", + "pruner.compress()" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "NaiveModel(\n", + " (conv1): PrunerModuleWrapper(\n", + " (module): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))\n", + " )\n", + " (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))\n", + " (fc1): Linear(in_features=800, out_features=500, bias=True)\n", + " (fc2): Linear(in_features=500, out_features=10, bias=True)\n", + " (relu1): ReLU6()\n", + " (relu2): ReLU6()\n", + " (relu3): ReLU6()\n", + " (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 10 + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 11, + "source": [ + "# show the mask size of `conv1`\n", + "\n", + "print(model.conv1.weight_mask.size())" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "torch.Size([20, 1, 5, 5])\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 12, + "source": [ + "# show the mask of `conv1`\n", + "\n", + "print(model.conv1.weight_mask)" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tensor([[[[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]]],\n", + "\n", + "\n", + " [[[1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.],\n", + " [1., 1., 1., 1., 1.]]],\n", + "\n", + "\n", + " [[[0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.],\n", + " [0., 0., 0., 0., 0.]]]], device='cuda:0')\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 13, + "source": [ + "# use a dummy input to apply the sparsify.\n", + "\n", + "model(torch.rand(1, 1, 28, 28).to(device))\n", + "\n", + "# the weights of `conv1` have been sparsified.\n", + "\n", + "print(model.conv1.module.weight.data)" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "tensor([[[[ 1.5338e-01, -1.1766e-01, -2.6654e-01, -2.9445e-02, -1.4650e-01],\n", + " [-1.8796e-01, -2.9882e-01, 6.9725e-02, 2.1561e-01, 6.5688e-02],\n", + " [ 1.5274e-01, -9.8471e-03, 3.2303e-01, 1.3472e-03, 1.7235e-01],\n", + " [ 1.1804e-01, 2.2535e-01, -8.3370e-02, -3.4553e-02, -1.2529e-01],\n", + " [-6.6012e-02, -2.0272e-02, -1.8797e-01, -4.6882e-02, -8.3206e-02]]],\n", + "\n", + "\n", + " [[[-0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00],\n", + " [ 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00],\n", + " [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n", + "\n", + "\n", + " [[[ 3.8332e-02, -1.4270e-01, -1.9585e-01, 2.2653e-01, 1.0104e-01],\n", + " [-2.7956e-03, -1.4108e-01, -1.4694e-01, -1.3525e-01, 2.6959e-01],\n", + " [ 1.9522e-01, -1.2281e-01, -1.9173e-01, -1.8910e-02, 3.1572e-03],\n", + " [-1.0580e-01, -2.5239e-02, -5.8266e-02, -6.5815e-02, 6.6433e-02],\n", + " [ 8.9601e-02, 7.1189e-02, -2.4255e-01, 1.5746e-01, -1.4708e-01]]],\n", + "\n", + "\n", + " [[[-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00],\n", + " [-0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00]]],\n", + "\n", + "\n", + " [[[-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [ 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n", + "\n", + "\n", + " [[[ 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00],\n", + " [ 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00]]],\n", + "\n", + "\n", + " [[[ 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n", + " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n", + "\n", + "\n", + " [[[ 4.0772e-02, -8.2086e-02, -2.7555e-01, -3.2547e-01, -1.2226e-01],\n", + " [-5.9877e-02, 9.8567e-02, 2.5186e-01, -1.0280e-01, -2.3416e-01],\n", + " [ 8.5760e-02, 1.0896e-01, 1.4898e-01, 2.1579e-01, 8.5297e-02],\n", + " [ 5.4720e-02, -1.7226e-01, -7.2518e-02, 6.7099e-03, -1.6011e-03],\n", + " [-8.9944e-02, 1.7404e-01, -3.6985e-02, 1.8602e-01, 7.2353e-02]]],\n", + "\n", + "\n", + " [[[ 1.6276e-02, -9.6439e-02, -9.6085e-02, -2.4267e-01, -1.8521e-01],\n", + " [ 6.3310e-02, 1.7866e-01, 1.1694e-01, -1.4464e-01, -2.7711e-01],\n", + " [-2.4514e-02, 2.2222e-01, 2.1053e-01, -1.4271e-01, 8.7045e-02],\n", + " [-1.9207e-01, -5.4719e-02, -5.7775e-03, -1.0034e-05, -1.0923e-01],\n", + " [-2.4006e-02, 2.3780e-02, 1.8988e-01, 2.4734e-01, 4.8097e-02]]],\n", + "\n", + "\n", + " [[[ 1.1335e-01, -5.8451e-02, 5.2440e-02, -1.3223e-01, -2.5534e-02],\n", + " [ 9.1323e-02, -6.0707e-02, 2.3524e-01, 2.4992e-01, 8.7842e-02],\n", + " [ 2.9002e-02, 3.5379e-02, -5.9689e-02, -2.8363e-03, 1.8618e-01],\n", + " [-2.9671e-01, 8.1830e-03, 1.1076e-01, -5.4118e-02, -6.1685e-02],\n", + " [-1.7580e-01, -3.4534e-01, -3.9250e-01, -2.7569e-01, -2.6131e-01]]],\n", + "\n", + "\n", + " [[[ 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00]]],\n", + "\n", + "\n", + " [[[ 6.3785e-02, -1.8044e-01, -1.0190e-01, -1.3588e-01, 8.5433e-02],\n", + " [ 2.0675e-01, 3.3238e-02, 9.2437e-02, 1.1799e-01, 2.1111e-01],\n", + " [-5.2138e-02, 1.5790e-01, 1.8151e-01, 8.0470e-02, 1.0131e-01],\n", + " [-4.4786e-02, 1.1771e-01, 2.1706e-02, -1.2563e-01, -2.1142e-01],\n", + " [-2.3589e-01, -2.1154e-01, -1.7890e-01, -2.7769e-01, -1.2512e-01]]],\n", + "\n", + "\n", + " [[[ 1.9133e-01, 2.4711e-01, 1.0413e-01, -1.9187e-01, -3.0991e-01],\n", + " [-1.2382e-01, 8.3641e-03, -5.6734e-02, 5.8376e-02, 2.2880e-02],\n", + " [-3.1734e-01, -1.0637e-02, -5.5974e-02, 1.0676e-01, -1.1080e-02],\n", + " [-2.2980e-01, 2.0486e-01, 1.0147e-01, 1.4484e-01, 5.2265e-02],\n", + " [ 7.4410e-02, 2.2806e-02, 8.5137e-02, -2.1809e-01, 3.1704e-02]]],\n", + "\n", + "\n", + " [[[-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00],\n", + " [ 0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n", + "\n", + "\n", + " [[[ 7.6976e-02, -3.8627e-02, 1.2610e-01, 1.1994e-01, 2.1706e-03],\n", + " [ 7.4357e-02, 6.7929e-02, 3.1386e-02, 1.4606e-01, 2.1429e-01],\n", + " [-2.6569e-01, -4.2631e-04, -3.6654e-02, -3.0967e-02, -9.4961e-02],\n", + " [-2.0192e-01, -3.5423e-01, -2.5246e-01, -3.5092e-01, -2.4159e-01],\n", + " [ 1.7636e-02, 1.3744e-01, -1.0306e-01, 8.8370e-02, 7.3258e-02]]],\n", + "\n", + "\n", + " [[[ 2.0016e-01, 1.0956e-01, -5.9223e-02, 6.4871e-03, -2.4165e-01],\n", + " [ 5.6283e-02, 1.7276e-01, -2.2316e-01, -1.6699e-01, -7.0742e-02],\n", + " [ 2.6179e-01, -2.5102e-01, -2.0774e-01, -9.6413e-02, 3.4367e-02],\n", + " [-9.1882e-02, -2.9195e-01, -8.7432e-02, 1.0144e-01, -2.0559e-02],\n", + " [-2.5668e-01, -9.8016e-02, 1.1103e-01, -3.0233e-02, 1.1076e-01]]],\n", + "\n", + "\n", + " [[[ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00],\n", + " [ 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [ 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n", + "\n", + "\n", + " [[[ 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]]],\n", + "\n", + "\n", + " [[[-1.8396e-01, -6.8774e-03, -1.6675e-01, 7.1980e-03, 1.9922e-02],\n", + " [ 1.3416e-01, -1.1450e-01, -1.5277e-01, -6.5713e-02, -9.5435e-02],\n", + " [ 1.5406e-01, -9.1235e-02, -1.0880e-01, -7.1603e-02, -9.5575e-02],\n", + " [ 2.1772e-01, 8.4073e-02, -2.5264e-01, -2.1428e-01, 1.9537e-01],\n", + " [ 1.3124e-01, 7.9532e-02, -2.4044e-01, -1.5717e-01, 1.6562e-01]]],\n", + "\n", + "\n", + " [[[ 0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00],\n", + " [-0.0000e+00, 0.0000e+00, -0.0000e+00, 0.0000e+00, -0.0000e+00],\n", + " [-0.0000e+00, -0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00],\n", + " [-0.0000e+00, 0.0000e+00, 0.0000e+00, -0.0000e+00, -0.0000e+00]]]],\n", + " device='cuda:0')\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 14, + "source": [ + "# export the sparsified model state to './pruned_naive_mnist_l1filter.pth'.\n", + "# export the mask to './mask_naive_mnist_l1filter.pth'.\n", + "\n", + "pruner.export_model(model_path='pruned_naive_mnist_l1filter.pth', mask_path='mask_naive_mnist_l1filter.pth')" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[2021-07-26 22:26:05] INFO (nni.compression.pytorch.compressor/MainThread) Model state_dict saved to pruned_naive_mnist_l1filter.pth\n", + "[2021-07-26 22:26:05] INFO (nni.compression.pytorch.compressor/MainThread) Mask dict saved to mask_naive_mnist_l1filter.pth\n" + ] + } + ], + "metadata": { + "scrolled": true + } + }, + { + "cell_type": "markdown", + "source": [ + "# 4. Speed Up" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 15, + "source": [ + "# If you use a wrapped model, don't forget to unwrap it.\n", + "\n", + "pruner._unwrap_model()\n", + "\n", + "# the model has been unwrapped.\n", + "\n", + "print(model)" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "NaiveModel(\n", + " (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))\n", + " (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))\n", + " (fc1): Linear(in_features=800, out_features=500, bias=True)\n", + " (fc2): Linear(in_features=500, out_features=10, bias=True)\n", + " (relu1): ReLU6()\n", + " (relu2): ReLU6()\n", + " (relu3): ReLU6()\n", + " (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + ")\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 16, + "source": [ + "from nni.compression.pytorch import ModelSpeedup\n", + "\n", + "m_speedup = ModelSpeedup(model, dummy_input=torch.rand(10, 1, 28, 28).to(device), masks_file='mask_naive_mnist_l1filter.pth')\n", + "m_speedup.speedup_model()" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + ":22: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!\n", + " x = x.view(-1, x.size()[1:].numel())\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) start to speed up the model\n", + "[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) {'conv1': 1, 'conv2': 1}\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) dim0 sparsity: 0.500000\n", + "[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) dim1 sparsity: 0.000000\n", + "[2021-07-26 22:26:18] INFO (FixMaskConflict/MainThread) Dectected conv prune dim\" 0\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) infer module masks...\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for conv1\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu1\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for max_pool1\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for conv2\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu2\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for max_pool2\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::view.9\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.jit_translate/MainThread) View Module output size: [-1, 800]\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for fc1\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for relu3\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for fc2\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update mask for .aten::log_softmax.10\n", + "[2021-07-26 22:26:18] ERROR (nni.compression.pytorch.speedup.jit_translate/MainThread) aten::log_softmax is not Supported! Please report an issue at https://github.com/microsoft/nni. Thanks~\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for .aten::log_softmax.10\n", + "[2021-07-26 22:26:18] WARNING (nni.compression.pytorch.speedup.compressor/MainThread) Note: .aten::log_softmax.10 does not have corresponding mask inference object\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for fc2\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the fc2\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu3\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu3\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for fc1\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the fc1\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for .aten::view.9\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the .aten::view.9\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for max_pool2\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the max_pool2\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu2\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu2\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for conv2\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the conv2\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for max_pool1\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the max_pool1\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for relu1\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the relu1\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update indirect sparsity for conv1\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Update the indirect sparsity for the conv1\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) resolve the mask conflict\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace compressed modules...\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: conv1, op_type: Conv2d)\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu1, op_type: ReLU6)\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: max_pool1, op_type: MaxPool2d)\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: conv2, op_type: Conv2d)\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu2, op_type: ReLU6)\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: max_pool2, op_type: MaxPool2d)\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::view.9, op_type: aten::view) which is func type\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: fc1, op_type: Linear)\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compress_modules/MainThread) replace linear with new in_features: 800, out_features: 500\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: relu3, op_type: ReLU6)\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) replace module (name: fc2, op_type: Linear)\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compress_modules/MainThread) replace linear with new in_features: 500, out_features: 10\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) Warning: cannot replace (name: .aten::log_softmax.10, op_type: aten::log_softmax) which is func type\n", + "[2021-07-26 22:26:18] INFO (nni.compression.pytorch.speedup.compressor/MainThread) speedup done\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 17, + "source": [ + "# the `conv1` has been replace from `Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))` to `Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))`\n", + "# and the following layer `conv2` has also changed because the input channel of `conv2` should aware the output channel of `conv1`.\n", + "\n", + "print(model)" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "NaiveModel(\n", + " (conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))\n", + " (conv2): Conv2d(10, 50, kernel_size=(5, 5), stride=(1, 1))\n", + " (fc1): Linear(in_features=800, out_features=500, bias=True)\n", + " (fc2): Linear(in_features=500, out_features=10, bias=True)\n", + " (relu1): ReLU6()\n", + " (relu2): ReLU6()\n", + " (relu3): ReLU6()\n", + " (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + ")\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 18, + "source": [ + "# finetune the model to recover the accuracy.\n", + "\n", + "optimizer = torch.optim.SGD(model.parameters(), lr=0.01)\n", + "\n", + "for epoch in range(0, 1):\n", + " trainer(model, optimizer, criterion, epoch)\n", + " evaluator(model)" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Train Epoch: 0 [0/60000 (0%)]\tLoss: 0.306930\n", + "Train Epoch: 0 [6400/60000 (11%)]\tLoss: 0.045807\n", + "Train Epoch: 0 [12800/60000 (21%)]\tLoss: 0.049293\n", + "Train Epoch: 0 [19200/60000 (32%)]\tLoss: 0.031464\n", + "Train Epoch: 0 [25600/60000 (43%)]\tLoss: 0.005392\n", + "Train Epoch: 0 [32000/60000 (53%)]\tLoss: 0.005652\n", + "Train Epoch: 0 [38400/60000 (64%)]\tLoss: 0.040619\n", + "Train Epoch: 0 [44800/60000 (75%)]\tLoss: 0.016515\n", + "Train Epoch: 0 [51200/60000 (85%)]\tLoss: 0.092886\n", + "Train Epoch: 0 [57600/60000 (96%)]\tLoss: 0.041380\n", + "\n", + "Test set: Average loss: 0.0257, Accuracy: 9917/10000 (99%)\n", + "\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "# 5. Prepare config_list for quantization" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 19, + "source": [ + "config_list = [{\n", + " 'quant_types': ['weight'],\n", + " 'quant_bits': {'weight': 8},\n", + " 'op_names': ['conv1', 'conv2']\n", + "}]" + ], + "outputs": [], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "# 6. Choose a quantizer and quantizing" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 20, + "source": [ + "from nni.algorithms.compression.pytorch.quantization import QAT_Quantizer\n", + "\n", + "quantizer = QAT_Quantizer(model, config_list, optimizer)\n", + "quantizer.compress()" + ], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "NaiveModel(\n", + " (conv1): QuantizerModuleWrapper(\n", + " (module): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))\n", + " )\n", + " (conv2): QuantizerModuleWrapper(\n", + " (module): Conv2d(10, 50, kernel_size=(5, 5), stride=(1, 1))\n", + " )\n", + " (fc1): Linear(in_features=800, out_features=500, bias=True)\n", + " (fc2): Linear(in_features=500, out_features=10, bias=True)\n", + " (relu1): ReLU6()\n", + " (relu2): ReLU6()\n", + " (relu3): ReLU6()\n", + " (max_pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + " (max_pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)\n", + ")" + ] + }, + "metadata": {}, + "execution_count": 20 + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 21, + "source": [ + "# finetune the model for calibration.\n", + "\n", + "for epoch in range(0, 1):\n", + " trainer(model, optimizer, criterion, epoch)\n", + " evaluator(model)" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Train Epoch: 0 [0/60000 (0%)]\tLoss: 0.004960\n", + "Train Epoch: 0 [6400/60000 (11%)]\tLoss: 0.036269\n", + "Train Epoch: 0 [12800/60000 (21%)]\tLoss: 0.018744\n", + "Train Epoch: 0 [19200/60000 (32%)]\tLoss: 0.021916\n", + "Train Epoch: 0 [25600/60000 (43%)]\tLoss: 0.003095\n", + "Train Epoch: 0 [32000/60000 (53%)]\tLoss: 0.003947\n", + "Train Epoch: 0 [38400/60000 (64%)]\tLoss: 0.032094\n", + "Train Epoch: 0 [44800/60000 (75%)]\tLoss: 0.017358\n", + "Train Epoch: 0 [51200/60000 (85%)]\tLoss: 0.083886\n", + "Train Epoch: 0 [57600/60000 (96%)]\tLoss: 0.040433\n", + "\n", + "Test set: Average loss: 0.0247, Accuracy: 9917/10000 (99%)\n", + "\n" + ] + } + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": 22, + "source": [ + "# export the sparsified model state to './quantized_naive_mnist_l1filter.pth'.\n", + "# export the calibration config to './calibration_naive_mnist_l1filter.pth'.\n", + "\n", + "quantizer.export_model(model_path='quantized_naive_mnist_l1filter.pth', calibration_path='calibration_naive_mnist_l1filter.pth')" + ], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "[2021-07-26 22:34:41] INFO (nni.compression.pytorch.compressor/MainThread) Model state_dict saved to quantized_naive_mnist_l1filter.pth\n", + "[2021-07-26 22:34:41] INFO (nni.compression.pytorch.compressor/MainThread) Mask dict saved to calibration_naive_mnist_l1filter.pth\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "{'conv1': {'weight_bit': 8,\n", + " 'tracked_min_input': -0.42417848110198975,\n", + " 'tracked_max_input': 2.8212687969207764},\n", + " 'conv2': {'weight_bit': 8,\n", + " 'tracked_min_input': 0.0,\n", + " 'tracked_max_input': 4.246923446655273}}" + ] + }, + "metadata": {}, + "execution_count": 22 + } + ], + "metadata": {} + }, + { + "cell_type": "markdown", + "source": [ + "# 7. Speed Up" + ], + "metadata": {} + }, + { + "cell_type": "code", + "execution_count": null, + "source": [ + "# speed up with tensorRT\n", + "\n", + "engine = ModelSpeedupTensorRT(model, (32, 1, 28, 28), config=calibration_config, batchsize=32)\n", + "engine.compress()" + ], + "outputs": [], + "metadata": {} + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file