Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue add super group config #1708

Merged
merged 2 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,7 @@ def _set_op_type_configs(self, op_configs: OpTypeType):
:param op_configs: Dictionary containing configurations for ops of certain types
"""

@classmethod
def _build_supergroup_patterns(cls, supergroup_config: SupergroupType, callback: SupergroupConfigCallback,
def _build_supergroup_patterns(self, supergroup_config: SupergroupType, callback: SupergroupConfigCallback,
onnx_conn_graph_type_mapper: OnnxConnectedGraphTypeMapper) \
-> List[PatternType]:
"""
Expand All @@ -411,9 +410,17 @@ def _build_supergroup_patterns(cls, supergroup_config: SupergroupType, callback:
"""
op_list = supergroup_config[ConfigDictKeys.OP_LIST]
list_of_permutations = _build_list_of_permutations(op_list, onnx_conn_graph_type_mapper)
return self._build_list_of_pattern(list_of_permutations, callback)

@staticmethod
def _build_list_of_pattern(list_of_op_names: List[List[str]], callback: SupergroupConfigCallback) -> \
List[PatternType]:
"""
Builds list of patterns given a list of op names
"""
list_of_patterns = []
for permutation in list_of_permutations:
list_of_patterns.append(PatternType(pattern=permutation, action=callback))
for op_names in list_of_op_names:
list_of_patterns.append(PatternType(pattern=op_names, action=callback))
return list_of_patterns

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@

from onnx import onnx_pb
from aimet_common.defs import QuantizationDataType
from aimet_common.quantsim_config.json_config_importer import ConfigDictKeys, ConfigType, OpType, ParamType, OpTypeType
from aimet_common.graph_searcher import GraphSearcher
from aimet_common.quantsim_config.json_config_importer import ConfigDictKeys, ConfigType, OpType, ParamType, OpTypeType, \
SupergroupType
from aimet_common.quantsim_config.quantsim_config import QuantSimConfigurator as AimetCommonQuantSimConfigurator, \
get_setting_type
get_setting_type, SupergroupConfigCallback as AimetCommonSupergroupConfigCallback
from aimet_common.utils import AimetLogger
from aimet_onnx.meta.connectedgraph import ConnectedGraph
from aimet_onnx.utils import get_product_name_from_quantized_name
Expand All @@ -63,9 +65,27 @@ def __init__(self):
self.parameter_quantizers = []


class SupergroupConfigCallback(AimetCommonSupergroupConfigCallback):
""" Class acting as a callback for when supergroups are found """

def __init__(self, model: onnx_pb.ModelProto, op_to_quantizers: Dict):
super().__init__()
self._model = model
self._op_to_quantizers = op_to_quantizers

def __call__(self, _, op_list: List[str]):
# Turn off output quantizaters for all ops except for the last op
# Assumes op list is at least of length two
for op in op_list[:-1]:
output_quantizers = self._op_to_quantizers[op.dotted_name].output_quantizers
for output_quantizer in output_quantizers:
output_quantizer.enabled = False


class QuantSimConfigurator(AimetCommonQuantSimConfigurator):
""" Class for parsing and applying
quantsim configurations from json config file """

def __init__(self, model: onnx_pb.ModelProto, conn_graph: ConnectedGraph, config_file: str, quantsim_output_bw: int,
quantsim_param_bw: int, quantsim_data_type: QuantizationDataType = QuantizationDataType.int):
super().__init__(config_file, quantsim_data_type, quantsim_output_bw, quantsim_param_bw)
Expand Down Expand Up @@ -194,19 +214,32 @@ def _set_op_type_configs(self, op_configs: OpTypeType):
op_config = op_configs[op.type]
self._set_config_for_op(op_name, op_to_quantizer, op_config, modified_quantize_ops)

def _set_supergroup_configs(self, supergroups_configs):
def _set_supergroup_configs(self, supergroups_configs: List[SupergroupType]):
"""
Set supergroup specific configurations (fourth level of specificity in configuration file)
:param supergroups_configs: Configurations for supergroups
"""
patterns_with_callbacks = []
for supergroup_config in supergroups_configs:
callback = SupergroupConfigCallback(self._model, self._op_to_quantizers)
op_list = supergroup_config[ConfigDictKeys.OP_LIST]

# Op list consists of patterns to be searched for, we pass a list of op_list to be compatible with build_list
patterns = self._build_list_of_pattern([op_list], callback)
for pattern in patterns:
patterns_with_callbacks.append(pattern)

if patterns_with_callbacks:
graph_searcher = GraphSearcher(self._conn_graph, patterns_with_callbacks)
graph_searcher.find_all_patterns_in_graph_apply_actions()

def _set_model_input_configs(self, model_input_configs):
def _set_model_input_configs(self, model_input_configs: ConfigType):
"""
Set model input specific configurations (fifth level of specificity in configuration file)
:param model_input_configs: Configuration for model inputs
"""

def _set_model_output_configs(self, model_output_configs):
def _set_model_output_configs(self, model_output_configs: ConfigType):
"""
Set model output specific configurations (sixth level of specificity in configuration file)
:param model_output_configs: Configuration for model outputs
Expand Down Expand Up @@ -343,15 +376,15 @@ def _set_strict_symmetric(self, use_strict_symmetric: bool):
Set strict symmetric configuration for all quantizers in the model.
:param use_strict_symmetric: True or False setting for using strict symmetric mode
"""
for _, quantizer in self._quant_ops_dict.items():
for quantizer in self._quant_ops_dict.values():
quantizer.use_strict_symmetric = use_strict_symmetric

def _set_unsigned_symmetric(self, use_unsigned_symmetric: bool):
"""
Set unsigned symmetric configuration for all quantizers in the model.
:param use_unsigned_symmetric: True or False setting for using unsigned symmetric mode
"""
for _, quantizer in self._quant_ops_dict.items():
for quantizer in self._quant_ops_dict.values():
quantizer.use_unsigned_symmetric = use_unsigned_symmetric

def _generate_and_apply_op_instance_specific_config(self):
Expand Down
50 changes: 47 additions & 3 deletions TrainingExtensions/onnx/test/python/test_quantsim_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_qs_config_dummy_model(self):
assert sim.qc_quantize_op_dict['fc_w'].enabled == True
assert sim.qc_quantize_op_dict['fc_b'].enabled == False
assert sim.qc_quantize_op_dict['input'].enabled == False
assert sim.qc_quantize_op_dict['3'].enabled == True
assert sim.qc_quantize_op_dict['3'].enabled == False
assert sim.qc_quantize_op_dict['4'].enabled == True
assert sim.qc_quantize_op_dict['5'].enabled == True
assert sim.qc_quantize_op_dict['6'].enabled == True
Expand Down Expand Up @@ -160,7 +160,6 @@ def test_op_level_config(self):
"model_input": {},
"model_output": {}
}

if not os.path.exists('./data'):
os.makedirs('./data')
with open('./data/quantsim_config.json', 'w') as f:
Expand All @@ -172,6 +171,51 @@ def test_op_level_config(self):
assert sim.qc_quantize_op_dict['input'].enabled == True
assert sim.qc_quantize_op_dict['input'].use_symmetric_encodings == False

def test_parse_config_file_supergroups(self):
""" Test that supergroup quantization parameters are set correctly when using json config file """
model = test_models.build_dummy_model()

quantsim_config = {
"defaults": {
"ops": {
"is_output_quantized": "True",
"is_symmetric": "False"
},
"params": {
"is_quantized": "False",
"is_symmetric": "False"
}
},
"params": {},
"op_type": {},
"supergroups": [
{
"op_list": ["Conv", "Relu"]
},
{
"op_list": ["Relu", "MaxPool"]
},
],
"model_input": {},
"model_output": {}
}

if not os.path.exists('./data'):
os.makedirs('./data')
with open('./data/quantsim_config.json', 'w') as f:
json.dump(quantsim_config, f)
sim = QuantizationSimModel(model, config_file='./data/quantsim_config.json')

# 3 in conv output, 4 is relu output (even though it was not touched with Conv, relu pattern, it was disabled for
# relu maxpool pattern
for name in ['3', '4',]:
assert sim.qc_quantize_op_dict[name].enabled == False

assert sim.qc_quantize_op_dict['5'].enabled == True

if os.path.exists('./data/quantsim_config.json'):
os.remove('./data/quantsim_config.json')

def test_parse_config_file_symmetric_modes(self):
""" Test that model output quantization parameters are set correctly when using json config file """
model = test_models.build_dummy_model()
Expand Down Expand Up @@ -200,6 +244,6 @@ def test_parse_config_file_symmetric_modes(self):
json.dump(quantsim_config, f)
sim = QuantizationSimModel(model, config_file='./data/quantsim_config.json')

for _, quantizer in sim.qc_quantize_op_dict.items():
for quantizer in sim.qc_quantize_op_dict.values():
assert quantizer.use_strict_symmetric == True
assert quantizer.use_unsigned_symmetric == False