Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

add exclude config validate in compressor #3815

Merged
merged 11 commits into from
Jul 12, 2021
6 changes: 5 additions & 1 deletion examples/model_compress/pruning/basic_pruners_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def trainer(model, optimizer, criterion, epoch):

# Reproduced result in paper 'PRUNING FILTERS FOR EFFICIENT CONVNETS',
# Conv_1, Conv_8, Conv_9, Conv_10, Conv_11, Conv_12 are pruned with 50% sparsity, as 'VGG-16-pruned-A'
# If you want to skip some layer, you can use 'exclude' like follow.
if args.pruner == 'slim':
config_list = [{
'sparsity': args.sparsity,
Expand All @@ -252,7 +253,10 @@ def trainer(model, optimizer, criterion, epoch):
config_list = [{
'sparsity': args.sparsity,
'op_types': ['Conv2d'],
'op_names': ['feature.0', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
'op_names': ['feature.0', 'feature.10', 'feature.24', 'feature.27', 'feature.30', 'feature.34', 'feature.37']
}, {
'exclude': True,
'op_names': ['feature.10']
}]

pruner = pruner_cls(model, config_list, **kw_args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from nni.compression.pytorch import ModelSpeedup

from nni.compression.pytorch.compressor import Pruner
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from .simulated_annealing_pruner import SimulatedAnnealingPruner
from .iterative_pruner import ADMMPruner

Expand Down Expand Up @@ -130,16 +130,18 @@ def validate_config(self, model, config_list):
"""

if self._base_algo == 'level':
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)

schema.validate(config_list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
# Licensed under the MIT license.

import logging
from schema import And, Optional, SchemaError
from schema import And, Optional
from nni.common.graph_utils import TorchModuleGraph
from nni.compression.pytorch.utils.shape_dependency import ChannelDependency, GroupDependency
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from nni.compression.pytorch.compressor import Pruner
from .constants import MASKER_DICT

Expand Down Expand Up @@ -82,17 +82,14 @@ def update_mask(self):
self._dependency_update_mask()

def validate_config(self, model, config_list):
schema = CompressorSchema([{
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): ['Conv2d'],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)

schema.validate(config_list)
for config in config_list:
if 'exclude' not in config and 'sparsity' not in config:
raise SchemaError('Either sparisty or exclude must be specified!')
QuanluZhang marked this conversation as resolved.
Show resolved Hide resolved

def _supported_dependency_aware(self):
raise NotImplementedError
Expand Down
28 changes: 16 additions & 12 deletions nni/algorithms/compression/pytorch/pruning/iterative_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
import torch
from schema import And, Optional
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from .constants import MASKER_DICT
from .dependency_aware_pruner import DependencyAwarePruner

Expand Down Expand Up @@ -138,10 +138,11 @@ def validate_config(self, model, config_list):
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 <= n <= 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 <= n <= 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)

schema.validate(config_list)
Expand Down Expand Up @@ -300,16 +301,18 @@ def validate_config(self, model, config_list):
"""

if self._base_algo == 'level':
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)

schema.validate(config_list)
Expand Down Expand Up @@ -436,10 +439,11 @@ def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifyin
self.patch_optimizer_before(self._callback)

def validate_config(self, model, config_list):
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['BatchNorm2d'],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)

schema.validate(config_list)
Expand Down
9 changes: 5 additions & 4 deletions nni/algorithms/compression/pytorch/pruning/lottery_ticket.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import logging
import torch
from schema import And, Optional
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from nni.compression.pytorch.compressor import Pruner
from .finegrained_pruning_masker import LevelPrunerMasker

Expand Down Expand Up @@ -56,11 +56,12 @@ def validate_config(self, model, config_list):
- prune_iterations : The number of rounds for the iterative pruning.
- sparsity : The final sparsity when the compression is done.
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'prune_iterations': And(int, lambda n: n > 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)

schema.validate(config_list)
Expand Down
14 changes: 8 additions & 6 deletions nni/algorithms/compression/pytorch/pruning/net_adapt_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from nni.utils import OptimizeMode

from nni.compression.pytorch.compressor import Pruner
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from nni.compression.pytorch.utils.num_param_counter import get_total_num_weights
from .constants_pruner import PRUNER_DICT

Expand Down Expand Up @@ -120,16 +120,18 @@ def validate_config(self, model, config_list):
"""

if self._base_algo == 'level':
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)

schema.validate(config_list)
Expand Down
9 changes: 5 additions & 4 deletions nni/algorithms/compression/pytorch/pruning/one_shot_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from schema import And, Optional

from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from .dependency_aware_pruner import DependencyAwarePruner

__all__ = ['LevelPruner', 'L1FilterPruner', 'L2FilterPruner', 'FPGMPruner']
Expand Down Expand Up @@ -48,10 +48,11 @@ def validate_config(self, model, config_list):
config_list : list
List on pruning configs
"""
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)

schema.validate(config_list)
Expand Down
14 changes: 8 additions & 6 deletions nni/algorithms/compression/pytorch/pruning/sensitivity_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from schema import And, Optional
from nni.compression.pytorch.compressor import Pruner
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from nni.compression.pytorch.utils.sensitivity_analysis import SensitivityAnalysis

from .constants_pruner import PRUNER_DICT
Expand Down Expand Up @@ -146,16 +146,18 @@ def validate_config(self, model, config_list):
"""

if self.base_algo == 'level':
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)
elif self.base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)

schema.validate(config_list)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from nni.utils import OptimizeMode

from nni.compression.pytorch.compressor import Pruner
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import PrunerSchema
from .constants_pruner import PRUNER_DICT


Expand Down Expand Up @@ -115,16 +115,18 @@ def validate_config(self, model, config_list):
"""

if self._base_algo == 'level':
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
Optional('op_types'): [str],
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)
elif self._base_algo in ['l1', 'l2', 'fpgm']:
schema = CompressorSchema([{
'sparsity': And(float, lambda n: 0 < n < 1),
schema = PrunerSchema([{
Optional('sparsity'): And(float, lambda n: 0 < n < 1),
'op_types': ['Conv2d'],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, _logger)

schema.validate(config_list)
Expand Down
22 changes: 13 additions & 9 deletions nni/algorithms/compression/pytorch/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import copy
import torch
from schema import Schema, And, Or, Optional
from nni.compression.pytorch.utils.config_validation import CompressorSchema
from nni.compression.pytorch.utils.config_validation import QuantizerSchema
from nni.compression.pytorch.compressor import Quantizer, QuantForward, QuantGrad, QuantType

__all__ = ['NaiveQuantizer', 'QAT_Quantizer', 'DoReFaQuantizer', 'BNNQuantizer', 'LsqQuantizer']
Expand All @@ -22,11 +22,12 @@ def __init__(self, model, config_list, optimizer=None):
self.layer_scale = {}

def validate_config(self, model, config_list):
schema = CompressorSchema([{
schema = QuantizerSchema([{
Optional('quant_types'): ['weight'],
Optional('quant_bits'): Or(8, {'weight': 8}),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)

schema.validate(config_list)
Expand Down Expand Up @@ -183,15 +184,16 @@ def validate_config(self, model, config_list):
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32),
Optional('output'): And(int, lambda n: 0 < n < 32),
})),
Optional('quant_start_step'): And(int, lambda n: n >= 0),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)

schema.validate(config_list)
Expand Down Expand Up @@ -386,13 +388,14 @@ def validate_config(self, model, config_list):
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32)
})),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)

schema.validate(config_list)
Expand Down Expand Up @@ -493,14 +496,15 @@ def validate_config(self, model, config_list):
config_list : list of dict
List of configurations
"""
schema = CompressorSchema([{
schema = QuantizerSchema([{
Optional('quant_types'): Schema([lambda x: x in ['weight', 'output']]),
Optional('quant_bits'): Or(And(int, lambda n: 0 < n < 32), Schema({
Optional('weight'): And(int, lambda n: 0 < n < 32),
Optional('output'): And(int, lambda n: 0 < n < 32),
})),
Optional('op_types'): [str],
Optional('op_names'): [str]
Optional('op_names'): [str],
Optional('exclude'): bool
}], model, logger)

schema.validate(config_list)
Expand Down
22 changes: 22 additions & 0 deletions nni/compression/pytorch/utils/config_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,25 @@ def _modify_schema(self, data_schema, model, logger):

def validate(self, data):
self.compressor_schema.validate(data)

def validate_exclude_sparsity(data):
if not ('exclude' in data or 'sparsity' in data):
raise SchemaError('Either sparisty or exclude must be specified.')
return True

def validate_exclude_quant_types_quant_bits(data):
if not ('exclude' in data or ('quant_types' in data and 'quant_bits' in data)):
raise SchemaError('Either (quant_types and quant_bits) or exclude must be specified.')
return True

class PrunerSchema(CompressorSchema):
def _modify_schema(self, data_schema, model, logger):
data_schema = super()._modify_schema(data_schema, model, logger)
data_schema[0] = And(data_schema[0], lambda d: validate_exclude_sparsity(d))
return data_schema

class QuantizerSchema(CompressorSchema):
def _modify_schema(self, data_schema, model, logger):
data_schema = super()._modify_schema(data_schema, model, logger)
data_schema[0] = And(data_schema[0], lambda d: validate_exclude_quant_types_quant_bits(d))
return data_schema