diff --git a/mashumaro/core/meta/code/builder.py b/mashumaro/core/meta/code/builder.py index 1aa449c9..4e65ae5f 100644 --- a/mashumaro/core/meta/code/builder.py +++ b/mashumaro/core/meta/code/builder.py @@ -216,6 +216,21 @@ def get_field_types( ) -> typing.Dict[str, typing.Any]: return self.__get_field_types(include_extras=include_extras) + def get_type_name_identifier( + self, + typ: typing.Optional[typing.Type], + resolved_type_params: typing.Optional[ + typing.Dict[typing.Type, typing.Type] + ] = None, + ) -> str: + field_type = type_name(typ, resolved_type_params=resolved_type_params) + + if is_local_type_name(field_type): + field_type = clean_id(field_type) + self.ensure_object_imported(typ, field_type) + + return field_type + @property # type: ignore @lru_cache() def dataclass_fields(self) -> typing.Dict[str, Field]: @@ -1257,17 +1272,13 @@ def build( ) -> FieldUnpackerCodeBlock: default = self.parent.get_field_default(fname) has_default = default is not MISSING - field_type = type_name( + field_type = self.parent.get_type_name_identifier( ftype, resolved_type_params=self.parent.get_field_resolved_type_params( fname ), ) - if is_local_type_name(field_type): - field_type = clean_id(field_type) - self.parent.ensure_object_imported(ftype, field_type) - could_be_none = ( ftype in (typing.Any, type(None), None) or is_type_var_any(self.parent.get_real_type(fname, ftype)) diff --git a/mashumaro/core/meta/types/pack.py b/mashumaro/core/meta/types/pack.py index cdbb5feb..1e181ae3 100644 --- a/mashumaro/core/meta/types/pack.py +++ b/mashumaro/core/meta/types/pack.py @@ -32,7 +32,6 @@ is_final, is_generic, is_literal, - is_local_type_name, is_named_tuple, is_new_type, is_not_required, @@ -301,17 +300,13 @@ def pack_union( with lines.indent("try:"): lines.append(f"return {packer}") lines.append("except Exception: pass") - field_type = type_name( + field_type = spec.builder.get_type_name_identifier( spec.type, resolved_type_params=spec.builder.get_field_resolved_type_params( spec.field_ctx.name ), ) - if is_local_type_name(field_type): - field_type = clean_id(field_type) - spec.builder.ensure_object_imported(spec.type, field_type) - if spec.builder.is_nailed: lines.append( "raise InvalidFieldValue(" @@ -361,17 +356,11 @@ def pack_literal(spec: ValueSpec) -> Expression: spec.copy(type=value_type, expression="value") ) if isinstance(literal_value, enum.Enum): - enum_type_name = type_name( + enum_type_name = spec.builder.get_type_name_identifier( typ=value_type, resolved_type_params=resolved_type_params, ) - if is_local_type_name(enum_type_name): - enum_type_name = clean_id(enum_type_name) - spec.builder.ensure_object_imported( - value_type, enum_type_name - ) - with lines.indent( f"if value == {enum_type_name}.{literal_value.name}:" ): @@ -382,15 +371,11 @@ def pack_literal(spec: ValueSpec) -> Expression: ): with lines.indent(f"if value == {literal_value!r}:"): lines.append(f"return {packer}") - field_type = type_name( + field_type = spec.builder.get_type_name_identifier( typ=spec.type, resolved_type_params=resolved_type_params, ) - if is_local_type_name(field_type): - field_type = clean_id(field_type) - spec.builder.ensure_object_imported(spec.type, field_type) - if spec.builder.is_nailed: lines.append( f"raise InvalidFieldValue('{spec.field_ctx.name}'," diff --git a/mashumaro/core/meta/types/unpack.py b/mashumaro/core/meta/types/unpack.py index 32ac6d8e..976cb0a2 100644 --- a/mashumaro/core/meta/types/unpack.py +++ b/mashumaro/core/meta/types/unpack.py @@ -39,7 +39,6 @@ is_final, is_generic, is_literal, - is_local_type_name, is_named_tuple, is_new_type, is_not_required, @@ -174,17 +173,13 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None: with lines.indent("try:"): lines.append(f"return {unpacker}") lines.append("except Exception: pass") - field_type = type_name( + field_type = spec.builder.get_type_name_identifier( spec.type, resolved_type_params=spec.builder.get_field_resolved_type_params( spec.field_ctx.name ), ) if spec.builder.is_nailed: - if is_local_type_name(field_type): - field_type = clean_id(field_type) - spec.builder.ensure_object_imported(spec.type, field_type) - lines.append( "raise InvalidFieldValue(" f"'{spec.field_ctx.name}',{field_type},value,cls)" @@ -209,13 +204,9 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None: for literal_value in get_literal_values(spec.type): if isinstance(literal_value, enum.Enum): lit_type = type(literal_value) - enum_type_name = type_name(lit_type) - - if is_local_type_name(enum_type_name): - enum_type_name = clean_id(enum_type_name) - spec.builder.ensure_object_imported( - lit_type, enum_type_name - ) + enum_type_name = spec.builder.get_type_name_identifier( + lit_type + ) with lines.indent( f"if value == {enum_type_name}.{literal_value.name}.value:" @@ -311,11 +302,7 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None: variants_map = self._get_variants_map(spec) variants_attr_holder = self._get_variants_attr_holder(spec) variants = self._get_variant_names_iterable(spec) - variants_type_expr = type_name(spec.type) - - if is_local_type_name(variants_type_expr): - variants_type_expr = clean_id(variants_type_expr) - spec.builder.ensure_object_imported(spec.type, variants_type_expr) + variants_type_expr = spec.builder.get_type_name_identifier(spec.type) if variants_attr not in variants_attr_holder.__dict__: setattr(variants_attr_holder, variants_attr, {}) @@ -583,11 +570,7 @@ def _unpack_annotated_serializable_type( ) unpacker = UnpackerRegistry.get(spec.copy(type=value_type)) - field_type = type_name(spec.type) - - if is_local_type_name(field_type): - field_type = clean_id(field_type) - spec.builder.ensure_object_imported(spec.type, field_type) + field_type = spec.builder.get_type_name_identifier(spec.type) return f"{field_type}._deserialize({unpacker})" @@ -602,10 +585,7 @@ def unpack_serializable_type(spec: ValueSpec) -> Optional[Expression]: if spec.origin_type.__use_annotations__: return _unpack_annotated_serializable_type(spec) else: - field_type = type_name(spec.type) - if is_local_type_name(field_type): - field_type = clean_id(field_type) - spec.builder.ensure_object_imported(spec.type, field_type) + field_type = spec.builder.get_type_name_identifier(spec.type) return f"{field_type}._deserialize({spec.expression})" @@ -617,11 +597,9 @@ def unpack_generic_serializable_type(spec: ValueSpec) -> Optional[Expression]: type_arg_names = ", ".join( list(map(type_name, get_args(spec.type))) ) - field_type = type_name(spec.origin_type) - - if is_local_type_name(field_type): - field_type = clean_id(field_type) - spec.builder.ensure_object_imported(spec.type, field_type) + field_type = spec.builder.get_type_name_identifier( + spec.origin_type + ) return ( f"{field_type}._deserialize({spec.expression}, " @@ -1025,10 +1003,7 @@ def unpack_named_tuple(spec: ValueSpec) -> Expression: unpackers.append(unpacker) if not defaults: - field_type = type_name(spec.type) - if is_local_type_name(field_type): - field_type = clean_id(field_type) - spec.builder.ensure_object_imported(spec.type, field_type) + field_type = spec.builder.get_type_name_identifier(spec.type) return f"{field_type}({', '.join(unpackers)})" @@ -1056,10 +1031,7 @@ def unpack_named_tuple(spec: ValueSpec) -> Expression: with lines.indent("except IndexError:"): lines.append("pass") - field_type = type_name(spec.type) - if is_local_type_name(field_type): - field_type = clean_id(field_type) - spec.builder.ensure_object_imported(spec.type, field_type) + field_type = spec.builder.get_type_name_identifier(spec.type) lines.append(f"return {field_type}(*fields)") lines.append( @@ -1240,11 +1212,7 @@ def unpack_pathlike(spec: ValueSpec) -> Optional[Expression]: spec.builder.ensure_module_imported(pathlib) return f"{type_name(pathlib.PurePath)}({spec.expression})" elif issubclass(spec.origin_type, os.PathLike): - field_type = type_name(spec.origin_type) - - if is_local_type_name(field_type): - field_type = clean_id(field_type) - spec.builder.ensure_object_imported(spec.type, field_type) + field_type = spec.builder.get_type_name_identifier(spec.origin_type) return f"{field_type}({spec.expression})" @@ -1252,10 +1220,6 @@ def unpack_pathlike(spec: ValueSpec) -> Optional[Expression]: @register def unpack_enum(spec: ValueSpec) -> Optional[Expression]: if issubclass(spec.origin_type, enum.Enum): - field_type = type_name(spec.origin_type) - - if is_local_type_name(field_type): - field_type = clean_id(field_type) - spec.builder.ensure_object_imported(spec.type, field_type) + field_type = spec.builder.get_type_name_identifier(spec.origin_type) return f"{field_type}({spec.expression})"