Skip to content

Commit

Permalink
Refactor Target Platform Capabilities - Phase 3 (#1297)
Browse files Browse the repository at this point in the history
* Refactor Target Platform Capabilities - Phase 3
Remove context manager functionality from the Target Platform Model.
Fix all tests and TP models.

---------

Co-authored-by: liord <lior.dikstein@altair-semi.com>
Co-authored-by: Ofir Gordon <ofirg6@gmail.com>
  • Loading branch information
3 people authored Dec 22, 2024
1 parent ef3270b commit b125f96
Show file tree
Hide file tree
Showing 45 changed files with 1,115 additions and 1,264 deletions.
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def set_tpc(self,
Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
' Please add the custom layer to Target Platform Capabilities (TPC), or file a feature '
'request or an issue if you believe this should be supported.') # pragma: no cover
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_config_list]):
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_configurations]):
Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover

self.tpc = tpc
Expand Down
6 changes: 3 additions & 3 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,12 +582,12 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
"""
# Filter quantization config options that don't match the graph.
_base_config = node_qc_options.base_config
_node_qc_options = node_qc_options.quantization_config_list
_node_qc_options = node_qc_options.quantization_configurations
if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])
for op_cfg in qc_opts.quantization_configurations])

# Filter node's QC options that match next nodes input bit-width.
_node_qc_options = [_option for _option in _node_qc_options
Expand All @@ -599,7 +599,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
# a qco from quantization_configurations with maximum activation bit-width.
if len(_node_qc_options) > 0:
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ def filter_node_qco_by_graph(node: BaseNode,
"""
# Filter quantization config options that don't match the graph.
_base_config = node_qc_options.base_config
_node_qc_options = node_qc_options.quantization_config_list
_node_qc_options = node_qc_options.quantization_configurations

# Build next_nodes list by appending to the node's next nodes list all nodes that are quantization preserving.
_next_nodes = graph.get_next_nodes(node)
next_nodes = []
while len(_next_nodes):
n = _next_nodes.pop(0)
qco = n.get_qco(tpc)
qp = [qc.quantization_preserving for qc in qco.quantization_config_list]
qp = [qc.quantization_preserving for qc in qco.quantization_configurations]
if not all(qp) and any(qp):
Logger.error(f'Attribute "quantization_preserving" should be the same for all QuantizaionConfigOptions in {n}.')
if qp[0]:
Expand All @@ -120,7 +120,7 @@ def filter_node_qco_by_graph(node: BaseNode,
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])
for op_cfg in qc_opts.quantization_configurations])

# Filter node's QC options that match next nodes input bit-width.
_node_qc_options = [_option for _option in _node_qc_options
Expand All @@ -132,7 +132,7 @@ def filter_node_qco_by_graph(node: BaseNode,
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
# a qco from quantization_configurations with maximum activation bit-width.
if len(_node_qc_options) > 0:
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def shift_negative_function(graph: Graph,
bypass_candidate_qc.activation_quantization_cfg.activation_quantization_params[SIGNED] = False
graph.shift_stats_collector(bypass_node, np.array(shift_value))

add_node_qco = add_node.get_qco(graph.tpc).quantization_config_list
add_node_qco = add_node.get_qco(graph.tpc).quantization_configurations
for op_qc_idx, candidate_qc in enumerate(add_node.candidates_quantization_cfg):
for attr in add_node.get_node_weights_attributes():
candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
Expand Down Expand Up @@ -535,7 +535,7 @@ def apply_shift_negative_correction(graph: Graph,
# Skip substitution if QuantizationMethod is uniform.
node_qco = n.get_qco(graph.tpc)
if any([op_qc.activation_quantization_method is QuantizationMethod.UNIFORM
for op_qc in node_qco.quantization_config_list]):
for op_qc in node_qco.quantization_configurations]):
continue

if snc_node_types.apply(n):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema

OperatorSetNames = schema.OperatorSetNames
Signedness = schema.Signedness
AttributeQuantizationConfig = schema.AttributeQuantizationConfig
OpQuantizationConfig = schema.OpQuantizationConfig
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuant
Raises:
AssertionError: If the default quantization configuration list contains more than one configuration option.
"""
assert len(tp_model.default_qco.quantization_config_list) == 1, \
assert len(tp_model.default_qco.quantization_configurations) == 1, \
f"Default quantization configuration options must contain only one option, " \
f"but found {len(tp_model.default_qco.quantization_config_list)} configurations." # pragma: no cover
return tp_model.default_qco.quantization_config_list[0]
f"but found {len(tp_model.default_qco.quantization_configurations)} configurations." # pragma: no cover
return tp_model.default_qco.quantization_configurations[0]


def is_opset_in_model(tp_model: TargetPlatformModel, opset_name: str) -> bool:
Expand All @@ -82,8 +82,7 @@ def is_opset_in_model(tp_model: TargetPlatformModel, opset_name: str) -> bool:
bool: True if an OperatorsSet with the given name exists in the target platform model,
otherwise False.
"""
return opset_name in [x.name for x in tp_model.operator_set]

return tp_model.operator_set is not None and opset_name in [x.name for x in tp_model.operator_set]

def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optional[OperatorsSetBase]:
"""
Expand Down
Loading

0 comments on commit b125f96

Please sign in to comment.