From 951568ca5b52dfa9120340eb19ce059d550f9a74 Mon Sep 17 00:00:00 2001 From: liord Date: Tue, 10 Dec 2024 11:00:43 +0200 Subject: [PATCH 1/8] **Refactor Target Platform Capabilities - Phase 2** - Convert all schema classes to immutable dataclasses, replacing existing methods with equivalent dataclass methods (e.g., `replace`). - Ensure all schema classes are strictly immutable to enhance reliability and maintain consistency. - Update target platform model versions to align with the new class structure. - Refactor tests to support and validate the updated class types and functionality. --- .../target_platform_capabilities/schema/v1.py | 889 +++++++++--------- .../tpc_models/imx500_tpc/v1/tp_model.py | 11 +- .../tpc_models/imx500_tpc/v1_lut/tp_model.py | 6 +- .../tpc_models/imx500_tpc/v1_pot/tp_model.py | 6 +- .../tpc_models/imx500_tpc/v2/tp_model.py | 11 +- .../tpc_models/imx500_tpc/v2_lut/tp_model.py | 6 +- .../tpc_models/imx500_tpc/v3/tp_model.py | 11 +- .../tpc_models/imx500_tpc/v3_lut/tp_model.py | 6 +- .../tpc_models/imx500_tpc/v4/tp_model.py | 18 +- .../tpc_models/tflite_tpc/v1/tp_model.py | 4 +- tests/common_tests/test_tp_model.py | 14 +- .../tflite_int8/imx500_int8_tp_model.py | 6 +- .../feature_networks/activation_16bit_test.py | 10 +- .../feature_networks/manual_bit_selection.py | 12 +- .../function_tests/test_layer_fusing.py | 4 +- .../test_quant_config_filtering.py | 6 +- .../non_parallel_tests/test_keras_tp_model.py | 2 +- .../function_tests/layer_fusing_test.py | 2 +- .../function_tests/test_pytorch_tp_model.py | 2 +- .../test_quant_config_filtering.py | 6 +- .../feature_models/activation_16bit_test.py | 13 +- .../feature_models/manual_bit_selection.py | 14 +- 22 files changed, 525 insertions(+), 534 deletions(-) diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v1.py b/model_compression_toolkit/target_platform_capabilities/schema/v1.py index 019454642..16934a3ad 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/v1.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/v1.py @@ -12,23 +12,72 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -import copy - +from dataclasses import replace, dataclass, asdict, field from enum import Enum - -import pprint - from typing import Dict, Any, Union, Tuple, List, Optional - from mct_quantizers import QuantizationMethod from model_compression_toolkit.constants import FLOAT_BITWIDTH - from model_compression_toolkit.logger import Logger from model_compression_toolkit.target_platform_capabilities.constants import OPS_SET_LIST -from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import \ - get_current_tp_model, _current_tp_model -from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import clone_and_edit_object_params + _current_tp_model + +class OperatorSetNames(Enum): + OPSET_NO_QUANTIZATION = "NoQuantization" + OPSET_QUANTIZATION_PRESERVING = "QuantizationPreserving" + OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS = "DimensionManipulationOpsWithWeights" + OPSET_DIMENSION_MANIPULATION_OPS = "DimensionManipulationOps" + OPSET_MERGE_OPS = "MergeOps" + OPSET_CONV = "Conv" + OPSET_DEPTHWISE_CONV = "DepthwiseConv2D" + OPSET_CONV_TRANSPOSE = "ConvTraspose" + OPSET_FULLY_CONNECTED = "FullyConnected" + OPSET_CONCATENATE = "Concatenate" + OPSET_STACK = "Stack" + OPSET_UNSTACK = "Unstack" + OPSET_GATHER = "Gather" + OPSET_EXPAND = "Expend" + OPSET_BATCH_NORM = "BatchNorm" + OPSET_ANY_RELU = "AnyReLU" + OPSET_ADD = "Add" + OPSET_SUB = "Sub" + OPSET_MUL = "Mul" + OPSET_DIV = "Div" + OPSET_MIN_MAX = "MinMax" + OPSET_PRELU = "PReLU" + OPSET_SWISH = "Swish" + OPSET_SIGMOID = "Sigmoid" + OPSET_TANH = "Tanh" + OPSET_GELU = "Gelu" + OPSET_HARDSIGMOID = "HardSigmoid" + OPSET_HARDSWISH = "HardSwish" + OPSET_FLATTEN = "Flatten" + OPSET_GET_ITEM = "GetItem" + OPSET_RESHAPE = "Reshape" + OPSET_UNSQUEEZE = "Unsqueeze" + OPSET_SQUEEZE = "Squeeze" + OPSET_PERMUTE = "Permute" + OPSET_TRANSPOSE = "Transpose" + OPSET_DROPOUT = "Dropout" + OPSET_SPLIT = "Split" + OPSET_CHUNK = "Chunk" + OPSET_UNBIND = "Unbind" + OPSET_MAXPOOL = "MaxPool" + OPSET_SIZE = "Size" + OPSET_SHAPE = "Shape" + OPSET_EQUAL = "Equal" + OPSET_ARGMAX = "ArgMax" + OPSET_TOPK = "TopK" + OPSET_FAKE_QUANT_WITH_MIN_MAX_VARS = "FakeQuantWithMinMaxVars" + OPSET_COMBINED_NON_MAX_SUPPRESSION = "CombinedNonMaxSuppression" + OPSET_CROPPING2D = "Cropping2D" + OPSET_ZERO_PADDING2d = "ZeroPadding2D" + OPSET_CAST = "Cast" + OPSET_STRIDED_SLICE = "StridedSlice" + + @classmethod + def get_values(cls): + return [v.value for v in cls] class Signedness(Enum): @@ -44,451 +93,431 @@ class Signedness(Enum): UNSIGNED = 2 +@dataclass(frozen=True) class AttributeQuantizationConfig: """ - Hold the quantization configuration of a weight attribute of a layer. - """ - def __init__(self, - weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO, - weights_n_bits: int = FLOAT_BITWIDTH, - weights_per_channel_threshold: bool = False, - enable_weights_quantization: bool = False, - lut_values_bitwidth: Union[int, None] = None, # If None - set 8 in hptq, o.w use it - ): - """ - Initializes an attribute quantization config. + Holds the quantization configuration of a weight attribute of a layer. - Args: - weights_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for weights quantization. - weights_n_bits (int): Number of bits to quantize the coefficients. - weights_per_channel_threshold (bool): Whether to quantize the weights per-channel or not (per-tensor). - enable_weights_quantization (bool): Whether to quantize the model weights or not. - lut_values_bitwidth (int): Number of bits to use when quantizing in look-up-table. - - """ - - self.weights_quantization_method = weights_quantization_method - self.weights_n_bits = weights_n_bits - self.weights_per_channel_threshold = weights_per_channel_threshold - self.enable_weights_quantization = enable_weights_quantization - self.lut_values_bitwidth = lut_values_bitwidth + Attributes: + weights_quantization_method (QuantizationMethod): The method to use from QuantizationMethod for weights quantization. + weights_n_bits (int): Number of bits to quantize the coefficients. + weights_per_channel_threshold (bool): Indicates whether to quantize the weights per-channel or per-tensor. + enable_weights_quantization (bool): Indicates whether to quantize the model weights or not. + lut_values_bitwidth (Union[int, None]): Number of bits to use when quantizing in a look-up table. + If None, defaults to 8 in hptq; otherwise, it uses the provided value. + """ + weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO + weights_n_bits: int = FLOAT_BITWIDTH + weights_per_channel_threshold: bool = False + enable_weights_quantization: bool = False + lut_values_bitwidth: Union[int, None] = None - def clone_and_edit(self, **kwargs): + def __post_init__(self): """ - Clone the quantization config and edit some of its attributes. + Post-initialization processing for input validation. - Args: - **kwargs: Keyword arguments to edit the configuration to clone. - - Returns: - Edited quantization configuration. + Raises: + Logger critical if attributes are of incorrect type or have invalid values. """ + if not isinstance(self.weights_n_bits, int) or self.weights_n_bits < 1: + Logger.critical("weights_n_bits must be a positive integer.") + if not isinstance(self.enable_weights_quantization, bool): + Logger.critical("enable_weights_quantization must be a boolean.") + if self.lut_values_bitwidth is not None and not isinstance(self.lut_values_bitwidth, int): + Logger.critical("lut_values_bitwidth must be an integer or None.") - return clone_and_edit_object_params(self, **kwargs) - - def __eq__(self, other): + def clone_and_edit(self, **kwargs) -> 'AttributeQuantizationConfig': """ - Is this configuration equal to another object. + Clone the current AttributeQuantizationConfig and edit some of its attributes. Args: - other: Object to compare. + **kwargs: Keyword arguments representing the attributes to edit in the cloned instance. Returns: - - Whether this configuration is equal to another object or not. + AttributeQuantizationConfig: A new instance of AttributeQuantizationConfig with updated attributes. """ - if not isinstance(other, AttributeQuantizationConfig): - return False # pragma: no cover - return self.weights_quantization_method == other.weights_quantization_method and \ - self.weights_n_bits == other.weights_n_bits and \ - self.weights_per_channel_threshold == other.weights_per_channel_threshold and \ - self.enable_weights_quantization == other.enable_weights_quantization and \ - self.lut_values_bitwidth == other.lut_values_bitwidth + return replace(self, **kwargs) +@dataclass(frozen=True) class OpQuantizationConfig: """ OpQuantizationConfig is a class to configure the quantization parameters of an operator. - """ - - def __init__(self, - default_weight_attr_config: AttributeQuantizationConfig, - attr_weights_configs_mapping: Dict[str, AttributeQuantizationConfig], - activation_quantization_method: QuantizationMethod, - activation_n_bits: int, - supported_input_activation_n_bits: Union[int, Tuple[int]], - enable_activation_quantization: bool, - quantization_preserving: bool, - fixed_scale: float, - fixed_zero_point: int, - simd_size: int, - signedness: Signedness - ): - """ - - Args: - default_weight_attr_config (AttributeQuantizationConfig): A default attribute quantization configuration for the operation. - attr_weights_configs_mapping (Dict[str, AttributeQuantizationConfig]): A mapping between an op attribute name and its quantization configuration. - activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization. - activation_n_bits (int): Number of bits to quantize the activations. - supported_input_activation_n_bits (int or Tuple[int]): Number of bits that operator accepts as input. - enable_activation_quantization (bool): Whether to quantize the model activations or not. - quantization_preserving (bool): Whether quantization parameters should be the same for an operator's input and output. - fixed_scale (float): Scale to use for an operator quantization parameters. - fixed_zero_point (int): Zero-point to use for an operator quantization parameters. - simd_size (int): Per op integer representing the Single Instruction, Multiple Data (SIMD) width of an operator. It indicates the number of data elements that can be fetched and processed simultaneously in a single instruction. - signedness (bool): Set activation quantization signedness. - - """ - self.default_weight_attr_config = default_weight_attr_config - self.attr_weights_configs_mapping = attr_weights_configs_mapping + Args: + default_weight_attr_config (AttributeQuantizationConfig): A default attribute quantization configuration for the operation. + attr_weights_configs_mapping (Dict[str, AttributeQuantizationConfig]): A mapping between an op attribute name and its quantization configuration. + activation_quantization_method (QuantizationMethod): Which method to use from QuantizationMethod for activation quantization. + activation_n_bits (int): Number of bits to quantize the activations. + supported_input_activation_n_bits (int or Tuple[int]): Number of bits that operator accepts as input. + enable_activation_quantization (bool): Whether to quantize the model activations or not. + quantization_preserving (bool): Whether quantization parameters should be the same for an operator's input and output. + fixed_scale (float): Scale to use for an operator quantization parameters. + fixed_zero_point (int): Zero-point to use for an operator quantization parameters. + simd_size (int): Per op integer representing the Single Instruction, Multiple Data (SIMD) width of an operator. It indicates the number of data elements that can be fetched and processed simultaneously in a single instruction. + signedness (bool): Set activation quantization signedness. - self.activation_quantization_method = activation_quantization_method - self.activation_n_bits = activation_n_bits - if isinstance(supported_input_activation_n_bits, tuple): - self.supported_input_activation_n_bits = supported_input_activation_n_bits - elif isinstance(supported_input_activation_n_bits, int): - self.supported_input_activation_n_bits = (supported_input_activation_n_bits,) - else: - Logger.critical(f"Supported_input_activation_n_bits only accepts int or tuple of ints, but got {type(supported_input_activation_n_bits)}") # pragma: no cover - self.enable_activation_quantization = enable_activation_quantization - self.quantization_preserving = quantization_preserving - self.fixed_scale = fixed_scale - self.fixed_zero_point = fixed_zero_point - self.signedness = signedness - self.simd_size = simd_size + """ + default_weight_attr_config: AttributeQuantizationConfig + attr_weights_configs_mapping: Dict[str, AttributeQuantizationConfig] + activation_quantization_method: QuantizationMethod + activation_n_bits: int + supported_input_activation_n_bits: Union[int, Tuple[int]] + enable_activation_quantization: bool + quantization_preserving: bool + fixed_scale: float + fixed_zero_point: int + simd_size: int + signedness: Signedness + + def __post_init__(self) -> None: + """ + Post-initialization processing for input validation. + + Raises: + Logger critical if supported_input_activation_n_bits is not an int or a tuple of ints. + """ + if isinstance(self.supported_input_activation_n_bits, int): + object.__setattr__(self, 'supported_input_activation_n_bits', (self.supported_input_activation_n_bits,)) + elif not isinstance(self.supported_input_activation_n_bits, tuple): + Logger.critical( + f"Supported_input_activation_n_bits only accepts int or tuple of ints, but got {type(self.supported_input_activation_n_bits)}") # pragma: no cover - def get_info(self): + def get_info(self) -> Dict[str, Any]: """ + Get information about the quantization configuration. - Returns: Info about the quantization configuration as a dictionary. - + Returns: + dict: Information about the quantization configuration as a dictionary. """ - return self.__dict__ # pragma: no cover + return asdict(self) # pragma: no cover - def clone_and_edit(self, attr_to_edit: Dict[str, Dict[str, Any]] = {}, **kwargs): + def clone_and_edit(self, attr_to_edit: Dict[str, Dict[str, Any]] = {}, **kwargs) -> 'OpQuantizationConfig': """ Clone the quantization config and edit some of its attributes. + Args: - attr_to_edit: A mapping between attributes names to edit and their parameters that - should be edited to a new value. + attr_to_edit (Dict[str, Dict[str, Any]]): A mapping between attribute names to edit and their parameters that + should be edited to a new value. **kwargs: Keyword arguments to edit the configuration to clone. Returns: - Edited quantization configuration. + OpQuantizationConfig: Edited quantization configuration. """ - qc = clone_and_edit_object_params(self, **kwargs) - - # optionally: editing specific parameters in the config of specified attributes - edited_attrs = copy.deepcopy(qc.attr_weights_configs_mapping) - for attr_name, attr_cfg in qc.attr_weights_configs_mapping.items(): - if attr_name in attr_to_edit: - edited_attrs[attr_name] = attr_cfg.clone_and_edit(**attr_to_edit[attr_name]) - - qc.attr_weights_configs_mapping = edited_attrs + # Clone and update top-level attributes + updated_config = replace(self, **kwargs) - return qc + # Clone and update nested immutable dataclasses in `attr_weights_configs_mapping` + updated_attr_mapping = { + attr_name: (attr_cfg.clone_and_edit(**attr_to_edit[attr_name]) + if attr_name in attr_to_edit else attr_cfg) + for attr_name, attr_cfg in updated_config.attr_weights_configs_mapping.items() + } - def __eq__(self, other): - """ - Is this configuration equal to another object. - Args: - other: Object to compare. - - Returns: - Whether this configuration is equal to another object or not. - """ - if not isinstance(other, OpQuantizationConfig): - return False # pragma: no cover - return self.default_weight_attr_config == other.default_weight_attr_config and \ - self.attr_weights_configs_mapping == other.attr_weights_configs_mapping and \ - self.activation_quantization_method == other.activation_quantization_method and \ - self.activation_n_bits == other.activation_n_bits and \ - self.supported_input_activation_n_bits == other.supported_input_activation_n_bits and \ - self.enable_activation_quantization == other.enable_activation_quantization and \ - self.signedness == other.signedness and \ - self.simd_size == other.simd_size + # Return a new instance with the updated attribute mapping + return replace(updated_config, attr_weights_configs_mapping=updated_attr_mapping) @property def max_input_activation_n_bits(self) -> int: """ - Get maximum supported input bit-width. - - Returns: Maximum supported input bit-width. + Get the maximum supported input bit-width. + Returns: + int: Maximum supported input bit-width. """ return max(self.supported_input_activation_n_bits) +@dataclass(frozen=True) class QuantizationConfigOptions: """ + QuantizationConfigOptions wraps a set of quantization configurations to consider during the quantization of an operator. - Wrap a set of quantization configurations to consider during the quantization - of an operator. - + Attributes: + quantization_config_list (List[OpQuantizationConfig]): List of possible OpQuantizationConfig to gather. + base_config (Union[OpQuantizationConfig, None]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner. """ - def __init__(self, - quantization_config_list: List[OpQuantizationConfig], - base_config: OpQuantizationConfig = None): - """ + quantization_config_list: List[OpQuantizationConfig] + base_config: Union[OpQuantizationConfig, None] = None - Args: - quantization_config_list (List[OpQuantizationConfig]): List of possible OpQuantizationConfig to gather. - base_config (OpQuantizationConfig): Fallback OpQuantizationConfig to use when optimizing the model in a non mixed-precision manner. - """ - - assert isinstance(quantization_config_list, - list), f"'QuantizationConfigOptions' options list must be a list, but received: {type(quantization_config_list)}." - for cfg in quantization_config_list: - assert isinstance(cfg, OpQuantizationConfig),\ - f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}." - self.quantization_config_list = quantization_config_list - if len(quantization_config_list) > 1: - assert base_config is not None, \ - f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization." - assert any([base_config is cfg for cfg in quantization_config_list]), \ - f"'base_config' must be included in the quantization config options list." - # Enforce base_config to be a reference to an instance in quantization_config_list. - self.base_config = base_config - elif len(quantization_config_list) == 1: - assert base_config is None or base_config == quantization_config_list[0], "'base_config' should be included in 'quantization_config_list'" - # Set base_config to be a reference to the first instance in quantization_config_list. - self.base_config = quantization_config_list[0] - else: - raise AssertionError("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.") - - def __eq__(self, other): - """ - Is this QCOptions equal to another object. - Args: - other: Object to compare. + def __post_init__(self) -> None: + """ + Post-initialization processing for input validation. - Returns: - Whether this QCOptions equal to another object or not. + Raises: + Logger critical if quantization_config_list is not a list, contains invalid elements, or if base_config is not set correctly. """ + # Validate `quantization_config_list` + if not isinstance(self.quantization_config_list, list): + Logger.critical( + f"'quantization_config_list' must be a list, but received: {type(self.quantization_config_list)}.") + for cfg in self.quantization_config_list: + if not isinstance(cfg, OpQuantizationConfig): + Logger.critical( + f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}.") - if not isinstance(other, QuantizationConfigOptions): - return False - if len(self.quantization_config_list) != len(other.quantization_config_list): - return False - for qc, other_qc in zip(self.quantization_config_list, other.quantization_config_list): - if qc != other_qc: - return False - return True + # Handle base_config + if len(self.quantization_config_list) > 1: + if self.base_config is None: + Logger.critical(f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization.") + if not any(self.base_config == cfg for cfg in self.quantization_config_list): + Logger.critical(f"'base_config' must be included in the quantization config options list.") + elif len(self.quantization_config_list) == 1: + if self.base_config is None: + object.__setattr__(self, 'base_config', self.quantization_config_list[0]) + elif self.base_config != self.quantization_config_list[0]: + Logger.critical( + "'base_config' should be the same as the sole item in 'quantization_config_list'.") - def clone_and_edit(self, **kwargs): - qc_options = copy.deepcopy(self) - for qc in qc_options.quantization_config_list: - self.__edit_quantization_configuration(qc, kwargs) - return qc_options + elif len(self.quantization_config_list) == 0: + Logger.critical("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.") - def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs): + def clone_and_edit(self, **kwargs) -> 'QuantizationConfigOptions': """ - Clones the quantization configurations and edits some of their attributes' parameters. + Clone the quantization configuration options and edit attributes in each configuration. Args: - attrs: attributes names to clone their configurations. If None is provided, updating the configurations - of all attributes in the operation attributes config mapping. - **kwargs: Keyword arguments to edit in the attributes configuration. + **kwargs: Keyword arguments to edit in each configuration. Returns: - QuantizationConfigOptions with edited attributes configurations. + A new instance of QuantizationConfigOptions with updated configurations. + """ + updated_base_config = replace(self.base_config, **kwargs) + updated_configs_list = [ + replace(cfg, **kwargs) for cfg in self.quantization_config_list + ] + return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs_list) + def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs) -> 'QuantizationConfigOptions': """ + Clones the quantization configurations and edits some of their attributes' parameters. - qc_options = copy.deepcopy(self) + Args: + attrs (List[str]): Attributes names to clone and edit their configurations. If None, updates all attributes. + **kwargs: Keyword arguments to edit in the attributes configuration. - for qc in qc_options.quantization_config_list: + Returns: + QuantizationConfigOptions: A new instance of QuantizationConfigOptions with edited attributes configurations. + """ + updated_base_config = self.base_config + updated_configs = [] + for qc in self.quantization_config_list: if attrs is None: attrs_to_update = list(qc.attr_weights_configs_mapping.keys()) else: - if not isinstance(attrs, List): # pragma: no cover - Logger.critical(f"Expected a list of attributes but received {type(attrs)}.") attrs_to_update = attrs - + # Ensure all attributes exist in the config for attr in attrs_to_update: - if qc.attr_weights_configs_mapping.get(attr) is None: # pragma: no cover - Logger.critical(f'Editing attributes is only possible for existing attributes in the configuration\'s ' - f'weights config mapping; {attr} does not exist in {qc}.') - self.__edit_quantization_configuration(qc.attr_weights_configs_mapping[attr], kwargs) - return qc_options + if attr not in qc.attr_weights_configs_mapping: + Logger.critical(f"{attr} does not exist in {qc}.") + updated_attr_mapping = { + attr: qc.attr_weights_configs_mapping[attr].clone_and_edit(**kwargs) + for attr in attrs_to_update + } + if qc == updated_base_config: + updated_base_config = replace(updated_base_config, attr_weights_configs_mapping=updated_attr_mapping) + updated_configs.append(replace(qc, attr_weights_configs_mapping=updated_attr_mapping)) + return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs) - def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Union[Dict[str, str], None]): + def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Union[Dict[str, str], None]) -> 'QuantizationConfigOptions': """ - Clones the quantization configuration options and edits the keys in each configuration attributes config mapping, - based on the given attributes names mapping. + Clones the quantization configurations and updates keys in attribute config mappings. Args: - layer_attrs_mapping: A mapping between attributes names. + layer_attrs_mapping (Union[Dict[str, str], None]): A mapping between attribute names. Returns: - QuantizationConfigOptions with edited attributes names. - + QuantizationConfigOptions: A new instance of QuantizationConfigOptions with updated attribute keys. """ - qc_options = copy.deepcopy(self) - - # Extract the list of existing quantization configurations from qc_options - - # Check if the base_config is already included in the quantization configuration list - # If not, add base_config to the list of configurations to update - cfgs_to_update = [cfg for cfg in qc_options.quantization_config_list] - if not any(qc_options.base_config is cfg for cfg in cfgs_to_update): - # TODO: add test for this case - cfgs_to_update.append(qc_options.base_config) - - for qc in cfgs_to_update: + updated_configs = [] + new_base_config = self.base_config + for qc in self.quantization_config_list: if layer_attrs_mapping is None: - qc.attr_weights_configs_mapping = {} - else: new_attr_mapping = {} - for attr in list(qc.attr_weights_configs_mapping.keys()): - new_key = layer_attrs_mapping.get(attr) - if new_key is None: # pragma: no cover - Logger.critical(f"Attribute \'{attr}\' does not exist in the provided attribute mapping.") - - new_attr_mapping[new_key] = qc.attr_weights_configs_mapping.pop(attr) - - qc.attr_weights_configs_mapping.update(new_attr_mapping) - - return qc_options + else: + new_attr_mapping = { + layer_attrs_mapping.get(attr, attr): cfg + for attr, cfg in qc.attr_weights_configs_mapping.items() + } + if qc == self.base_config: + new_base_config = replace(qc, attr_weights_configs_mapping=new_attr_mapping) + updated_configs.append(replace(qc, attr_weights_configs_mapping=new_attr_mapping)) + return replace(self, base_config=new_base_config, quantization_config_list=updated_configs) - def __edit_quantization_configuration(self, qc, kwargs): - for k, v in kwargs.items(): - assert hasattr(qc, - k), (f'Editing is only possible for existing attributes in the configuration; ' - f'{k} is not an attribute of {qc}.') - setattr(qc, k, v) + def get_info(self) -> Dict[str, Any]: + """ + Get detailed information about each quantization configuration option. - def get_info(self): + Returns: + dict: Information about the quantization configuration options as a dictionary. + """ return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.quantization_config_list)} +@dataclass(frozen=True) class TargetPlatformModelComponent: """ - Component of TargetPlatformModel (Fusing, OperatorsSet, etc.) + Component of TargetPlatformModel (Fusing, OperatorsSet, etc.). """ - def __init__(self, name: str): - """ - Args: - name: Name of component. + def __post_init__(self) -> None: + """ + Post-initialization to register the component with the current TargetPlatformModel. """ - self.name = name _current_tp_model.get().append_component(self) def get_info(self) -> Dict[str, Any]: """ + Get information about the component to display. - Returns: Get information about the component to display (return an empty dictionary. - the actual component should fill it with info). - + Returns: + Dict[str, Any]: Returns an empty dictionary. The actual component should override + this method to provide relevant information. """ return {} +@dataclass(frozen=True) class OperatorsSetBase(TargetPlatformModelComponent): """ - Base class to represent a set of operators. + Base class to represent a set of a target platform model component of operator set types. + Inherits from TargetPlatformModelComponent. """ - def __init__(self, name: str): + def __post_init__(self) -> None: """ - - Args: - name: Name of OperatorsSet. + Post-initialization to ensure the component is registered with the TargetPlatformModel. + Calls the parent class's __post_init__ method to append this component to the current TargetPlatformModel. """ - super().__init__(name=name) + super().__post_init__() +@dataclass(frozen=True) class OperatorsSet(OperatorsSetBase): - def __init__(self, - name: str, - qc_options: QuantizationConfigOptions = None): - """ - Set of operators that are represented by a unique label. + """ + Set of operators that are represented by a unique label. - Args: - name (str): Set's label (must be unique in a TargetPlatformModel). - qc_options (QuantizationConfigOptions): Configuration options to use for this set of operations. - """ + Attributes: + name (str): The set's label (must be unique within a TargetPlatformModel). + qc_options (QuantizationConfigOptions): Configuration options to use for this set of operations. + If None, it represents a fusing set. + is_default (bool): Indicates whether this set is the default quantization configuration + for the TargetPlatformModel or a fusing set. + """ + name: str + qc_options: QuantizationConfigOptions = None - super().__init__(name) - self.qc_options = qc_options - is_fusing_set = qc_options is None - self.is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set + def __post_init__(self) -> None: + """ + Post-initialization processing to mark the operator set as default if applicable. + Calls the parent class's __post_init__ method and sets `is_default` to True + if this set corresponds to the default quantization configuration for the + TargetPlatformModel or if it is a fusing set. - def get_info(self) -> Dict[str,Any]: """ + super().__post_init__() + is_fusing_set = self.qc_options is None + is_default = _current_tp_model.get().default_qco == self.qc_options or is_fusing_set + object.__setattr__(self, 'is_default', is_default) - Returns: Info about the set as a dictionary. + def get_info(self) -> Dict[str, Any]: + """ + Get information about the set as a dictionary. + Returns: + Dict[str, Any]: A dictionary containing the set name and + whether it is the default quantization configuration. """ return {"name": self.name, "is_default_qc": self.is_default} +@dataclass(frozen=True) class OperatorSetConcat(OperatorsSetBase): """ Concatenate a list of operator sets to treat them similarly in different places (like fusing). + + Attributes: + op_set_list (List[OperatorsSet]): List of operator sets to group. + qc_options (None): Configuration options for the set, always None for concatenated sets. + name (str): Concatenated name generated from the names of the operator sets in the list. """ - def __init__(self, *opsets: OperatorsSet): - """ - Group a list of operation sets. + op_set_list: List[OperatorsSet] = field(default_factory=list) + qc_options: None = field(default=None, init=False) + name: str = None - Args: - *opsets (OperatorsSet): List of operator sets to group. + def __post_init__(self) -> None: """ - name = "_".join([a.name for a in opsets]) - super().__init__(name=name) - self.op_set_list = opsets - self.qc_options = None # Concat have no qc options + Post-initialization processing to generate the concatenated name and set it as the `name` attribute. - def get_info(self) -> Dict[str,Any]: + Calls the parent class's __post_init__ method and creates a concatenated name + by joining the names of all operator sets in `op_set_list`. """ + super().__post_init__() + # Generate the concatenated name from the operator sets + concatenated_name = "_".join([op.name for op in self.op_set_list]) + # Set the inherited name attribute using `object.__setattr__` since the dataclass is frozen + object.__setattr__(self, "name", concatenated_name) - Returns: Info about the sets group as a dictionary. + def get_info(self) -> Dict[str, Any]: + """ + Get information about the concatenated set as a dictionary. + Returns: + Dict[str, Any]: A dictionary containing the concatenated name and + the list of names of the operator sets in `op_set_list`. """ return {"name": self.name, OPS_SET_LIST: [s.name for s in self.op_set_list]} +@dataclass(frozen=True) class Fusing(TargetPlatformModelComponent): """ - Fusing defines a list of operators that should be combined and treated as a single operator, - hence no quantization is applied between them. + Fusing defines a list of operators that should be combined and treated as a single operator, + hence no quantization is applied between them. + + Attributes: + operator_groups_list (Tuple[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups, + each being either an OperatorSetConcat or an OperatorsSet. + name (str): The name for the Fusing instance. If not provided, it is generated from the operator groups' names. """ + operator_groups_list: Tuple[Union[OperatorsSet, OperatorSetConcat]] + name: str = None - def __init__(self, - operator_groups_list: List[Union[OperatorsSet, OperatorSetConcat]], - name: str = None): - """ - Args: - operator_groups_list (List[Union[OperatorsSet, OperatorSetConcat]]): A list of operator groups, each being either an OperatorSetConcat or an OperatorsSet. - name (str): The name for the Fusing instance. If not provided, it's generated from the operator groups' names. + def __post_init__(self) -> None: """ - assert isinstance(operator_groups_list, - list), f'List of operator groups should be of type list but is {type(operator_groups_list)}' - assert len(operator_groups_list) >= 2, f'Fusing can not be created for a single operators group' + Post-initialization processing for input validation and name generation. - # Generate a name from the operator groups if no name is provided - if name is None: - name = '_'.join([x.name for x in operator_groups_list]) + Calls the parent class's __post_init__ method, validates the operator_groups_list, + and generates the name if not explicitly provided. - super().__init__(name) - self.operator_groups_list = operator_groups_list + Raises: + Logger critical if operator_groups_list is not a list or if it contains fewer than two operators. + """ + super().__post_init__() + # Validate the operator_groups_list + if not isinstance(self.operator_groups_list, list): + Logger.critical( + f"List of operator groups should be of type list but is {type(self.operator_groups_list)}.") + if len(self.operator_groups_list) < 2: + Logger.critical("Fusing cannot be created for a single operator.") + + # if self.name is None: + # Generate the name from the operator groups if not provided + generated_name = '_'.join([x.name for x in self.operator_groups_list]) + object.__setattr__(self, 'name', generated_name) def contains(self, other: Any) -> bool: """ Determines if the current Fusing instance contains another Fusing instance. Args: - other: The other Fusing instance to check against. + other (Any): The other Fusing instance to check against. Returns: - A boolean indicating whether the other instance is contained within this one. + bool: True if the other Fusing instance is contained within this one, False otherwise. """ if not isinstance(other, Fusing): return False @@ -506,81 +535,72 @@ def contains(self, other: Any) -> bool: # Other Fusing instance is not contained return False - def get_info(self): + def get_info(self) -> Union[Dict[str, str], str]: """ Retrieves information about the Fusing instance, including its name and the sequence of operator groups. Returns: - A dictionary with the Fusing instance's name as the key and the sequence of operator groups as the value, - or just the sequence of operator groups if no name is set. + Union[Dict[str, str], str]: A dictionary with the Fusing instance's name as the key + and the sequence of operator groups as the value, + or just the sequence of operator groups if no name is set. """ if self.name is not None: return {self.name: ' -> '.join([x.name for x in self.operator_groups_list])} return ' -> '.join([x.name for x in self.operator_groups_list]) -class TargetPlatformModel(ImmutableClass): +@dataclass(frozen=True) +class TargetPlatformModel: """ Represents the hardware configuration used for quantized model inference. - This model defines: - - The operators and their associated quantization configurations. - - Fusing patterns, enabling multiple operators to be combined into a single operator - for optimization during inference. - - Versioning support through minor and patch versions for backward compatibility. - Attributes: - SCHEMA_VERSION (int): The schema version of the target platform model. + default_qco (QuantizationConfigOptions): Default quantization configuration options for the model. + tpc_minor_version (Optional[int]): Minor version of the Target Platform Configuration. + tpc_patch_version (Optional[int]): Patch version of the Target Platform Configuration. + tpc_platform_type (Optional[str]): Type of the platform for the Target Platform Configuration. + add_metadata (bool): Flag to determine if metadata should be added. + name (str): Name of the Target Platform Model. + operator_set (List[OperatorsSetBase]): List of operator sets within the model. + fusing_patterns (List[Fusing]): List of fusing patterns for the model. + is_simd_padding (bool): Indicates if SIMD padding is applied. + SCHEMA_VERSION (int): Version of the schema for the Target Platform Model. """ - SCHEMA_VERSION = 1 - def __init__(self, - default_qco: QuantizationConfigOptions, - tpc_minor_version: Optional[int], - tpc_patch_version: Optional[int], - tpc_platform_type: Optional[str], - add_metadata: bool = True, - name="default_tp_model"): + default_qco: QuantizationConfigOptions + tpc_minor_version: Optional[int] + tpc_patch_version: Optional[int] + tpc_platform_type: Optional[str] + add_metadata: bool = True + name: str = "default_tp_model" + operator_set: List[OperatorsSetBase] = field(default_factory=list) + fusing_patterns: List[Fusing] = field(default_factory=list) + is_simd_padding: bool = False + + SCHEMA_VERSION: int = 1 + + def __post_init__(self) -> None: """ + Post-initialization processing for input validation. - Args: - default_qco (QuantizationConfigOptions): Default QuantizationConfigOptions to use for operators that their QuantizationConfigOptions are not defined in the model. - tpc_minor_version (Optional[int]): The minor version of the target platform capabilities. - tpc_patch_version (Optional[int]): The patch version of the target platform capabilities. - tpc_platform_type (Optional[str]): The platform type of the target platform capabilities. - add_metadata (bool): Whether to add metadata to the model or not. - name (str): Name of the model. - - Raises: - AssertionError: If the provided `default_qco` does not contain exactly one quantization configuration. - """ - - super().__init__() - self.tpc_minor_version = tpc_minor_version - self.tpc_patch_version = tpc_patch_version - self.tpc_platform_type = tpc_platform_type - self.add_metadata = add_metadata - self.name = name - self.operator_set = [] - assert isinstance(default_qco, QuantizationConfigOptions), \ - "default_qco must be an instance of QuantizationConfigOptions" - assert len(default_qco.quantization_config_list) == 1, \ - "Default QuantizationConfigOptions must contain exactly one option." - - self.default_qco = default_qco - self.fusing_patterns = [] - self.is_simd_padding = False - - def get_config_options_by_operators_set(self, - operators_set_name: str) -> QuantizationConfigOptions: - """ - Get the QuantizationConfigOptions of a OperatorsSet by the OperatorsSet name. - If the name is not in the model, the default QuantizationConfigOptions is returned. + Raises: + Logger critical if the default_qco is not an instance of QuantizationConfigOptions + or if it contains more than one quantization configuration. + """ + # Validate `default_qco` + if not isinstance(self.default_qco, QuantizationConfigOptions): + Logger.critical("'default_qco' must be an instance of QuantizationConfigOptions.") + if len(self.default_qco.quantization_config_list) != 1: + Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") + + def get_config_options_by_operators_set(self, operators_set_name: str) -> QuantizationConfigOptions: + """ + Get the QuantizationConfigOptions of an OperatorsSet by its name. Args: - operators_set_name: Name of OperatorsSet to get. + operators_set_name (str): Name of the OperatorsSet to get. Returns: - QuantizationConfigOptions to use for ops in OperatorsSet named operators_set_name. + QuantizationConfigOptions: Quantization configuration options for the given OperatorsSet name. """ for op_set in self.operator_set: if operators_set_name == op_set.name: @@ -589,143 +609,114 @@ def get_config_options_by_operators_set(self, def get_default_op_quantization_config(self) -> OpQuantizationConfig: """ + Get the default OpQuantizationConfig of the TargetPlatformModel. - Returns: The default OpQuantizationConfig of the TargetPlatformModel. - + Returns: + OpQuantizationConfig: The default quantization configuration. """ assert len(self.default_qco.quantization_config_list) == 1, \ - f'Default quantization configuration options must contain only one option,' \ - f' but found {len(get_current_tp_model().default_qco.quantization_config_list)} configurations.' + f"Default quantization configuration options must contain only one option, " \ + f"but found {len(self.default_qco.quantization_config_list)} configurations." return self.default_qco.quantization_config_list[0] - def is_opset_in_model(self, - opset_name: str) -> bool: + def is_opset_in_model(self, opset_name: str) -> bool: """ - Check whether an operators set is defined in the model or not. + Check whether an OperatorsSet is defined in the model. Args: - opset_name: Operators set name to check. + opset_name (str): Name of the OperatorsSet to check. Returns: - Whether an operators set is defined in the model or not. + bool: True if the OperatorsSet exists, False otherwise. """ return opset_name in [x.name for x in self.operator_set] - def get_opset_by_name(self, - opset_name: str) -> OperatorsSetBase: + def get_opset_by_name(self, opset_name: str) -> Optional[OperatorsSetBase]: """ Get an OperatorsSet object from the model by its name. - If name is not in the model - None is returned. Args: - opset_name: OperatorsSet name to retrieve. + opset_name (str): Name of the OperatorsSet to retrieve. Returns: - OperatorsSet object with the name opset_name, or None if opset_name is not in the model. + Optional[OperatorsSetBase]: The OperatorsSet object with the given name, + or None if not found in the model. """ - opset_list = [x for x in self.operator_set if x.name == opset_name] - assert len(opset_list) <= 1, f'Found more than one OperatorsSet in' \ - f' TargetPlatformModel with the name {opset_name}. ' \ - f'OperatorsSet name must be unique.' - if len(opset_list) == 0: # opset_name is not in the model. - return None + if len(opset_list) > 1: + Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.") + return opset_list[0] if opset_list else None - return opset_list[0] # There's one opset with that name - - def append_component(self, - tp_model_component: TargetPlatformModelComponent): + def append_component(self, tp_model_component: TargetPlatformModelComponent) -> None: """ - Attach a TargetPlatformModel component to the model. Components can be for example: - Fusing, OperatorsSet, etc. + Attach a TargetPlatformModel component to the model (like Fusing or OperatorsSet). Args: - tp_model_component: Component to attach to the model. + tp_model_component (TargetPlatformModelComponent): Component to attach to the model. + Raises: + Logger critical if the component is not an instance of Fusing or OperatorsSetBase. """ if isinstance(tp_model_component, Fusing): self.fusing_patterns.append(tp_model_component) elif isinstance(tp_model_component, OperatorsSetBase): self.operator_set.append(tp_model_component) else: # pragma: no cover - Logger.critical(f'Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.') - - def __enter__(self): - """ - Start defining the TargetPlatformModel using 'with'. - - Returns: Initialized TargetPlatformModel object. + Logger.critical( + f"Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.") + def get_info(self) -> Dict[str, Any]: """ - _current_tp_model.set(self) - return self + Get a dictionary summarizing the TargetPlatformModel properties. - def __exit__(self, exc_type, exc_value, tb): - """ - Finish defining the TargetPlatformModel at the end of the 'with' clause. - Returns the final and immutable TargetPlatformModel instance. + Returns: + Dict[str, Any]: Summary of the TargetPlatformModel properties. """ + return { + "Model name": self.name, + "Default quantization config": self.get_default_op_quantization_config().get_info(), + "Operators sets": [o.get_info() for o in self.operator_set], + "Fusing patterns": [f.get_info() for f in self.fusing_patterns], + } - if exc_value is not None: - print(exc_value, exc_value.args) - raise exc_value - self.__validate_model() # Assert that model is valid. - _current_tp_model.reset() - self.initialized_done() # Make model immutable. - return self - - def __validate_model(self): + def __validate_model(self) -> None: """ + Validate the model's configuration to ensure its integrity. - Assert model is valid. - Model is invalid if, for example, it contains multiple operator sets with the same name, - as their names should be unique. - + Raises: + Logger critical if the model contains multiple operator sets with the same name. """ opsets_names = [op.name for op in self.operator_set] if len(set(opsets_names)) != len(opsets_names): - Logger.critical(f'Operator Sets must have unique names.') + Logger.critical("Operator Sets must have unique names.") - def get_default_config(self) -> OpQuantizationConfig: + def __enter__(self) -> 'TargetPlatformModel': """ + Start defining the TargetPlatformModel using a 'with' statement. Returns: - - """ - assert len(self.default_qco.quantization_config_list) == 1, \ - f'Default quantization configuration options must contain only one option,' \ - f' but found {len(self.default_qco.quantization_config_list)} configurations.' - return self.default_qco.quantization_config_list[0] - - def get_info(self) -> Dict[str, Any]: + TargetPlatformModel: The initialized TargetPlatformModel object. """ + _current_tp_model.set(self) + return self - Returns: Dictionary that summarizes the TargetPlatformModel properties (for display purposes). - - """ - return {"Model name": self.name, - "Default quantization config": self.get_default_config().get_info(), - "Operators sets": [o.get_info() for o in self.operator_set], - "Fusing patterns": [f.get_info() for f in self.fusing_patterns] - } - - def show(self): - """ - - Display the TargetPlatformModel. - - """ - pprint.pprint(self.get_info(), sort_dicts=False) - - def set_simd_padding(self, - is_simd_padding: bool): + def __exit__(self, exc_type, exc_value, tb) -> 'TargetPlatformModel': """ - Set flag is_simd_padding to indicate whether this TP model defines - that padding due to SIMD constrains occurs. + Finalize and validate the TargetPlatformModel at the end of the 'with' clause. Args: - is_simd_padding: Whether this TP model defines that padding due to SIMD constrains occurs. + exc_type: Exception type, if any occurred. + exc_value: Exception value, if any occurred. + tb: Traceback object, if an exception occurred. - """ - self.is_simd_padding = is_simd_padding + Raises: + The exception raised in the 'with' block, if any. + Returns: + TargetPlatformModel: The validated TargetPlatformModel object. + """ + if exc_value is not None: + raise exc_value + self.__validate_model() + _current_tp_model.reset() + return self diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py index 27d032c29..f9e94f81d 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1/tp_model.py @@ -164,7 +164,8 @@ def generate_tp_model(default_config: OpQuantizationConfig, tpc_patch_version=0, tpc_platform_type=IMX500_TP_MODEL, name=name, - add_metadata=False) + add_metadata=False, + is_simd_padding=True) # To start defining the model's components (such as operator sets, and fusing patterns), # use 'with' the TargetPlatformModel instance, and create them as below: @@ -175,8 +176,6 @@ def generate_tp_model(default_config: OpQuantizationConfig, # be used for operations that will be attached to this set's label. # Otherwise, it will be a configure-less set (used in fusing): - generated_tpc.set_simd_padding(is_simd_padding=True) - # May suit for operations like: Dropout, Reshape, etc. default_qco = tp.get_default_quantization_config_options() schema.OperatorsSet("NoQuantization", @@ -206,9 +205,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py index 9da497022..707fa76e1 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_lut/tp_model.py @@ -201,9 +201,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py index 24f3e6eae..032a42c6a 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v1_pot/tp_model.py @@ -197,9 +197,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py index 947c1608f..ae7056b99 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2/tp_model.py @@ -166,7 +166,8 @@ def generate_tp_model(default_config: OpQuantizationConfig, tpc_patch_version=0, tpc_platform_type=IMX500_TP_MODEL, add_metadata=True, - name=name) + name=name, + is_simd_padding=True) # To start defining the model's components (such as operator sets, and fusing patterns), # use 'with' the TargetPlatformModel instance, and create them as below: @@ -177,8 +178,6 @@ def generate_tp_model(default_config: OpQuantizationConfig, # be used for operations that will be attached to this set's label. # Otherwise, it will be a configure-less set (used in fusing): - generated_tpm.set_simd_padding(is_simd_padding=True) - # May suit for operations like: Dropout, Reshape, etc. default_qco = tp.get_default_quantization_config_options() schema.OperatorsSet("NoQuantization", @@ -208,9 +207,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py index 31ba2d9ab..187ef1100 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v2_lut/tp_model.py @@ -203,9 +203,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py index b053ea9eb..5e07cb7d9 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3/tp_model.py @@ -187,7 +187,8 @@ def generate_tp_model(default_config: OpQuantizationConfig, tpc_patch_version=0, tpc_platform_type=IMX500_TP_MODEL, add_metadata=True, - name=name) + name=name, + is_simd_padding=True) # To start defining the model's components (such as operator sets, and fusing patterns), # use 'with' the TargetPlatformModel instance, and create them as below: @@ -198,8 +199,6 @@ def generate_tp_model(default_config: OpQuantizationConfig, # be used for operations that will be attached to this set's label. # Otherwise, it will be a configure-less set (used in fusing): - generated_tpm.set_simd_padding(is_simd_padding=True) - # May suit for operations like: Dropout, Reshape, etc. default_qco = tp.get_default_quantization_config_options() schema.OperatorsSet("NoQuantization", @@ -231,9 +230,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py index 9102fcc02..8b25c33c2 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v3_lut/tp_model.py @@ -214,9 +214,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py index 7c056778e..2f658d2f8 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py @@ -15,7 +15,7 @@ from typing import List, Tuple import model_compression_toolkit as mct -import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema +import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema from model_compression_toolkit.constants import FLOAT_BITWIDTH from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS, \ IMX500_TP_MODEL @@ -235,7 +235,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, tpc_minor_version=4, tpc_patch_version=0, tpc_platform_type=IMX500_TP_MODEL, - add_metadata=True, name=name) + add_metadata=True, + name=name, + is_simd_padding=True) # To start defining the model's components (such as operator sets, and fusing patterns), # use 'with' the TargetPlatformModel instance, and create them as below: @@ -246,8 +248,6 @@ def generate_tp_model(default_config: OpQuantizationConfig, # be used for operations that will be attached to this set's label. # Otherwise, it will be a configure-less set (used in fusing): - generated_tpm.set_simd_padding(is_simd_padding=True) - # May suit for operations like: Dropout, Reshape, etc. default_qco = tp.get_default_quantization_config_options() schema.OperatorsSet(OPSET_NO_QUANTIZATION, @@ -294,11 +294,11 @@ def generate_tp_model(default_config: OpQuantizationConfig, # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. # To group multiple sets with regard to fusing, an OperatorSetConcat can be created - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, - tanh, gelu, hardswish, hardsigmoid) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid, tanh, gelu, - hardswish, hardsigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, + tanh, gelu, hardswish, hardsigmoid]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid, tanh, gelu, + hardswish, hardsigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) # ------------------- # # Fusions diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py index b0a69c6e7..d269d7f4e 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/tflite_tpc/v1/tp_model.py @@ -180,11 +180,11 @@ def generate_tp_model(default_config: OpQuantizationConfig, fixed_zero_point=-128, fixed_scale=1 / 256)) conv2d = schema.OperatorsSet("Conv2d") - kernel = schema.OperatorSetConcat(conv2d, fc) + kernel = schema.OperatorSetConcat([conv2d, fc]) relu = schema.OperatorsSet("Relu") elu = schema.OperatorsSet("Elu") - activations_to_fuse = schema.OperatorSetConcat(relu, elu) + activations_to_fuse = schema.OperatorSetConcat([relu, elu]) batch_norm = schema.OperatorsSet("BatchNorm") bias_add = schema.OperatorsSet("BiasAdd") diff --git a/tests/common_tests/test_tp_model.py b/tests/common_tests/test_tp_model.py index ed8a52b59..5fea3155b 100644 --- a/tests/common_tests/test_tp_model.py +++ b/tests/common_tests/test_tp_model.py @@ -55,7 +55,7 @@ def test_immutable_tp(self): with model: schema.OperatorsSet("opset") model.operator_set = [] - self.assertEqual('Immutable class. Can\'t edit attributes.', str(e.exception)) + self.assertEqual("cannot assign to field 'operator_set'", str(e.exception)) def test_default_options_more_than_single_qc(self): test_qco = schema.QuantizationConfigOptions([TEST_QC, TEST_QC], base_config=TEST_QC) @@ -76,8 +76,6 @@ def test_tp_model_show(self): with tpm: a = schema.OperatorsSet("opA") - tpm.show() - class OpsetTest(unittest.TestCase): @@ -114,7 +112,7 @@ def test_opset_concat(self): b = schema.OperatorsSet('opset_B', get_default_quantization_config_options().clone_and_edit(activation_n_bits=2)) schema.OperatorsSet('opset_C') # Just add it without using it in concat - schema.OperatorSetConcat(a, b) + schema.OperatorSetConcat([a, b]) self.assertEqual(len(hm.operator_set), 4) self.assertTrue(hm.is_opset_in_model("opset_A_opset_B")) self.assertTrue(hm.get_config_options_by_operators_set('opset_A_opset_B') is None) @@ -136,14 +134,14 @@ def test_non_unique_opset(self): class QCOptionsTest(unittest.TestCase): def test_empty_qc_options(self): - with self.assertRaises(AssertionError) as e: + with self.assertRaises(Exception) as e: schema.QuantizationConfigOptions([]) self.assertEqual( "'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.", str(e.exception)) def test_list_of_no_qc(self): - with self.assertRaises(AssertionError) as e: + with self.assertRaises(Exception) as e: schema.QuantizationConfigOptions([TEST_QC, 3]) self.assertEqual( 'Each option must be an instance of \'OpQuantizationConfig\', but found an object of type: .', @@ -186,7 +184,7 @@ def test_fusing_single_opset(self): add = schema.OperatorsSet("add") with self.assertRaises(Exception) as e: schema.Fusing([add]) - self.assertEqual('Fusing can not be created for a single operators group', str(e.exception)) + self.assertEqual('Fusing cannot be created for a single operator.', str(e.exception)) def test_fusing_contains(self): hm = schema.TargetPlatformModel( @@ -220,7 +218,7 @@ def test_fusing_contains_with_opset_concat(self): conv = schema.OperatorsSet("conv") add = schema.OperatorsSet("add") tanh = schema.OperatorsSet("tanh") - add_tanh = schema.OperatorSetConcat(add, tanh) + add_tanh = schema.OperatorSetConcat([add, tanh]) schema.Fusing([conv, add]) schema.Fusing([conv, add_tanh]) schema.Fusing([conv, add, tanh]) diff --git a/tests/keras_tests/exporter_tests/tflite_int8/imx500_int8_tp_model.py b/tests/keras_tests/exporter_tests/tflite_int8/imx500_int8_tp_model.py index 8e8f2eac4..209287fbf 100644 --- a/tests/keras_tests/exporter_tests/tflite_int8/imx500_int8_tp_model.py +++ b/tests/keras_tests/exporter_tests/tflite_int8/imx500_int8_tp_model.py @@ -95,9 +95,9 @@ def generate_tp_model(default_config: OpQuantizationConfig, swish = schema.OperatorsSet("Swish") sigmoid = schema.OperatorsSet("Sigmoid") tanh = schema.OperatorsSet("Tanh") - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, prelu, sigmoid, tanh) - activations_after_fc_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid) - any_binary = schema.OperatorSetConcat(add, sub, mul, div) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, prelu, sigmoid, tanh]) + activations_after_fc_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid]) + any_binary = schema.OperatorSetConcat([add, sub, mul, div]) schema.Fusing([conv, activations_after_conv_to_fuse]) schema.Fusing([fc, activations_after_fc_to_fuse]) schema.Fusing([any_binary, any_relu]) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py index 7b4e86d05..2218a8d16 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/activation_16bit_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import replace + import numpy as np import tensorflow as tf @@ -34,8 +36,8 @@ def get_tpc(self): tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v4') # Force Mul base_config to 16bit only mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config) return tpc def create_networks(self): @@ -67,8 +69,8 @@ class Activation16BitMixedPrecisionTest(Activation16BitTest): def get_tpc(self): tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v3') mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config) mul_op_set.qc_options.quantization_config_list.extend( [mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=4), mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=2)]) diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/manual_bit_selection.py b/tests/keras_tests/feature_networks_tests/feature_networks/manual_bit_selection.py index 9c35e1582..243316a21 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/manual_bit_selection.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/manual_bit_selection.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import replace + import numpy as np import tensorflow as tf @@ -133,9 +135,8 @@ def get_tpc(self): tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v3') # Force Mul base_config to 16bit only mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = \ - [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config) return tpc def create_networks(self): @@ -159,9 +160,8 @@ class Manual16BitWidthSelectionMixedPrecisionTest(Manual16BitWidthSelectionTest) def get_tpc(self): tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v3') mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = \ - [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config) mul_op_set.qc_options.quantization_config_list.extend( [mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=4), mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=2)]) diff --git a/tests/keras_tests/function_tests/test_layer_fusing.py b/tests/keras_tests/function_tests/test_layer_fusing.py index 0c8a5b2e6..f55c31d4f 100644 --- a/tests/keras_tests/function_tests/test_layer_fusing.py +++ b/tests/keras_tests/function_tests/test_layer_fusing.py @@ -120,7 +120,7 @@ def get_tpc_2(): swish = schema.OperatorsSet("Swish") sigmoid = schema.OperatorsSet("Sigmoid") tanh = schema.OperatorsSet("Tanh") - activations_after_conv_to_fuse = schema.OperatorSetConcat(any_relu, swish, sigmoid, tanh) + activations_after_conv_to_fuse = schema.OperatorSetConcat([any_relu, swish, sigmoid, tanh]) # Define fusions schema.Fusing([conv, activations_after_conv_to_fuse]) @@ -161,7 +161,7 @@ def get_tpc_4(): any_relu = schema.OperatorsSet("AnyReLU") add = schema.OperatorsSet("Add") swish = schema.OperatorsSet("Swish") - activations_to_fuse = schema.OperatorSetConcat(any_relu, swish) + activations_to_fuse = schema.OperatorSetConcat([any_relu, swish]) # Define fusions schema.Fusing([conv, activations_to_fuse]) schema.Fusing([conv, add, activations_to_fuse]) diff --git a/tests/keras_tests/function_tests/test_quant_config_filtering.py b/tests/keras_tests/function_tests/test_quant_config_filtering.py index c9365c103..6e5c3c871 100644 --- a/tests/keras_tests/function_tests/test_quant_config_filtering.py +++ b/tests/keras_tests/function_tests/test_quant_config_filtering.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +from dataclasses import replace import unittest import numpy as np @@ -44,8 +44,8 @@ def get_tpc_default_16bit(): tpc = mct.get_target_platform_capabilities(TENSORFLOW, IMX500_TP_MODEL, 'v3') # Force Mul base_config to 16bit only mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[tf.multiply].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[tf.multiply] = replace(tpc.layer2qco[tf.multiply], base_config=base_config) return tpc def test_config_filtering(self): diff --git a/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py b/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py index 6f4478aff..add49fd26 100644 --- a/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py +++ b/tests/keras_tests/non_parallel_tests/test_keras_tp_model.py @@ -130,7 +130,7 @@ def test_get_layers_by_opconcat(self): with hm: op_obj_a = schema.OperatorsSet('opsetA') op_obj_b = schema.OperatorsSet('opsetB') - op_concat = schema.OperatorSetConcat(op_obj_a, op_obj_b) + op_concat = schema.OperatorSetConcat([op_obj_a, op_obj_b]) fw_tp = TargetPlatformCapabilities(hm) with fw_tp: diff --git a/tests/pytorch_tests/function_tests/layer_fusing_test.py b/tests/pytorch_tests/function_tests/layer_fusing_test.py index ccf131ddd..6ecdca713 100644 --- a/tests/pytorch_tests/function_tests/layer_fusing_test.py +++ b/tests/pytorch_tests/function_tests/layer_fusing_test.py @@ -229,7 +229,7 @@ def get_tpc(self): any_relu = schema.OperatorsSet("AnyReLU") add = schema.OperatorsSet("Add") swish = schema.OperatorsSet("Swish") - activations_to_fuse = schema.OperatorSetConcat(any_relu, swish) + activations_to_fuse = schema.OperatorSetConcat([any_relu, swish]) # Define fusions schema.Fusing([conv, activations_to_fuse]) schema.Fusing([conv, add, activations_to_fuse]) diff --git a/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py b/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py index cb7c7647d..68c597f13 100644 --- a/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py +++ b/tests/pytorch_tests/function_tests/test_pytorch_tp_model.py @@ -168,7 +168,7 @@ def test_get_layers_by_opconcat(self): with hm: op_obj_a = schema.OperatorsSet('opsetA') op_obj_b = schema.OperatorsSet('opsetB') - op_concat = schema.OperatorSetConcat(op_obj_a, op_obj_b) + op_concat = schema.OperatorSetConcat([op_obj_a, op_obj_b]) fw_tp = TargetPlatformCapabilities(hm) with fw_tp: diff --git a/tests/pytorch_tests/function_tests/test_quant_config_filtering.py b/tests/pytorch_tests/function_tests/test_quant_config_filtering.py index e2754302e..d26bfe3f9 100644 --- a/tests/pytorch_tests/function_tests/test_quant_config_filtering.py +++ b/tests/pytorch_tests/function_tests/test_quant_config_filtering.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +from dataclasses import replace import unittest import model_compression_toolkit as mct @@ -34,8 +34,8 @@ def get_tpc_default_16bit(): tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v3') # Force Mul base_config to 16bit only mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[torch.multiply].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[torch.multiply] = replace(tpc.layer2qco[torch.multiply], base_config=base_config) return tpc def test_config_filtering(self): diff --git a/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py b/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py index cfc6fa2e8..6d2196053 100644 --- a/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py +++ b/tests/pytorch_tests/model_tests/feature_models/activation_16bit_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import replace + from operator import mul import torch @@ -62,10 +64,9 @@ def forward(self, x): def set_16bit_as_default(tpc, required_op_set, required_ops_list): - op_set = get_op_set(required_op_set, tpc.tp_model.operator_set) - op_set.qc_options.base_config = [l for l in op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] for op in required_ops_list: - tpc.layer2qco[op].base_config = [l for l in tpc.layer2qco[op].quantization_config_list if l.activation_n_bits == 16][0] + base_config = [l for l in tpc.layer2qco[op].quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[op] = replace(tpc.layer2qco[op], base_config=base_config) class Activation16BitTest(BasePytorchFeatureNetworkTest): @@ -106,9 +107,9 @@ class Activation16BitMixedPrecisionTest(Activation16BitTest): def get_tpc(self): tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v3') mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[torch.mul].base_config = mul_op_set.qc_options.base_config - tpc.layer2qco[mul].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[torch.mul] = replace(tpc.layer2qco[torch.mul], base_config=base_config) + tpc.layer2qco[mul] = replace(tpc.layer2qco[mul], base_config=base_config) mul_op_set.qc_options.quantization_config_list.extend( [mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=4), mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=2)]) diff --git a/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py b/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py index 8d2207974..3178785f2 100644 --- a/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py +++ b/tests/pytorch_tests/model_tests/feature_models/manual_bit_selection.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from dataclasses import replace + from operator import mul import inspect @@ -186,9 +188,9 @@ class Manual16BitTest(ManualBitWidthByLayerNameTest): def get_tpc(self): tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v3') mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[torch.mul].base_config = mul_op_set.qc_options.base_config - tpc.layer2qco[mul].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[torch.mul] = replace(tpc.layer2qco[torch.mul], base_config=base_config) + tpc.layer2qco[mul] = replace(tpc.layer2qco[mul] , base_config=base_config) return {'mixed_precision_activation_model': tpc} def create_feature_network(self, input_shape): @@ -200,9 +202,9 @@ class Manual16BitTestMixedPrecisionTest(ManualBitWidthByLayerNameTest): def get_tpc(self): tpc = mct.get_target_platform_capabilities(PYTORCH, IMX500_TP_MODEL, 'v3') mul_op_set = get_op_set('Mul', tpc.tp_model.operator_set) - mul_op_set.qc_options.base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] - tpc.layer2qco[torch.mul].base_config = mul_op_set.qc_options.base_config - tpc.layer2qco[mul].base_config = mul_op_set.qc_options.base_config + base_config = [l for l in mul_op_set.qc_options.quantization_config_list if l.activation_n_bits == 16][0] + tpc.layer2qco[torch.mul] = replace(tpc.layer2qco[torch.mul], base_config=base_config) + tpc.layer2qco[mul] = replace(tpc.layer2qco[mul], base_config=base_config) mul_op_set.qc_options.quantization_config_list.extend( [mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=4), mul_op_set.qc_options.base_config.clone_and_edit(activation_n_bits=2)]) From f183073c2ae7271bf283085f02a2bb5b201e22b7 Mon Sep 17 00:00:00 2001 From: liord Date: Tue, 10 Dec 2024 12:32:07 +0200 Subject: [PATCH 2/8] Add "# pragma: no cover" to Logger critical errors --- .../target_platform_capabilities/schema/v1.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v1.py b/model_compression_toolkit/target_platform_capabilities/schema/v1.py index 16934a3ad..a2be27f5a 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/v1.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/v1.py @@ -120,11 +120,11 @@ def __post_init__(self): Logger critical if attributes are of incorrect type or have invalid values. """ if not isinstance(self.weights_n_bits, int) or self.weights_n_bits < 1: - Logger.critical("weights_n_bits must be a positive integer.") + Logger.critical("weights_n_bits must be a positive integer.") # pragma: no cover if not isinstance(self.enable_weights_quantization, bool): - Logger.critical("enable_weights_quantization must be a boolean.") + Logger.critical("enable_weights_quantization must be a boolean.") # pragma: no cover if self.lut_values_bitwidth is not None and not isinstance(self.lut_values_bitwidth, int): - Logger.critical("lut_values_bitwidth must be an integer or None.") + Logger.critical("lut_values_bitwidth must be an integer or None.") # pragma: no cover def clone_and_edit(self, **kwargs) -> 'AttributeQuantizationConfig': """ @@ -251,27 +251,27 @@ def __post_init__(self) -> None: # Validate `quantization_config_list` if not isinstance(self.quantization_config_list, list): Logger.critical( - f"'quantization_config_list' must be a list, but received: {type(self.quantization_config_list)}.") + f"'quantization_config_list' must be a list, but received: {type(self.quantization_config_list)}.") # pragma: no cover for cfg in self.quantization_config_list: if not isinstance(cfg, OpQuantizationConfig): Logger.critical( - f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}.") + f"Each option must be an instance of 'OpQuantizationConfig', but found an object of type: {type(cfg)}.") # pragma: no cover # Handle base_config if len(self.quantization_config_list) > 1: if self.base_config is None: - Logger.critical(f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization.") + Logger.critical(f"For multiple configurations, a 'base_config' is required for non-mixed-precision optimization.") # pragma: no cover if not any(self.base_config == cfg for cfg in self.quantization_config_list): - Logger.critical(f"'base_config' must be included in the quantization config options list.") + Logger.critical(f"'base_config' must be included in the quantization config options list.") # pragma: no cover elif len(self.quantization_config_list) == 1: if self.base_config is None: object.__setattr__(self, 'base_config', self.quantization_config_list[0]) elif self.base_config != self.quantization_config_list[0]: Logger.critical( - "'base_config' should be the same as the sole item in 'quantization_config_list'.") + "'base_config' should be the same as the sole item in 'quantization_config_list'.") # pragma: no cover elif len(self.quantization_config_list) == 0: - Logger.critical("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.") + Logger.critical("'QuantizationConfigOptions' requires at least one 'OpQuantizationConfig'. The provided list is empty.") # pragma: no cover def clone_and_edit(self, **kwargs) -> 'QuantizationConfigOptions': """ @@ -500,9 +500,9 @@ def __post_init__(self) -> None: # Validate the operator_groups_list if not isinstance(self.operator_groups_list, list): Logger.critical( - f"List of operator groups should be of type list but is {type(self.operator_groups_list)}.") + f"List of operator groups should be of type list but is {type(self.operator_groups_list)}.") # pragma: no cover if len(self.operator_groups_list) < 2: - Logger.critical("Fusing cannot be created for a single operator.") + Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover # if self.name is None: # Generate the name from the operator groups if not provided @@ -588,9 +588,9 @@ def __post_init__(self) -> None: """ # Validate `default_qco` if not isinstance(self.default_qco, QuantizationConfigOptions): - Logger.critical("'default_qco' must be an instance of QuantizationConfigOptions.") + Logger.critical("'default_qco' must be an instance of QuantizationConfigOptions.") # pragma: no cover if len(self.default_qco.quantization_config_list) != 1: - Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") + Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover def get_config_options_by_operators_set(self, operators_set_name: str) -> QuantizationConfigOptions: """ @@ -661,9 +661,9 @@ def append_component(self, tp_model_component: TargetPlatformModelComponent) -> self.fusing_patterns.append(tp_model_component) elif isinstance(tp_model_component, OperatorsSetBase): self.operator_set.append(tp_model_component) - else: # pragma: no cover + else: Logger.critical( - f"Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.") + f"Attempted to append an unrecognized TargetPlatformModelComponent of type: {type(tp_model_component)}.") # pragma: no cover def get_info(self) -> Dict[str, Any]: """ @@ -688,7 +688,7 @@ def __validate_model(self) -> None: """ opsets_names = [op.name for op in self.operator_set] if len(set(opsets_names)) != len(opsets_names): - Logger.critical("Operator Sets must have unique names.") + Logger.critical("Operator Sets must have unique names.") # pragma: no cover def __enter__(self) -> 'TargetPlatformModel': """ From eb5bbb292f0f215987585d08fd19714c965874e8 Mon Sep 17 00:00:00 2001 From: liord Date: Tue, 10 Dec 2024 12:49:10 +0200 Subject: [PATCH 3/8] Fix OperatorSetNames Enum --- .../target_platform_capabilities/schema/v1.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v1.py b/model_compression_toolkit/target_platform_capabilities/schema/v1.py index a2be27f5a..306da315f 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/v1.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/v1.py @@ -23,11 +23,6 @@ _current_tp_model class OperatorSetNames(Enum): - OPSET_NO_QUANTIZATION = "NoQuantization" - OPSET_QUANTIZATION_PRESERVING = "QuantizationPreserving" - OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS = "DimensionManipulationOpsWithWeights" - OPSET_DIMENSION_MANIPULATION_OPS = "DimensionManipulationOps" - OPSET_MERGE_OPS = "MergeOps" OPSET_CONV = "Conv" OPSET_DEPTHWISE_CONV = "DepthwiseConv2D" OPSET_CONV_TRANSPOSE = "ConvTraspose" @@ -38,7 +33,10 @@ class OperatorSetNames(Enum): OPSET_GATHER = "Gather" OPSET_EXPAND = "Expend" OPSET_BATCH_NORM = "BatchNorm" - OPSET_ANY_RELU = "AnyReLU" + OPSET_RELU = "ReLU" + OPSET_RELU6 = "ReLU6" + OPSET_LEAKY_RELU = "LEAKYReLU" + OPSET_HARD_TANH = "HardTanh" OPSET_ADD = "Add" OPSET_SUB = "Sub" OPSET_MUL = "Mul" From 52b00d02bd9518c0279e134aac0c10a345cfa744 Mon Sep 17 00:00:00 2001 From: liord Date: Tue, 10 Dec 2024 15:17:37 +0200 Subject: [PATCH 4/8] Add show function to tp model --- .../target_platform_capabilities/schema/v1.py | 10 ++++++++++ tests/common_tests/test_tp_model.py | 1 + 2 files changed, 11 insertions(+) diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v1.py b/model_compression_toolkit/target_platform_capabilities/schema/v1.py index 306da315f..be4954b7a 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/v1.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/v1.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +import pprint + from dataclasses import replace, dataclass, asdict, field from enum import Enum from typing import Dict, Any, Union, Tuple, List, Optional @@ -718,3 +720,11 @@ def __exit__(self, exc_type, exc_value, tb) -> 'TargetPlatformModel': self.__validate_model() _current_tp_model.reset() return self + + def show(self) -> None: + """ + + Display the TargetPlatformModel. + + """ + pprint.pprint(self.get_info(), sort_dicts=False) \ No newline at end of file diff --git a/tests/common_tests/test_tp_model.py b/tests/common_tests/test_tp_model.py index 5fea3155b..5b1cd5799 100644 --- a/tests/common_tests/test_tp_model.py +++ b/tests/common_tests/test_tp_model.py @@ -76,6 +76,7 @@ def test_tp_model_show(self): with tpm: a = schema.OperatorsSet("opA") + tpm.show() class OpsetTest(unittest.TestCase): From 0e71a8710fd8ad894f55e13f98499b83119a51ec Mon Sep 17 00:00:00 2001 From: liord Date: Wed, 11 Dec 2024 12:38:32 +0200 Subject: [PATCH 5/8] Remove functionality from schema to schema_functions --- .../core/common/graph/base_node.py | 5 +- .../set_node_quantization_config.py | 5 +- .../schema/schema_functions.py | 109 +++++++++++++++++- .../target_platform_capabilities/schema/v1.py | 101 +++------------- .../operations_to_layers.py | 8 +- .../target_platform_capabilities.py | 8 +- .../tpc_models/imx500_tpc/v4/tp_model.py | 2 +- tests/common_tests/test_tp_model.py | 16 +-- 8 files changed, 151 insertions(+), 103 deletions(-) diff --git a/model_compression_toolkit/core/common/graph/base_node.py b/model_compression_toolkit/core/common/graph/base_node.py index ecad233dd..b47a03f1f 100644 --- a/model_compression_toolkit/core/common/graph/base_node.py +++ b/model_compression_toolkit/core/common/graph/base_node.py @@ -24,6 +24,7 @@ from model_compression_toolkit.logger import Logger from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \ OpQuantizationConfig +from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, LayerFilterParams @@ -585,7 +586,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities, _node_qc_options = node_qc_options.quantization_config_list if len(next_nodes): next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes] - next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits + next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg) for qc_opts in next_nodes_qc_options for op_cfg in qc_opts.quantization_config_list]) @@ -596,7 +597,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities, Logger.critical(f"Graph doesn't match TPC bit configurations: {self} -> {next_nodes}.") # pragma: no cover # Verify base config match - if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits + if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config) for qc_opt in next_nodes_qc_options]): # base_config activation bits doesn't match next node supported input bit-width -> replace with # a qco from quantization_config_list with maximum activation bit-width. diff --git a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py index 2e6cc8d9d..5d4d18441 100644 --- a/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py +++ b/model_compression_toolkit/core/common/quantization/set_node_quantization_config.py @@ -32,6 +32,7 @@ get_activation_quantization_params_fn, get_weights_quantization_params_fn from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \ get_weights_quantization_fn +from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \ QuantizationConfigOptions @@ -117,7 +118,7 @@ def filter_node_qco_by_graph(node: BaseNode, if len(next_nodes): next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes] - next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits + next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg) for qc_opts in next_nodes_qc_options for op_cfg in qc_opts.quantization_config_list]) @@ -128,7 +129,7 @@ def filter_node_qco_by_graph(node: BaseNode, Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.") # Verify base config match - if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits + if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config) for qc_opt in next_nodes_qc_options]): # base_config activation bits doesn't match next node supported input bit-width -> replace with # a qco from quantization_config_list with maximum activation bit-width. diff --git a/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py b/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py index 105136647..03b26e2d9 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py @@ -12,8 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== +from logging import Logger import copy -from typing import Any, Dict +from typing import Any, Dict, Optional + +from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \ + TargetPlatformModel, QuantizationConfigOptions, OperatorsSetBase def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any: @@ -35,3 +39,106 @@ def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any: f'but {k} is not a parameter of {obj_copy}.' setattr(obj_copy, k, v) return obj_copy + + +def get_config_options_by_operators_set(self, operators_set_name: str) -> QuantizationConfigOptions: + """ + Get the QuantizationConfigOptions of an OperatorsSet by its name. + + Args: + operators_set_name (str): Name of the OperatorsSet to get. + + Returns: + QuantizationConfigOptions: Quantization configuration options for the given OperatorsSet name. + """ + for op_set in self.operator_set: + if operators_set_name == op_set.name: + return op_set.qc_options + return self.default_qco + + +def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) -> int: + """ + Get the maximum supported input bit-width. + + Args: + op_quantization_config (OpQuantizationConfig): The configuration object from which to retrieve the maximum supported input bit-width. + + Returns: + int: Maximum supported input bit-width. + """ + return max(op_quantization_config.supported_input_activation_n_bits) + + +def get_config_options_by_operators_set(tp_model: TargetPlatformModel, + operators_set_name: str) -> QuantizationConfigOptions: + """ + Get the QuantizationConfigOptions of an OperatorsSet by its name. + + Args: + tp_model (TargetPlatformModel): The target platform model containing the operator sets and their configurations. + operators_set_name (str): The name of the OperatorsSet whose quantization configuration options are to be retrieved. + + Returns: + QuantizationConfigOptions: The quantization configuration options associated with the specified OperatorsSet, + or the default quantization configuration options if the OperatorsSet is not found. + """ + for op_set in tp_model.operator_set: + if operators_set_name == op_set.name: + return op_set.qc_options + return tp_model.default_qco + + +def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuantizationConfig: + """ + Get the default OpQuantizationConfig of the TargetPlatformModel. + + Args: + tp_model (TargetPlatformModel): The target platform model containing the default quantization configuration. + + Returns: + OpQuantizationConfig: The default quantization configuration. + + Raises: + AssertionError: If the default quantization configuration list contains more than one configuration option. + """ + assert len(tp_model.default_qco.quantization_config_list) == 1, \ + f"Default quantization configuration options must contain only one option, " \ + f"but found {len(tp_model.default_qco.quantization_config_list)} configurations." + return tp_model.default_qco.quantization_config_list[0] + + +def is_opset_in_model(tp_model: TargetPlatformModel, opset_name: str) -> bool: + """ + Check whether an OperatorsSet is defined in the model. + + Args: + tp_model (TargetPlatformModel): The target platform model containing the list of operator sets. + opset_name (str): The name of the OperatorsSet to check for existence. + + Returns: + bool: True if an OperatorsSet with the given name exists in the target platform model, + otherwise False. + """ + return opset_name in [x.name for x in tp_model.operator_set] + + +def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optional[OperatorsSetBase]: + """ + Get an OperatorsSet object from the model by its name. + + Args: + tp_model (TargetPlatformModel): The target platform model containing the list of operator sets. + opset_name (str): The name of the OperatorsSet to be retrieved. + + Returns: + Optional[OperatorsSetBase]: The OperatorsSet object with the specified name if found. + If no operator set with the specified name is found, None is returned. + + Raises: + A critical log message if multiple operator sets with the same name are found. + """ + opset_list = [x for x in tp_model.operator_set if x.name == opset_name] + if len(opset_list) > 1: + Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.") + return opset_list[0] if opset_list else None diff --git a/model_compression_toolkit/target_platform_capabilities/schema/v1.py b/model_compression_toolkit/target_platform_capabilities/schema/v1.py index be4954b7a..4353a7d98 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/v1.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/v1.py @@ -103,14 +103,14 @@ class AttributeQuantizationConfig: weights_n_bits (int): Number of bits to quantize the coefficients. weights_per_channel_threshold (bool): Indicates whether to quantize the weights per-channel or per-tensor. enable_weights_quantization (bool): Indicates whether to quantize the model weights or not. - lut_values_bitwidth (Union[int, None]): Number of bits to use when quantizing in a look-up table. + lut_values_bitwidth (Optional[int]): Number of bits to use when quantizing in a look-up table. If None, defaults to 8 in hptq; otherwise, it uses the provided value. """ weights_quantization_method: QuantizationMethod = QuantizationMethod.POWER_OF_TWO weights_n_bits: int = FLOAT_BITWIDTH weights_per_channel_threshold: bool = False enable_weights_quantization: bool = False - lut_values_bitwidth: Union[int, None] = None + lut_values_bitwidth: Optional[int] = None def __post_init__(self): """ @@ -170,7 +170,7 @@ class OpQuantizationConfig: simd_size: int signedness: Signedness - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization processing for input validation. @@ -218,16 +218,6 @@ def clone_and_edit(self, attr_to_edit: Dict[str, Dict[str, Any]] = {}, **kwargs) # Return a new instance with the updated attribute mapping return replace(updated_config, attr_weights_configs_mapping=updated_attr_mapping) - @property - def max_input_activation_n_bits(self) -> int: - """ - Get the maximum supported input bit-width. - - Returns: - int: Maximum supported input bit-width. - """ - return max(self.supported_input_activation_n_bits) - @dataclass(frozen=True) class QuantizationConfigOptions: @@ -236,12 +226,12 @@ class QuantizationConfigOptions: Attributes: quantization_config_list (List[OpQuantizationConfig]): List of possible OpQuantizationConfig to gather. - base_config (Union[OpQuantizationConfig, None]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner. + base_config (Optional[OpQuantizationConfig]): Fallback OpQuantizationConfig to use when optimizing the model in a non-mixed-precision manner. """ quantization_config_list: List[OpQuantizationConfig] - base_config: Union[OpQuantizationConfig, None] = None + base_config: Optional[OpQuantizationConfig] = None - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization processing for input validation. @@ -320,12 +310,12 @@ def clone_and_edit_weight_attribute(self, attrs: List[str] = None, **kwargs) -> updated_configs.append(replace(qc, attr_weights_configs_mapping=updated_attr_mapping)) return replace(self, base_config=updated_base_config, quantization_config_list=updated_configs) - def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Union[Dict[str, str], None]) -> 'QuantizationConfigOptions': + def clone_and_map_weights_attr_keys(self, layer_attrs_mapping: Optional[Dict[str, str]]) -> 'QuantizationConfigOptions': """ Clones the quantization configurations and updates keys in attribute config mappings. Args: - layer_attrs_mapping (Union[Dict[str, str], None]): A mapping between attribute names. + layer_attrs_mapping (Optional[Dict[str, str]]): A mapping between attribute names. Returns: QuantizationConfigOptions: A new instance of QuantizationConfigOptions with updated attribute keys. @@ -361,7 +351,7 @@ class TargetPlatformModelComponent: Component of TargetPlatformModel (Fusing, OperatorsSet, etc.). """ - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization to register the component with the current TargetPlatformModel. """ @@ -384,7 +374,7 @@ class OperatorsSetBase(TargetPlatformModelComponent): Base class to represent a set of a target platform model component of operator set types. Inherits from TargetPlatformModelComponent. """ - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization to ensure the component is registered with the TargetPlatformModel. Calls the parent class's __post_init__ method to append this component to the current TargetPlatformModel. @@ -407,7 +397,7 @@ class OperatorsSet(OperatorsSetBase): name: str qc_options: QuantizationConfigOptions = None - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization processing to mark the operator set as default if applicable. @@ -447,7 +437,7 @@ class OperatorSetConcat(OperatorsSetBase): qc_options: None = field(default=None, init=False) name: str = None - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization processing to generate the concatenated name and set it as the `name` attribute. @@ -486,7 +476,7 @@ class Fusing(TargetPlatformModelComponent): operator_groups_list: Tuple[Union[OperatorsSet, OperatorSetConcat]] name: str = None - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization processing for input validation and name generation. @@ -504,7 +494,6 @@ def __post_init__(self) -> None: if len(self.operator_groups_list) < 2: Logger.critical("Fusing cannot be created for a single operator.") # pragma: no cover - # if self.name is None: # Generate the name from the operator groups if not provided generated_name = '_'.join([x.name for x in self.operator_groups_list]) object.__setattr__(self, 'name', generated_name) @@ -578,7 +567,7 @@ class TargetPlatformModel: SCHEMA_VERSION: int = 1 - def __post_init__(self) -> None: + def __post_init__(self): """ Post-initialization processing for input validation. @@ -592,62 +581,7 @@ def __post_init__(self) -> None: if len(self.default_qco.quantization_config_list) != 1: Logger.critical("Default QuantizationConfigOptions must contain exactly one option.") # pragma: no cover - def get_config_options_by_operators_set(self, operators_set_name: str) -> QuantizationConfigOptions: - """ - Get the QuantizationConfigOptions of an OperatorsSet by its name. - - Args: - operators_set_name (str): Name of the OperatorsSet to get. - - Returns: - QuantizationConfigOptions: Quantization configuration options for the given OperatorsSet name. - """ - for op_set in self.operator_set: - if operators_set_name == op_set.name: - return op_set.qc_options - return self.default_qco - - def get_default_op_quantization_config(self) -> OpQuantizationConfig: - """ - Get the default OpQuantizationConfig of the TargetPlatformModel. - - Returns: - OpQuantizationConfig: The default quantization configuration. - """ - assert len(self.default_qco.quantization_config_list) == 1, \ - f"Default quantization configuration options must contain only one option, " \ - f"but found {len(self.default_qco.quantization_config_list)} configurations." - return self.default_qco.quantization_config_list[0] - - def is_opset_in_model(self, opset_name: str) -> bool: - """ - Check whether an OperatorsSet is defined in the model. - - Args: - opset_name (str): Name of the OperatorsSet to check. - - Returns: - bool: True if the OperatorsSet exists, False otherwise. - """ - return opset_name in [x.name for x in self.operator_set] - - def get_opset_by_name(self, opset_name: str) -> Optional[OperatorsSetBase]: - """ - Get an OperatorsSet object from the model by its name. - - Args: - opset_name (str): Name of the OperatorsSet to retrieve. - - Returns: - Optional[OperatorsSetBase]: The OperatorsSet object with the given name, - or None if not found in the model. - """ - opset_list = [x for x in self.operator_set if x.name == opset_name] - if len(opset_list) > 1: - Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.") - return opset_list[0] if opset_list else None - - def append_component(self, tp_model_component: TargetPlatformModelComponent) -> None: + def append_component(self, tp_model_component: TargetPlatformModelComponent): """ Attach a TargetPlatformModel component to the model (like Fusing or OperatorsSet). @@ -674,12 +608,11 @@ def get_info(self) -> Dict[str, Any]: """ return { "Model name": self.name, - "Default quantization config": self.get_default_op_quantization_config().get_info(), "Operators sets": [o.get_info() for o in self.operator_set], "Fusing patterns": [f.get_info() for f in self.fusing_patterns], } - def __validate_model(self) -> None: + def __validate_model(self): """ Validate the model's configuration to ensure its integrity. @@ -721,7 +654,7 @@ def __exit__(self, exc_type, exc_value, tb) -> 'TargetPlatformModel': _current_tp_model.reset() return self - def show(self) -> None: + def show(self): """ Display the TargetPlatformModel. diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py index 669a068a7..aa378ff16 100644 --- a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/operations_to_layers.py @@ -16,6 +16,8 @@ from typing import List, Any, Dict from model_compression_toolkit.logger import Logger +from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \ + get_config_options_by_operators_set, is_opset_in_model from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorsSetBase, OperatorSetConcat @@ -137,14 +139,14 @@ def validate_op_sets(self): f'is of type {type(ops2layers)}' # Assert that opset in the current TargetPlatformCapabilities and has a unique name. - is_opset_in_model = _current_tpc.get().tp_model.is_opset_in_model(ops2layers.name) - assert is_opset_in_model, f'{ops2layers.name} is not defined in the target platform model that is associated with the target platform capabilities.' + opset_in_model = is_opset_in_model(_current_tpc.get().tp_model, ops2layers.name) + assert opset_in_model, f'{ops2layers.name} is not defined in the target platform model that is associated with the target platform capabilities.' assert not (ops2layers.name in existing_opset_names), f'OperationsSetToLayers names should be unique, but {ops2layers.name} appears to violate it.' existing_opset_names.append(ops2layers.name) # Assert that a layer does not appear in more than a single OperatorsSet in the TargetPlatformModel. for layer in ops2layers.layers: - qco_by_opset_name = _current_tpc.get().tp_model.get_config_options_by_operators_set(ops2layers.name) + qco_by_opset_name = get_config_options_by_operators_set(_current_tpc.get().tp_model, ops2layers.name) if layer in existing_layers: Logger.critical(f'Found layer {layer.__name__} in more than one ' f'OperatorsSet') # pragma: no cover diff --git a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py index ef0cd5713..924069c82 100644 --- a/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py +++ b/model_compression_toolkit/target_platform_capabilities/target_platform/targetplatform2framework/target_platform_capabilities.py @@ -19,6 +19,8 @@ from typing import List, Any, Dict, Tuple from model_compression_toolkit.logger import Logger +from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \ + get_config_options_by_operators_set, get_default_op_quantization_config, get_opset_by_name from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.operations_to_layers import \ OperationsToLayers, OperationsSetToLayers from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent @@ -64,7 +66,7 @@ def get_layers_by_opset_name(self, opset_name: str) -> List[Any]: Returns: List of layers/LayerFilterParams that are attached to the opset name. """ - opset = self.tp_model.get_opset_by_name(opset_name) + opset = get_opset_by_name(self.tp_model, opset_name) if opset is None: Logger.warning(f'{opset_name} was not found in TargetPlatformCapabilities.') return None @@ -165,7 +167,7 @@ def get_default_op_qc(self) -> OpQuantizationConfig: to the TargetPlatformCapabilities. """ - return self.tp_model.get_default_op_quantization_config() + return get_default_op_quantization_config(self.tp_model) def _get_config_options_mapping(self) -> Tuple[Dict[Any, QuantizationConfigOptions], @@ -181,7 +183,7 @@ def _get_config_options_mapping(self) -> Tuple[Dict[Any, QuantizationConfigOptio filterlayer2qco = {} for op2layers in self.op_sets_to_layers.op_sets_to_layers: for l in op2layers.layers: - qco = self.tp_model.get_config_options_by_operators_set(op2layers.name) + qco = get_config_options_by_operators_set(self.tp_model, op2layers.name) if qco is None: qco = self.tp_model.default_qco diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py index 2f658d2f8..9ca5f4643 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py @@ -271,7 +271,7 @@ def generate_tp_model(default_config: OpQuantizationConfig, base_config=base_config) # Define operator sets that use mixed_precision_configuration_options: - conv = schema.OperatorsSet(OPSET_CONV, mixed_precision_configuration_options) + conv = schema.OperatorsSet(schema.OPS_SET_LIST.OPSET_CONV, mixed_precision_configuration_options) fc = schema.OperatorsSet(OPSET_FULLY_CONNECTED, mixed_precision_configuration_options) schema.OperatorsSet(OPSET_BATCH_NORM, default_config_options_16bit) diff --git a/tests/common_tests/test_tp_model.py b/tests/common_tests/test_tp_model.py index 5b1cd5799..4e96a13df 100644 --- a/tests/common_tests/test_tp_model.py +++ b/tests/common_tests/test_tp_model.py @@ -20,6 +20,8 @@ from model_compression_toolkit.constants import FLOAT_BITWIDTH from model_compression_toolkit.core.common import BaseNode from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR +from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import \ + get_config_options_by_operators_set, is_opset_in_model from model_compression_toolkit.target_platform_capabilities.target_platform import \ get_default_quantization_config_options from tests.common_tests.helpers.generate_test_tp_model import generate_test_attr_configs, generate_test_op_qc @@ -92,13 +94,13 @@ def test_opset_qco(self): qco_3bit = get_default_quantization_config_options().clone_and_edit(activation_n_bits=3) schema.OperatorsSet(opset_name, qco_3bit) - for op_qc in hm.get_config_options_by_operators_set(opset_name).quantization_config_list: + for op_qc in get_config_options_by_operators_set(hm, opset_name).quantization_config_list: self.assertEqual(op_qc.activation_n_bits, 3) - self.assertTrue(hm.is_opset_in_model(opset_name)) - self.assertFalse(hm.is_opset_in_model("ShouldNotBeInModel")) - self.assertEqual(hm.get_config_options_by_operators_set(opset_name), qco_3bit) - self.assertEqual(hm.get_config_options_by_operators_set("ShouldNotBeInModel"), + self.assertTrue(is_opset_in_model(hm, opset_name)) + self.assertFalse(is_opset_in_model(hm, "ShouldNotBeInModel")) + self.assertEqual(get_config_options_by_operators_set(hm, opset_name), qco_3bit) + self.assertEqual(get_config_options_by_operators_set(hm, "ShouldNotBeInModel"), hm.default_qco) def test_opset_concat(self): @@ -115,8 +117,8 @@ def test_opset_concat(self): schema.OperatorsSet('opset_C') # Just add it without using it in concat schema.OperatorSetConcat([a, b]) self.assertEqual(len(hm.operator_set), 4) - self.assertTrue(hm.is_opset_in_model("opset_A_opset_B")) - self.assertTrue(hm.get_config_options_by_operators_set('opset_A_opset_B') is None) + self.assertTrue(is_opset_in_model(hm, "opset_A_opset_B")) + self.assertTrue(get_config_options_by_operators_set(hm, 'opset_A_opset_B') is None) def test_non_unique_opset(self): hm = schema.TargetPlatformModel( From 24ea028c625525ea0de77d63a9d4a2621fc763f6 Mon Sep 17 00:00:00 2001 From: liord Date: Wed, 11 Dec 2024 12:58:42 +0200 Subject: [PATCH 6/8] Fix typo --- .../tpc_models/imx500_tpc/v4/tp_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py index 9ca5f4643..2f658d2f8 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py @@ -271,7 +271,7 @@ def generate_tp_model(default_config: OpQuantizationConfig, base_config=base_config) # Define operator sets that use mixed_precision_configuration_options: - conv = schema.OperatorsSet(schema.OPS_SET_LIST.OPSET_CONV, mixed_precision_configuration_options) + conv = schema.OperatorsSet(OPSET_CONV, mixed_precision_configuration_options) fc = schema.OperatorsSet(OPSET_FULLY_CONNECTED, mixed_precision_configuration_options) schema.OperatorsSet(OPSET_BATCH_NORM, default_config_options_16bit) From 276baccfa7694b01be710ea78ca648e03115fe7d Mon Sep 17 00:00:00 2001 From: liord Date: Wed, 11 Dec 2024 15:09:01 +0200 Subject: [PATCH 7/8] Delete unused function --- .../schema/schema_functions.py | 28 ++----------------- 1 file changed, 3 insertions(+), 25 deletions(-) diff --git a/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py b/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py index 03b26e2d9..bebf31b60 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py @@ -13,34 +13,12 @@ # limitations under the License. # ============================================================================== from logging import Logger -import copy -from typing import Any, Dict, Optional +from typing import Optional from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \ TargetPlatformModel, QuantizationConfigOptions, OperatorsSetBase -def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any: - """ - Clones the given object and edit some of its parameters. - - Args: - obj: An object to clone. - **kwargs: Keyword arguments to edit in the cloned object. - - Returns: - Edited copy of the given object. - """ - - obj_copy = copy.deepcopy(obj) - for k, v in kwargs.items(): - assert hasattr(obj_copy, - k), f'Edit parameter is possible only for existing parameters in the given object, ' \ - f'but {k} is not a parameter of {obj_copy}.' - setattr(obj_copy, k, v) - return obj_copy - - def get_config_options_by_operators_set(self, operators_set_name: str) -> QuantizationConfigOptions: """ Get the QuantizationConfigOptions of an OperatorsSet by its name. @@ -104,7 +82,7 @@ def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuant """ assert len(tp_model.default_qco.quantization_config_list) == 1, \ f"Default quantization configuration options must contain only one option, " \ - f"but found {len(tp_model.default_qco.quantization_config_list)} configurations." + f"but found {len(tp_model.default_qco.quantization_config_list)} configurations." # pragma: no cover return tp_model.default_qco.quantization_config_list[0] @@ -140,5 +118,5 @@ def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optiona """ opset_list = [x for x in tp_model.operator_set if x.name == opset_name] if len(opset_list) > 1: - Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.") + Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.") # pragma: no cover return opset_list[0] if opset_list else None From 274e8e314c194b4b8a35319876b14c6ce018fea5 Mon Sep 17 00:00:00 2001 From: liord Date: Thu, 12 Dec 2024 10:47:59 +0200 Subject: [PATCH 8/8] Delete unused function --- .../schema/schema_functions.py | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py b/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py index bebf31b60..84633abb3 100644 --- a/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py +++ b/model_compression_toolkit/target_platform_capabilities/schema/schema_functions.py @@ -19,22 +19,6 @@ TargetPlatformModel, QuantizationConfigOptions, OperatorsSetBase -def get_config_options_by_operators_set(self, operators_set_name: str) -> QuantizationConfigOptions: - """ - Get the QuantizationConfigOptions of an OperatorsSet by its name. - - Args: - operators_set_name (str): Name of the OperatorsSet to get. - - Returns: - QuantizationConfigOptions: Quantization configuration options for the given OperatorsSet name. - """ - for op_set in self.operator_set: - if operators_set_name == op_set.name: - return op_set.qc_options - return self.default_qco - - def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) -> int: """ Get the maximum supported input bit-width.