Skip to content

Commit

Permalink
[PTQ][OV] FP8 implementation (openvinotoolkit#2283)
Browse files Browse the repository at this point in the history
### Changes

- Added FP8 implementation
- Added `Mode` parameter

### Reason for changes

- New FP8 implementation

### Related tickets

- 119805

### Tests

- `tests/openvino/native/quantization/test_graphs.py`
- `tests/openvino/native/test_model_transformer.py`

On top of openvinotoolkit/openvino#21034 -
**Merged**
  • Loading branch information
nikita-malininn authored Dec 15, 2023
1 parent 0c389c3 commit 5f2c20e
Show file tree
Hide file tree
Showing 71 changed files with 1,045 additions and 263 deletions.
1 change: 1 addition & 0 deletions nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nncf.parameters import CompressWeightsMode as CompressWeightsMode
from nncf.parameters import DropType as DropType
from nncf.parameters import ModelType as ModelType
from nncf.parameters import QuantizationMode as QuantizationMode
from nncf.parameters import TargetDevice as TargetDevice
from nncf.quantization import QuantizationPreset as QuantizationPreset
from nncf.quantization import compress_weights as compress_weights
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/hardware/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.logging import nncf_logger
from nncf.common.quantization import quantizers as quant
from nncf.common.quantization.structs import QuantizationMode
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.common.quantization.structs import QuantizerConfig
from nncf.common.utils.helpers import product_dict
from nncf.common.utils.os import safe_open
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/quantization/initialization/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from typing import Dict, List, Optional

from nncf.common.initialization.dataloader import NNCFDataLoader
from nncf.common.quantization.structs import QuantizationMode
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.common.quantization.structs import QuantizerGroup
from nncf.config.schemata.defaults import NUM_INIT_SAMPLES

Expand Down
2 changes: 1 addition & 1 deletion nncf/common/quantization/quantizer_propagation/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from nncf.common.quantization.quantizer_setup import QuantizationInsertionPointBase
from nncf.common.quantization.quantizer_setup import QuantizationPointId
from nncf.common.quantization.quantizer_setup import WeightQuantizationInsertionPoint
from nncf.common.quantization.structs import QuantizationMode
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.common.quantization.structs import QuantizerConfig
from nncf.common.quantization.structs import UnifiedScaleType
from nncf.common.scopes import should_consider_scope
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/quantization/quantizer_propagation/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from nncf.common.quantization.quantizer_setup import SingleConfigQuantizerSetup
from nncf.common.quantization.structs import QuantizableWeightedLayerNode
from nncf.common.quantization.structs import QuantizationConstraints
from nncf.common.quantization.structs import QuantizationMode
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.common.quantization.structs import QuantizerConfig
from nncf.common.quantization.structs import QuantizerGroup
from nncf.common.quantization.structs import UnifiedScaleType
Expand Down
2 changes: 1 addition & 1 deletion nncf/common/quantization/quantizer_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from nncf.common.graph import NNCFNodeName
from nncf.common.logging import nncf_logger
from nncf.common.quantization.structs import NonWeightQuantizerId
from nncf.common.quantization.structs import QuantizationMode
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.common.quantization.structs import QuantizerConfig
from nncf.common.quantization.structs import UnifiedScaleType
from nncf.common.quantization.structs import WeightQuantizerId
Expand Down
16 changes: 8 additions & 8 deletions nncf/common/quantization/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@


@api()
class QuantizationMode:
class QuantizationScheme:
"""
Basic enumeration for quantization mode specification.
Basic enumeration for quantization scheme specification.
:param SYMMETRIC:
:param ASYMMETRIC:
Expand All @@ -43,7 +43,7 @@ class QuantizerConfig:
def __init__(
self,
num_bits: int = QUANTIZATION_BITS,
mode: QuantizationMode = QuantizationMode.SYMMETRIC,
mode: QuantizationScheme = QuantizationScheme.SYMMETRIC,
signedness_to_force: Optional[bool] = None,
per_channel: bool = QUANTIZATION_PER_CHANNEL,
):
Expand All @@ -66,7 +66,7 @@ def __eq__(self, other):
def __str__(self):
return "B:{bits} M:{mode} SGN:{signedness} PC:{per_channel}".format(
bits=self.num_bits,
mode="S" if self.mode == QuantizationMode.SYMMETRIC else "A",
mode="S" if self.mode == QuantizationScheme.SYMMETRIC else "A",
signedness="ANY" if self.signedness_to_force is None else ("S" if self.signedness_to_force else "U"),
per_channel="Y" if self.per_channel else "N",
)
Expand All @@ -86,7 +86,7 @@ def is_valid_requantization_for(self, other: "QuantizerConfig") -> bool:
"""
fail_conditions = [
self.num_bits > other.num_bits,
self.mode is QuantizationMode.ASYMMETRIC and other.mode is QuantizationMode.SYMMETRIC,
self.mode is QuantizationScheme.ASYMMETRIC and other.mode is QuantizationScheme.SYMMETRIC,
self.signedness_to_force is None and other.signedness_to_force is not None,
self.signedness_to_force is True and other.signedness_to_force is False,
]
Expand Down Expand Up @@ -153,7 +153,7 @@ class QuantizerSpec:
"""

def __init__(
self, num_bits: int, mode: QuantizationMode, signedness_to_force: bool, narrow_range: bool, half_range: bool
self, num_bits: int, mode: QuantizationScheme, signedness_to_force: bool, narrow_range: bool, half_range: bool
):
"""
:param num_bits: Bitwidth of the quantization.
Expand Down Expand Up @@ -334,5 +334,5 @@ class QuantizationPreset(Enum):

def get_params_configured_by_preset(self, quant_group: QuantizerGroup) -> Dict:
if quant_group == QuantizerGroup.ACTIVATIONS and self == QuantizationPreset.MIXED:
return {"mode": QuantizationMode.ASYMMETRIC}
return {"mode": QuantizationMode.SYMMETRIC}
return {"mode": QuantizationScheme.ASYMMETRIC}
return {"mode": QuantizationScheme.SYMMETRIC}
16 changes: 15 additions & 1 deletion nncf/experimental/tensor/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
# limitations under the License.

import functools
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, TypeVar, Union

from nncf.experimental.tensor.enums import TensorDataType
from nncf.experimental.tensor.enums import TensorDeviceType
from nncf.experimental.tensor.tensor import Tensor
from nncf.experimental.tensor.tensor import unwrap_tensor_data

TypeInfo = TypeVar("TypeInfo")


def _tensor_guard(func: callable):
"""
Expand Down Expand Up @@ -428,6 +430,18 @@ def _binary_reverse_op_nowarn(a: Tensor, b: Union[Tensor, float], operator_fn: C
return Tensor(_binary_reverse_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn))


@functools.singledispatch
@_tensor_guard
def finfo(a: Tensor) -> TypeInfo:
"""
Returns machine limits for tensor type.
:param a: Tensor.
:return: TypeInfo.
"""
return finfo(a.data)


def _dispatch_list(fn: "functools._SingleDispatchCallable", tensor_list: List[Tensor], *args, **kwargs):
"""
Dispatches the function to the type of the wrapped data of the first element in tensor_list.
Expand Down
5 changes: 5 additions & 0 deletions nncf/experimental/tensor/numpy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,8 @@ def _(
# Run operator with disabled warning
with np.errstate(invalid="ignore", divide="ignore"):
return operator_fn(b, a)


@_register_numpy_types(fns.finfo)
def _(a: np.ndarray) -> np.finfo:
return np.finfo(a.dtype)
2 changes: 1 addition & 1 deletion nncf/experimental/tensorflow/quantization/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import tensorflow as tf

from nncf.common.quantization.structs import QuantizationMode
from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode
from nncf.common.utils.registry import Registry
from nncf.tensorflow.layers.operation import InputType
from nncf.tensorflow.quantization.quantizers import AsymmetricQuantizer
Expand Down
19 changes: 12 additions & 7 deletions nncf/onnx/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Union
from typing import Optional

import onnx

Expand All @@ -18,8 +18,10 @@
from nncf.data import Dataset
from nncf.onnx.graph.nncf_graph_builder import GraphConverter
from nncf.parameters import ModelType
from nncf.parameters import QuantizationMode
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters
from nncf.quantization.advanced_parameters import QuantizationParameters
from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization
from nncf.quantization.telemetry_extractors import CompressionStartedWithQuantizeApi
from nncf.scopes import IgnoredScope
Expand All @@ -31,10 +33,11 @@
def quantize_impl(
model: onnx.ModelProto,
calibration_dataset: Dataset,
preset: Union[QuantizationPreset, None],
target_device: TargetDevice,
subset_size: int,
fast_bias_correction: bool,
mode: Optional[QuantizationMode] = None,
preset: Optional[QuantizationPreset] = None,
target_device: TargetDevice = TargetDevice.ANY,
subset_size: int = 300,
fast_bias_correction: bool = True,
model_type: Optional[ModelType] = None,
ignored_scope: Optional[IgnoredScope] = None,
advanced_parameters: Optional[AdvancedQuantizationParameters] = None,
Expand All @@ -44,6 +47,8 @@ def quantize_impl(
"""
if target_device == TargetDevice.CPU_SPR:
raise RuntimeError("target_device == CPU_SPR is not supported.")
if mode is not None:
raise ValueError(f"mode={mode} is not supported")
if model.opset_import[0].version < 10:
raise RuntimeError("ONNX models with opset version < 10 do not support quantization.")
if model.opset_import[0].version < 13:
Expand All @@ -53,8 +58,8 @@ def quantize_impl(
)
if advanced_parameters is None:
advanced_parameters = AdvancedQuantizationParameters()
advanced_parameters.weights_quantization_params.per_channel = False
advanced_parameters.activations_quantization_params.per_channel = False
advanced_parameters.weights_quantization_params = QuantizationParameters(per_channel=False)
advanced_parameters.activations_quantization_params = QuantizationParameters(per_channel=False)

quantization_algorithm = PostTrainingQuantization(
preset=preset,
Expand Down
4 changes: 1 addition & 3 deletions nncf/openvino/graph/metatypes/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@
]


FAKE_QUANTIZE_OPERATIONS = [
ov_metatypes.OVFakeQuantizeMetatype,
]
FAKE_QUANTIZE_OPERATIONS = [ov_metatypes.OVFakeQuantizeMetatype, ov_metatypes.OVFakeConvertMetatype]


CONSTANT_OPERATIONS = [
Expand Down
12 changes: 9 additions & 3 deletions nncf/openvino/graph/metatypes/openvino_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ class OVFakeQuantizeMetatype(OVOpMetatype):
op_names = ["FakeQuantize"]


@OV_OPERATOR_METATYPES.register()
class OVFakeConvertMetatype(OVOpMetatype):
name = "FakeConvertOp"
op_names = ["FakeConvert"]


@OV_OPERATOR_METATYPES.register()
class OVLessMetatype(OVOpMetatype):
name = "LessOp"
Expand Down Expand Up @@ -713,13 +719,13 @@ def get_operation_const_op(operation: ov.Node, const_port_id: int) -> Optional[o
# There are several cases here
# (Constant) -> (Operation)
# (Constant) -> (Convert) -> (Operation)
# (Constant) -> (Convert) -> (FakeQuantize) -> (Operation)
# (Constant) -> (Convert) -> (FakeQuantize) -> (Reshape) -> (Operation)
# (Constant) -> (Convert) -> (FakeQuantize, FakeConvert) -> (Operation)
# (Constant) -> (Convert) -> (FakeQuantize, FakeConvert) -> (Reshape) -> (Operation)
# and etc. We need properly find the constant node. So we start with
# `node` and traverse up until the constant node is not found.
queue = deque([node])
constant_node = None
allowed_propagation_types_list = ["Convert", "FakeQuantize", "Reshape"]
allowed_propagation_types_list = ["Convert", "FakeQuantize", "FakeConvert", "Reshape"]

while len(queue) != 0:
curr_node = queue.popleft()
Expand Down
Loading

0 comments on commit 5f2c20e

Please sign in to comment.