Skip to content

Commit

Permalink
Merge openvinotoolkit/develop into nm/fp8_implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
nikita-malininn committed Dec 15, 2023
2 parents 7512da2 + 0c389c3 commit d6dedc2
Show file tree
Hide file tree
Showing 34 changed files with 753 additions and 267 deletions.
81 changes: 78 additions & 3 deletions nncf/experimental/common/pruning/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,13 @@ def mask_propagation(
cls, node: NNCFNode, graph: NNCFGraph, tensor_processor: Type[NNCFPruningBaseTensorProcessor]
) -> None:
input_masks = get_input_masks(node, graph)
node.attributes["output_mask"] = cls._get_output_mask(input_masks)
input_shapes = [x.tensor_shape for x in graph.get_input_edges(node)]
node.attributes["output_mask"] = cls._get_output_mask(input_masks, input_shapes)

@classmethod
def _get_output_mask(cls, input_masks: List[Optional[PropagationMask]]) -> Optional[PropagationMask]:
def _get_output_mask(
cls, input_masks: List[Optional[PropagationMask]], input_shapes: List[Tuple[int, ...]]
) -> Optional[PropagationMask]:
if not input_masks:
return None
output_mask = None
Expand All @@ -246,11 +249,16 @@ def _get_output_mask(cls, input_masks: List[Optional[PropagationMask]]) -> Optio
"node_name={node.node_name}"
)
output_mask = input_masks[0]
elif any(m is None for m in input_masks) and any(m is not None for m in input_masks):
# In case of one from input_masks is None
output_mask = cls._propagate_single_mask(input_masks, input_shapes)
if output_mask is None:
cls.invalidate_masks(input_masks)
elif any(not m for m in input_masks):
# Need non-empty masks on all branches in order to properly propagate pruning mask,
# otherwise - invalidate masks
cls.invalidate_masks(input_masks)
else:
elif all(m is not None for m in input_masks):
# Each branch/mask should have a single group along the same dimension. These groups are joined, all others
# are invalidated.
output_mask = PropagationMask()
Expand All @@ -275,6 +283,73 @@ def _get_output_mask(cls, input_masks: List[Optional[PropagationMask]]) -> Optio
group.invalidate()
return output_mask

@classmethod
def _propagate_single_mask(
cls, input_masks: List[Optional[PropagationMask]], input_shapes: List[Tuple[int, ...]]
) -> Optional[PropagationMask]:
"""
Attempts to propagate a mask in case of one input mask is None.
:param input_masks: List of propagation masks for each input of the element-wise operation
:param input_shapes: List of tensor shapes for each input.
:return: An instance of PropagationMask or None.
"""
if cls._are_broadcast_dims_in_both_shapes(input_shapes):
return None

none_mask_ind = input_masks.index(None)
mask_ind = 0 if none_mask_ind else 1

dims_diff = len(input_shapes[mask_ind]) - len(input_shapes[none_mask_ind])
padded_none_mask_shape = (1,) * dims_diff + input_shapes[none_mask_ind]

dims_shift = min(dims_diff, 0)
for dim in input_masks[mask_ind].dim_groups_map:
if padded_none_mask_shape[dim - dims_shift] != 1:
return None

output_mask = PropagationMask()
for dim, groups in input_masks[mask_ind].dim_groups_map.items():
output_mask.dim_groups_map[dim - dims_shift] = groups

return output_mask

@staticmethod
def _are_broadcast_dims_in_both_shapes(shapes) -> bool:
"""
Propagation mask is not supported if both shapes will broadcasting by elementwise operation.
True, if both shapes have broadcasted dimensions, otherwise False.
Example:
(1, 10), (10, 1) -> True
(1,10), (10,) -> False
:param shapes: Shapes of tensors.
:return: True, if both shapes have broadcasted dimensions, otherwise False.
"""
shape_a = shapes[0]
shape_b = shapes[1]
shape_a_size_diff = len(shape_a) - len(shape_b)
shape_b_size_diff = -shape_a_size_diff
broadcasted_dims_1 = set()
broadcasted_dims_2 = set()

for i in range(len(shape_a)):
shifted_elem = i + shape_b_size_diff
if shifted_elem >= 0 and shape_a[i] == 1 and shape_b[shifted_elem] != 1:
broadcasted_dims_1.add(i)
if shifted_elem < 0 and shape_a[i] != 1:
broadcasted_dims_2.add(shifted_elem)

for i in range(len(shape_b)):
shifted_elem = i + shape_a_size_diff
if shifted_elem >= 0 and shape_b[i] == 1 and shape_a[shifted_elem] != 1:
broadcasted_dims_2.add(i)
if shifted_elem < 0 and shape_b[i] != 1:
broadcasted_dims_1.add(shifted_elem)

return bool(broadcasted_dims_1) and bool(broadcasted_dims_2)


class GatherPruningOp(BasePruningOp):
@classmethod
Expand Down
7 changes: 3 additions & 4 deletions nncf/openvino/graph/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import numpy as np
import openvino.runtime as ov
from openvino._pyopenvino import DescriptorTensor
from openvino.runtime import opset9 as opset
from openvino.runtime import opset13
from openvino.runtime import opset13 as opset

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.common.graph.model_transformer import TModel
Expand Down Expand Up @@ -317,7 +316,7 @@ def _create_fake_convert(
shift = OVModelTransformer._convert_to_fp16(shift)

destination_type = fake_convert_params.destination_type.value
return opset13.fake_convert(
return opset.fake_convert(
data=op_output,
scale=scale,
shift=shift,
Expand Down Expand Up @@ -483,7 +482,7 @@ def _set_const_value(node_with_const: ov.Node, const_port_id: int, const_value:
const_dtype = const_node.data.dtype
const_value = np.reshape(const_value, const_shape).astype(const_dtype)

# TODO(andrey-churkin): Replace on opset13.constant() in a future release
# TODO(andrey-churkin): Replace on opset13.constant() in 2023.3 release
new_const_node = ov.op.Constant(const_value, shared_memory=True)
new_const_node.set_friendly_name(const_node.get_friendly_name())
const_port.replace_source_output(new_const_node.output(0))
Expand Down
2 changes: 1 addition & 1 deletion nncf/openvino/graph/node_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import numpy as np
import openvino.runtime as ov
import openvino.runtime.opset9 as opset
import openvino.runtime.opset13 as opset

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import numpy as np
import openvino.runtime as ov
from openvino.runtime import opset9 as opset
from openvino.runtime import opset13 as opset

from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
Expand Down
15 changes: 6 additions & 9 deletions nncf/torch/dynamic_graph/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@

import threading
import weakref
from collections import defaultdict
from collections import deque
from contextlib import contextmanager
from typing import Callable, Dict, List, Optional
from typing import Callable, DefaultDict, List, Optional

import torch

Expand Down Expand Up @@ -91,8 +92,8 @@ class TracingContext:
def __init__(self):
self.graph = DynamicGraph()

self._post_hooks = {}
self._pre_hooks: Dict[PreHookId, List[Callable]] = {}
self._post_hooks: DefaultDict[OperationAddress, List[Callable]] = defaultdict(list)
self._pre_hooks: DefaultDict[PreHookId, List[Callable]] = defaultdict(list)
self._num_nested_hooks = 0

self._threading = CopySafeThreadingVars()
Expand Down Expand Up @@ -261,9 +262,7 @@ def pop_scope(self):

def register_pre_hooks(self, fn_list: List[Callable], op_address: OperationAddress, input_port_id: int):
pre_hook_id = PreHookId(op_address, input_port_id)
if pre_hook_id in self._pre_hooks:
raise KeyError("Pre hook for context {} is already registered".format(str(pre_hook_id)))
self._pre_hooks[pre_hook_id] = fn_list
self._pre_hooks[pre_hook_id].extend(fn_list)

def execute_pre_hooks(self, op_address: OperationAddress, op_inputs: OperatorInput) -> OperatorInput:
in_op = getattr(self, "in_operator", False)
Expand All @@ -282,9 +281,7 @@ def execute_pre_hooks(self, op_address: OperationAddress, op_inputs: OperatorInp
return op_inputs

def register_post_hooks(self, fn_list: List[Callable], op_address: OperationAddress):
if op_address in self._post_hooks:
raise KeyError("Post hook for context {} is already registered".format(str(op_address)))
self._post_hooks[op_address] = fn_list
self._post_hooks[op_address].extend(fn_list)

def execute_post_hooks(self, op_address: OperationAddress, outputs):
in_op = getattr(self, "in_operator", False)
Expand Down
3 changes: 0 additions & 3 deletions nncf/torch/dynamic_graph/patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ class FunctionsToPatchWithoutTracing:
"as_tensor",
"copysign",
"copysign_",
"detach",
"detach_",
"empty",
"ones",
"ones_like",
Expand Down Expand Up @@ -112,7 +110,6 @@ class FunctionsToPatchWithoutTracing:
"storage",
"storage_offset",
"stride",
"to",
"get_device",
]

Expand Down
12 changes: 10 additions & 2 deletions nncf/torch/graph/graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,19 @@ class GraphBuilder:
def __init__(self, custom_forward_fn: Callable[[torch.nn.Module], Any]):
self.custom_forward_fn = custom_forward_fn

def build_dynamic_graph(
self,
model: torch.nn.Module,
context_to_use: Optional[TracingContext] = None,
as_eval: bool = False,
) -> DynamicGraph:
tracer = GraphTracer(self.custom_forward_fn)
return tracer.trace_graph(model, context_to_use, as_eval)

def build_graph(
self, model: torch.nn.Module, context_to_use: Optional[TracingContext] = None, as_eval: bool = False
) -> PTNNCFGraph:
tracer = GraphTracer(self.custom_forward_fn)
dynamic_graph = tracer.trace_graph(model, context_to_use, as_eval)
dynamic_graph = self.build_dynamic_graph(model, context_to_use, as_eval)
return GraphConverter.convert(dynamic_graph)


Expand Down
22 changes: 20 additions & 2 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ class PTNoopMetatype(PTOperatorMetatype):
external_op_names = [name]
module_to_function_names = {
NamespaceTarget.TORCH_NN_FUNCTIONAL: [],
NamespaceTarget.TORCH_TENSOR: ["contiguous"],
NamespaceTarget.TORCH: ["clone"],
NamespaceTarget.TORCH_TENSOR: ["contiguous", "clone", "detach", "detach_", "to"],
NamespaceTarget.TORCH: ["clone", "detach", "detach_"],
}


Expand Down Expand Up @@ -315,6 +315,7 @@ class PTDeformConv2dMetatype(PTOperatorMetatype):
name = "DeformConv2dOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["deform_conv2d"]}
subtypes = [PTModuleDeformConv2dMetatype]
num_expected_input_edges = 4


@PT_OPERATOR_METATYPES.register()
Expand All @@ -323,6 +324,7 @@ class PTModuleLinearMetatype(PTModuleOperatorSubtype):
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["linear"], NamespaceTarget.TORCH: ["addmm"]}
hw_config_names = [HWConfigOpName.MATMUL]
output_channel_axis = -1
num_expected_input_edges = 2


@PT_OPERATOR_METATYPES.register()
Expand All @@ -332,55 +334,64 @@ class PTLinearMetatype(PTOperatorMetatype):
hw_config_names = [HWConfigOpName.MATMUL]
subtypes = [PTModuleLinearMetatype]
output_channel_axis = -1
num_expected_input_edges = 2


@PT_OPERATOR_METATYPES.register()
class PTHardTanhMetatype(PTOperatorMetatype):
name = "HardTanhOP"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["hardtanh"]}
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
class PTHardSwishMetatype(PTOperatorMetatype):
name = "HardSwishOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["hardswish"]}
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
class PTHardSigmoidMetatype(PTOperatorMetatype):
name = "HardSigmoidOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["hardsigmoid"]}
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
class PTTanhMetatype(PTOperatorMetatype):
name = "TanhOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["tanh"], NamespaceTarget.TORCH: ["tanh"]}
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
class PTELUMetatype(PTOperatorMetatype):
name = "EluOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["elu", "elu_"]}
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
class PTPRELUMetatype(PTOperatorMetatype):
name = "PReluOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["prelu"]}
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
class PTLeakyRELUMetatype(PTOperatorMetatype):
name = "LeakyReluOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["leaky_relu"]}
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
class PTModuleLayerNormMetatype(PTModuleOperatorSubtype):
name = "LayerNormOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["layer_norm"]}
hw_config_names = [HWConfigOpName.MVN]
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
Expand All @@ -389,6 +400,7 @@ class PTLayerNormMetatype(PTOperatorMetatype):
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["layer_norm"]}
hw_config_names = [HWConfigOpName.MVN]
subtypes = [PTModuleLayerNormMetatype]
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
Expand Down Expand Up @@ -488,6 +500,7 @@ class PTFloorDivMetatype(PTOperatorMetatype):
NamespaceTarget.TORCH_TENSOR: ["__floordiv__", "__ifloordiv__", "__rfloordiv__"],
NamespaceTarget.TORCH: ["floor_divide"],
}
num_expected_input_edges = 2


@PT_OPERATOR_METATYPES.register()
Expand Down Expand Up @@ -910,26 +923,30 @@ class PTInterpolateMetatype(PTOperatorMetatype):
name = "InterpolateOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["interpolate"]}
hw_config_names = [HWConfigOpName.INTERPOLATE]
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
class PTRepeatMetatype(PTOperatorMetatype):
name = "RepeatOp"
module_to_function_names = {NamespaceTarget.TORCH: ["repeat_interleave"]}
hw_config_names = [HWConfigOpName.TILE]
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
class PTPixelShuffleMetatype(PTOperatorMetatype):
name = "PixelShuffleOp"
module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["pixel_shuffle"]}
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
class PTSumMetatype(PTOperatorMetatype):
name = "SumOp"
module_to_function_names = {NamespaceTarget.TORCH_TENSOR: ["sum"], NamespaceTarget.TORCH: ["sum"]}
hw_config_names = [HWConfigOpName.REDUCESUM]
num_expected_input_edges = 1


@PT_OPERATOR_METATYPES.register()
Expand All @@ -939,6 +956,7 @@ class PTReduceL2(PTOperatorMetatype):
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["normalize"], # note: normalize is for general L_p normalization
}
hw_config_names = [HWConfigOpName.REDUCEL2]
num_expected_input_edges = 1


def get_operator_metatypes() -> List[Type[OperatorMetatype]]:
Expand Down
Loading

0 comments on commit d6dedc2

Please sign in to comment.