Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTQ][OV] FP8 implementation #2283

Merged
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8522249
Initial commit
nikita-malininn Nov 20, 2023
d316ac8
Add activation
nikita-malininn Nov 21, 2023
1a2a224
Update data types
nikita-malininn Nov 22, 2023
434705a
Unify Fake nodes creation
nikita-malininn Nov 22, 2023
b243a92
Remove extra
nikita-malininn Nov 22, 2023
f0154a4
Added description
nikita-malininn Nov 22, 2023
097aaee
Fix tests
nikita-malininn Nov 22, 2023
9db1e70
Add tests for model transformer
nikita-malininn Nov 23, 2023
f473c57
Calibrate fix
nikita-malininn Nov 23, 2023
12cd199
Added graph test
nikita-malininn Nov 23, 2023
123bee3
Fix test
nikita-malininn Nov 23, 2023
50e807d
Merge remote-tracking branch 'openvinotoolkit/develop' into nm/fp8_im…
nikita-malininn Nov 24, 2023
a11e367
Changes to enable FBC
nikita-malininn Nov 29, 2023
5a03419
Update with QuantizationMode
nikita-malininn Nov 30, 2023
df9591c
Fix type name
nikita-malininn Nov 30, 2023
83bc6d9
Fix API check
nikita-malininn Nov 30, 2023
23bab25
Fix naming
nikita-malininn Nov 30, 2023
001094f
Added AdvancedParameters
nikita-malininn Dec 1, 2023
5371afe
Fix MinMax
nikita-malininn Dec 1, 2023
65c25a6
Global change mode to scheme
nikita-malininn Dec 6, 2023
ebdcda2
Placed mode parameter before preset
nikita-malininn Dec 6, 2023
9eaebcb
Update tests
nikita-malininn Dec 6, 2023
88e9b89
Fix tests
nikita-malininn Dec 6, 2023
4301113
Revert "Global change mode to scheme"
nikita-malininn Dec 7, 2023
6c0325b
Remove estimator params redefining
nikita-malininn Dec 7, 2023
0cac099
Merge branch 'develop' into nm/fp8_implementation
nikita-malininn Dec 7, 2023
0d02ccb
Fix tests after merge
nikita-malininn Dec 7, 2023
d46f604
Rollback QuantizationScheme renaming, change QuantizationMode import
nikita-malininn Dec 7, 2023
288281a
Rollback not needed changes
nikita-malininn Dec 7, 2023
540b84f
Fix tests after rollback
nikita-malininn Dec 7, 2023
2511f1a
Fix tests again
nikita-malininn Dec 7, 2023
e8d5944
Update QuantizationScheme name, import
nikita-malininn Dec 8, 2023
98574fe
Apply comments
nikita-malininn Dec 12, 2023
08e7735
Merge branch 'develop' into nm/fp8_implementation
nikita-malininn Dec 12, 2023
82c1df8
Fix
nikita-malininn Dec 12, 2023
096d501
Fix tests
nikita-malininn Dec 12, 2023
7512da2
Fix for per-channel activations
nikita-malininn Dec 13, 2023
d6dedc2
Merge openvinotoolkit/develop into nm/fp8_implementation
nikita-malininn Dec 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion nncf/experimental/tensor/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
# limitations under the License.

import functools
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, TypeVar, Union

from nncf.experimental.tensor.enums import TensorDataType
from nncf.experimental.tensor.enums import TensorDeviceType
from nncf.experimental.tensor.tensor import Tensor
from nncf.experimental.tensor.tensor import unwrap_tensor_data

TypeInfo = TypeVar("TypeInfo")


def _tensor_guard(func: callable):
"""
Expand Down Expand Up @@ -428,6 +430,18 @@ def _binary_reverse_op_nowarn(a: Tensor, b: Union[Tensor, float], operator_fn: C
return Tensor(_binary_reverse_op_nowarn(a.data, unwrap_tensor_data(b), operator_fn))


@functools.singledispatch
@_tensor_guard
def finfo(a: Tensor) -> TypeInfo:
"""
Returns machine limits for tensor type.

:param a: Tensor.
:return: TypeInfo.
"""
return finfo(a.data)


def _dispatch_list(fn: "functools._SingleDispatchCallable", tensor_list: List[Tensor], *args, **kwargs):
"""
Dispatches the function to the type of the wrapped data of the first element in tensor_list.
Expand Down
5 changes: 5 additions & 0 deletions nncf/experimental/tensor/numpy_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,8 @@ def _(
# Run operator with disabled warning
with np.errstate(invalid="ignore", divide="ignore"):
return operator_fn(b, a)


@_register_numpy_types(fns.finfo)
def _(a: np.ndarray) -> np.finfo:
return np.finfo(a.dtype)
135 changes: 92 additions & 43 deletions nncf/openvino/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import openvino.runtime as ov
from openvino._pyopenvino import DescriptorTensor
from openvino.runtime import opset9 as opset
from openvino.runtime import opset13
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.common.graph.model_transformer import TModel
Expand All @@ -34,6 +35,8 @@
from nncf.openvino.graph.transformations.commands import OVQuantizerInsertionCommand
from nncf.openvino.graph.transformations.commands import OVUpdateIfBodyCommand
from nncf.openvino.graph.transformations.commands import OVWeightUpdateCommand
from nncf.quantization.advanced_parameters import Mode
from nncf.quantization.fake_quantize import FakeConvertParameters
from nncf.quantization.fake_quantize import FakeQuantizeParameters


Expand All @@ -58,6 +61,11 @@ def __init__(self, model: TModel):
(OVExtractIfBodyCommand, self._apply_extract_if_body_transformation),
]

@staticmethod
def _convert_to_fp16(data):
clip_data = np.clip(data, np.finfo(np.float16).min, np.finfo(np.float16).max)
return clip_data.astype(np.float16)

@staticmethod
def _get_name_to_node_mapping(model: ov.Model) -> Dict[str, ov.Node]:
"""
Expand Down Expand Up @@ -231,81 +239,122 @@ def _apply_quantizer_insertion_transformations(
"""
name_to_node_mapping = OVModelTransformer._get_name_to_node_mapping(model)
for transformation in transformations:
OVModelTransformer._insert_fake_quantize_op(transformation, name_to_node_mapping)
OVModelTransformer._insert_fake_op(transformation, name_to_node_mapping)
return model

@staticmethod
def convert_params_to_fp16(
fq_params: FakeQuantizeParameters,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
def _create_fake_quantize(
op_output: ov.Output,
fake_quantize_params: FakeQuantizeParameters,
fake_quantize_name: str,
convert_to_fp16: bool,
) -> ov.Node:
"""
Creates FakeQuantize node.

:param op_output: Output of the previous node.
:param fake_quantize_params: FakeQuantizeParameters instance.
:param fake_quantize_name: New layer name.
:param convert_to_fp16: Whether convert parameters to FP16 or not.
:return: ov.Node instance.
"""

input_low = fake_quantize_params.input_low.data
input_high = fake_quantize_params.input_high.data
output_low = fake_quantize_params.output_low.data
output_high = fake_quantize_params.output_high.data
levels = fake_quantize_params.levels

if convert_to_fp16:
input_low = OVModelTransformer._convert_to_fp16(input_low)
input_high = OVModelTransformer._convert_to_fp16(input_high)
output_low = OVModelTransformer._convert_to_fp16(output_low)
output_high = OVModelTransformer._convert_to_fp16(output_high)

return opset.fake_quantize(
op_output, input_low, input_high, output_low, output_high, levels, name=fake_quantize_name
)

@staticmethod
def _create_fake_convert(
op_output: ov.Output, fake_convert_params: FakeConvertParameters, fake_convert_name: str, convert_to_fp16: bool
) -> ov.Node:
"""
Converts FakeQuantize parameters to FP16 precision.
Creates FakeConvert node.

:param fq_params: FakeQuantize node attributes.
:return: FakeQuantize parameters in FP16 precision.
:param op_output: Output of the previous node.
:param fake_convert_params: FakeConvertParameters instance.
:param fake_convert_name: New layer name.
:param convert_to_fp16: Whether convert parameters to FP16 or not.
:return: ov.Node instance.
"""

def _convert_to_fp16(data):
clip_data = np.clip(data, np.finfo(np.float16).min, np.finfo(np.float16).max)
return clip_data.astype(np.float16)
scale = fake_convert_params.scale.data
shift = fake_convert_params.shift.data

input_low = _convert_to_fp16(fq_params.input_low.data)
input_high = _convert_to_fp16(fq_params.input_high.data)
output_low = _convert_to_fp16(fq_params.output_low.data)
output_high = _convert_to_fp16(fq_params.output_high.data)
return input_low, input_high, output_low, output_high
if convert_to_fp16:
scale = OVModelTransformer._convert_to_fp16(scale)
shift = OVModelTransformer._convert_to_fp16(shift)

return opset13.fake_convert(op_output, scale, shift, name=fake_convert_name)

@staticmethod
def _insert_fake_quantize_op(
transformation: OVQuantizerInsertionCommand, name_to_node_mapping: Dict[str, ov.Node]
) -> None:
def _insert_fake_op(transformation: OVQuantizerInsertionCommand, name_to_node_mapping: Dict[str, ov.Node]) -> None:
"""
Inserts FakeQuantize Operation to a model which name_to_node_mapping is passed.
Inserts Fake (Quantize or Convert) Operation to a model which name_to_node_mapping is passed.

:param transformation: FakeQuantize insertion command.
:param transformation: Fake (Quantize or Convert) insertion command.
:param name_to_node_mapping: Mapping from node name to node instance.
"""
fq_params = transformation.quantizer_parameters
input_low = fq_params.input_low.data
input_high = fq_params.input_high.data
output_low = fq_params.output_low.data
output_high = fq_params.output_high.data
levels = fq_params.levels
fake_op_params = transformation.fake_op_parameters
mode = transformation.mode
if mode not in [Mode.FQ, Mode.FP8]:
raise RuntimeError(f"Incorrect Mode {mode}")

node_name = transformation.target_point.target_node_name
target_node = name_to_node_mapping[node_name]
port_id = transformation.target_point.port_id
transform_type = transformation.target_point.type

name = "weights" if transform_type == TargetType.OPERATION_WITH_WEIGHTS else "input"

if transform_type in [TargetType.PRE_LAYER_OPERATION, TargetType.OPERATION_WITH_WEIGHTS]:
inp_node = target_node.input(port_id)
input_node_output = inp_node.get_source_output()
data_type = inp_node.get_element_type()
if data_type == ov.Type(np.float16):
input_low, input_high, output_low, output_high = OVModelTransformer.convert_params_to_fp16(fq_params)
name = "fq_weights" if transform_type == TargetType.OPERATION_WITH_WEIGHTS else "fq_input"
fq_name = f"{node_name}/{name}_{port_id}"

fq = None
fake_op = None
if transform_type == TargetType.OPERATION_WITH_WEIGHTS:
# If the nodes share one weight tensor, we should have only one quantizer on that
for out in input_node_output.get_target_inputs():
if out.get_node().get_type_name() == "FakeQuantize":
fq = out.get_node()
if fq is None:
fq = opset.fake_quantize(
input_node_output, input_low, input_high, output_low, output_high, levels, name=fq_name
)
inp_node.replace_source_output(fq.output(0))
if out.get_node().get_type_name() in ["FakeQuantize", "FakeConvert"]:
fake_op = out.get_node()
if fake_op is None:
convert_to_fp16 = data_type == ov.Type(np.float16)
fake_op_name = f"{node_name}/{mode}_{name}_{port_id}"
if mode == Mode.FQ:
fake_op = OVModelTransformer._create_fake_quantize(
input_node_output, fake_op_params, fake_op_name, convert_to_fp16
)
elif mode == Mode.FP8:
fake_op = OVModelTransformer._create_fake_convert(
input_node_output, fake_op_params, fake_op_name, convert_to_fp16
)
inp_node.replace_source_output(fake_op.output(0))
elif transform_type == TargetType.POST_LAYER_OPERATION:
output = target_node.output(port_id)
data_type = output.get_element_type()
if data_type == ov.Type(np.float16):
input_low, input_high, output_low, output_high = OVModelTransformer.convert_params_to_fp16(fq_params)
target_inputs = output.get_target_inputs()
fq_name = f"{node_name}/fq_output_{port_id}"
fq = opset.fake_quantize(output, input_low, input_high, output_low, output_high, levels, name=fq_name)
convert_to_fp16 = data_type == ov.Type(np.float16)
fake_op_name = f"{node_name}/{mode}_output_{port_id}"
if mode == Mode.FQ:
fake_op = OVModelTransformer._create_fake_quantize(
output, fake_op_params, fake_op_name, convert_to_fp16
)
elif mode == Mode.FP8:
fake_op = OVModelTransformer._create_fake_convert(output, fake_op_params, fake_op_name, convert_to_fp16)
for inp_node in target_inputs:
inp_node.replace_source_output(fq.output(0))
inp_node.replace_source_output(fake_op.output(0))
else:
raise RuntimeError(f"Incorrect target point type {transform_type}")

Expand Down
14 changes: 11 additions & 3 deletions nncf/openvino/graph/transformations/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import List, Union

import numpy as np
import openvino.runtime as ov
Expand All @@ -20,6 +20,8 @@
from nncf.common.graph.transformations.commands import TransformationCommand
from nncf.common.graph.transformations.commands import TransformationType
from nncf.openvino.graph.node_utils import InplaceInsertionFnType
from nncf.quantization.advanced_parameters import Mode
from nncf.quantization.fake_quantize import FakeConvertParameters
from nncf.quantization.fake_quantize import FakeQuantizeParameters


Expand Down Expand Up @@ -84,9 +86,15 @@ def union(self, other: "TransformationCommand") -> "TransformationCommand":


class OVQuantizerInsertionCommand(OVInsertionCommand):
def __init__(self, target_point: OVTargetPoint, quantizer_parameters: FakeQuantizeParameters):
def __init__(
self,
target_point: OVTargetPoint,
fake_op_parameters: Union[FakeQuantizeParameters, FakeConvertParameters],
mode: Mode = Mode.FQ,
):
super().__init__(target_point)
self.quantizer_parameters = quantizer_parameters
self.fake_op_parameters = fake_op_parameters
self.mode = mode

def union(self, other: "TransformationCommand") -> "TransformationCommand":
# Have a look at nncf/torch/graph/transformations/commands/PTInsertionCommand
Expand Down
17 changes: 17 additions & 0 deletions nncf/quantization/advanced_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,22 @@ class AdvancedSmoothQuantParameters:
matmul: float = 0.95


@api()
@dataclass
class Mode:
"""
Contains values corresponding to the available quantization modes.

:param FQ: Whether to insert FakeQuantize operations.
:type FQ: str
:param FP8: Whether to insert FakeConvert operations.
:type FP8: str
"""

FQ: str = "fq"
FP8: str = "fc"
nikita-malininn marked this conversation as resolved.
Show resolved Hide resolved


@api()
@dataclass
class AdvancedQuantizationParameters:
Expand Down Expand Up @@ -176,6 +192,7 @@ class AdvancedQuantizationParameters:
inplace_statistics: bool = True
disable_channel_alignment: bool = True
disable_bias_correction: bool = False
mode: Mode = Mode.FQ

# Advanced Quantization parameters
activations_quantization_params: QuantizationParameters = field(default_factory=QuantizationParameters)
Expand Down
23 changes: 19 additions & 4 deletions nncf/quantization/algorithms/min_max/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@
from nncf.common.utils.backend import get_backend
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import Mode
from nncf.quantization.advanced_parameters import OverflowFix
from nncf.quantization.advanced_parameters import QuantizationParameters
from nncf.quantization.advanced_parameters import changes_asdict
from nncf.quantization.algorithms.algorithm import Algorithm
from nncf.quantization.fake_quantize import calculate_convert_parameters
from nncf.quantization.fake_quantize import calculate_quantizer_parameters
from nncf.quantization.fake_quantize import get_quantizer_narrow_range
from nncf.quantization.passes import transform_to_inference_graph
Expand Down Expand Up @@ -104,6 +106,7 @@ def __init__(
overflow_fix: OverflowFix = OverflowFix.FIRST_LAYER,
quantize_outputs: bool = False,
inplace_statistics: bool = True,
mode: Mode = Mode.FQ,
activations_quantization_params: Optional[QuantizationParameters] = None,
weights_quantization_params: Optional[QuantizationParameters] = None,
activations_range_estimator_params: Optional[RangeEstimatorParameters] = None,
Expand Down Expand Up @@ -133,6 +136,8 @@ def __init__(
:param inplace_statistics: Defines wheather to calculate quantizers statistics
by backend graph operations or by default Python implementation, defaults
to True.
:param mode: Defines mode for the algorithm: FakeConvert (FP8), FakeQuantize (FQ) and etc.
By default - FQ.
:param activations_quantization_params: Quantization parameters for model
activations.
:param weights_quantization_params: Quantization parameters for model weights.
Expand All @@ -145,6 +150,7 @@ def __init__(
self._target_device = target_device
self._subset_size = subset_size
self._model_type = model_type
self._mode = mode
self._ignored_scope = IgnoredScope() if ignored_scope is None else ignored_scope
self._overflow_fix = overflow_fix
self._quantize_outputs = quantize_outputs
Expand Down Expand Up @@ -722,9 +728,12 @@ def filter_func(point: StatisticPoint) -> bool:
qconfig = quantization_target_points[quantization_target_point]
q_group = QuantizerGroup.ACTIVATIONS
narrow_range = get_quantizer_narrow_range(qconfig, q_group)
parameters = calculate_quantizer_parameters(unified_values, qconfig, q_group, narrow_range)
if self._mode == Mode.FP8:
parameters = calculate_convert_parameters(unified_values, is_activation=True)
elif self._mode == Mode.FQ:
parameters = calculate_quantizer_parameters(unified_values, qconfig, q_group, narrow_range)
command = self._backend_entity.create_quantizer_insertion_command(
graph, quantization_target_point, qconfig, parameters
graph, quantization_target_point, qconfig, parameters, self._mode
)
transformation_layout.register(command)
unified_ops_list.add(quantization_target_point)
Expand All @@ -750,9 +759,15 @@ def filter_func(point: StatisticPoint) -> bool:
statistics = tensor_collector.get_statistics()
if statistics.min_values is None or statistics.max_values is None:
raise RuntimeError(f"Statistics were not collected for the node {target_node_name}")
parameters = calculate_quantizer_parameters(statistics, qconfig, quant_group, narrow_range, half_range)
if self._mode == Mode.FP8:
is_activation = quant_group == QuantizerGroup.ACTIVATIONS
parameters = calculate_convert_parameters(statistics, is_activation=is_activation)
elif self._mode == Mode.FQ:
parameters = calculate_quantizer_parameters(
statistics, qconfig, quant_group, narrow_range, half_range
)
command = self._backend_entity.create_quantizer_insertion_command(
alexsu52 marked this conversation as resolved.
Show resolved Hide resolved
graph, quantization_target_point, qconfig, parameters
graph, quantization_target_point, qconfig, parameters, self._mode
)
transformation_layout.register(command)
if not transformation_layout.transformations:
Expand Down
Loading