Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Target Platform Capabilities - Phase 2 #1290

Merged
merged 8 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
OpQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, LayerFilterParams


Expand Down Expand Up @@ -585,7 +586,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
_node_qc_options = node_qc_options.quantization_config_list
if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])

Expand All @@ -596,7 +597,7 @@ def filter_node_qco_by_graph(self, tpc: TargetPlatformCapabilities,
Logger.critical(f"Graph doesn't match TPC bit configurations: {self} -> {next_nodes}.") # pragma: no cover

# Verify base config match
if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
get_activation_quantization_params_fn, get_weights_quantization_params_fn
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
get_weights_quantization_fn
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import max_input_activation_n_bits
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
QuantizationConfigOptions
Expand Down Expand Up @@ -117,7 +118,7 @@ def filter_node_qco_by_graph(node: BaseNode,

if len(next_nodes):
next_nodes_qc_options = [_node.get_qco(tpc) for _node in next_nodes]
next_nodes_supported_input_bitwidth = min([op_cfg.max_input_activation_n_bits
next_nodes_supported_input_bitwidth = min([max_input_activation_n_bits(op_cfg)
for qc_opts in next_nodes_qc_options
for op_cfg in qc_opts.quantization_config_list])

Expand All @@ -128,7 +129,7 @@ def filter_node_qco_by_graph(node: BaseNode,
Logger.critical(f"Graph doesn't match TPC bit configurations: {node} -> {next_nodes}.")

# Verify base config match
if any([node_qc_options.base_config.activation_n_bits > qc_opt.base_config.max_input_activation_n_bits
if any([node_qc_options.base_config.activation_n_bits > max_input_activation_n_bits(qc_opt.base_config)
for qc_opt in next_nodes_qc_options]):
# base_config activation bits doesn't match next node supported input bit-width -> replace with
# a qco from quantization_config_list with maximum activation bit-width.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,95 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy
from typing import Any, Dict
from logging import Logger
from typing import Optional

from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
TargetPlatformModel, QuantizationConfigOptions, OperatorsSetBase

def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any:

def max_input_activation_n_bits(op_quantization_config: OpQuantizationConfig) -> int:
"""
Get the maximum supported input bit-width.

Args:
op_quantization_config (OpQuantizationConfig): The configuration object from which to retrieve the maximum supported input bit-width.

Returns:
int: Maximum supported input bit-width.
"""
return max(op_quantization_config.supported_input_activation_n_bits)


def get_config_options_by_operators_set(tp_model: TargetPlatformModel,
operators_set_name: str) -> QuantizationConfigOptions:
"""
Get the QuantizationConfigOptions of an OperatorsSet by its name.

Args:
tp_model (TargetPlatformModel): The target platform model containing the operator sets and their configurations.
operators_set_name (str): The name of the OperatorsSet whose quantization configuration options are to be retrieved.

Returns:
QuantizationConfigOptions: The quantization configuration options associated with the specified OperatorsSet,
or the default quantization configuration options if the OperatorsSet is not found.
"""
for op_set in tp_model.operator_set:
if operators_set_name == op_set.name:
return op_set.qc_options
return tp_model.default_qco


def get_default_op_quantization_config(tp_model: TargetPlatformModel) -> OpQuantizationConfig:
"""
Clones the given object and edit some of its parameters.
Get the default OpQuantizationConfig of the TargetPlatformModel.

Args:
obj: An object to clone.
**kwargs: Keyword arguments to edit in the cloned object.
tp_model (TargetPlatformModel): The target platform model containing the default quantization configuration.

Returns:
Edited copy of the given object.
OpQuantizationConfig: The default quantization configuration.

Raises:
AssertionError: If the default quantization configuration list contains more than one configuration option.
"""
assert len(tp_model.default_qco.quantization_config_list) == 1, \
f"Default quantization configuration options must contain only one option, " \
f"but found {len(tp_model.default_qco.quantization_config_list)} configurations." # pragma: no cover
return tp_model.default_qco.quantization_config_list[0]


def is_opset_in_model(tp_model: TargetPlatformModel, opset_name: str) -> bool:
"""
Check whether an OperatorsSet is defined in the model.

Args:
tp_model (TargetPlatformModel): The target platform model containing the list of operator sets.
opset_name (str): The name of the OperatorsSet to check for existence.

Returns:
bool: True if an OperatorsSet with the given name exists in the target platform model,
otherwise False.
"""
return opset_name in [x.name for x in tp_model.operator_set]

obj_copy = copy.deepcopy(obj)
for k, v in kwargs.items():
assert hasattr(obj_copy,
k), f'Edit parameter is possible only for existing parameters in the given object, ' \
f'but {k} is not a parameter of {obj_copy}.'
setattr(obj_copy, k, v)
return obj_copy

def get_opset_by_name(tp_model: TargetPlatformModel, opset_name: str) -> Optional[OperatorsSetBase]:
"""
Get an OperatorsSet object from the model by its name.

Args:
tp_model (TargetPlatformModel): The target platform model containing the list of operator sets.
opset_name (str): The name of the OperatorsSet to be retrieved.

Returns:
Optional[OperatorsSetBase]: The OperatorsSet object with the specified name if found.
If no operator set with the specified name is found, None is returned.

Raises:
A critical log message if multiple operator sets with the same name are found.
"""
opset_list = [x for x in tp_model.operator_set if x.name == opset_name]
if len(opset_list) > 1:
Logger.critical(f"Found more than one OperatorsSet in TargetPlatformModel with the name {opset_name}.") # pragma: no cover
return opset_list[0] if opset_list else None
Loading
Loading