Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Model Compression] auto compression #3631

Merged
merged 18 commits into from
May 28, 2021

Conversation

J-shang
Copy link
Contributor

@J-shang J-shang commented May 11, 2021

Depend on #3507

  • AutoCompressExperiment / AutoCompressionExperiment ?

  • AutoCompressExperiment(auto_compress_module, config, training_service) / experiment.config.module_file_path = '...' ?

@ultmaster ultmaster linked an issue May 12, 2021 that may be closed by this pull request
config_dict[(quant_types, op_types, op_names)][var_name] = value

config_list = []
for key, config in config_dict.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If quant_bits is set to {'quant_bits':{'weight':8, 'output':8}} in initial quantizer config, the converted config here which is added into config_list would not contain the key quant_bits. Is it still correct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it contains quant_bits, but I find this kind of nested search space has a bug, I will fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it

image


@classmethod
@abstractmethod
def finetune_trainer(cls, compressor_type: str, algorithm_name: str) -> Optional[Callable[[Module, Optimizer], None]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the difference between trainer and finetune_trainer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some pruners need a trainer to help in pruning process, I don't know if it has a difference with the trainer in finetune, so I reserve an interface here. If this is not necessary, I will remove it.

config_list = cls._convert_config_list(compressor_type, converted_config_dict)

model, evaluator, optimizer = auto_compress_module.model(), auto_compress_module.evaluator(), auto_compress_module.optimizer()
trainer = auto_compress_module.trainer(compressor_type, algorithm_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious. I haven't found the implementation of trainer in example. So in the shown case, we only finetune model without training it during compress pipeline? Or they represent the same thing so we can ignore 'trainer'.

from torch.optim import Optimizer

import nni
from nni.retiarii.utils import import_
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather maintain another copy of import_ here than import from retiarii.

Importing from another component looks weird and un-self-contained.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add copy

_test(_model)
_scheduler.step()

class AutoCompressModule(AbstractAutoCompressModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this module used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module is implemented by user, and will import by import_ in AutoCompressEngine.trial_execute_compress().

It is strange to fix the code file name auto_compress_module.py, I will modify this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do users have to use the name "AutoCompressModule"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor and no need to fix name AutoCompressModule.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add docstring for the member functions

@J-shang J-shang marked this pull request as ready for review May 18, 2021 06:47
The main differences are as follows:

* Use a generator to help generate search space object.
* Need to implement the abstract class ``AbstractAutoCompressModule`` as ``AutoCompressModule``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the meaning of this line?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified to be more readable.

'bnn': BNNQuantizer
}

Implement ``AbstractAutoCompressModule``
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-> Provide user model

Implement ``AbstractAutoCompressModule``
----------------------------------------

This class will be called by ``AutoCompressEngine`` on training service.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do not mention this at the beginning

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


Similar to launch from python, the difference is no need to set ``trial_command``.
By default, ``auto_compress_module_file_name`` is set as ``./auto_compress_module.py``.
Remember that ``auto_compress_module_file_name`` is the relative file path under ``trial_code_directory``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it has to be relative path, why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need anymore

experiment.config.training_service.use_active_gpu = True

# the relative file path under trial_code_directory, which contains the class AutoCompressModule
experiment.config.auto_compress_module_file_name = './auto_compress_module.py'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to put this config as AutoCompressExperiment's input argument

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor it

@@ -3,7 +3,7 @@ Automatic Model Pruning using NNI Tuners

It's convenient to implement auto model pruning with NNI compression and NNI tuners

First, model compression with NNI
First, model pruning with NNI
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can directly remove this file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

@@ -6,4 +6,5 @@ Advanced Usage

Framework <./Framework>
Customize a new algorithm <./CustomizeCompressor>
Automatic Model Compression <./AutoPruningUsingTuners>
Automatic Model Pruning <./AutoPruningUsingTuners>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

return _model

@classmethod
def optimizer(cls) -> torch.optim.Optimizer:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems you do not mention optimizer in doc? do users need to implement this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rewritten the doc and mention optimizer and other interfaces.

'_type': 'choice',
'_value': compressor_choice_value
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is better to merge pruner_choice_list and quantizer_choice_list together, then the nest depth is reduced by one, which is simpler and is friendly to tuning algorithms

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge these.

return model

@classmethod
def __compress_quantization_pipeline(cls, algorithm_name: str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the two functions can be merged together

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it has a little different, pls see the new update, I think it’s clearer to implement separately.

return model

@classmethod
def _compress_pipeline(cls, compressor_type: str,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why it is called "pipeline"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

renaming _compress_pipeline() -> _compress()

raise AttributeError(f'{key} is not supposed to be set in AutoCompress mode by users!')
# 'trial_code_directory' is handled differently because the path will be converted to absolute path by us
if key == 'trial_code_directory' and not (value == Path('.') or os.path.isabs(value)):
raise AttributeError(f'{key} is not supposed to be set in AutoCompress mode by users!')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not a correct error message...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and why the value must be Path('.') or absolute path?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have removed this.

from pathlib import Path
from nni.algorithms.compression.pytorch.auto_compress import AutoCompressExperiment

experiment = AutoCompressExperiment('local')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exp = AutoCompressExperiment('local', AutoCompressModule)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactor it.

experiment.config.trial_code_directory = Path(__file__).parent
experiment.config.tuner.name = 'TPE'
experiment.config.tuner.class_args['optimize_mode'] = 'maximize'
experiment.config.training_service.use_active_gpu = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am not mistaken, this feature is for users who want to try different model compression algorithms without many effort. I think some of they would be confused about the experiment config setting if they are not familiar with NNI. Maybe we should tell user what these experiment parameters are or refer to related NNI doc which introduces parameters in detail.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, trying to refactor and use the original config for less effort.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor and now we can use experiment = AutoCompressExperiment(AutoCompressModule, 'local'), no need to use a specific config.


This approach is mainly a combination of compression and nni experiments.
It allows users to define compressor search space, including types, parameters, etc.
Its using experience is similar to launch the NNI experiment from python.
Copy link
Contributor

@linbinskn linbinskn May 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just personal concern, maybe not correct. I think this doc mainly focuses on how to use this new feature, but it doesn't tell users what things this feature can actually help them do clearly and users will be confused or misunderstand.
In my opinion, this feature can help users try different compression algorithms including pruning algorithms and quantization algorithms by adding them into our 'search space'. By using it, users can easily choose different compression algorithms and apply them to model to get feedback easily and automatically. But If I am a brand new user, after reading this doc, I can't get this key point and still miss some important information such as

  • what 'search space', 'types' and parameters mean?
  • what is the meaning of 'auto compress'?
  • what is the meaning of 'combination of compression and nni experiments'?
  • If I want to try different compression algorithms, will they be applied together(apply pruning algorithm and quantization to model simultaneously) or single sequentially?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion, will add more descriptions and explanations about auto compress and what this can help.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rewrite the doc and welcome more comments.

The main differences are as follows:

* Use a generator to help generate search space object.
* Need to provide the model to be compressed, and the model should have already pre-trained.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have already been pre-trained

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix it

@QuanluZhang
Copy link
Contributor

Depend on #3507

  • AutoCompressExperiment / AutoCompressionExperiment ?
  • AutoCompressExperiment(auto_compress_module, config, training_service) / experiment.config.module_file_path = '...' ?

AutoCompressionExperiment is better

@J-shang J-shang changed the title [Model Compression] auto compress [Model Compression] auto compression May 27, 2021
@J-shang
Copy link
Contributor Author

J-shang commented May 27, 2021

Depend on #3507

  • AutoCompressExperiment / AutoCompressionExperiment ?
  • AutoCompressExperiment(auto_compress_module, config, training_service) / experiment.config.module_file_path = '...' ?

AutoCompressionExperiment is better

Rename AutoCompressXXX to AutoCompressionXXX

@J-shang J-shang requested a review from QuanluZhang May 27, 2021 12:42
Generate search space
---------------------

Due to the extensive use of nested search space, we recommend a using generator to configure search space.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recommend using a?

@ultmaster ultmaster merged commit a8879dd into microsoft:master May 28, 2021
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Multi and Auto Compressor
4 participants