diff --git a/examples/model_compress/pruning/basic_pruners_torch.py b/examples/model_compress/pruning/basic_pruners_torch.py index 1da4c6994f..551c0257f3 100644 --- a/examples/model_compress/pruning/basic_pruners_torch.py +++ b/examples/model_compress/pruning/basic_pruners_torch.py @@ -243,6 +243,7 @@ def trainer(model, optimizer, criterion, epoch): # Reproduced result in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS', # Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A' + # If you want to skip some layer, you can use 'exclude' like follow. if args.pruner == 'slim': config_list = [{ 'sparsity': args.sparsity, @@ -252,7 +253,10 @@ def trainer(model, optimizer, criterion, epoch): config_list = [{ 'sparsity': args.sparsity, 'op_types': ['Conv2d'], - 'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37'] + 'op_names': ['feature.0', 'feature.10', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37'] + }, { + 'exclude': True, + 'op_names': ['feature.10'] }] pruner = pruner_cls(model, config_list, **kw_args) diff --git a/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py b/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py index 207a8aa2f9..9b082ab504 100644 --- a/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/auto_compress_pruner.py @@ -11,7 +11,7 @@ from nni.compression.pytorch import ModelSpeedup from nni.compression.pytorch.compressor import Pruner -from nni.compression.pytorch.utils.config_validation import CompressorSchema +from nni.compression.pytorch.utils.config_validation import PrunerSchema from .simulated_annealing_pruner import SimulatedAnnealingPruner from .iterative_pruner import ADMMPruner @@ -130,16 +130,18 @@ def validate_config(self, model, config_list): """ if self._base_algo == 'level': - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), Optional('op_types'): [str], Optional('op_names'): [str], + Optional('exclude'): bool }], model, _logger) elif self._base_algo in ['l1', 'l2', 'fpgm']: - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), 'op_types': ['Conv2d'], - Optional('op_names'): [str] + Optional('op_names'): [str], + Optional('exclude'): bool }], model, _logger) schema.validate(config_list) diff --git a/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py b/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py index c0ca053a7d..c6d14f8cbf 100644 --- a/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/dependency_aware_pruner.py @@ -2,10 +2,10 @@ # Licensed under the MIT license. import logging -from schema import And, Optional, SchemaError +from schema import And, Optional from nni.common.graph_utils import TorchModuleGraph from nni.compression.pytorch.utils.shape_dependency import ChannelDependency, GroupDependency -from nni.compression.pytorch.utils.config_validation import CompressorSchema +from nni.compression.pytorch.utils.config_validation import PrunerSchema from nni.compression.pytorch.compressor import Pruner from .constants import MASKER_DICT @@ -82,7 +82,7 @@ def update_mask(self): self._dependency_update_mask() def validate_config(self, model, config_list): - schema = CompressorSchema([{ + schema = PrunerSchema([{ Optional('sparsity'): And(float, lambda n: 0 < n < 1), Optional('op_types'): ['Conv2d'], Optional('op_names'): [str], @@ -90,9 +90,6 @@ def validate_config(self, model, config_list): }], model, logger) schema.validate(config_list) - for config in config_list: - if 'exclude' not in config and 'sparsity' not in config: - raise SchemaError('Either sparisty or exclude must be specified!') def _supported_dependency_aware(self): raise NotImplementedError diff --git a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py index 82940ad737..b95f11ea6a 100644 --- a/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/iterative_pruner.py @@ -5,7 +5,7 @@ import copy import torch from schema import And, Optional -from nni.compression.pytorch.utils.config_validation import CompressorSchema +from nni.compression.pytorch.utils.config_validation import PrunerSchema from .constants import MASKER_DICT from .dependency_aware_pruner import DependencyAwarePruner @@ -138,10 +138,11 @@ def validate_config(self, model, config_list): config_list : list List on pruning configs """ - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 <= n <= 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 <= n <= 1), Optional('op_types'): [str], - Optional('op_names'): [str] + Optional('op_names'): [str], + Optional('exclude'): bool }], model, logger) schema.validate(config_list) @@ -300,16 +301,18 @@ def validate_config(self, model, config_list): """ if self._base_algo == 'level': - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), Optional('op_types'): [str], Optional('op_names'): [str], + Optional('exclude'): bool }], model, logger) elif self._base_algo in ['l1', 'l2', 'fpgm']: - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), 'op_types': ['Conv2d'], - Optional('op_names'): [str] + Optional('op_names'): [str], + Optional('exclude'): bool }], model, logger) schema.validate(config_list) @@ -436,10 +439,11 @@ def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifyin self.patch_optimizer_before(self._callback) def validate_config(self, model, config_list): - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), 'op_types': ['BatchNorm2d'], - Optional('op_names'): [str] + Optional('op_names'): [str], + Optional('exclude'): bool }], model, logger) schema.validate(config_list) diff --git a/nni/algorithms/compression/pytorch/pruning/lottery_ticket.py b/nni/algorithms/compression/pytorch/pruning/lottery_ticket.py index caa1c831e6..0e09fae904 100644 --- a/nni/algorithms/compression/pytorch/pruning/lottery_ticket.py +++ b/nni/algorithms/compression/pytorch/pruning/lottery_ticket.py @@ -5,7 +5,7 @@ import logging import torch from schema import And, Optional -from nni.compression.pytorch.utils.config_validation import CompressorSchema +from nni.compression.pytorch.utils.config_validation import PrunerSchema from nni.compression.pytorch.compressor import Pruner from .finegrained_pruning_masker import LevelPrunerMasker @@ -56,11 +56,12 @@ def validate_config(self, model, config_list): - prune_iterations : The number of rounds for the iterative pruning. - sparsity : The final sparsity when the compression is done. """ - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), 'prune_iterations': And(int, lambda n: n > 0), Optional('op_types'): [str], - Optional('op_names'): [str] + Optional('op_names'): [str], + Optional('exclude'): bool }], model, logger) schema.validate(config_list) diff --git a/nni/algorithms/compression/pytorch/pruning/net_adapt_pruner.py b/nni/algorithms/compression/pytorch/pruning/net_adapt_pruner.py index 08416319ea..4087bb0c57 100644 --- a/nni/algorithms/compression/pytorch/pruning/net_adapt_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/net_adapt_pruner.py @@ -11,7 +11,7 @@ from nni.utils import OptimizeMode from nni.compression.pytorch.compressor import Pruner -from nni.compression.pytorch.utils.config_validation import CompressorSchema +from nni.compression.pytorch.utils.config_validation import PrunerSchema from nni.compression.pytorch.utils.num_param_counter import get_total_num_weights from .constants_pruner import PRUNER_DICT @@ -120,16 +120,18 @@ def validate_config(self, model, config_list): """ if self._base_algo == 'level': - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), Optional('op_types'): [str], Optional('op_names'): [str], + Optional('exclude'): bool }], model, _logger) elif self._base_algo in ['l1', 'l2', 'fpgm']: - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), 'op_types': ['Conv2d'], - Optional('op_names'): [str] + Optional('op_names'): [str], + Optional('exclude'): bool }], model, _logger) schema.validate(config_list) diff --git a/nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py b/nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py index c17a5ddafa..39b2201aa2 100644 --- a/nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py @@ -4,7 +4,7 @@ import logging from schema import And, Optional -from nni.compression.pytorch.utils.config_validation import CompressorSchema +from nni.compression.pytorch.utils.config_validation import PrunerSchema from .dependency_aware_pruner import DependencyAwarePruner __all__ = ['LevelPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner'] @@ -48,10 +48,11 @@ def validate_config(self, model, config_list): config_list : list List on pruning configs """ - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), Optional('op_types'): [str], - Optional('op_names'): [str] + Optional('op_names'): [str], + Optional('exclude'): bool }], model, logger) schema.validate(config_list) diff --git a/nni/algorithms/compression/pytorch/pruning/sensitivity_pruner.py b/nni/algorithms/compression/pytorch/pruning/sensitivity_pruner.py index ea0a725004..a8193a4649 100644 --- a/nni/algorithms/compression/pytorch/pruning/sensitivity_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/sensitivity_pruner.py @@ -9,7 +9,7 @@ from schema import And, Optional from nni.compression.pytorch.compressor import Pruner -from nni.compression.pytorch.utils.config_validation import CompressorSchema +from nni.compression.pytorch.utils.config_validation import PrunerSchema from nni.compression.pytorch.utils.sensitivity_analysis import SensitivityAnalysis from .constants_pruner import PRUNER_DICT @@ -146,16 +146,18 @@ def validate_config(self, model, config_list): """ if self.base_algo == 'level': - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), Optional('op_types'): [str], Optional('op_names'): [str], + Optional('exclude'): bool }], model, _logger) elif self.base_algo in ['l1', 'l2', 'fpgm']: - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), 'op_types': ['Conv2d'], - Optional('op_names'): [str] + Optional('op_names'): [str], + Optional('exclude'): bool }], model, _logger) schema.validate(config_list) diff --git a/nni/algorithms/compression/pytorch/pruning/simulated_annealing_pruner.py b/nni/algorithms/compression/pytorch/pruning/simulated_annealing_pruner.py index b371b2c6fb..b501e34aef 100644 --- a/nni/algorithms/compression/pytorch/pruning/simulated_annealing_pruner.py +++ b/nni/algorithms/compression/pytorch/pruning/simulated_annealing_pruner.py @@ -13,7 +13,7 @@ from nni.utils import OptimizeMode from nni.compression.pytorch.compressor import Pruner -from nni.compression.pytorch.utils.config_validation import CompressorSchema +from nni.compression.pytorch.utils.config_validation import PrunerSchema from .constants_pruner import PRUNER_DICT @@ -115,16 +115,18 @@ def validate_config(self, model, config_list): """ if self._base_algo == 'level': - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), Optional('op_types'): [str], Optional('op_names'): [str], + Optional('exclude'): bool }], model, _logger) elif self._base_algo in ['l1', 'l2', 'fpgm']: - schema = CompressorSchema([{ - 'sparsity': And(float, lambda n: 0 < n < 1), + schema = PrunerSchema([{ + Optional('sparsity'): And(float, lambda n: 0 < n < 1), 'op_types': ['Conv2d'], - Optional('op_names'): [str] + Optional('op_names'): [str], + Optional('exclude'): bool }], model, _logger) schema.validate(config_list) diff --git a/nni/algorithms/compression/pytorch/quantization/quantizers.py b/nni/algorithms/compression/pytorch/quantization/quantizers.py index 54fc2df295..dca1ef778e 100644 --- a/nni/algorithms/compression/pytorch/quantization/quantizers.py +++ b/nni/algorithms/compression/pytorch/quantization/quantizers.py @@ -5,7 +5,7 @@ import copy import torch from schema import Schema, And, Or, Optional -from nni.compression.pytorch.utils.config_validation import CompressorSchema +from nni.compression.pytorch.utils.config_validation import QuantizerSchema from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType __all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer'] @@ -22,11 +22,12 @@ def __init__(self, model, config_list, optimizer=None): self.layer_scale = {} def validate_config(self, model, config_list): - schema = CompressorSchema([{ + schema = QuantizerSchema([{ Optional('quant_types'): ['weight'], Optional('quant_bits'): Or(8, {'weight': 8}), Optional('op_types'): [str], - Optional('op_names'): [str] + Optional('op_names'): [str], + Optional('exclude'): bool }], model, logger) schema.validate(config_list) @@ -183,7 +184,7 @@ def validate_config(self, model, config_list): config_list : list of dict List of configurations """ - schema = CompressorSchema([{ + schema = QuantizerSchema([{ Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]), Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({ Optional('weight'): And(int, lambda n: 0 < n < 32), @@ -191,7 +192,8 @@ def validate_config(self, model, config_list): })), Optional('quant_start_step'): And(int, lambda n: n >= 0), Optional('op_types'): [str], - Optional('op_names'): [str] + Optional('op_names'): [str], + Optional('exclude'): bool }], model, logger) schema.validate(config_list) @@ -386,13 +388,14 @@ def validate_config(self, model, config_list): config_list : list of dict List of configurations """ - schema = CompressorSchema([{ + schema = QuantizerSchema([{ Optional('quant_types'): Schema([lambda x: x in ['weight']]), Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({ Optional('weight'): And(int, lambda n: 0 < n < 32) })), Optional('op_types'): [str], - Optional('op_names'): [str] + Optional('op_names'): [str], + Optional('exclude'): bool }], model, logger) schema.validate(config_list) @@ -493,14 +496,15 @@ def validate_config(self, model, config_list): config_list : list of dict List of configurations """ - schema = CompressorSchema([{ + schema = QuantizerSchema([{ Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]), Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({ Optional('weight'): And(int, lambda n: 0 < n < 32), Optional('output'): And(int, lambda n: 0 < n < 32), })), Optional('op_types'): [str], - Optional('op_names'): [str] + Optional('op_names'): [str], + Optional('exclude'): bool }], model, logger) schema.validate(config_list) diff --git a/nni/compression/pytorch/utils/config_validation.py b/nni/compression/pytorch/utils/config_validation.py index 3b9f93f962..930e4e686e 100644 --- a/nni/compression/pytorch/utils/config_validation.py +++ b/nni/compression/pytorch/utils/config_validation.py @@ -51,3 +51,25 @@ def _modify_schema(self, data_schema, model, logger): def validate(self, data): self.compressor_schema.validate(data) + +def validate_exclude_sparsity(data): + if not ('exclude' in data or 'sparsity' in data): + raise SchemaError('Either sparisty or exclude must be specified.') + return True + +def validate_exclude_quant_types_quant_bits(data): + if not ('exclude' in data or ('quant_types' in data and 'quant_bits' in data)): + raise SchemaError('Either (quant_types and quant_bits) or exclude must be specified.') + return True + +class PrunerSchema(CompressorSchema): + def _modify_schema(self, data_schema, model, logger): + data_schema = super()._modify_schema(data_schema, model, logger) + data_schema[0] = And(data_schema[0], lambda d: validate_exclude_sparsity(d)) + return data_schema + +class QuantizerSchema(CompressorSchema): + def _modify_schema(self, data_schema, model, logger): + data_schema = super()._modify_schema(data_schema, model, logger) + data_schema[0] = And(data_schema[0], lambda d: validate_exclude_quant_types_quant_bits(d)) + return data_schema