Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

7.2 Release #2196

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 0 additions & 44 deletions coremltools/converters/mil/backend/mil/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,6 @@ def create_valuetype_list(length, elem_shape, dtype):
update_listtype(v_type.listType, length, elem_shape, dtype)
return v_type

def create_valuetype_dict(key_type, value_type):
"""
Return proto.MIL_pb2.ValueType with dict (dictionaryType) set
"""
v_type = proto.MIL_pb2.ValueType()
v_type.dictionaryType.keyType.CopyFrom(types_to_proto(key_type))
v_type.dictionaryType.valueType.CopyFrom(types_to_proto(value_type))
return v_type


def create_valuetype_tensor(shape, data_type):
"""
Return proto.MIL_pb2.ValueType with tensor (TensorType) set.
Expand Down Expand Up @@ -261,40 +251,6 @@ def types_to_proto_primitive(valuetype):
)
return types.BUILTIN_TO_PROTO_TYPES[valuetype]


def types_to_proto(valuetype):
if types.is_tensor(valuetype):
primitive = types_to_proto_primitive(valuetype.get_primitive())
return create_valuetype_tensor(valuetype.get_shape(), primitive)
elif types.is_tuple(valuetype):
v_type = proto.MIL_pb2.ValueType()
t_type = v_type.tupleType
for t in valuetype.T:
new_v_type = t_type.types.add()
new_v_type.CopyFrom(types_to_proto(t))
return v_type
elif types.is_list(valuetype):
elem = valuetype.T[0]
length = valuetype.T[1]
if types.is_tensor(elem):
dtype = types_to_proto_primitive(elem.get_primitive())
elem_shape = elem.get_shape()
elif types.is_scalar(elem):
dtype = types_to_proto_primitive(valuetype)
elem_shape = ()
elif types.is_str(elem):
dtype = types_to_proto_primitive(elem)
elem_shape = ()
else:
raise NotImplementedError("Only list of either tensors or scalars supported. "
"Got element of type {}".format(elem.__type_info__()))
return create_valuetype_list(length=length, elem_shape=elem_shape, dtype=dtype)
elif types.is_dict(valuetype):
return create_valuetype_dict(valuetype.T[0], valuetype.T[1])
else:
return create_valuetype_scalar(types_to_proto_primitive(valuetype))


def _get_offset_by_writing_data(output_var, blob_writer):
if output_var.val.dtype.kind == 'f' and output_var.val.dtype.itemsize == 4:
offset = blob_writer.write_float_data(np.ascontiguousarray(output_var.val.flatten()))
Expand Down
89 changes: 80 additions & 9 deletions coremltools/converters/mil/backend/mil/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
create_immediate_value,
create_list_scalarvalue,
create_scalar_value,
types_to_proto,
create_valuetype_list,
create_valuetype_scalar,
create_valuetype_tensor,
types_to_proto_primitive,
)
from coremltools.converters.mil.backend.nn.load import _set_optional_inputs
Expand Down Expand Up @@ -158,7 +160,7 @@ def translate_const(self, op: Operation) -> proto.MIL_pb2.Operation:
attributes={"name": create_scalar_value(op.name), "val": value},
outputs=[
proto.MIL_pb2.NamedValueType(
name=output_var.name, type=types_to_proto(output_var.sym_type)
name=output_var.name, type=self.types_to_proto(output_var.sym_type)
)
],
)
Expand Down Expand Up @@ -190,12 +192,58 @@ def translate_constexpr(self, op: Operation) -> proto.MIL_pb2.Operation:
attributes=attributes,
outputs=[
proto.MIL_pb2.NamedValueType(
name=output_var.name, type=types_to_proto(output_var.sym_type)
name=output_var.name, type=self.types_to_proto(output_var.sym_type)
)
for output_var in op.outputs
],
)

def create_valuetype_dict(self, key_type: type, value_type: type) -> proto.MIL_pb2.ValueType:
"""
Return proto.MIL_pb2.ValueType with dict (dictionaryType) set
"""
v_type = proto.MIL_pb2.ValueType()
v_type.dictionaryType.keyType.CopyFrom(self.types_to_proto(key_type))
v_type.dictionaryType.valueType.CopyFrom(self.types_to_proto(value_type))
return v_type

def types_to_proto(self, valuetype: type) -> proto.MIL_pb2.ValueType:
"""
Return proto.MIL_pb2.ValueType from PyMIL types.
"""
if types.is_tensor(valuetype):
primitive = types_to_proto_primitive(valuetype.get_primitive())
return create_valuetype_tensor(valuetype.get_shape(), primitive)
elif types.is_tuple(valuetype):
v_type = proto.MIL_pb2.ValueType()
t_type = v_type.tupleType
for t in valuetype.T:
new_v_type = t_type.types.add()
new_v_type.CopyFrom(self.types_to_proto(t))
return v_type
elif types.is_list(valuetype):
elem = valuetype.T[0]
length = valuetype.T[1]
if types.is_tensor(elem):
dtype = types_to_proto_primitive(elem.get_primitive())
elem_shape = elem.get_shape()
elif types.is_scalar(elem):
dtype = types_to_proto_primitive(valuetype)
elem_shape = ()
elif types.is_str(elem):
dtype = types_to_proto_primitive(elem)
elem_shape = ()
else:
raise NotImplementedError(
"Only list of either tensors or scalars supported. "
"Got element of type {}".format(elem.__type_info__())
)
return create_valuetype_list(length=length, elem_shape=elem_shape, dtype=dtype)
elif types.is_dict(valuetype):
return self.create_valuetype_dict(valuetype.T[0], valuetype.T[1])
else:
return create_valuetype_scalar(types_to_proto_primitive(valuetype))

def translate_generic_op(
self, op: Operation, literal_params: Optional[List[str]] = None
) -> proto.MIL_pb2.Operation:
Expand Down Expand Up @@ -228,7 +276,7 @@ def translate_generic_op(
inputs[param_name] = args

outputs = [
proto.MIL_pb2.NamedValueType(name=v.name, type=types_to_proto(v.sym_type))
proto.MIL_pb2.NamedValueType(name=v.name, type=self.types_to_proto(v.sym_type))
for v in op.outputs
]
blocks = None
Expand Down Expand Up @@ -311,14 +359,18 @@ def feeds_to_only_constexprs(op: Operation) -> bool:
literal_params = ["begins", "ends", "end_masks"]
proto_ops.append(self.translate_generic_op(op, literal_params))
else:
proto_ops.append(self.translate_generic_op(op))
# A single pymil op might be decomposed into multiple ops
ops = self.translate_generic_op(op)
if not isinstance(ops, list):
ops = [ops]
proto_ops.extend(ops)

inputs = []
if not isinstance(block, Function):
# Function is subclass of Block, but function's block has no input,
# and hence skipping reading the block inputs.
for var in block.inputs:
proto_type = types_to_proto(var.sym_type)
proto_type = self.types_to_proto(var.sym_type)
inputs.append(proto.MIL_pb2.NamedValueType(name=var.name, type=proto_type))
output_names = [v.name for v in block.outputs]
return proto.MIL_pb2.Block(inputs=inputs, outputs=output_names, operations=proto_ops)
Expand All @@ -331,7 +383,7 @@ def convert_function(self, function: Function, opset: str) -> proto.MIL_pb2.Func

inputs = []
for name, var in function.inputs.items():
proto_type = types_to_proto(var.sym_type)
proto_type = self.types_to_proto(var.sym_type)
inputs.append(proto.MIL_pb2.NamedValueType(name=name, type=proto_type))

return proto.MIL_pb2.Function(
Expand Down Expand Up @@ -467,6 +519,15 @@ def get_additional_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
"""
return {}

@staticmethod
def _try_convert_other_input_type(
input_var: Var, input_features: List[proto.Model_pb2.FeatureDescription]
) -> bool:
"""
Try to convert an input var with additional type.
"""
return False

def get_func_input(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDescription]:
"""
Utils to get function input feature description.
Expand Down Expand Up @@ -554,7 +615,7 @@ def get_func_input(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDesc
input_features.append(
proto.Model_pb2.FeatureDescription(name=var.name, type=input_feature_type)
)
else:
elif not self._try_convert_other_input_type(var, input_features):
raise NotImplementedError(f"Unsupported input type {var.sym_type}.")

if not is_input_shape_symbolic:
Expand Down Expand Up @@ -746,6 +807,16 @@ def get_func_output(self, func: mil.Function) -> List[proto.Model_pb2.FeatureDes

return output_features

def create_model_description(
self,
input_features: List[proto.Model_pb2.FeatureDescription],
output_features: List[proto.Model_pb2.FeatureDescription],
) -> proto.Model_pb2.ModelDescription:
"""
Create model description from input and output features
"""
return proto.Model_pb2.ModelDescription(input=input_features, output=output_features)

def get_coreml_model(
self,
input: Dict[str, List[proto.Model_pb2.FeatureDescription]],
Expand All @@ -758,7 +829,7 @@ def get_coreml_model(
# Model description
input_features = input[self._DEFAULT_FUNCTION_NAME]
output_features = output[self._DEFAULT_FUNCTION_NAME]
desc = proto.Model_pb2.ModelDescription(input=input_features, output=output_features)
desc = self.create_model_description(input_features, output_features)

if self.classifier_config is not None:
desc.predictedFeatureName = self.predicted_feature_name
Expand Down
3 changes: 2 additions & 1 deletion coremltools/converters/mil/backend/nn/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ImageType,
RangeDim,
Shape,
TensorType,
)
from coremltools.converters.mil.mil import types
from coremltools.converters.mil.mil.types.symbolic import any_symbolic, any_variadic, is_symbolic
Expand Down Expand Up @@ -169,7 +170,7 @@ def _set_optional_inputs(proto, input_types):
# Set default values for optional input_types
default_map = {}
for input_type in input_types:
if isinstance(input_type, ImageType):
if not isinstance(input_type, TensorType):
continue
if input_type.default_value is not None:
default_map[input_type.name] = input_type.default_value
Expand Down
23 changes: 13 additions & 10 deletions coremltools/converters/mil/frontend/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,33 +512,36 @@ def _concat_dims(dims, none_if_empty=False):
return ab


def _lower_scaled_dot_product_attention(q: Var, k: Var, v: Var, mask: Var, name: str) -> Var:
def _lower_scaled_dot_product_attention(
q: Var, k: Var, v: Var, mask: Var, name: str, before_op: Optional[Operation] = None
) -> Var:
# scale the query input
embed_size = q.shape[-1]
if is_symbolic(embed_size):
raise ValueError(
"The embedding size, i.e. last dimension of the shape of query tensor"
" cannot be symbolic, in scaled_dot_product_attention op"
)

q, k, v = promote_input_dtypes([q, k, v])
multiplicative_scale_factor = 1 / math.sqrt(embed_size)
q, k, v, multiplicative_scale_factor = promote_input_dtypes(
[q, k, v, multiplicative_scale_factor]
)
q = mb.mul(x=q, y=multiplicative_scale_factor)
if types.builtin_to_string(q.dtype) == "fp16":
multiplicative_scale_factor = _np.float16(multiplicative_scale_factor)
q = mb.mul(x=q, y=multiplicative_scale_factor, before_op=before_op)

# multiply query and key input tensors
# shape of output: (target_seq, source_seq) or (B,...,target_seq, source_seq)
attn_weights = mb.matmul(x=q, y=k, transpose_y=True)
attn_weights = mb.matmul(x=q, y=k, transpose_y=True, before_op=before_op)

# add mask if applicable
if mask is not None:
attn_weights = mb.add(x=attn_weights, y=mask)
attn_weights = mb.add(x=attn_weights, y=mask, before_op=before_op)

# do softmax
attn_weights_normalized = mb.softmax(x=attn_weights, axis=-1)
attn_weights_normalized = mb.softmax(x=attn_weights, axis=-1, before_op=before_op)

# multiply attn_weights and value tensor
res = mb.matmul(x=attn_weights_normalized, y=v, name=name)
res = mb.matmul(x=attn_weights_normalized, y=v, name=name, before_op=before_op)
return res


Expand All @@ -549,7 +552,7 @@ def _construct_constexpr_affine_op(
axis: Optional[Union[Var, int]] = None,
name: Optional[str] = None,
before_op: Optional[Operation] = None,
) -> Operation:
) -> Var:
"""Constructs the constexpr op to represent the dequantized weight from PyTorch's data."""
# The constexpr_affine_dequantize op requires axis.
if axis is None:
Expand Down
17 changes: 16 additions & 1 deletion coremltools/converters/mil/frontend/torch/exir_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,27 @@ def extract_inputs_from_exir_program(
val = node.meta["val"]
assert isinstance(val, torch.Tensor), "placeholder val must be a tensor or fake tensor"
user_inputs.append(to_coreml_tensor_type(node.name, val))

elif input_spec.kind == torch.export.graph_signature.InputKind.PARAMETER:
lifted_parameters[input_spec.arg.name] = parameters[input_spec.target]

elif input_spec.kind == torch.export.graph_signature.InputKind.BUFFER:
lifted_buffers[input_spec.arg.name] = buffers[input_spec.target]
# This is a workaround on mutable buffer: Core ML does not support stateful execution,
# so ExecuTorch will pass mutable buffers as inputs/outputs to Core ML delegation,
# then in-place copy Core ML outputs into buffers
# On Core ML side, we do not have to do anything special with outputs,
# but for inputs we will need to identify ExecuTorch lifted mutable buffers
# as Core ML user inputs
if input_spec.target in exported_program.graph_signature.buffers_to_mutate.values():
user_inputs.append(
to_coreml_tensor_type(input_spec.arg.name, buffers[input_spec.target])
)
else:
lifted_buffers[input_spec.arg.name] = buffers[input_spec.target]

elif input_spec.kind == torch.export.graph_signature.InputKind.CONSTANT_TENSOR:
lifted_constants[input_spec.arg.name] = exported_program.constants[input_spec.target]

else:
raise NotImplementedError(
"Only 4 types of inputs handled yet: user input, parameter, buffer, constant. "
Expand Down
12 changes: 3 additions & 9 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1678,7 +1678,7 @@ def view(context, node):
x = inputs[0]
shape = inputs[1]

if np.prod(shape.shape) == 0:
if isinstance(shape, Var) and np.prod(shape.shape) == 0:
# Reshape to empty shape (works only for scalar) is a no op
assert (
np.prod(x.shape) <= 1
Expand Down Expand Up @@ -6694,21 +6694,15 @@ def _get_causal_attn_mask(is_causal: bool, query_var: Var, key_var: Var) -> Var:

def _cast_bool_attn_mask(attn_mask: Var, query_var: Var) -> Var:
"""
compute float mask as:
mask = cast(bool_mask) + (1-cast(bool_mask)) * -30k*ones(shape(bool_mask))
compute float mask as (1 - cast(bool_mask)) * -30k
"""
assert is_bool(attn_mask.dtype)

shape = mb.shape(x=attn_mask)
negative_inf = mb.fill(
shape=shape, value=_np.array([-3e4]).astype(types.nptype_from_builtin(query_var.dtype))
)
mask = mb.cast(x=attn_mask, dtype=types.builtin_to_string(query_var.dtype))
compliment_of_mask = mb.sub(
x=_np.array([1.0]).astype(types.nptype_from_builtin(mask.dtype)), y=mask
)
compliment_of_mask = mb.mul(x=negative_inf, y=compliment_of_mask)
return mb.add(x=mask, y=compliment_of_mask)
return mb.mul(x=-3e4, y=compliment_of_mask)

@register_torch_op
def scaled_dot_product_attention(context, node):
Expand Down
Loading