From 34d24a4eb73b3ca260fbbf65a1428eef60a93ec4 Mon Sep 17 00:00:00 2001 From: Elad Cohen <78862769+elad-c@users.noreply.github.com> Date: Mon, 23 Sep 2024 15:30:30 +0300 Subject: [PATCH] Support weight quantization for tf.gather in TPC imx500.v4. (#1226) Support weight quantization for tf.gather in TPC imx500.v4. --- .../tpc_models/imx500_tpc/v4/tp_model.py | 52 ++++++++---- .../tpc_models/imx500_tpc/v4/tpc_keras.py | 49 +++++------ .../tpc_models/imx500_tpc/v4/tpc_pytorch.py | 82 ++++++++++--------- .../const_quantization_test.py | 13 ++- .../const_representation_test.py | 5 +- 5 files changed, 121 insertions(+), 80 deletions(-) diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py index 72ea15029..63fba3743 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tp_model.py @@ -24,6 +24,23 @@ tp = mct.target_platform +OPSET_NO_QUANTIZATION = "NoQuantization" +OPSET_QUANTIZATION_PRESERVING = "QuantizationPreserving" +OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS = "DimensionManipulationOpsWithWeights" +OPSET_DIMENSION_MANIPULATION_OPS = "DimensionManipulationOps" +OPSET_MERGE_OPS = "MergeOps" +OPSET_CONV = "Conv" +OPSET_FULLY_CONNECTED = "FullyConnected" +OPSET_ANY_RELU = "AnyReLU" +OPSET_ADD = "Add" +OPSET_SUB = "Sub" +OPSET_MUL = "Mul" +OPSET_DIV = "Div" +OPSET_PRELU = "PReLU" +OPSET_SWISH = "Swish" +OPSET_SIGMOID = "Sigmoid" +OPSET_TANH = "Tanh" + def get_tp_model() -> TargetPlatformModel: """ @@ -189,6 +206,10 @@ def generate_tp_model(default_config: OpQuantizationConfig, const_config_input16_per_tensor], base_config=const_config_input16_per_tensor) + qpreserving_const_config = const_config.clone_and_edit(enable_activation_quantization=False, + quantization_preserving=True) + qpreserving_const_config_options = tp.QuantizationConfigOptions([qpreserving_const_config]) + # Create a TargetPlatformModel and set its default quantization config. # This default configuration will be used for all operations # unless specified otherwise (see OperatorsSet, for example): @@ -207,39 +228,40 @@ def generate_tp_model(default_config: OpQuantizationConfig, # May suit for operations like: Dropout, Reshape, etc. default_qco = tp.get_default_quantization_config_options() - tp.OperatorsSet("NoQuantization", + tp.OperatorsSet(OPSET_NO_QUANTIZATION, default_qco.clone_and_edit(enable_activation_quantization=False) .clone_and_edit_weight_attribute(enable_weights_quantization=False)) - tp.OperatorsSet("QuantizationPreserving", + tp.OperatorsSet(OPSET_QUANTIZATION_PRESERVING, default_qco.clone_and_edit(enable_activation_quantization=False, quantization_preserving=True) .clone_and_edit_weight_attribute(enable_weights_quantization=False)) - tp.OperatorsSet("DimensionManipulationOps", + tp.OperatorsSet(OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, qpreserving_const_config_options) + tp.OperatorsSet(OPSET_DIMENSION_MANIPULATION_OPS, default_qco.clone_and_edit(enable_activation_quantization=False, quantization_preserving=True, supported_input_activation_n_bits=(8, 16)) .clone_and_edit_weight_attribute(enable_weights_quantization=False)) - tp.OperatorsSet("MergeOps", const_configuration_options_inout16_per_tensor) + tp.OperatorsSet(OPSET_MERGE_OPS, const_configuration_options_inout16_per_tensor) # Create Mixed-Precision quantization configuration options from the given list of OpQuantizationConfig objects mixed_precision_configuration_options = tp.QuantizationConfigOptions(mixed_precision_cfg_list, base_config=base_config) # Define operator sets that use mixed_precision_configuration_options: - conv = tp.OperatorsSet("Conv", mixed_precision_configuration_options) - fc = tp.OperatorsSet("FullyConnected", mixed_precision_configuration_options) + conv = tp.OperatorsSet(OPSET_CONV, mixed_precision_configuration_options) + fc = tp.OperatorsSet(OPSET_FULLY_CONNECTED, mixed_precision_configuration_options) # Define operations sets without quantization configuration # options (useful for creating fusing patterns, for example): - any_relu = tp.OperatorsSet("AnyReLU") - add = tp.OperatorsSet("Add", const_configuration_options_inout16) - sub = tp.OperatorsSet("Sub", const_configuration_options_inout16) - mul = tp.OperatorsSet("Mul", const_configuration_options_inout16) - div = tp.OperatorsSet("Div", const_configuration_options) - prelu = tp.OperatorsSet("PReLU") - swish = tp.OperatorsSet("Swish") - sigmoid = tp.OperatorsSet("Sigmoid") - tanh = tp.OperatorsSet("Tanh") + any_relu = tp.OperatorsSet(OPSET_ANY_RELU) + add = tp.OperatorsSet(OPSET_ADD, const_configuration_options_inout16) + sub = tp.OperatorsSet(OPSET_SUB, const_configuration_options_inout16) + mul = tp.OperatorsSet(OPSET_MUL, const_configuration_options_inout16) + div = tp.OperatorsSet(OPSET_DIV, const_configuration_options) + prelu = tp.OperatorsSet(OPSET_PRELU) + swish = tp.OperatorsSet(OPSET_SWISH) + sigmoid = tp.OperatorsSet(OPSET_SIGMOID) + tanh = tp.OperatorsSet(OPSET_TANH) # Combine multiple operators into a single operator to avoid quantization between # them. To do this we define fusing patterns using the OperatorsSets that were created. diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py index b403d6453..76ff28af6 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_keras.py @@ -35,6 +35,10 @@ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import get_tp_model import model_compression_toolkit as mct from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4 import __version__ as TPC_VERSION +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import OPSET_NO_QUANTIZATION, \ + OPSET_QUANTIZATION_PRESERVING, OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, OPSET_DIMENSION_MANIPULATION_OPS, \ + OPSET_MERGE_OPS, OPSET_CONV, OPSET_FULLY_CONNECTED, OPSET_ANY_RELU, OPSET_ADD, OPSET_SUB, OPSET_MUL, OPSET_DIV, \ + OPSET_PRELU, OPSET_SWISH, OPSET_SIGMOID, OPSET_TANH tp = mct.target_platform @@ -74,10 +78,8 @@ def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel): Dropout, MaxPooling2D, tf.split, - tf.gather, tf.cast, tf.unstack, - tf.compat.v1.gather, tf.__operators__.getitem, tf.strided_slice] quantization_preserving_list_16bit_input = [Reshape, @@ -90,11 +92,12 @@ def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel): no_quant_list.append(SSDPostProcess) with keras_tpc: - tp.OperationsSetToLayers("NoQuantization", no_quant_list) - tp.OperationsSetToLayers("QuantizationPreserving", quantization_preserving) - tp.OperationsSetToLayers("DimensionManipulationOps", quantization_preserving_list_16bit_input) - tp.OperationsSetToLayers("MergeOps", [tf.stack, tf.concat, Concatenate]) - tp.OperationsSetToLayers("Conv", + tp.OperationsSetToLayers(OPSET_NO_QUANTIZATION, no_quant_list) + tp.OperationsSetToLayers(OPSET_QUANTIZATION_PRESERVING, quantization_preserving) + tp.OperationsSetToLayers(OPSET_DIMENSION_MANIPULATION_OPS, quantization_preserving_list_16bit_input) + tp.OperationsSetToLayers(OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, [tf.gather, tf.compat.v1.gather]) + tp.OperationsSetToLayers(OPSET_MERGE_OPS, [tf.stack, tf.concat, Concatenate]) + tp.OperationsSetToLayers(OPSET_CONV, [Conv2D, DepthwiseConv2D, Conv2DTranspose, @@ -111,23 +114,23 @@ def generate_keras_tpc(name: str, tp_model: tp.TargetPlatformModel): DepthwiseConv2D: KERAS_DEPTHWISE_KERNEL, tf.nn.depthwise_conv2d: KERAS_DEPTHWISE_KERNEL}, default_value=KERAS_KERNEL), BIAS_ATTR: DefaultDict(default_value=BIAS)}) - tp.OperationsSetToLayers("FullyConnected", [Dense], + tp.OperationsSetToLayers(OPSET_FULLY_CONNECTED, [Dense], attr_mapping={KERNEL_ATTR: DefaultDict(default_value=KERAS_KERNEL), BIAS_ATTR: DefaultDict(default_value=BIAS)}) - tp.OperationsSetToLayers("AnyReLU", [tf.nn.relu, - tf.nn.relu6, - tf.nn.leaky_relu, - ReLU, - LeakyReLU, - tp.LayerFilterParams(Activation, activation="relu"), - tp.LayerFilterParams(Activation, activation="leaky_relu")]) - tp.OperationsSetToLayers("Add", [tf.add, Add]) - tp.OperationsSetToLayers("Sub", [tf.subtract, Subtract]) - tp.OperationsSetToLayers("Mul", [tf.math.multiply, Multiply]) - tp.OperationsSetToLayers("Div", [tf.math.divide, tf.math.truediv]) - tp.OperationsSetToLayers("PReLU", [PReLU]) - tp.OperationsSetToLayers("Swish", [tf.nn.swish, tp.LayerFilterParams(Activation, activation="swish")]) - tp.OperationsSetToLayers("Sigmoid", [tf.nn.sigmoid, tp.LayerFilterParams(Activation, activation="sigmoid")]) - tp.OperationsSetToLayers("Tanh", [tf.nn.tanh, tp.LayerFilterParams(Activation, activation="tanh")]) + tp.OperationsSetToLayers(OPSET_ANY_RELU, [tf.nn.relu, + tf.nn.relu6, + tf.nn.leaky_relu, + ReLU, + LeakyReLU, + tp.LayerFilterParams(Activation, activation="relu"), + tp.LayerFilterParams(Activation, activation="leaky_relu")]) + tp.OperationsSetToLayers(OPSET_ADD, [tf.add, Add]) + tp.OperationsSetToLayers(OPSET_SUB, [tf.subtract, Subtract]) + tp.OperationsSetToLayers(OPSET_MUL, [tf.math.multiply, Multiply]) + tp.OperationsSetToLayers(OPSET_DIV, [tf.math.divide, tf.math.truediv]) + tp.OperationsSetToLayers(OPSET_PRELU, [PReLU]) + tp.OperationsSetToLayers(OPSET_SWISH, [tf.nn.swish, tp.LayerFilterParams(Activation, activation="swish")]) + tp.OperationsSetToLayers(OPSET_SIGMOID, [tf.nn.sigmoid, tp.LayerFilterParams(Activation, activation="sigmoid")]) + tp.OperationsSetToLayers(OPSET_TANH, [tf.nn.tanh, tp.LayerFilterParams(Activation, activation="tanh")]) return keras_tpc diff --git a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py index 9b4bf4e91..f2a10b3d0 100644 --- a/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py +++ b/model_compression_toolkit/target_platform_capabilities/tpc_models/imx500_tpc/v4/tpc_pytorch.py @@ -29,6 +29,10 @@ from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import get_tp_model import model_compression_toolkit as mct from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4 import __version__ as TPC_VERSION +from model_compression_toolkit.target_platform_capabilities.tpc_models.imx500_tpc.v4.tp_model import OPSET_NO_QUANTIZATION, \ + OPSET_QUANTIZATION_PRESERVING, OPSET_DIMENSION_MANIPULATION_OPS_WITH_WEIGHTS, OPSET_DIMENSION_MANIPULATION_OPS, \ + OPSET_MERGE_OPS, OPSET_CONV, OPSET_FULLY_CONNECTED, OPSET_ANY_RELU, OPSET_ADD, OPSET_SUB, OPSET_MUL, OPSET_DIV, \ + OPSET_PRELU, OPSET_SWISH, OPSET_SIGMOID, OPSET_TANH tp = mct.target_platform @@ -65,49 +69,49 @@ def generate_pytorch_tpc(name: str, tp_model: tp.TargetPlatformModel): BIAS_ATTR: DefaultDict(default_value=BIAS)} with pytorch_tpc: - tp.OperationsSetToLayers("NoQuantization", [torch.Tensor.size, - equal, - argmax, - topk]) - tp.OperationsSetToLayers("QuantizationPreserving", [Dropout, - dropout, - split, - chunk, - unbind, - gather, - MaxPool2d]) - tp.OperationsSetToLayers("DimensionManipulationOps", [Flatten, - flatten, - operator.getitem, - reshape, - unsqueeze, - squeeze, - permute, - transpose]) - tp.OperationsSetToLayers("MergeOps", + tp.OperationsSetToLayers(OPSET_NO_QUANTIZATION, [torch.Tensor.size, + equal, + argmax, + topk]) + tp.OperationsSetToLayers(OPSET_QUANTIZATION_PRESERVING, [Dropout, + dropout, + split, + chunk, + unbind, + gather, + MaxPool2d]) + tp.OperationsSetToLayers(OPSET_DIMENSION_MANIPULATION_OPS, [Flatten, + flatten, + operator.getitem, + reshape, + unsqueeze, + squeeze, + permute, + transpose]) + tp.OperationsSetToLayers(OPSET_MERGE_OPS, [torch.stack, torch.cat, torch.concat, torch.concatenate]) - tp.OperationsSetToLayers("Conv", [Conv2d, ConvTranspose2d], + tp.OperationsSetToLayers(OPSET_CONV, [Conv2d, ConvTranspose2d], attr_mapping=pytorch_linear_attr_mapping) - tp.OperationsSetToLayers("FullyConnected", [Linear], + tp.OperationsSetToLayers(OPSET_FULLY_CONNECTED, [Linear], attr_mapping=pytorch_linear_attr_mapping) - tp.OperationsSetToLayers("AnyReLU", [torch.relu, - ReLU, - ReLU6, - LeakyReLU, - relu, - relu6, - leaky_relu, - tp.LayerFilterParams(Hardtanh, min_val=0), - tp.LayerFilterParams(hardtanh, min_val=0)]) + tp.OperationsSetToLayers(OPSET_ANY_RELU, [torch.relu, + ReLU, + ReLU6, + LeakyReLU, + relu, + relu6, + leaky_relu, + tp.LayerFilterParams(Hardtanh, min_val=0), + tp.LayerFilterParams(hardtanh, min_val=0)]) - tp.OperationsSetToLayers("Add", [operator.add, add]) - tp.OperationsSetToLayers("Sub", [operator.sub, sub, subtract]) - tp.OperationsSetToLayers("Mul", [operator.mul, mul, multiply]) - tp.OperationsSetToLayers("Div", [operator.truediv, div, divide]) - tp.OperationsSetToLayers("PReLU", [PReLU, prelu]) - tp.OperationsSetToLayers("Swish", [SiLU, silu, Hardswish, hardswish]) - tp.OperationsSetToLayers("Sigmoid", [Sigmoid, sigmoid]) - tp.OperationsSetToLayers("Tanh", [Tanh, tanh]) + tp.OperationsSetToLayers(OPSET_ADD, [operator.add, add]) + tp.OperationsSetToLayers(OPSET_SUB, [operator.sub, sub, subtract]) + tp.OperationsSetToLayers(OPSET_MUL, [operator.mul, mul, multiply]) + tp.OperationsSetToLayers(OPSET_DIV, [operator.truediv, div, divide]) + tp.OperationsSetToLayers(OPSET_PRELU, [PReLU, prelu]) + tp.OperationsSetToLayers(OPSET_SWISH, [SiLU, silu, Hardswish, hardswish]) + tp.OperationsSetToLayers(OPSET_SIGMOID, [Sigmoid, sigmoid]) + tp.OperationsSetToLayers(OPSET_TANH, [Tanh, tanh]) return pytorch_tpc diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py index c0ba52620..df8929552 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/const_quantization_test.py @@ -189,6 +189,13 @@ def create_networks(self): x3 = tf.add_n([x1, as_const(x), x2]) x1 = tf.reshape(tf.stack([as_const(x1), x1, as_const(x1)], axis=1), (-1, 3*x1.shape[1], x1.shape[2], x1.shape[3])) x = tf.concat([x1, x2, as_const(x3), x3], 1) + ind_select_const = np.zeros((192*32, 38)) + ind_select_const[4, :] = 100 + x1 = tf.add(x, ind_select_const.reshape((192, 32, 38))) + inds = tf.argmax(tf.reshape(x1, (-1, 192 * 32, 38)), axis=1) + b = tf.reshape(tf.gather(np.random.random((192 * 32 * 38,)).astype(np.float32), inds), (-1, 1, 1, 38)) + x = tf.add(x, b) + return tf.keras.models.Model(inputs=inputs, outputs=x) def compare(self, quantized_model, float_model, input_x=None, quantization_info=None): @@ -196,10 +203,12 @@ def compare(self, quantized_model, float_model, input_x=None, quantization_info= y_hat = quantized_model.predict(input_x) self.unit_test.assertTrue(y.shape == y_hat.shape, msg=f'out shape is not as expected!') cs = cosine_similarity(y, y_hat) - self.unit_test.assertTrue(np.isclose(cs, 1, atol=1e-2), msg=f'fail cosine similarity check:{cs}') + # atol is rather high because there's a potential large difference between the float and quantized indices tensor + # that goes to the tf.argmax. + self.unit_test.assertTrue(np.isclose(cs, 1, atol=0.01), msg=f'fail cosine similarity check:{cs}') # check quantization layers: - for op in [tf.concat, tf.stack, layers.Add, layers.Multiply, layers.Concatenate]: + for op in [tf.concat, tf.stack, layers.Add, layers.Multiply, layers.Concatenate, tf.gather, tf.compat.v1.gather]: for qlayer in get_layers_from_model_by_type(quantized_model, op): self.unit_test.assertTrue(isinstance(qlayer, KerasQuantizationWrapper), msg=f"{op} should be quantized.") diff --git a/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py b/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py index bbefb4bcd..11754f504 100644 --- a/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py +++ b/tests/keras_tests/feature_networks_tests/feature_networks/const_representation_test.py @@ -155,7 +155,10 @@ def get_tpc(self): def create_networks(self): as_const = lambda v: np.random.random(v.shape.as_list()).astype(np.float32) inputs = layers.Input(shape=self.get_input_shapes()[0][1:]) - x = layers.Concatenate()([inputs, np.random.random((1, 32, 32, 3)), inputs, np.random.random((1, 32, 32, 3))]) + inds = tf.reshape(tf.argmax(tf.reshape(inputs, (-1, 32 * 32, 16)), axis=1), (-1, 1, 1, 16)) + b = tf.gather(np.random.random((2000,)).astype(np.float32), inds) + x = tf.add(inputs, b) + x = layers.Concatenate()([x, np.random.random((1, 32, 32, 3)), x, np.random.random((1, 32, 32, 3))]) x1 = layers.Add()([np.random.random((1, x.shape[-1])), x, np.random.random((1, x.shape[-1]))]) x2 = layers.Multiply()([x, np.random.random((1, x.shape[-1])), x, np.random.random((1, x.shape[-1]))]) x3 = tf.add_n([x1, as_const(x), x2])