Skip to content

Commit

Permalink
PR fixes:
Browse files Browse the repository at this point in the history
1. Created mct_current_schema.py for one place to update the schema version, and replaced all the imports in mct to work with this location.
2. Replaced metadata dictionary with dataclass.
  • Loading branch information
liord committed Dec 2, 2024
1 parent 6e3bfb5 commit be71127
Show file tree
Hide file tree
Showing 68 changed files with 166 additions and 175 deletions.
6 changes: 0 additions & 6 deletions model_compression_toolkit/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,6 @@
TENSORFLOW = 'tensorflow'
PYTORCH = 'pytorch'

# Metadata fields
MCT_VERSION = 'mct_version'
TPC_MINOR_VERSION = 'tpc_minor_version'
TPC_PATCH_VERSION = 'tpc_patch_version'
TPC_PLATFORM_TYPE = 'tpc_platform_type'
TPC_SCHEMA = 'tpc_schema'

WEIGHTS_SIGNED = True
# Minimal threshold to use for quantization ranges:
Expand Down
4 changes: 2 additions & 2 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
ACTIVATION_N_BITS_ATTRIBUTE, FP32_BYTES_PER_PARAMETER
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions, \
OpQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities, LayerFilterParams
from model_compression_toolkit.target_platform_capabilities.schema.v1 import OpQuantizationConfig, \
QuantizationConfigOptions


class BaseNode:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from model_compression_toolkit.core.common.graph.edge import EDGE_SINK_INDEX
from model_compression_toolkit.core.graph_prep_runner import graph_preparation_runner
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.schema.v1 import QuantizationConfigOptions
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions


def compute_resource_utilization_data(in_model: Any,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from model_compression_toolkit.core import QuantizationConfig
from model_compression_toolkit.core.common.quantization.node_quantization_config import BaseNodeQuantizationConfig, \
NodeWeightsQuantizationConfig, NodeActivationQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig, \
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig, \
OpQuantizationConfig
from model_compression_toolkit.logger import Logger

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from model_compression_toolkit.core.common.quantization.quantization_config import QuantizationConfig, \
QuantizationErrorMethod
from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig, \
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig, \
OpQuantizationConfig


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Dict, Union

from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.target_platform_capabilities.schema.v1 import Signedness
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import Signedness
from model_compression_toolkit.core.common.collectors.statistics_collector import BaseStatsCollector
from model_compression_toolkit.core.common.quantization import quantization_params_generation
from model_compression_toolkit.core.common.node_prior_info import NodePriorInfo
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from model_compression_toolkit.core.common.quantization.quantization_fn_selection import \
get_weights_quantization_fn
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities
from model_compression_toolkit.target_platform_capabilities.schema.v1 import OpQuantizationConfig, \
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OpQuantizationConfig, \
QuantizationConfigOptions


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from model_compression_toolkit.core.common import BaseNode, Graph
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig


def apply_activation_bias_correction_to_graph(graph: Graph,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from model_compression_toolkit.core.common import Graph, BaseNode
from model_compression_toolkit.core.common.framework_implementation import FrameworkImplementation
from model_compression_toolkit.core.common.quantization.node_quantization_config import WeightsAttrQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig


def apply_bias_correction_to_graph(graph_to_apply_bias_correction: Graph,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from model_compression_toolkit.core.common.graph.base_node import BaseNode
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig


class BatchNormalizationReconstruction(common.BaseSubstitution):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from model_compression_toolkit.constants import THRESHOLD, SIGNED, SHIFT_NEGATIVE_NON_LINEAR_NUM_BITS
from model_compression_toolkit.core.common.graph.graph_matchers import NodeOperationMatcher
from model_compression_toolkit.target_platform_capabilities.target_platform import QuantizationMethod
from model_compression_toolkit.target_platform_capabilities.schema.v1 import AttributeQuantizationConfig
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import AttributeQuantizationConfig
from model_compression_toolkit.core.common.quantization.set_node_quantization_config import create_node_activation_qc, \
set_quantization_configs_to_node
from model_compression_toolkit.core.common.quantization.core_config import CoreConfig
Expand Down
26 changes: 13 additions & 13 deletions model_compression_toolkit/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass, asdict

from typing import Dict, Any
from model_compression_toolkit.constants import MCT_VERSION, TPC_MINOR_VERSION, OPERATORS_SCHEDULING, \
FUSED_NODES_MAPPING, \
CUTS, MAX_CUT, OP_ORDER, OP_RECORD, SHAPE, NODE_OUTPUT_INDEX, NODE_NAME, TOTAL_SIZE, MEM_ELEMENTS, TPC_SCHEMA, \
TPC_PATCH_VERSION, TPC_PLATFORM_TYPE
from model_compression_toolkit.constants import OPERATORS_SCHEDULING, FUSED_NODES_MAPPING, CUTS, MAX_CUT, OP_ORDER, \
OP_RECORD, SHAPE, NODE_OUTPUT_INDEX, NODE_NAME, TOTAL_SIZE, MEM_ELEMENTS
from model_compression_toolkit.core.common.graph.memory_graph.compute_graph_max_cut import SchedulerInfo
from model_compression_toolkit.target_platform_capabilities.target_platform import TargetPlatformCapabilities

Expand Down Expand Up @@ -50,15 +49,16 @@ def get_versions_dict(tpc) -> Dict:
"""
# imported inside to avoid circular import error
from model_compression_toolkit import __version__ as mct_version
tpc_minor_version = f'{tpc.tp_model.tpc_minor_version}'
tpc_patch_version = f'{tpc.tp_model.tpc_patch_version}'
tpc_platform_type = f'{tpc.tp_model.tpc_platform_type}'
tpc_schema = f'{tpc.tp_model.SCHEMA_VERSION}'
return {MCT_VERSION: mct_version,
TPC_MINOR_VERSION: tpc_minor_version,
TPC_PATCH_VERSION: tpc_patch_version,
TPC_PLATFORM_TYPE: tpc_platform_type,
TPC_SCHEMA: tpc_schema}

@dataclass
class TPCVersions:
mct_version: str
tpc_minor_version: str = f'{tpc.tp_model.tpc_minor_version}'
tpc_patch_version: str = f'{tpc.tp_model.tpc_patch_version}'
tpc_platform_type: str = f'{tpc.tp_model.tpc_platform_type}'
tpc_schema: str = f'{tpc.tp_model.SCHEMA_VERSION}'

return asdict(TPCVersions(mct_version))


def get_scheduler_metadata(scheduler_info: SchedulerInfo) -> Dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema

Signedness = schema.Signedness
AttributeQuantizationConfig = schema.AttributeQuantizationConfig
OpQuantizationConfig = schema.OpQuantizationConfig
QuantizationConfigOptions = schema.QuantizationConfigOptions
OperatorsSetBase = schema.OperatorsSetBase
OperatorsSet = schema.OperatorsSet
OperatorSetConcat= schema.OperatorSetConcat
Fusing = schema.Fusing
TargetPlatformModel = schema.TargetPlatformModel
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright 2024 Sony Semiconductor Israel, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import copy
from typing import Any, Dict


def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any:
"""
Clones the given object and edit some of its parameters.
Args:
obj: An object to clone.
**kwargs: Keyword arguments to edit in the cloned object.
Returns:
Edited copy of the given object.
"""

obj_copy = copy.deepcopy(obj)
for k, v in kwargs.items():
assert hasattr(obj_copy,
k), f'Edit parameter is possible only for existing parameters in the given object, ' \
f'but {k} is not a parameter of {obj_copy}.'
setattr(obj_copy, k, v)
return obj_copy
47 changes: 24 additions & 23 deletions model_compression_toolkit/target_platform_capabilities/schema/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,7 @@
from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import \
get_current_tp_model, _current_tp_model
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model_component import \
TargetPlatformModelComponent
from model_compression_toolkit.target_platform_capabilities.schema.schema_functions import clone_and_edit_object_params


class Signedness(Enum):
Expand All @@ -45,27 +44,6 @@ class Signedness(Enum):
UNSIGNED = 2


def clone_and_edit_object_params(obj: Any, **kwargs: Dict) -> Any:
"""
Clones the given object and edit some of its parameters.
Args:
obj: An object to clone.
**kwargs: Keyword arguments to edit in the cloned object.
Returns:
Edited copy of the given object.
"""

obj_copy = copy.deepcopy(obj)
for k, v in kwargs.items():
assert hasattr(obj_copy,
k), f'Edit parameter is possible only for existing parameters in the given object, ' \
f'but {k} is not a parameter of {obj_copy}.'
setattr(obj_copy, k, v)
return obj_copy


class AttributeQuantizationConfig:
"""
Hold the quantization configuration of a weight attribute of a layer.
Expand Down Expand Up @@ -387,6 +365,29 @@ def get_info(self):
return {f'option {i}': cfg.get_info() for i, cfg in enumerate(self.quantization_config_list)}


class TargetPlatformModelComponent:
"""
Component of TargetPlatformModel (Fusing, OperatorsSet, etc.)
"""
def __init__(self, name: str):
"""
Args:
name: Name of component.
"""
self.name = name
_current_tp_model.get().append_component(self)

def get_info(self) -> Dict[str, Any]:
"""
Returns: Get information about the component to display (return an empty dictionary.
the actual component should fill it with info).
"""
return {}


class OperatorsSetBase(TargetPlatformModelComponent):
"""
Base class to represent a set of operators.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.attribute_filter import AttributeFilter
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework import TargetPlatformCapabilities, OperationsSetToLayers, Smaller, SmallerEq, NotEq, Eq, GreaterEq, Greater, LayerFilterParams, OperationsToLayers, get_current_tpc
from model_compression_toolkit.target_platform_capabilities.target_platform.target_platform_model import get_default_quantization_config_options
from model_compression_toolkit.target_platform_capabilities.schema.v1 import TargetPlatformModel, OperatorsSet, \
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, OperatorsSet, \
OperatorSetConcat, Signedness, AttributeQuantizationConfig, OpQuantizationConfig, QuantizationConfigOptions, Fusing

from mct_quantizers import QuantizationMethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ==============================================================================

from model_compression_toolkit.target_platform_capabilities.target_platform.current_tp_model import get_current_tp_model
from model_compression_toolkit.target_platform_capabilities.schema.v1 import QuantizationConfigOptions
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import QuantizationConfigOptions


def get_default_quantization_config_options() -> QuantizationConfigOptions:
Expand All @@ -28,15 +28,3 @@ def get_default_quantization_config_options() -> QuantizationConfigOptions:
return get_current_tp_model().default_qco


def get_default_quantization_config():
"""
Returns: The default OpQuantizationConfig of the model. This is the OpQuantizationConfig
to use when a layer's options is queried and it wasn't specified in the TargetPlatformCapabilities.
This OpQuantizationConfig is the single option in the default QuantizationConfigOptions.
"""

return get_current_tp_model().get_default_op_quantization_config()


This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from model_compression_toolkit.logger import Logger
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
from model_compression_toolkit.target_platform_capabilities.schema.v1 import OperatorsSetBase, OperatorSetConcat
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import OperatorsSetBase, OperatorSetConcat
from model_compression_toolkit import DefaultDict


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.target_platform_capabilities_component import TargetPlatformCapabilitiesComponent
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.layer_filter_params import LayerFilterParams
from model_compression_toolkit.target_platform_capabilities.immutable import ImmutableClass
from model_compression_toolkit.target_platform_capabilities.schema.v1 import TargetPlatformModel, OperatorsSetBase, \
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, OperatorsSetBase, \
OpQuantizationConfig, QuantizationConfigOptions
from model_compression_toolkit.target_platform_capabilities.target_platform.targetplatform2framework.current_tpc import _current_tpc

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from typing import List, Tuple

import model_compression_toolkit as mct
import model_compression_toolkit.target_platform_capabilities.schema.v1 as schema
import model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema as schema
from model_compression_toolkit.constants import FLOAT_BITWIDTH
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, WEIGHTS_N_BITS, \
IMX500_TP_MODEL
from model_compression_toolkit.target_platform_capabilities.schema.v1 import TargetPlatformModel, Signedness, \
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel, Signedness, \
AttributeQuantizationConfig, OpQuantizationConfig

tp = mct.target_platform
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import tensorflow as tf
from packaging import version

from model_compression_toolkit.target_platform_capabilities.schema.v1 import TargetPlatformModel
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
from model_compression_toolkit.defaultdict import DefaultDict
from model_compression_toolkit.verify_packages import FOUND_SONY_CUSTOM_LAYERS
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, KERAS_DEPTHWISE_KERNEL, \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from torch.nn import ReLU, ReLU6, PReLU, SiLU, Sigmoid, Tanh, Hardswish, LeakyReLU
from torch.nn.functional import relu, relu6, prelu, silu, hardtanh, hardswish, leaky_relu

from model_compression_toolkit.target_platform_capabilities.schema.v1 import TargetPlatformModel
from model_compression_toolkit.target_platform_capabilities.schema.mct_current_schema import TargetPlatformModel
from model_compression_toolkit.defaultdict import DefaultDict
from model_compression_toolkit.target_platform_capabilities.constants import KERNEL_ATTR, BIAS_ATTR, PYTORCH_KERNEL, \
BIAS
Expand Down
Loading

0 comments on commit be71127

Please sign in to comment.