From 7521b68fba363d4add0c772750d119e4d9815ce6 Mon Sep 17 00:00:00 2001 From: Yifan Shen Date: Fri, 19 Apr 2024 10:26:09 -0700 Subject: [PATCH] 7.2 release (#2196) Co-authored-by: yifan_shen3 --- .../converters/mil/backend/mil/helper.py | 44 ---- .../converters/mil/backend/mil/load.py | 89 ++++++- coremltools/converters/mil/backend/nn/load.py | 3 +- coremltools/converters/mil/frontend/_utils.py | 23 +- .../mil/frontend/torch/exir_utils.py | 17 +- .../converters/mil/frontend/torch/ops.py | 12 +- .../mil/frontend/torch/test/test_torch_ops.py | 230 ++++++++++++++---- .../mil/passes/defs/optimize_quantization.py | 25 +- coremltools/converters/mil/mil/program.py | 6 +- .../converters/mil/mil/types/type_tensor.py | 14 +- coremltools/models/model.py | 7 +- .../optimize/coreml/_quantization_passes.py | 61 ++--- coremltools/optimize/coreml/_utils.py | 8 +- .../coreml/test_post_training_quantization.py | 35 ++- coremltools/version.py | 2 +- reqs/docs.pip | 2 +- 16 files changed, 395 insertions(+), 183 deletions(-) diff --git a/coremltools/converters/mil/backend/mil/helper.py b/coremltools/converters/mil/backend/mil/helper.py index d6c7cd66a..c123e0ece 100644 --- a/coremltools/converters/mil/backend/mil/helper.py +++ b/coremltools/converters/mil/backend/mil/helper.py @@ -42,16 +42,6 @@ def create_valuetype_list(length, elem_shape, dtype): update_listtype(v_type.listType, length, elem_shape, dtype) return v_type -def create_valuetype_dict(key_type, value_type): - """ - Return proto.MIL_pb2.ValueType with dict (dictionaryType) set - """ - v_type = proto.MIL_pb2.ValueType() - v_type.dictionaryType.keyType.CopyFrom(types_to_proto(key_type)) - v_type.dictionaryType.valueType.CopyFrom(types_to_proto(value_type)) - return v_type - - def create_valuetype_tensor(shape, data_type): """ Return proto.MIL_pb2.ValueType with tensor (TensorType) set. @@ -261,40 +251,6 @@ def types_to_proto_primitive(valuetype): ) return types.BUILTIN_TO_PROTO_TYPES[valuetype] - -def types_to_proto(valuetype): - if types.is_tensor(valuetype): - primitive = types_to_proto_primitive(valuetype.get_primitive()) - return create_valuetype_tensor(valuetype.get_shape(), primitive) - elif types.is_tuple(valuetype): - v_type = proto.MIL_pb2.ValueType() - t_type = v_type.tupleType - for t in valuetype.T: - new_v_type = t_type.types.add() - new_v_type.CopyFrom(types_to_proto(t)) - return v_type - elif types.is_list(valuetype): - elem = valuetype.T[0] - length = valuetype.T[1] - if types.is_tensor(elem): - dtype = types_to_proto_primitive(elem.get_primitive()) - elem_shape = elem.get_shape() - elif types.is_scalar(elem): - dtype = types_to_proto_primitive(valuetype) - elem_shape = () - elif types.is_str(elem): - dtype = types_to_proto_primitive(elem) - elem_shape = () - else: - raise NotImplementedError("Only list of either tensors or scalars supported. " - "Got element of type {}".format(elem.__type_info__())) - return create_valuetype_list(length=length, elem_shape=elem_shape, dtype=dtype) - elif types.is_dict(valuetype): - return create_valuetype_dict(valuetype.T[0], valuetype.T[1]) - else: - return create_valuetype_scalar(types_to_proto_primitive(valuetype)) - - def _get_offset_by_writing_data(output_var, blob_writer): if output_var.val.dtype.kind == 'f' and output_var.val.dtype.itemsize == 4: offset = blob_writer.write_float_data(np.ascontiguousarray(output_var.val.flatten())) diff --git a/coremltools/converters/mil/backend/mil/load.py b/coremltools/converters/mil/backend/mil/load.py index b57b590df..2d4742a2a 100644 --- a/coremltools/converters/mil/backend/mil/load.py +++ b/coremltools/converters/mil/backend/mil/load.py @@ -22,7 +22,9 @@ create_immediate_value, create_list_scalarvalue, create_scalar_value, - types_to_proto, + create_valuetype_list, + create_valuetype_scalar, + create_valuetype_tensor, types_to_proto_primitive, ) from coremltools.converters.mil.backend.nn.load import _set_optional_inputs @@ -158,7 +160,7 @@ def translate_const(self, op: Operation) -> proto.MIL_pb2.Operation: attributes={"name": create_scalar_value(op.name), "val": value}, outputs=[ proto.MIL_pb2.NamedValueType( - name=output_var.name, type=types_to_proto(output_var.sym_type) + name=output_var.name, type=self.types_to_proto(output_var.sym_type) ) ], ) @@ -190,12 +192,58 @@ def translate_constexpr(self, op: Operation) -> proto.MIL_pb2.Operation: attributes=attributes, outputs=[ proto.MIL_pb2.NamedValueType( - name=output_var.name, type=types_to_proto(output_var.sym_type) + name=output_var.name, type=self.types_to_proto(output_var.sym_type) ) for output_var in op.outputs ], ) + def create_valuetype_dict(self, key_type: type, value_type: type) -> proto.MIL_pb2.ValueType: + """ + Return proto.MIL_pb2.ValueType with dict (dictionaryType) set + """ + v_type = proto.MIL_pb2.ValueType() + v_type.dictionaryType.keyType.CopyFrom(self.types_to_proto(key_type)) + v_type.dictionaryType.valueType.CopyFrom(self.types_to_proto(value_type)) + return v_type + + def types_to_proto(self, valuetype: type) -> proto.MIL_pb2.ValueType: + """ + Return proto.MIL_pb2.ValueType from PyMIL types. + """ + if types.is_tensor(valuetype): + primitive = types_to_proto_primitive(valuetype.get_primitive()) + return create_valuetype_tensor(valuetype.get_shape(), primitive) + elif types.is_tuple(valuetype): + v_type = proto.MIL_pb2.ValueType() + t_type = v_type.tupleType + for t in valuetype.T: + new_v_type = t_type.types.add() + new_v_type.CopyFrom(self.types_to_proto(t)) + return v_type + elif types.is_list(valuetype): + elem = valuetype.T[0] + length = valuetype.T[1] + if types.is_tensor(elem): + dtype = types_to_proto_primitive(elem.get_primitive()) + elem_shape = elem.get_shape() + elif types.is_scalar(elem): + dtype = types_to_proto_primitive(valuetype) + elem_shape = () + elif types.is_str(elem): + dtype = types_to_proto_primitive(elem) + elem_shape = () + else: + raise NotImplementedError( + "Only list of either tensors or scalars supported. " + "Got element of type {}".format(elem.__type_info__()) + ) + return create_valuetype_list(length=length, elem_shape=elem_shape, dtype=dtype) + elif types.is_dict(valuetype): + return self.create_valuetype_dict(valuetype.T[0], valuetype.T[1]) + else: + return create_valuetype_scalar(types_to_proto_primitive(valuetype)) + def translate_generic_op( self, op: Operation, literal_params: Optional[List[str]] = None ) -> proto.MIL_pb2.Operation: @@ -228,7 +276,7 @@ def translate_generic_op( inputs[param_name] = args outputs = [ - proto.MIL_pb2.NamedValueType(name=v.name, type=types_to_proto(v.sym_type)) + proto.MIL_pb2.NamedValueType(name=v.name, type=self.types_to_proto(v.sym_type)) for v in op.outputs ] blocks = None @@ -311,14 +359,18 @@ def feeds_to_only_constexprs(op: Operation) -> bool: literal_params = ["begins", "ends", "end_masks"] proto_ops.append(self.translate_generic_op(op, literal_params)) else: - proto_ops.append(self.translate_generic_op(op)) + # A single pymil op might be decomposed into multiple ops + ops = self.translate_generic_op(op) + if not isinstance(ops, list): + ops = [ops] + proto_ops.extend(ops) inputs = [] if not isinstance(block, Function): # Function is subclass of Block, but function's block has no input, # and hence skipping reading the block inputs. for var in block.inputs: - proto_type = types_to_proto(var.sym_type) + proto_type = self.types_to_proto(var.sym_type) inputs.append(proto.MIL_pb2.NamedValueType(name=var.name, type=proto_type)) output_names = [v.name for v in block.outputs] return proto.MIL_pb2.Block(inputs=inputs, outputs=output_names, operations=proto_ops) @@ -331,7 +383,7 @@ def convert_function(self, function: Function, opset: str) -> proto.MIL_pb2.Func inputs = [] for name, var in function.inputs.items(): - proto_type = types_to_proto(var.sym_type) + proto_type = self.types_to_proto(var.sym_type) inputs.append(proto.MIL_pb2.NamedValueType(name=name, type=proto_type)) return proto.MIL_pb2.Function( @@ -467,6 +519,15 @@ def get_additional_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: """ return {} + @staticmethod + def _try_convert_other_input_type( + input_var: Var, input_features: List[proto.Model_pb2.FeatureDescription] + ) -> bool: + """ + Try to convert an input var with additional type. + """ + return False + def get_func_input(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDescription]: """ Utils to get function input feature description. @@ -554,7 +615,7 @@ def get_func_input(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDesc input_features.append( proto.Model_pb2.FeatureDescription(name=var.name, type=input_feature_type) ) - else: + elif not self._try_convert_other_input_type(var, input_features): raise NotImplementedError(f"Unsupported input type {var.sym_type}.") if not is_input_shape_symbolic: @@ -746,6 +807,16 @@ def get_func_output(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDes return output_features + def create_model_description( + self, + input_features: List[proto.Model_pb2.FeatureDescription], + output_features: List[proto.Model_pb2.FeatureDescription], + ) -> proto.Model_pb2.ModelDescription: + """ + Create model description from input and output features + """ + return proto.Model_pb2.ModelDescription(input=input_features, output=output_features) + def get_coreml_model( self, input: Dict[str, List[proto.Model_pb2.FeatureDescription]], @@ -758,7 +829,7 @@ def get_coreml_model( # Model description input_features = input[self._DEFAULT_FUNCTION_NAME] output_features = output[self._DEFAULT_FUNCTION_NAME] - desc = proto.Model_pb2.ModelDescription(input=input_features, output=output_features) + desc = self.create_model_description(input_features, output_features) if self.classifier_config is not None: desc.predictedFeatureName = self.predicted_feature_name diff --git a/coremltools/converters/mil/backend/nn/load.py b/coremltools/converters/mil/backend/nn/load.py index 8c825cb09..a4c449e73 100644 --- a/coremltools/converters/mil/backend/nn/load.py +++ b/coremltools/converters/mil/backend/nn/load.py @@ -12,6 +12,7 @@ ImageType, RangeDim, Shape, + TensorType, ) from coremltools.converters.mil.mil import types from coremltools.converters.mil.mil.types.symbolic import any_symbolic, any_variadic, is_symbolic @@ -169,7 +170,7 @@ def _set_optional_inputs(proto, input_types): # Set default values for optional input_types default_map = {} for input_type in input_types: - if isinstance(input_type, ImageType): + if not isinstance(input_type, TensorType): continue if input_type.default_value is not None: default_map[input_type.name] = input_type.default_value diff --git a/coremltools/converters/mil/frontend/_utils.py b/coremltools/converters/mil/frontend/_utils.py index 7f5daccdc..1a4a11a69 100644 --- a/coremltools/converters/mil/frontend/_utils.py +++ b/coremltools/converters/mil/frontend/_utils.py @@ -512,7 +512,9 @@ def _concat_dims(dims, none_if_empty=False): return ab -def _lower_scaled_dot_product_attention(q: Var, k: Var, v: Var, mask: Var, name: str) -> Var: +def _lower_scaled_dot_product_attention( + q: Var, k: Var, v: Var, mask: Var, name: str, before_op: Optional[Operation] = None +) -> Var: # scale the query input embed_size = q.shape[-1] if is_symbolic(embed_size): @@ -520,25 +522,26 @@ def _lower_scaled_dot_product_attention(q: Var, k: Var, v: Var, mask: Var, name: "The embedding size, i.e. last dimension of the shape of query tensor" " cannot be symbolic, in scaled_dot_product_attention op" ) + + q, k, v = promote_input_dtypes([q, k, v]) multiplicative_scale_factor = 1 / math.sqrt(embed_size) - q, k, v, multiplicative_scale_factor = promote_input_dtypes( - [q, k, v, multiplicative_scale_factor] - ) - q = mb.mul(x=q, y=multiplicative_scale_factor) + if types.builtin_to_string(q.dtype) == "fp16": + multiplicative_scale_factor = _np.float16(multiplicative_scale_factor) + q = mb.mul(x=q, y=multiplicative_scale_factor, before_op=before_op) # multiply query and key input tensors # shape of output: (target_seq, source_seq) or (B,...,target_seq, source_seq) - attn_weights = mb.matmul(x=q, y=k, transpose_y=True) + attn_weights = mb.matmul(x=q, y=k, transpose_y=True, before_op=before_op) # add mask if applicable if mask is not None: - attn_weights = mb.add(x=attn_weights, y=mask) + attn_weights = mb.add(x=attn_weights, y=mask, before_op=before_op) # do softmax - attn_weights_normalized = mb.softmax(x=attn_weights, axis=-1) + attn_weights_normalized = mb.softmax(x=attn_weights, axis=-1, before_op=before_op) # multiply attn_weights and value tensor - res = mb.matmul(x=attn_weights_normalized, y=v, name=name) + res = mb.matmul(x=attn_weights_normalized, y=v, name=name, before_op=before_op) return res @@ -549,7 +552,7 @@ def _construct_constexpr_affine_op( axis: Optional[Union[Var, int]] = None, name: Optional[str] = None, before_op: Optional[Operation] = None, -) -> Operation: +) -> Var: """Constructs the constexpr op to represent the dequantized weight from PyTorch's data.""" # The constexpr_affine_dequantize op requires axis. if axis is None: diff --git a/coremltools/converters/mil/frontend/torch/exir_utils.py b/coremltools/converters/mil/frontend/torch/exir_utils.py index d962f260c..9cef8728a 100644 --- a/coremltools/converters/mil/frontend/torch/exir_utils.py +++ b/coremltools/converters/mil/frontend/torch/exir_utils.py @@ -87,12 +87,27 @@ def extract_inputs_from_exir_program( val = node.meta["val"] assert isinstance(val, torch.Tensor), "placeholder val must be a tensor or fake tensor" user_inputs.append(to_coreml_tensor_type(node.name, val)) + elif input_spec.kind == torch.export.graph_signature.InputKind.PARAMETER: lifted_parameters[input_spec.arg.name] = parameters[input_spec.target] + elif input_spec.kind == torch.export.graph_signature.InputKind.BUFFER: - lifted_buffers[input_spec.arg.name] = buffers[input_spec.target] + # This is a workaround on mutable buffer: Core ML does not support stateful execution, + # so ExecuTorch will pass mutable buffers as inputs/outputs to Core ML delegation, + # then in-place copy Core ML outputs into buffers + # On Core ML side, we do not have to do anything special with outputs, + # but for inputs we will need to identify ExecuTorch lifted mutable buffers + # as Core ML user inputs + if input_spec.target in exported_program.graph_signature.buffers_to_mutate.values(): + user_inputs.append( + to_coreml_tensor_type(input_spec.arg.name, buffers[input_spec.target]) + ) + else: + lifted_buffers[input_spec.arg.name] = buffers[input_spec.target] + elif input_spec.kind == torch.export.graph_signature.InputKind.CONSTANT_TENSOR: lifted_constants[input_spec.arg.name] = exported_program.constants[input_spec.target] + else: raise NotImplementedError( "Only 4 types of inputs handled yet: user input, parameter, buffer, constant. " diff --git a/coremltools/converters/mil/frontend/torch/ops.py b/coremltools/converters/mil/frontend/torch/ops.py index 42700b6b9..a1bd39900 100644 --- a/coremltools/converters/mil/frontend/torch/ops.py +++ b/coremltools/converters/mil/frontend/torch/ops.py @@ -1678,7 +1678,7 @@ def view(context, node): x = inputs[0] shape = inputs[1] - if np.prod(shape.shape) == 0: + if isinstance(shape, Var) and np.prod(shape.shape) == 0: # Reshape to empty shape (works only for scalar) is a no op assert ( np.prod(x.shape) <= 1 @@ -6694,21 +6694,15 @@ def _get_causal_attn_mask(is_causal: bool, query_var: Var, key_var: Var) -> Var: def _cast_bool_attn_mask(attn_mask: Var, query_var: Var) -> Var: """ - compute float mask as: - mask = cast(bool_mask) + (1-cast(bool_mask)) * -30k*ones(shape(bool_mask)) + compute float mask as (1 - cast(bool_mask)) * -30k """ assert is_bool(attn_mask.dtype) - shape = mb.shape(x=attn_mask) - negative_inf = mb.fill( - shape=shape, value=_np.array([-3e4]).astype(types.nptype_from_builtin(query_var.dtype)) - ) mask = mb.cast(x=attn_mask, dtype=types.builtin_to_string(query_var.dtype)) compliment_of_mask = mb.sub( x=_np.array([1.0]).astype(types.nptype_from_builtin(mask.dtype)), y=mask ) - compliment_of_mask = mb.mul(x=negative_inf, y=compliment_of_mask) - return mb.add(x=mask, y=compliment_of_mask) + return mb.mul(x=-3e4, y=compliment_of_mask) @register_torch_op def scaled_dot_product_attention(context, node): diff --git a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py index fd4513648..e9b024fff 100644 --- a/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py +++ b/coremltools/converters/mil/frontend/torch/test/test_torch_ops.py @@ -2894,10 +2894,11 @@ def test_adaptive_avg_pool2d( class TestMaxPool(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, input_shape, kernel_size, stride, padding, ceil_mode", + "compute_unit, backend, frontend, input_shape, kernel_size, stride, padding, ceil_mode", itertools.product( compute_units, backends, + frontends, [(1, 3, 15), (1, 1, 7)], [1, 3], [1, 2], @@ -2909,6 +2910,7 @@ def test_max_pool1d( self, compute_unit, backend, + frontend, input_shape, kernel_size, stride, @@ -2932,14 +2934,15 @@ def test_max_pool1d( ceil_mode=ceil_mode, ) self.run_compare_torch( - input_shape, model, backend=backend, compute_unit=compute_unit + input_shape, model, frontend=frontend, backend=backend, compute_unit=compute_unit ) @pytest.mark.parametrize( - "compute_unit, backend, input_shape, kernel_size, stride, padding, ceil_mode", + "compute_unit, backend, frontend, input_shape, kernel_size, stride, padding, ceil_mode", itertools.product( compute_units, backends, + frontends, [(1, 3, 15, 15), (1, 1, 7, 7)], [1, 3], [1, 2], @@ -2951,6 +2954,7 @@ def test_max_pool2d( self, compute_unit, backend, + frontend, input_shape, kernel_size, stride, @@ -2975,14 +2979,15 @@ def test_max_pool2d( ceil_mode=ceil_mode, ) self.run_compare_torch( - input_shape, model, backend=backend, compute_unit=compute_unit + input_shape, model, frontend=frontend, backend=backend, compute_unit=compute_unit ) @pytest.mark.parametrize( - "compute_unit, backend, input_shape, kernel_size, stride, padding, ceil_mode", + "compute_unit, backend, frontend, input_shape, kernel_size, stride, padding, ceil_mode", itertools.product( compute_units, backends, + frontends, [(1, 3, 11, 3, 11), (1, 1, 7, 4, 7)], [1, 3], [1, 2], @@ -2994,12 +2999,16 @@ def test_max_pool3d( self, compute_unit, backend, + frontend, input_shape, kernel_size, stride, padding, ceil_mode, ): + if frontend == TorchFrontend.EXIR: + pytest.xfail("TODO (rdar://115846125): handle multi-output op max_pool3d_with_indices") + if padding > kernel_size / 2: return if ceil_mode > 0 and padding == 0 and kernel_size == 1 and stride == 2: @@ -3018,16 +3027,17 @@ def test_max_pool3d( ceil_mode=ceil_mode, ) self.run_compare_torch( - input_shape, model, backend=backend, compute_unit=compute_unit + input_shape, model, frontend=frontend, backend=backend, compute_unit=compute_unit ) class TestMaximumMinimum(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, input_shapes, mode", + "compute_unit, backend, frontend, input_shapes, mode", itertools.product( compute_units, backends, + frontends, [ [(2, 5, 7, 3), (2, 5, 7, 3)], [(3, 2, 9), (3, 2, 9)], @@ -3038,7 +3048,7 @@ class TestMaximumMinimum(TorchBaseTest): ["minimum", "maximum"], ), ) - def test_minimum_maximum(self, compute_unit, backend, input_shapes, mode): + def test_minimum_maximum(self, compute_unit, backend, frontend, input_shapes, mode): class TestModel(torch.nn.Module): def forward(self, x, y): if mode == "minimum": @@ -3049,14 +3059,15 @@ def forward(self, x, y): raise ValueError("Unsupported mode: {mode}".format(mode=mode)) self.run_compare_torch( - input_shapes, TestModel(), backend=backend, compute_unit=compute_unit + input_shapes, TestModel(), frontend=frontend, backend=backend, compute_unit=compute_unit ) @pytest.mark.parametrize( - "compute_unit, backend, input_shapes, mode, xdtype, ydtype", + "compute_unit, backend, frontend, input_shapes, mode, xdtype, ydtype", itertools.product( compute_units, backends, + frontends, [ [(2, 5, 7, 3), (2, 5, 7, 3)], [(3, 2, 9), (3, 2, 9)], @@ -3070,7 +3081,7 @@ def forward(self, x, y): ), ) def test_minimum_maximum_mixed_precision( - self, compute_unit, backend, input_shapes, mode, xdtype, ydtype + self, compute_unit, backend, frontend, input_shapes, mode, xdtype, ydtype ): class TestModel(torch.nn.Module): def forward(self, x, y): @@ -3086,12 +3097,14 @@ def forward(self, x, y): self.run_compare_torch( input_shapes, TestModel(), + frontend=frontend, compute_unit=compute_unit, backend=backend, rtol=1e-6 if xdtype == ydtype and xdtype == torch.float32 else 1e-3, atol=1e-6 if xdtype == ydtype and xdtype == torch.float32 else 1e-3, ) + class TestAMaxAMin(TorchBaseTest): @pytest.mark.parametrize( "compute_unit, backend, input_shapes, mode, reduce_dim, keepdim", @@ -6312,16 +6325,12 @@ def forward(self, x): class TestSlice(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, start, end, step", + "compute_unit, backend, frontend, start, end, step", itertools.product( - compute_units, - backends, - (0, -5, None), - (7, -1, 100, None), - (1, 2, None) + compute_units, backends, frontends, (0, -5, None), (7, -1, 100, None), (1, 2, None) ), ) - def test_slice(self, compute_unit, backend, start, end, step): + def test_slice(self, compute_unit, backend, frontend, start, end, step): class SliceModel(torch.nn.Module): def forward(self, x): y = x[start : end : step] @@ -6331,18 +6340,25 @@ def forward(self, x): model.eval() self.run_compare_torch( - (9,), model, backend=backend, compute_unit=compute_unit + (9,), model, frontend=frontend, backend=backend, compute_unit=compute_unit ) @pytest.mark.skipif(_python_version() < (3, 6), reason="requires python 3.6") @pytest.mark.parametrize( - "compute_unit, backend", + "compute_unit, backend, frontend", itertools.product( compute_units, backends, + frontends, ), ) - def test_dynamic_slice(self, compute_unit, backend): + def test_dynamic_slice(self, compute_unit, backend, frontend): + if frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2189: " + "torch.export Cannot Use Dynamic Index to Slice" + ) + class DynamicSlicer(torch.nn.Module): def forward(self, x, context_length): return x[context_length:, :, :] @@ -6374,7 +6390,12 @@ def forward(self, tokens, context, context_length): TensorType(name="context_length", shape=(1,), dtype=np.int32), ] self.run_compare_torch( - inputs, model, rand_range=(0, 8), backend=backend, compute_unit=compute_unit + inputs, + model, + rand_range=(0, 8), + frontend=frontend, + backend=backend, + compute_unit=compute_unit, ) @@ -7306,10 +7327,10 @@ def forward(self, x, y): class TestWhere(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, shape", - itertools.product(compute_units, backends, [(2, 6), (3, 4, 5)]), + "compute_unit, backend, frontend, shape", + itertools.product(compute_units, backends, frontends, [(2, 6), (3, 4, 5)]), ) - def test_where_test1(self, compute_unit, backend, shape): + def test_where_test1(self, compute_unit, backend, frontend, shape): class WhereModel(nn.Module): def forward(self, x, y): return torch.where(x > 0.5, x, y) @@ -7317,14 +7338,14 @@ def forward(self, x, y): input_shape = [shape, shape] model = WhereModel() self.run_compare_torch( - input_shape, model, backend=backend, compute_unit=compute_unit + input_shape, model, frontend=frontend, backend=backend, compute_unit=compute_unit ) @pytest.mark.parametrize( - "compute_unit, backend, shape", - itertools.product(compute_units, backends, [(2, 6), (3, 4, 5)]), + "compute_unit, backend, frontend, shape", + itertools.product(compute_units, backends, frontends, [(2, 6), (3, 4, 5)]), ) - def test_where_test2(self, compute_unit, backend, shape): + def test_where_test2(self, compute_unit, backend, frontend, shape): class WhereModel(nn.Module): def forward(self, cond, x, y): return torch.where(cond, x, y) @@ -7336,6 +7357,7 @@ def forward(self, cond, x, y): self.run_compare_torch( inputs, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, expected_results=expected_results, @@ -7343,17 +7365,18 @@ def forward(self, cond, x, y): ) @pytest.mark.parametrize( - "compute_unit, backend, shapes", + "compute_unit, backend, frontend, shapes", itertools.product( compute_units, backends, + frontends, [ [(1, 2), (1, 2), (1, 1)], [(1, 2, 3), (1, 1, 1), (1, 1, 3)], ], ), ) - def test_where_test3(self, compute_unit, backend, shapes): + def test_where_test3(self, compute_unit, backend, frontend, shapes): class WhereModel(nn.Module): def forward(self, cond, x, y): return torch.where(cond, x, y) @@ -7366,6 +7389,7 @@ def forward(self, cond, x, y): self.run_compare_torch( inputs, model, + frontend=frontend, backend=backend, compute_unit=compute_unit, expected_results=expected_results, @@ -7373,10 +7397,11 @@ def forward(self, cond, x, y): ) @pytest.mark.parametrize( - "compute_unit, backend, shapes, xdtype, ydtype", + "compute_unit, backend, frontend, shapes, xdtype, ydtype", itertools.product( compute_units, backends, + frontends, [ [(1, 2), (1, 2), (1, 1)], [(1, 2, 3), (1, 2, 1), (1, 1, 3)], @@ -7385,7 +7410,7 @@ def forward(self, cond, x, y): (torch.float16, torch.float32), ), ) - def test_where_mixed_precision(self, compute_unit, backend, shapes, xdtype, ydtype): + def test_where_mixed_precision(self, compute_unit, backend, frontend, shapes, xdtype, ydtype): class WhereModel(nn.Module): def forward(self, cond, x, y): a = x.to(xdtype) @@ -7400,6 +7425,7 @@ def forward(self, cond, x, y): inputs, WhereModel(), compute_unit=compute_unit, + frontend=frontend, backend=backend, input_as_shape=False, rtol=1e-6 if xdtype == ydtype and xdtype == torch.float32 else 1e-3, @@ -7407,10 +7433,16 @@ def forward(self, cond, x, y): ) @pytest.mark.parametrize( - "compute_unit, backend, shape", - itertools.product(compute_units, backends, COMMON_SHAPES + [(10,)]), + "compute_unit, backend, frontend, shape", + itertools.product(compute_units, backends, frontends, COMMON_SHAPES + [(10,)]), ) - def test_where_single_param(self, compute_unit, backend, shape): + def test_where_single_param(self, compute_unit, backend, frontend, shape): + if frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2183: " + "Operator torch._ops.aten._assert_async.msg is not Aten Canonical" + ) + class WhereModelSingleParam(nn.Module): def forward(self, x): return torch.where(x) @@ -7429,6 +7461,7 @@ def forward(self, x): self.run_compare_torch( x, WhereModelSingleParam(), + frontend=frontend, backend=backend, input_as_shape=False, compute_unit=compute_unit, @@ -7953,7 +7986,7 @@ def test_index_put_bool_index_case_1(self, compute_unit, backend, frontend, mini "https://github.com/apple/coremltools/issues/2183: " "Operator torch._ops.aten._assert_async.msg is not Aten Canonical" ) - + class IndexPutModel(torch.nn.Module): def forward(self, x, y): y = x + 1 @@ -9392,14 +9425,15 @@ def forward(self, a, b, c, d): class TestEmbedding(TorchBaseTest): @pytest.mark.parametrize( - "compute_unit, backend, input_dtype", + "compute_unit, backend, frontend, input_dtype", itertools.product( compute_units, backends, + frontends, [np.int32, np.float32], ), ) - def test_embedding(self, compute_unit, backend, input_dtype): + def test_embedding(self, compute_unit, backend, frontend, input_dtype): num_embeddings = 4 embedding_size = 10 B = 2 @@ -9426,6 +9460,7 @@ def forward(self, x): model, expected_results=expected_results, input_as_shape=False, + frontend=frontend, backend=backend, compute_unit=compute_unit, converter_input_type=converter_input_type, @@ -10765,6 +10800,7 @@ def forward(self, x): input_as_shape=False, use_scripting=True, backend=backend, compute_unit=compute_unit) + class TestScaledDotProductAttention(TorchBaseTest): """ Tests for torch.nn.functional.scaled_dot_product_attention op @@ -10772,15 +10808,17 @@ class TestScaledDotProductAttention(TorchBaseTest): """ @pytest.mark.parametrize( - "compute_unit, backend, rank", + "compute_unit, backend, frontend, rank, dynamic", itertools.product( compute_units, backends, + frontends, [2, 3, 4, 5], + [True, False], ), ) def test_different_input_ranks_no_mask( - self, compute_unit, backend, rank, minimum_deployment_target=None + self, compute_unit, backend, frontend, rank, dynamic, minimum_deployment_target=None ): """ The query/key/value inputs can be any rank 2 or greater. @@ -10806,26 +10844,52 @@ def test_different_input_ranks_no_mask( }, ) + if dynamic: + converter_input_type = [ + ct.TensorType( + shape=(ct.RangeDim(upper_bound=10, default=batch_size),) + input_shape[1:] + ) + for _ in range(3) + ] + else: + converter_input_type = None + return self.run_compare_torch( [input_shape] * 3, model, + frontend=frontend, backend=backend, + converter_input_type=converter_input_type, compute_unit=compute_unit, minimum_deployment_target=minimum_deployment_target, )[1] @pytest.mark.parametrize( - "compute_unit, backend, seq_lengths, include_heads", + "compute_unit, backend, frontend, seq_lengths, include_heads, dynamic", itertools.product( compute_units, backends, + frontends, [(5, 5), (5, 7), (6, 4)], [False, True], + [True, False], ), ) def test_is_causal_flag( - self, compute_unit, backend, seq_lengths, include_heads, minimum_deployment_target=None + self, + compute_unit, + backend, + frontend, + seq_lengths, + include_heads, + dynamic, + minimum_deployment_target=None, ): + if frontend == TorchFrontend.EXIR: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2199: placeholder assertion error" + ) + source_seq_len, target_seq_len = seq_lengths query_shape = (2, 2, target_seq_len, 7) if include_heads else (2, target_seq_len, 7) key_shape = (2, 2, source_seq_len, 7) if include_heads else (2, source_seq_len, 7) @@ -10838,10 +10902,23 @@ def test_is_causal_flag( "is_causal": True, }, ) + + if dynamic: + converter_input_type = [ + ct.TensorType( + shape=(ct.RangeDim(upper_bound=10, default=input_shape[0]),) + input_shape[1:] + ) + for input_shape in [query_shape, key_shape, value_shape] + ] + else: + converter_input_type = None + res = self.run_compare_torch( [query_shape, key_shape, value_shape], model, + frontend=frontend, backend=backend, + converter_input_type=converter_input_type, compute_unit=compute_unit, minimum_deployment_target=minimum_deployment_target, ) @@ -10852,21 +10929,31 @@ def test_is_causal_flag( assert len(mil_prog.find_ops(op_type="band_part")) == 0 @pytest.mark.parametrize( - "compute_unit, backend, seq_lengths, bool_mask", + "compute_unit, backend, frontend, seq_lengths, bool_mask, dynamic", itertools.product( compute_units, backends, + frontends, [(5, 5), (7, 5)], [False, True], + [False, True], ), ) def test_attn_mask( - self, compute_unit, backend, seq_lengths, bool_mask, minimum_deployment_target=None + self, + compute_unit, + backend, + frontend, + seq_lengths, + bool_mask, + dynamic, + minimum_deployment_target=None, ): - if bool_mask: + if frontend == TorchFrontend.TORCHSCRIPT and bool_mask: pytest.xfail( "rdar://110499660 ([CI][Bug] test_attn_mask is occasionally failing when bool_mask = True)" ) + source_seq_len, target_seq_len = seq_lengths query_shape = (2, 3, target_seq_len, 7) key_shape = (2, 3, source_seq_len, 7) @@ -10883,26 +10970,53 @@ def test_attn_mask( mask = generate_input_data(mask_shape) model = ModuleWrapper(function=nn.functional.scaled_dot_product_attention) + + if dynamic: + converter_input_type = [ + ct.TensorType( + shape=(ct.RangeDim(upper_bound=10, default=input_data.shape[0]),) + + input_data.shape[1:] + ) + for input_data in [query, key, value, mask] + ] + else: + converter_input_type = None + self.run_compare_torch( (query, key, value, mask), model, + frontend=frontend, backend=backend, + converter_input_type=converter_input_type, compute_unit=compute_unit, minimum_deployment_target=minimum_deployment_target, input_as_shape=False, ) @pytest.mark.parametrize( - "compute_unit, backend, mask_as_input", + "compute_unit, backend, frontend, mask_as_input, dynamic", itertools.product( compute_units, backends, + frontends, + [True, False], [True, False], ), ) def test_toy_xformer_with_sdpa( - self, compute_unit, backend, mask_as_input, minimum_deployment_target=None + self, + compute_unit, + backend, + frontend, + mask_as_input, + dynamic, + minimum_deployment_target=None, ): + if frontend == TorchFrontend.EXIR and not mask_as_input: + pytest.xfail( + "https://github.com/apple/coremltools/issues/2199: placeholder assertion error" + ) + embedding_size = 32 seq_length = 16 n_heads = 4 @@ -10985,11 +11099,27 @@ def forward(self, x, mask=None): return x model = ToyTransformer() - self.run_compare_torch( + + input_shapes = ( [(batch_size, seq_length, embedding_size), (seq_length, seq_length)] if mask_as_input - else [(batch_size, seq_length, embedding_size)], + else [(batch_size, seq_length, embedding_size)] + ) + if dynamic: + converter_input_type = [ + ct.TensorType( + shape=(ct.RangeDim(upper_bound=16, default=input_shape[0]),) + input_shape[1:] + ) + for input_shape in input_shapes + ] + else: + converter_input_type = None + + self.run_compare_torch( + input_shapes, model, + converter_input_type=converter_input_type, + frontend=frontend, backend=backend, compute_unit=compute_unit, minimum_deployment_target=minimum_deployment_target, diff --git a/coremltools/converters/mil/mil/passes/defs/optimize_quantization.py b/coremltools/converters/mil/mil/passes/defs/optimize_quantization.py index cd66e421e..47ac7ac2b 100644 --- a/coremltools/converters/mil/mil/passes/defs/optimize_quantization.py +++ b/coremltools/converters/mil/mil/passes/defs/optimize_quantization.py @@ -26,23 +26,28 @@ class merge_affine_dequantize_with_consecutive_ops(AbstractGraphPass): """ This graph pass does const folding to a chain of supported ops starts with a ``constexpr_affine_dequantize`` op. More types of op are supported when quantization - is tensor-wise, and only a subset is supported for channel-wise + is tensor-wise, and only a subset is supported for channel-wise. For example - For example: - Input graph: - data -> constexpr_affine_dequantize -> transpose -> expand_dims -> out + .. code-block:: + + Input graph: + data -> constexpr_affine_dequantize -> transpose -> expand_dims -> out - Output graph: - new_data -> constexpr_affine_dequantize -> out + Output graph: + new_data -> constexpr_affine_dequantize -> out where ``new_data`` is computed by ``data -> transpose -> expand_dims``. Note that, the graph pass only supports const folding of a single linked list pattern. - For example, the following pattern will not be changed: + For example, the following pattern will not be changed + + .. code-block:: + + |-> constexpr_affine_dequantize -> transpose -> out + data -| + |-> constexpr_affine_dequantize -> reshape -> out_2 - data ---> constexpr_affine_dequantize -> transpose -> out - | - --> constexpr_affine_dequantize -> reshape -> out_2 + since the quantized data is used by multiple ``constexpr`` """ SUPPORTED_OP_TYPES_PER_TENSOR = { diff --git a/coremltools/converters/mil/mil/program.py b/coremltools/converters/mil/mil/program.py index 9a1da7cf8..2d142e9e4 100644 --- a/coremltools/converters/mil/mil/program.py +++ b/coremltools/converters/mil/mil/program.py @@ -252,11 +252,15 @@ def construct_debug_handle_to_ops_mapping(self) -> Dict: """ debug_handle_to_ops_mapping = {} for function_name, function in self.functions.items(): + if ScopeSource.EXIR_DEBUG_HANDLE not in function._essential_scope_sources: + raise NotImplementedError( + f"Function ({function_name}) must have EXIR_DEBUG_HANDLE as an essential scope source." + ) for operation in function.operations: # TODO (rdar://115846569): Handle multi-block case from EXIR if len(operation.blocks) > 0: raise NotImplementedError("Multi-block case has not been supported yet") - debug_handle = operation.scopes.get(ScopeSource.EXIR_DEBUG_HANDLE) + debug_handle = operation.scopes[ScopeSource.EXIR_DEBUG_HANDLE] if debug_handle is None: continue debug_handle = debug_handle[0] diff --git a/coremltools/converters/mil/mil/types/type_tensor.py b/coremltools/converters/mil/mil/types/type_tensor.py index bee987b1d..6782cbf6a 100644 --- a/coremltools/converters/mil/mil/types/type_tensor.py +++ b/coremltools/converters/mil/mil/types/type_tensor.py @@ -9,9 +9,15 @@ from coremltools import _logger as logger from .get_type_info import get_type_info -from .type_mapping import (builtin_to_string, is_subtype, is_tensor, - nptype_from_builtin, numpy_type_to_builtin_type, - promote_types) +from .symbolic import is_symbolic +from .type_mapping import ( + builtin_to_string, + is_subtype, + is_tensor, + nptype_from_builtin, + numpy_type_to_builtin_type, + promote_types, +) from .type_spec import Type @@ -184,6 +190,8 @@ def is_tensor_and_is_compatible(tensor_type1, tensor_type2, allow_promotion=Fals most_specific_shape.append(shape1[i]) elif shape1[i] == shape2[i]: most_specific_shape.append(shape1[i]) + elif is_symbolic(shape1[i]) or is_symbolic(shape2[i]): + most_specific_shape.append(shape1[i] if is_symbolic(shape2[i]) else shape2[i]) elif shape1[i] != shape2[i]: return False, None diff --git a/coremltools/models/model.py b/coremltools/models/model.py index 50fe739bf..ac3dcebbc 100644 --- a/coremltools/models/model.py +++ b/coremltools/models/model.py @@ -21,6 +21,7 @@ from coremltools import proto as _proto from coremltools._deps import _HAS_TF_1, _HAS_TF_2, _HAS_TORCH from coremltools.converters.mil.mil.program import Program as _Program +from coremltools.converters.mil.mil.scope import ScopeSource as _ScopeSource from .utils import ( _MLMODEL_EXTENSION, @@ -518,7 +519,11 @@ def save(self, save_path: str): ) _shutil.copytree(self.package_path, save_path) - if self._mil_program is not None: + if self._mil_program is not None and all( + [ + _ScopeSource.EXIR_DEBUG_HANDLE in function._essential_scope_sources for function in self._mil_program.functions.values() + ] + ): debug_handle_to_ops_mapping = ( self._mil_program.construct_debug_handle_to_ops_mapping() ) diff --git a/coremltools/optimize/coreml/_quantization_passes.py b/coremltools/optimize/coreml/_quantization_passes.py index 0338c2e4c..a87160d8d 100644 --- a/coremltools/optimize/coreml/_quantization_passes.py +++ b/coremltools/optimize/coreml/_quantization_passes.py @@ -37,29 +37,6 @@ OptimizationConfig, ) -""" --------------------------------- -Compression parameters wrapper - --------------------------------- -""" -class SparseParams: - def __init__(self, nonzero_data=None, mask=None, shape=None): - self.nonzero_data = nonzero_data - self.mask = mask - self.shape = shape - -class LutParams: - def __init__(self, lut=None, indices=None, shape=None): - self.lut = lut - self.indices = indices - self.shape = shape - -class AffineQuantParams: - def __init__(self, quantized_data=None, zero_point=None, scale=None, axis=None): - self.quantized_data = quantized_data - self.zero_point = zero_point - self.scale = scale - self.axis = axis """ ------------------------ @@ -254,10 +231,11 @@ class prune_weights(AbstractCompressionPass): @staticmethod def _pack_val_to_sparse_param(val): flattened_val = val.flatten() - params = SparseParams() - params.nonzero_data = flattened_val[np.where(flattened_val != 0)] - params.mask = np.packbits(np.where(flattened_val != 0, 1, 0), bitorder="little") - params.shape = val.shape + params = _utils.SparseParams( + nonzero_data=flattened_val[np.where(flattened_val != 0)], + mask=np.packbits(np.where(flattened_val != 0, 1, 0), bitorder="little"), + shape=val.shape, + ) return params @staticmethod @@ -399,12 +377,12 @@ def compress_by_nm_sparsity(val, n_m_ratio, dim): @staticmethod def decompress(params): - if not isinstance(params, SparseParams): + if not isinstance(params, _utils.SparseParams): raise ValueError("Invalid type of params") return constexpr_sparse_to_dense.decompress(params.nonzero_data, params.mask, params.shape) @staticmethod - def _create_constexpr_var(op: Operation, sparse_params: SparseParams) -> Var: + def _create_constexpr_var(op: Operation, sparse_params: _utils.SparseParams) -> Var: return mb.constexpr_sparse_to_dense( nonzero_data=sparse_params.nonzero_data, mask=sparse_params.mask, @@ -574,7 +552,7 @@ def compress_unique(val, nbits): return lut, indices @staticmethod - def compress(val, mode, nbits=None, lut_function=None) -> LutParams: + def compress(val, mode, nbits=None, lut_function=None) -> _utils.LutParams: def check_lut_parameters_are_valid(val, lut, indices): if not isinstance(lut, np.ndarray) or not isinstance(indices, np.ndarray): raise ValueError("LUT and indices must be type of numpy array.") @@ -604,20 +582,21 @@ def check_lut_parameters_are_valid(val, lut, indices): check_lut_parameters_are_valid(val, lut, indices) - params = LutParams() - params.lut = lut - params.shape = val.shape - params.indices = pack_elements_into_bits(indices, int(np.log2(lut.shape[0]))) + params = _utils.LutParams( + lut=lut, + indices=pack_elements_into_bits(indices, int(np.log2(lut.shape[0]))), + shape=val.shape, + ) return params @staticmethod def decompress(params): - if not isinstance(params, LutParams): + if not isinstance(params, _utils.LutParams): raise ValueError("Invalid type of params") return constexpr_lut_to_dense.decompress(params.lut, params.indices, params.shape) @staticmethod - def _create_constexpr_var(op: Operation, lut_params: LutParams) -> Var: + def _create_constexpr_var(op: Operation, lut_params: _utils.LutParams) -> Var: return mb.constexpr_lut_to_dense( indices=lut_params.indices, lut=lut_params.lut, @@ -742,7 +721,9 @@ def _get_quantized_data( return quantized_data, scale, zero_point @classmethod - def compress(cls, val: np.ndarray, axis: int, mode: str, dtype: type) -> AffineQuantParams: + def compress( + cls, val: np.ndarray, axis: int, mode: str, dtype: type + ) -> _utils.AffineQuantParams: if not isinstance(val, (np.ndarray, np.generic)): raise ValueError("Only numpy arrays are supported") if isinstance(dtype, np.dtype): @@ -763,11 +744,11 @@ def compress(cls, val: np.ndarray, axis: int, mode: str, dtype: type) -> AffineQ if zero_point is None: # The iOS16 constexpr_affine_dequantize op requires zero_point. zero_point = np.zeros_like(scale).astype(quantized_data.dtype) - return AffineQuantParams(quantized_data, zero_point, scale, axis) + return _utils.AffineQuantParams(quantized_data, zero_point, scale, axis) @staticmethod - def decompress(params: AffineQuantParams) -> np.ndarray: - if not isinstance(params, AffineQuantParams): + def decompress(params: _utils.AffineQuantParams) -> np.ndarray: + if not isinstance(params, _utils.AffineQuantParams): raise ValueError("Invalid type of params") return constexpr_affine_dequantize.decompress( params.quantized_data, params.zero_point, params.scale, params.axis diff --git a/coremltools/optimize/coreml/_utils.py b/coremltools/optimize/coreml/_utils.py index 75a953419..347ee343e 100644 --- a/coremltools/optimize/coreml/_utils.py +++ b/coremltools/optimize/coreml/_utils.py @@ -3,8 +3,14 @@ # Use of this source code is governed by a BSD-3-clause license that can be # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause +from collections import namedtuple +from typing import Optional, Tuple + import numpy as np -from typing import Tuple, Optional + +SparseParams = namedtuple("SparseParams", "nonzero_data mask shape") +LutParams = namedtuple("LutParams", "lut indices shape") +AffineQuantParams = namedtuple("AffineQuantParams", "quantized_data zero_point scale axis") def get_quant_range(n_bits: int, signed: bool, mode: str) -> Tuple[int, int]: diff --git a/coremltools/test/optimize/coreml/test_post_training_quantization.py b/coremltools/test/optimize/coreml/test_post_training_quantization.py index e799fcec5..dacae90e2 100644 --- a/coremltools/test/optimize/coreml/test_post_training_quantization.py +++ b/coremltools/test/optimize/coreml/test_post_training_quantization.py @@ -4,6 +4,7 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause import itertools +from typing import Tuple import numpy as np import pytest @@ -15,6 +16,7 @@ from coremltools.converters.mil.mil import Builder as mb from coremltools.converters.mil.mil import types from coremltools.converters.mil.testing_utils import get_op_types_in_program +from coremltools.optimize.coreml import _utils as optimize_utils from coremltools.optimize.coreml._post_training_quantization import CoreMLWeightMetaData from coremltools.test.ml_program.test_compression import get_test_model_and_data @@ -114,6 +116,28 @@ def create_sparse_weight(weight, target_sparsity): return np.reshape(weight, shape).astype(np.float32) +def create_quantize_friendly_weight( + weight: np.ndarray, nbits: int, signed: bool +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Create quantize friendly weight by first quantize and then de-quantize the weight.""" + axes = tuple(axis for axis in range(len(weight.shape)) if axis != 0) + quantized_weight, scale, zero_point = optimize_utils.quantize_weight( + weight, + axes, + nbits, + signed, + quantization_mode="LINEAR", + dtype=np.int8 if signed else np.uint8, + ) + scale_shape = scale.shape + tuple([1] * len(axes)) + scale = scale.reshape(scale_shape) + zero_point = zero_point.reshape(scale_shape) + dequantized_weight = scale * ( + quantized_weight.astype(np.float32) - zero_point.astype(np.float32) + ) + return dequantized_weight, scale, zero_point + + def verify_model_outputs(model, compressed_model, input_values, rtol=1e-7, atol=0): """ This utility functions does the following checks: @@ -551,6 +575,12 @@ def test_weight_pruning_threshold_based(threshold): ) def test_weight_pruning_percentile_based(percentile): model, inputs, torch_input_values, coreml_input_values = get_test_model_and_data() + # Make sure no weight element is randomed to 0, to eliminate testing noise + # e.g. in percentile 0 test case, we would expect no element gets pruned + # if there is no 0 in initial weight + with torch.no_grad(): + non0_weight = torch.where(torch.abs(model.weight) > 1e-6, model.weight, 1e-6) + model.weight.copy_(non0_weight) torchmodel = torch.jit.trace(model, torch_input_values) mlmodel = ct.convert(torchmodel, inputs=inputs, convert_to="mlprogram") mlmodel_sparsified = prune_weights(mlmodel, mode="percentile_based", target_sparsity=percentile) @@ -567,7 +597,10 @@ def test_weight_pruning_percentile_based(percentile): if percentile == 0.: assert non_sparse_data.val.size == weight.size elif percentile == 0.5: - assert non_sparse_data.val.size <= 0.51 * (weight.size) and non_sparse_data.val.size >= 0.49 * (weight.size) + lower = 0.49 * weight.size + upper = 0.51 * weight.size + actual = non_sparse_data.val.size + assert lower <= actual and actual <= upper else: assert non_sparse_data.val.size == 0 diff --git a/coremltools/version.py b/coremltools/version.py index a5f2975a9..747eb5b28 100644 --- a/coremltools/version.py +++ b/coremltools/version.py @@ -4,4 +4,4 @@ # found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause -__version__ = "7.1.2" # VERSION_STRING +__version__ = "7.2" # VERSION_STRING diff --git a/reqs/docs.pip b/reqs/docs.pip index 2cec67b27..403de41cc 100644 --- a/reqs/docs.pip +++ b/reqs/docs.pip @@ -1,7 +1,7 @@ Babel MarkupSafe Pygments -Sphinx +Sphinx==7.3.1 alabaster certifi chardet