Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Model Compression] auto compression #3631

Merged
merged 18 commits into from
May 28, 2021
101 changes: 101 additions & 0 deletions docs/en_US/Compression/AutoCompression.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
Auto Compression with NNI Experiment
====================================

This approach is mainly a combination of compression and nni experiments.
It allows users to define compressor search space, including types, parameters, etc.
Its using experience is similar to launch the NNI experiment from python.
Copy link
Contributor

@linbinskn linbinskn May 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just personal concern, maybe not correct. I think this doc mainly focuses on how to use this new feature, but it doesn't tell users what things this feature can actually help them do clearly and users will be confused or misunderstand.
In my opinion, this feature can help users try different compression algorithms including pruning algorithms and quantization algorithms by adding them into our 'search space'. By using it, users can easily choose different compression algorithms and apply them to model to get feedback easily and automatically. But If I am a brand new user, after reading this doc, I can't get this key point and still miss some important information such as

  • what 'search space', 'types' and parameters mean?
  • what is the meaning of 'auto compress'?
  • what is the meaning of 'combination of compression and nni experiments'?
  • If I want to try different compression algorithms, will they be applied together(apply pruning algorithm and quantization to model simultaneously) or single sequentially?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, will add more descriptions and explanations about auto compress and what this can help.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrite the doc and welcome more comments.

The main differences are as follows:

* Use a generator to help generate search space object.
* Need to implement the abstract class ``AbstractAutoCompressModule`` as ``AutoCompressModule``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of this line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified to be more readable.

* No need to set ``trial_command``, additional need to set ``auto_compress_module_file_name``.

Generate search space
---------------------

Due to the extensive use of nested search space, we recommend using generator to configure search space.
The following is an example. Using ``add_pruner_config()`` and ``add_quantizer_config()`` add subconfig, then ``dumps()`` search space dict.

.. code-block:: python

from nni.algorithms.compression.pytorch.auto_compress import AutoCompressSearchSpaceGenerator

generator = AutoCompressSearchSpaceGenerator()
generator.add_pruner_config('level', [
{
"sparsity": {
"_type": "uniform",
"_value": [0.01, 0.99]
},
'op_types': ['default']
}
])
generator.add_quantizer_config('qat', [
{
'quant_types': ['weight', 'output'],
'quant_bits': {
'weight': 8,
'output': 8
},
'op_types': ['Conv2d', 'Linear']
}])

search_space = generator.dumps()

Now we support the following pruners and quantizers:

.. code-block:: python

PRUNER_DICT = {
'level': LevelPruner,
'slim': SlimPruner,
'l1': L1FilterPruner,
'l2': L2FilterPruner,
'fpgm': FPGMPruner,
'taylorfo': TaylorFOWeightFilterPruner,
'apoz': ActivationAPoZRankFilterPruner,
'mean_activation': ActivationMeanRankFilterPruner
}

QUANTIZER_DICT = {
'naive': NaiveQuantizer,
'qat': QAT_Quantizer,
'dorefa': DoReFaQuantizer,
'bnn': BNNQuantizer
}

Implement ``AbstractAutoCompressModule``
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> Provide user model

----------------------------------------

This class will be called by ``AutoCompressEngine`` on training service.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not mention this at the beginning

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Users need to implement at least ``model()`` and ``evaluator``, and naming the class as ``AutoCompressModule``.
The path of file that contains the ``AutoCompressModule`` needs to be specified in experiment config.
The full abstract interface refers to :githublink:`interface.py <nni/algorithms/compression/pytorch/auto_compress/interface.py>`.
An example of ``AutoCompressModule`` implementation refers to :githublink:`auto_compress_module.py <examples/model_compress/auto_compress/torch/auto_compress_module.py>`.

Launch NNI experiment
---------------------

Similar to launch from python, the difference is no need to set ``trial_command``.
By default, ``auto_compress_module_file_name`` is set as ``./auto_compress_module.py``.
Remember that ``auto_compress_module_file_name`` is the relative file path under ``trial_code_directory``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it has to be relative path, why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need anymore


.. code-block:: python

from pathlib import Path
from nni.algorithms.compression.pytorch.auto_compress import AutoCompressExperiment

experiment = AutoCompressExperiment('local')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exp = AutoCompressExperiment('local', AutoCompressModule)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor it.

experiment.config.experiment_name = 'auto compress torch example'
experiment.config.trial_concurrency = 1
experiment.config.max_trial_number = 10
experiment.config.search_space = search_space
experiment.config.trial_code_directory = Path(__file__).parent
experiment.config.tuner.name = 'TPE'
experiment.config.tuner.class_args['optimize_mode'] = 'maximize'
experiment.config.training_service.use_active_gpu = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am not mistaken, this feature is for users who want to try different model compression algorithms without many effort. I think some of they would be confused about the experiment config setting if they are not familiar with NNI. Maybe we should tell user what these experiment parameters are or refer to related NNI doc which introduces parameters in detail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, trying to refactor and use the original config for less effort.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor and now we can use experiment = AutoCompressExperiment(AutoCompressModule, 'local'), no need to use a specific config.


# the relative file path under trial_code_directory, which contains the class AutoCompressModule
experiment.config.auto_compress_module_file_name = './auto_compress_module.py'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to put this config as AutoCompressExperiment's input argument

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor it


experiment.run(8088)
2 changes: 1 addition & 1 deletion docs/en_US/Compression/AutoPruningUsingTuners.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Automatic Model Pruning using NNI Tuners

It's convenient to implement auto model pruning with NNI compression and NNI tuners

First, model compression with NNI
First, model pruning with NNI
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can directly remove this file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

---------------------------------

You can easily compress a model with NNI compression. Take pruning for example, you can prune a pretrained model with L2FilterPruner like this
Expand Down
3 changes: 2 additions & 1 deletion docs/en_US/Compression/advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ Advanced Usage

Framework <./Framework>
Customize a new algorithm <./CustomizeCompressor>
Automatic Model Compression <./AutoPruningUsingTuners>
Automatic Model Pruning <./AutoPruningUsingTuners>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Automatic Model Compression (Beta) <./AutoCompression>
120 changes: 120 additions & 0 deletions examples/model_compress/auto_compress/torch/auto_compress_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Callable, Optional

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR
from torchvision import datasets, transforms

from nni.algorithms.compression.pytorch.auto_compress import AbstractAutoCompressModule

torch.manual_seed(1)

class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output

_use_cuda = torch.cuda.is_available()

_train_kwargs = {'batch_size': 64}
_test_kwargs = {'batch_size': 1000}
if _use_cuda:
_cuda_kwargs = {'num_workers': 1,
'pin_memory': True,
'shuffle': True}
_train_kwargs.update(_cuda_kwargs)
_test_kwargs.update(_cuda_kwargs)

_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])

_dataset1 = datasets.MNIST('./data', train=True, download=True, transform=_transform)
_dataset2 = datasets.MNIST('./data', train=False, transform=_transform)
_train_loader = torch.utils.data.DataLoader(_dataset1, **_train_kwargs)
_test_loader = torch.utils.data.DataLoader(_dataset2, **_test_kwargs)

_device = torch.device("cuda" if _use_cuda else "cpu")
_epoch = 2

def _train(model, optimizer):
model.train()
for data, target in _train_loader:
data, target = data.to(_device), target.to(_device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()

def _test(model):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in _test_loader:
data, target = data.to(_device), target.to(_device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(_test_loader.dataset)
acc = 100 * correct / len(_test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(_test_loader.dataset), acc))
return acc

_model = LeNet().to(_device)

_pre_train_optimizer = optim.Adadelta(_model.parameters(), lr=1)
_scheduler = StepLR(_pre_train_optimizer, step_size=1, gamma=0.7)
for _ in range(_epoch):
_train(_model, _pre_train_optimizer)
_test(_model)
_scheduler.step()

class AutoCompressModule(AbstractAutoCompressModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this module used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module is implemented by user, and will import by import_ in AutoCompressEngine.trial_execute_compress().

It is strange to fix the code file name auto_compress_module.py, I will modify this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do users have to use the name "AutoCompressModule"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor and no need to fix name AutoCompressModule.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add docstring for the member functions

@classmethod
def model(cls) -> nn.Module:
return _model

@classmethod
def optimizer(cls) -> torch.optim.Optimizer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems you do not mention optimizer in doc? do users need to implement this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewritten the doc and mention optimizer and other interfaces.

return torch.optim.SGD(_model.parameters(), lr=0.01)

@classmethod
def evaluator(cls) -> Callable[[nn.Module], float]:
return _test

@classmethod
def finetune_trainer(cls, compressor_type: str, algorithm_name: str) -> Optional[Callable[[nn.Module, optim.Optimizer], None]]:
def _trainer(model, optimizer):
for _ in range(_epoch):
_train(model, optimizer)
return _trainer
51 changes: 51 additions & 0 deletions examples/model_compress/auto_compress/torch/auto_compress_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pathlib import Path

from nni.algorithms.compression.pytorch.auto_compress import AutoCompressExperiment, AutoCompressSearchSpaceGenerator

generator = AutoCompressSearchSpaceGenerator()
generator.add_pruner_config('level', [
{
"sparsity": {
"_type": "uniform",
"_value": [0.01, 0.99]
},
'op_types': ['default']
}
])
generator.add_pruner_config('l1', [
{
"sparsity": {
"_type": "uniform",
"_value": [0.01, 0.99]
},
'op_types': ['Conv2d']
}
])
generator.add_quantizer_config('qat', [
{
'quant_types': ['weight', 'output'],
'quant_bits': {
'weight': 8,
'output': 8
},
'op_types': ['Conv2d', 'Linear']
}])
search_space = generator.dumps()

experiment = AutoCompressExperiment('local')
experiment.config.experiment_name = 'auto compress torch example'
experiment.config.trial_concurrency = 1
experiment.config.max_trial_number = 10
experiment.config.search_space = search_space
experiment.config.trial_code_directory = Path(__file__).parent
experiment.config.tuner.name = 'TPE'
experiment.config.tuner.class_args['optimize_mode'] = 'maximize'
experiment.config.training_service.use_active_gpu = True

# the relative file path under trial_code_directory, which contains the class AutoCompressModule
experiment.config.auto_compress_module_file_name = './auto_compress_module.py'

experiment.run(8088)
6 changes: 6 additions & 0 deletions nni/algorithms/compression/pytorch/auto_compress/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from .experiment import AutoCompressExperimentConfig, AutoCompressExperiment
from .interface import AbstractAutoCompressModule
from .utils import AutoCompressSearchSpaceGenerator
Loading