Skip to content

Commit

Permalink
Issue add super group config (#1708)
Browse files Browse the repository at this point in the history
* Add supergroup config

Signed-off-by: Harshita Mangal <quic_mangal@quicinc.com>

* Improve readability

Signed-off-by: Harshita Mangal <quic_mangal@quicinc.com>

Signed-off-by: Harshita Mangal <quic_mangal@quicinc.com>
  • Loading branch information
quic-mangal committed Nov 14, 2022
1 parent f93c120 commit a422782
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,7 @@ def _set_op_type_configs(self, op_configs: OpTypeType):
:param op_configs: Dictionary containing configurations for ops of certain types
"""

@classmethod
def _build_supergroup_patterns(cls, supergroup_config: SupergroupType, callback: SupergroupConfigCallback,
def _build_supergroup_patterns(self, supergroup_config: SupergroupType, callback: SupergroupConfigCallback,
onnx_conn_graph_type_mapper: OnnxConnectedGraphTypeMapper) \
-> List[PatternType]:
"""
Expand All @@ -411,9 +410,17 @@ def _build_supergroup_patterns(cls, supergroup_config: SupergroupType, callback:
"""
op_list = supergroup_config[ConfigDictKeys.OP_LIST]
list_of_permutations = _build_list_of_permutations(op_list, onnx_conn_graph_type_mapper)
return self._build_list_of_pattern(list_of_permutations, callback)

@staticmethod
def _build_list_of_pattern(list_of_op_names: List[List[str]], callback: SupergroupConfigCallback) -> \
List[PatternType]:
"""
Builds list of patterns given a list of op names
"""
list_of_patterns = []
for permutation in list_of_permutations:
list_of_patterns.append(PatternType(pattern=permutation, action=callback))
for op_names in list_of_op_names:
list_of_patterns.append(PatternType(pattern=op_names, action=callback))
return list_of_patterns

@abstractmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@

from onnx import onnx_pb
from aimet_common.defs import QuantizationDataType
from aimet_common.quantsim_config.json_config_importer import ConfigDictKeys, ConfigType, OpType, ParamType, OpTypeType
from aimet_common.graph_searcher import GraphSearcher
from aimet_common.quantsim_config.json_config_importer import ConfigDictKeys, ConfigType, OpType, ParamType, OpTypeType, \
SupergroupType
from aimet_common.quantsim_config.quantsim_config import QuantSimConfigurator as AimetCommonQuantSimConfigurator, \
get_setting_type
get_setting_type, SupergroupConfigCallback as AimetCommonSupergroupConfigCallback
from aimet_common.utils import AimetLogger
from aimet_onnx.meta.connectedgraph import ConnectedGraph
from aimet_onnx.utils import get_product_name_from_quantized_name
Expand All @@ -63,9 +65,27 @@ def __init__(self):
self.parameter_quantizers = []


class SupergroupConfigCallback(AimetCommonSupergroupConfigCallback):
""" Class acting as a callback for when supergroups are found """

def __init__(self, model: onnx_pb.ModelProto, op_to_quantizers: Dict):
super().__init__()
self._model = model
self._op_to_quantizers = op_to_quantizers

def __call__(self, _, op_list: List[str]):
# Turn off output quantizaters for all ops except for the last op
# Assumes op list is at least of length two
for op in op_list[:-1]:
output_quantizers = self._op_to_quantizers[op.dotted_name].output_quantizers
for output_quantizer in output_quantizers:
output_quantizer.enabled = False


class QuantSimConfigurator(AimetCommonQuantSimConfigurator):
""" Class for parsing and applying
quantsim configurations from json config file """

def __init__(self, model: onnx_pb.ModelProto, conn_graph: ConnectedGraph, config_file: str, quantsim_output_bw: int,
quantsim_param_bw: int, quantsim_data_type: QuantizationDataType = QuantizationDataType.int):
super().__init__(config_file, quantsim_data_type, quantsim_output_bw, quantsim_param_bw)
Expand Down Expand Up @@ -194,19 +214,32 @@ def _set_op_type_configs(self, op_configs: OpTypeType):
op_config = op_configs[op.type]
self._set_config_for_op(op_name, op_to_quantizer, op_config, modified_quantize_ops)

def _set_supergroup_configs(self, supergroups_configs):
def _set_supergroup_configs(self, supergroups_configs: List[SupergroupType]):
"""
Set supergroup specific configurations (fourth level of specificity in configuration file)
:param supergroups_configs: Configurations for supergroups
"""
patterns_with_callbacks = []
for supergroup_config in supergroups_configs:
callback = SupergroupConfigCallback(self._model, self._op_to_quantizers)
op_list = supergroup_config[ConfigDictKeys.OP_LIST]

# Op list consists of patterns to be searched for, we pass a list of op_list to be compatible with build_list
patterns = self._build_list_of_pattern([op_list], callback)
for pattern in patterns:
patterns_with_callbacks.append(pattern)

if patterns_with_callbacks:
graph_searcher = GraphSearcher(self._conn_graph, patterns_with_callbacks)
graph_searcher.find_all_patterns_in_graph_apply_actions()

def _set_model_input_configs(self, model_input_configs):
def _set_model_input_configs(self, model_input_configs: ConfigType):
"""
Set model input specific configurations (fifth level of specificity in configuration file)
:param model_input_configs: Configuration for model inputs
"""

def _set_model_output_configs(self, model_output_configs):
def _set_model_output_configs(self, model_output_configs: ConfigType):
"""
Set model output specific configurations (sixth level of specificity in configuration file)
:param model_output_configs: Configuration for model outputs
Expand Down Expand Up @@ -343,15 +376,15 @@ def _set_strict_symmetric(self, use_strict_symmetric: bool):
Set strict symmetric configuration for all quantizers in the model.
:param use_strict_symmetric: True or False setting for using strict symmetric mode
"""
for _, quantizer in self._quant_ops_dict.items():
for quantizer in self._quant_ops_dict.values():
quantizer.use_strict_symmetric = use_strict_symmetric

def _set_unsigned_symmetric(self, use_unsigned_symmetric: bool):
"""
Set unsigned symmetric configuration for all quantizers in the model.
:param use_unsigned_symmetric: True or False setting for using unsigned symmetric mode
"""
for _, quantizer in self._quant_ops_dict.items():
for quantizer in self._quant_ops_dict.values():
quantizer.use_unsigned_symmetric = use_unsigned_symmetric

def _generate_and_apply_op_instance_specific_config(self):
Expand Down
50 changes: 47 additions & 3 deletions TrainingExtensions/onnx/test/python/test_quantsim_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_qs_config_dummy_model(self):
assert sim.qc_quantize_op_dict['fc_w'].enabled == True
assert sim.qc_quantize_op_dict['fc_b'].enabled == False
assert sim.qc_quantize_op_dict['input'].enabled == False
assert sim.qc_quantize_op_dict['3'].enabled == True
assert sim.qc_quantize_op_dict['3'].enabled == False
assert sim.qc_quantize_op_dict['4'].enabled == True
assert sim.qc_quantize_op_dict['5'].enabled == True
assert sim.qc_quantize_op_dict['6'].enabled == True
Expand Down Expand Up @@ -160,7 +160,6 @@ def test_op_level_config(self):
"model_input": {},
"model_output": {}
}

if not os.path.exists('./data'):
os.makedirs('./data')
with open('./data/quantsim_config.json', 'w') as f:
Expand All @@ -172,6 +171,51 @@ def test_op_level_config(self):
assert sim.qc_quantize_op_dict['input'].enabled == True
assert sim.qc_quantize_op_dict['input'].use_symmetric_encodings == False

def test_parse_config_file_supergroups(self):
""" Test that supergroup quantization parameters are set correctly when using json config file """
model = test_models.build_dummy_model()

quantsim_config = {
"defaults": {
"ops": {
"is_output_quantized": "True",
"is_symmetric": "False"
},
"params": {
"is_quantized": "False",
"is_symmetric": "False"
}
},
"params": {},
"op_type": {},
"supergroups": [
{
"op_list": ["Conv", "Relu"]
},
{
"op_list": ["Relu", "MaxPool"]
},
],
"model_input": {},
"model_output": {}
}

if not os.path.exists('./data'):
os.makedirs('./data')
with open('./data/quantsim_config.json', 'w') as f:
json.dump(quantsim_config, f)
sim = QuantizationSimModel(model, config_file='./data/quantsim_config.json')

# 3 in conv output, 4 is relu output (even though it was not touched with Conv, relu pattern, it was disabled for
# relu maxpool pattern
for name in ['3', '4',]:
assert sim.qc_quantize_op_dict[name].enabled == False

assert sim.qc_quantize_op_dict['5'].enabled == True

if os.path.exists('./data/quantsim_config.json'):
os.remove('./data/quantsim_config.json')

def test_parse_config_file_symmetric_modes(self):
""" Test that model output quantization parameters are set correctly when using json config file """
model = test_models.build_dummy_model()
Expand Down Expand Up @@ -200,6 +244,6 @@ def test_parse_config_file_symmetric_modes(self):
json.dump(quantsim_config, f)
sim = QuantizationSimModel(model, config_file='./data/quantsim_config.json')

for _, quantizer in sim.qc_quantize_op_dict.items():
for quantizer in sim.qc_quantize_op_dict.values():
assert quantizer.use_strict_symmetric == True
assert quantizer.use_unsigned_symmetric == False

0 comments on commit a422782

Please sign in to comment.