Skip to content

Commit

Permalink
use builder method to get safe type identifier
Browse files Browse the repository at this point in the history
  • Loading branch information
mishamsk committed Dec 23, 2023
1 parent 4998a14 commit f33a528
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 73 deletions.
21 changes: 16 additions & 5 deletions mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,21 @@ def get_field_types(
) -> typing.Dict[str, typing.Any]:
return self.__get_field_types(include_extras=include_extras)

def get_type_name_identifier(
self,
typ: typing.Optional[typing.Type],
resolved_type_params: typing.Optional[
typing.Dict[typing.Type, typing.Type]
] = None,
) -> str:
field_type = type_name(typ, resolved_type_params=resolved_type_params)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
self.ensure_object_imported(typ, field_type)

return field_type

@property # type: ignore
@lru_cache()
def dataclass_fields(self) -> typing.Dict[str, Field]:
Expand Down Expand Up @@ -1257,17 +1272,13 @@ def build(
) -> FieldUnpackerCodeBlock:
default = self.parent.get_field_default(fname)
has_default = default is not MISSING
field_type = type_name(
field_type = self.parent.get_type_name_identifier(
ftype,
resolved_type_params=self.parent.get_field_resolved_type_params(
fname
),
)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
self.parent.ensure_object_imported(ftype, field_type)

could_be_none = (
ftype in (typing.Any, type(None), None)
or is_type_var_any(self.parent.get_real_type(fname, ftype))
Expand Down
21 changes: 3 additions & 18 deletions mashumaro/core/meta/types/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
is_final,
is_generic,
is_literal,
is_local_type_name,
is_named_tuple,
is_new_type,
is_not_required,
Expand Down Expand Up @@ -301,17 +300,13 @@ def pack_union(
with lines.indent("try:"):
lines.append(f"return {packer}")
lines.append("except Exception: pass")
field_type = type_name(
field_type = spec.builder.get_type_name_identifier(
spec.type,
resolved_type_params=spec.builder.get_field_resolved_type_params(
spec.field_ctx.name
),
)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)

if spec.builder.is_nailed:
lines.append(
"raise InvalidFieldValue("
Expand Down Expand Up @@ -361,17 +356,11 @@ def pack_literal(spec: ValueSpec) -> Expression:
spec.copy(type=value_type, expression="value")
)
if isinstance(literal_value, enum.Enum):
enum_type_name = type_name(
enum_type_name = spec.builder.get_type_name_identifier(
typ=value_type,
resolved_type_params=resolved_type_params,
)

if is_local_type_name(enum_type_name):
enum_type_name = clean_id(enum_type_name)
spec.builder.ensure_object_imported(
value_type, enum_type_name
)

with lines.indent(
f"if value == {enum_type_name}.{literal_value.name}:"
):
Expand All @@ -382,15 +371,11 @@ def pack_literal(spec: ValueSpec) -> Expression:
):
with lines.indent(f"if value == {literal_value!r}:"):
lines.append(f"return {packer}")
field_type = type_name(
field_type = spec.builder.get_type_name_identifier(
typ=spec.type,
resolved_type_params=resolved_type_params,
)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)

if spec.builder.is_nailed:
lines.append(
f"raise InvalidFieldValue('{spec.field_ctx.name}',"
Expand Down
64 changes: 14 additions & 50 deletions mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
is_final,
is_generic,
is_literal,
is_local_type_name,
is_named_tuple,
is_new_type,
is_not_required,
Expand Down Expand Up @@ -174,17 +173,13 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
with lines.indent("try:"):
lines.append(f"return {unpacker}")
lines.append("except Exception: pass")
field_type = type_name(
field_type = spec.builder.get_type_name_identifier(
spec.type,
resolved_type_params=spec.builder.get_field_resolved_type_params(
spec.field_ctx.name
),
)
if spec.builder.is_nailed:
if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)

lines.append(
"raise InvalidFieldValue("
f"'{spec.field_ctx.name}',{field_type},value,cls)"
Expand All @@ -209,13 +204,9 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
for literal_value in get_literal_values(spec.type):
if isinstance(literal_value, enum.Enum):
lit_type = type(literal_value)
enum_type_name = type_name(lit_type)

if is_local_type_name(enum_type_name):
enum_type_name = clean_id(enum_type_name)
spec.builder.ensure_object_imported(
lit_type, enum_type_name
)
enum_type_name = spec.builder.get_type_name_identifier(
lit_type
)

with lines.indent(
f"if value == {enum_type_name}.{literal_value.name}.value:"
Expand Down Expand Up @@ -311,11 +302,7 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
variants_map = self._get_variants_map(spec)
variants_attr_holder = self._get_variants_attr_holder(spec)
variants = self._get_variant_names_iterable(spec)
variants_type_expr = type_name(spec.type)

if is_local_type_name(variants_type_expr):
variants_type_expr = clean_id(variants_type_expr)
spec.builder.ensure_object_imported(spec.type, variants_type_expr)
variants_type_expr = spec.builder.get_type_name_identifier(spec.type)

if variants_attr not in variants_attr_holder.__dict__:
setattr(variants_attr_holder, variants_attr, {})
Expand Down Expand Up @@ -583,11 +570,7 @@ def _unpack_annotated_serializable_type(
)
unpacker = UnpackerRegistry.get(spec.copy(type=value_type))

field_type = type_name(spec.type)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)
field_type = spec.builder.get_type_name_identifier(spec.type)

return f"{field_type}._deserialize({unpacker})"

Expand All @@ -602,10 +585,7 @@ def unpack_serializable_type(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type.__use_annotations__:
return _unpack_annotated_serializable_type(spec)
else:
field_type = type_name(spec.type)
if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)
field_type = spec.builder.get_type_name_identifier(spec.type)

return f"{field_type}._deserialize({spec.expression})"

Expand All @@ -617,11 +597,9 @@ def unpack_generic_serializable_type(spec: ValueSpec) -> Optional[Expression]:
type_arg_names = ", ".join(
list(map(type_name, get_args(spec.type)))
)
field_type = type_name(spec.origin_type)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)
field_type = spec.builder.get_type_name_identifier(
spec.origin_type
)

return (
f"{field_type}._deserialize({spec.expression}, "
Expand Down Expand Up @@ -1025,10 +1003,7 @@ def unpack_named_tuple(spec: ValueSpec) -> Expression:
unpackers.append(unpacker)

if not defaults:
field_type = type_name(spec.type)
if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)
field_type = spec.builder.get_type_name_identifier(spec.type)

return f"{field_type}({', '.join(unpackers)})"

Expand Down Expand Up @@ -1056,10 +1031,7 @@ def unpack_named_tuple(spec: ValueSpec) -> Expression:
with lines.indent("except IndexError:"):
lines.append("pass")

field_type = type_name(spec.type)
if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)
field_type = spec.builder.get_type_name_identifier(spec.type)

lines.append(f"return {field_type}(*fields)")
lines.append(
Expand Down Expand Up @@ -1240,22 +1212,14 @@ def unpack_pathlike(spec: ValueSpec) -> Optional[Expression]:
spec.builder.ensure_module_imported(pathlib)
return f"{type_name(pathlib.PurePath)}({spec.expression})"
elif issubclass(spec.origin_type, os.PathLike):
field_type = type_name(spec.origin_type)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)
field_type = spec.builder.get_type_name_identifier(spec.origin_type)

return f"{field_type}({spec.expression})"


@register
def unpack_enum(spec: ValueSpec) -> Optional[Expression]:
if issubclass(spec.origin_type, enum.Enum):
field_type = type_name(spec.origin_type)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)
field_type = spec.builder.get_type_name_identifier(spec.origin_type)

return f"{field_type}({spec.expression})"

0 comments on commit f33a528

Please sign in to comment.