Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Target Platform Capabilities - Phase 3 #1297

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion model_compression_toolkit/core/common/graph/base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def set_tpc(self,
Logger.critical(f'MCT does not support optimizing Keras custom layers. Found a layer of type {n.type}. '
' Please add the custom layer to Target Platform Capabilities (TPC), or file a feature '
'request or an issue if you believe this should be supported.') # pragma: no cover
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_config_list]):
if any([qc.default_weight_attr_config.enable_weights_quantization for qc in n.get_qco(tpc).quantization_configurations]):
Logger.critical(f'Layer identified: {n.type}. MCT does not support weight quantization for Keras custom layers.') # pragma: no cover

self.tpc = tpc
Expand Down
6 changes: 3 additions & 3 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,12 +582,12 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
"""
# Filter quantization config options that don't match the graph.
_base_config = node_qc_options.base_config
_node_qc_options = node_qc_options.quantization_config_list
_node_qc_options = node_qc_options.quantization_configurations
if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])
for op_cfg in qc_opts.quantization_configurations])

# Filter node's QC options that match next nodes input bit-width.
_node_qc_options = [_option for _option in _node_qc_options
Expand All @@ -599,7 +599,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
# a qco from quantization_configurations with maximum activation bit-width.
if len(_node_qc_options) > 0:
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,15 @@ def filter_node_qco_by_graph(node: BaseNode,
"""
# Filter quantization config options that don't match the graph.
_base_config = node_qc_options.base_config
_node_qc_options = node_qc_options.quantization_config_list
_node_qc_options = node_qc_options.quantization_configurations

# Build next_nodes list by appending to the node's next nodes list all nodes that are quantization preserving.
_next_nodes = graph.get_next_nodes(node)
next_nodes = []
while len(_next_nodes):
n = _next_nodes.pop(0)
qco = n.get_qco(tpc)
qp = [qc.quantization_preserving for qc in qco.quantization_config_list]
qp = [qc.quantization_preserving for qc in qco.quantization_configurations]
if not all(qp) and any(qp):
Logger.error(f'Attribute "quantization_preserving" should be the same for all QuantizaionConfigOptions in {n}.')
if qp[0]:
Expand All @@ -120,7 +120,7 @@ def filter_node_qco_by_graph(node: BaseNode,
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])
for op_cfg in qc_opts.quantization_configurations])

# Filter node's QC options that match next nodes input bit-width.
_node_qc_options = [_option for _option in _node_qc_options
Expand All @@ -132,7 +132,7 @@ def filter_node_qco_by_graph(node: BaseNode,
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
# a qco from quantization_configurations with maximum activation bit-width.
if len(_node_qc_options) > 0:
output_act_bitwidth = {qco.activation_n_bits: i for i, qco in enumerate(_node_qc_options)}
_base_config = _node_qc_options[output_act_bitwidth[max(output_act_bitwidth)]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def shift_negative_function(graph: Graph,
bypass_candidate_qc.activation_quantization_cfg.activation_quantization_params[SIGNED] = False
graph.shift_stats_collector(bypass_node, np.array(shift_value))

add_node_qco = add_node.get_qco(graph.tpc).quantization_config_list
add_node_qco = add_node.get_qco(graph.tpc).quantization_configurations
for op_qc_idx, candidate_qc in enumerate(add_node.candidates_quantization_cfg):
for attr in add_node.get_node_weights_attributes():
candidate_qc.weights_quantization_cfg.get_attr_config(attr).enable_weights_quantization = False
Expand Down Expand Up @@ -535,7 +535,7 @@ def apply_shift_negative_correction(graph: Graph,
# Skip substitution if QuantizationMethod is uniform.
node_qco = n.get_qco(graph.tpc)
if any([op_qc.activation_quantization_method is QuantizationMethod.UNIFORM
for op_qc in node_qco.quantization_config_list]):
for op_qc in node_qco.quantization_configurations]):
continue

if snc_node_types.apply(n):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema

OperatorSetNames = schema.OperatorSetNames
Signedness = schema.Signedness
AttributeQuantizationConfig = schema.AttributeQuantizationConfig
OpQuantizationConfig = schema.OpQuantizationConfig
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuant
Raises:
AssertionError: If the default quantization configuration list contains more than one configuration option.
"""
assert len(tp_model.default_qco.quantization_config_list) == 1, \
assert len(tp_model.default_qco.quantization_configurations) == 1, \
f"Default quantization configuration options must contain only one option, " \
f"but found {len(tp_model.default_qco.quantization_config_list)} configurations." # pragma: no cover
return tp_model.default_qco.quantization_config_list[0]
f"but found {len(tp_model.default_qco.quantization_configurations)} configurations." # pragma: no cover
return tp_model.default_qco.quantization_configurations[0]


def is_opset_in_model(tp_model: TargetPlatformModel, opset_name: str) -> bool:
Expand All @@ -82,7 +82,10 @@ def is_opset_in_model(tp_model: TargetPlatformModel, opset_name: str) -> bool:
bool: True if an OperatorsSet with the given name exists in the target platform model,
otherwise False.
"""
return opset_name in [x.name for x in tp_model.operator_set]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

return tp_model.operator_set is not None and opset_name in [x.name for x in tp_model.operator_set]

if tp_model.operator_set is None:
return False
else:
return opset_name in [x.name for x in tp_model.operator_set]


def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optional[OperatorsSetBase]:
Expand Down
Loading
Loading