Skip to content

Commit

Permalink
implement local support for more types, both pack & unpack
Browse files Browse the repository at this point in the history
add more types to test
  • Loading branch information
mishamsk committed Dec 17, 2023
1 parent 99bbd45 commit 4998a14
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 20 deletions.
6 changes: 3 additions & 3 deletions mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
is_hashable,
is_init_var,
is_literal,
is_local_type,
is_local_type_name,
is_named_tuple,
is_optional,
is_type_var_any,
Expand Down Expand Up @@ -1264,9 +1264,9 @@ def build(
),
)

if is_local_type(ftype):
if is_local_type_name(field_type):
field_type = clean_id(field_type)
self.ensure_object_imported(ftype, field_type)
self.parent.ensure_object_imported(ftype, field_type)

could_be_none = (
ftype in (typing.Any, type(None), None)
Expand Down
4 changes: 2 additions & 2 deletions mashumaro/core/meta/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,8 +389,8 @@ def is_literal(typ: Type) -> bool:
return False


def is_local_type(typ: Type) -> bool:
return "<locals>" in getattr(typ, "__qualname__", "")
def is_local_type_name(type_name: str) -> bool:
return "<locals>" in type_name


def not_none_type_arg(
Expand Down
18 changes: 18 additions & 0 deletions mashumaro/core/meta/types/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
is_final,
is_generic,
is_literal,
is_local_type_name,
is_named_tuple,
is_new_type,
is_not_required,
Expand Down Expand Up @@ -306,6 +307,11 @@ def pack_union(
spec.field_ctx.name
),
)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)

if spec.builder.is_nailed:
lines.append(
"raise InvalidFieldValue("
Expand Down Expand Up @@ -359,6 +365,13 @@ def pack_literal(spec: ValueSpec) -> Expression:
typ=value_type,
resolved_type_params=resolved_type_params,
)

if is_local_type_name(enum_type_name):
enum_type_name = clean_id(enum_type_name)
spec.builder.ensure_object_imported(
value_type, enum_type_name
)

with lines.indent(
f"if value == {enum_type_name}.{literal_value.name}:"
):
Expand All @@ -373,6 +386,11 @@ def pack_literal(spec: ValueSpec) -> Expression:
typ=spec.type,
resolved_type_params=resolved_type_params,
)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)

if spec.builder.is_nailed:
lines.append(
f"raise InvalidFieldValue('{spec.field_ctx.name}',"
Expand Down
74 changes: 66 additions & 8 deletions mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
is_final,
is_generic,
is_literal,
is_local_type_name,
is_named_tuple,
is_new_type,
is_not_required,
Expand Down Expand Up @@ -180,6 +181,10 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
),
)
if spec.builder.is_nailed:
if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)

lines.append(
"raise InvalidFieldValue("
f"'{spec.field_ctx.name}',{field_type},value,cls)"
Expand All @@ -203,7 +208,15 @@ def get_method_prefix(self) -> str:
def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
for literal_value in get_literal_values(spec.type):
if isinstance(literal_value, enum.Enum):
enum_type_name = type_name(type(literal_value))
lit_type = type(literal_value)
enum_type_name = type_name(lit_type)

if is_local_type_name(enum_type_name):
enum_type_name = clean_id(enum_type_name)
spec.builder.ensure_object_imported(
lit_type, enum_type_name
)

with lines.indent(
f"if value == {enum_type_name}.{literal_value.name}.value:"
):
Expand Down Expand Up @@ -300,6 +313,10 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
variants = self._get_variant_names_iterable(spec)
variants_type_expr = type_name(spec.type)

if is_local_type_name(variants_type_expr):
variants_type_expr = clean_id(variants_type_expr)
spec.builder.ensure_object_imported(spec.type, variants_type_expr)

if variants_attr not in variants_attr_holder.__dict__:
setattr(variants_attr_holder, variants_attr, {})
variant_method_name = spec.builder.get_unpack_method_name(
Expand Down Expand Up @@ -565,7 +582,14 @@ def _unpack_annotated_serializable_type(
],
)
unpacker = UnpackerRegistry.get(spec.copy(type=value_type))
return f"{type_name(spec.type)}._deserialize({unpacker})"

field_type = type_name(spec.type)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)

return f"{field_type}._deserialize({unpacker})"


@register
Expand All @@ -578,7 +602,12 @@ def unpack_serializable_type(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type.__use_annotations__:
return _unpack_annotated_serializable_type(spec)
else:
return f"{type_name(spec.type)}._deserialize({spec.expression})"
field_type = type_name(spec.type)
if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)

return f"{field_type}._deserialize({spec.expression})"


@register
Expand All @@ -588,8 +617,14 @@ def unpack_generic_serializable_type(spec: ValueSpec) -> Optional[Expression]:
type_arg_names = ", ".join(
list(map(type_name, get_args(spec.type)))
)
field_type = type_name(spec.origin_type)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)

return (
f"{type_name(spec.type)}._deserialize({spec.expression}, "
f"{field_type}._deserialize({spec.expression}, "
f"[{type_arg_names}])"
)

Expand Down Expand Up @@ -990,7 +1025,12 @@ def unpack_named_tuple(spec: ValueSpec) -> Expression:
unpackers.append(unpacker)

if not defaults:
return f"{type_name(spec.type)}({', '.join(unpackers)})"
field_type = type_name(spec.type)
if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)

return f"{field_type}({', '.join(unpackers)})"

lines = CodeLines()
method_name = (
Expand All @@ -1015,7 +1055,13 @@ def unpack_named_tuple(spec: ValueSpec) -> Expression:
lines.append(f"fields.append({unpacker})")
with lines.indent("except IndexError:"):
lines.append("pass")
lines.append(f"return {type_name(spec.type)}(*fields)")

field_type = type_name(spec.type)
if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)

lines.append(f"return {field_type}(*fields)")
lines.append(
f"setattr({spec.cls_attrs_name}, '{method_name}', {method_name})"
)
Expand Down Expand Up @@ -1194,10 +1240,22 @@ def unpack_pathlike(spec: ValueSpec) -> Optional[Expression]:
spec.builder.ensure_module_imported(pathlib)
return f"{type_name(pathlib.PurePath)}({spec.expression})"
elif issubclass(spec.origin_type, os.PathLike):
return f"{type_name(spec.origin_type)}({spec.expression})"
field_type = type_name(spec.origin_type)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)

return f"{field_type}({spec.expression})"


@register
def unpack_enum(spec: ValueSpec) -> Optional[Expression]:
if issubclass(spec.origin_type, enum.Enum):
return f"{type_name(spec.origin_type)}({spec.expression})"
field_type = type_name(spec.origin_type)

if is_local_type_name(field_type):
field_type = clean_id(field_type)
spec.builder.ensure_object_imported(spec.type, field_type)

return f"{field_type}({spec.expression})"
83 changes: 76 additions & 7 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import dataclasses
from dataclasses import dataclass, field
from enum import Enum
from pathlib import PurePosixPath
from typing import Any, Literal, NamedTuple

import msgpack
import pytest
Expand Down Expand Up @@ -264,15 +267,81 @@ def test_kw_args_when_pos_arg_is_overridden_with_field():
assert loaded.kw1 == 4


def test_local_type():
def test_local_types():
@dataclass
class LocalType:
class LocalDataclassType:
foo: int

class LocalNamedTupleType(NamedTuple):
foo: int

class LocalPathLike(PurePosixPath):
pass

class LocalEnumType(Enum):
FOO = "foo"

class LocalSerializableType(SerializableType):
@classmethod
def _deserialize(self, value):
return LocalSerializableType()

def _serialize(self) -> Any:
return {}

def __eq__(self, __value: object) -> bool:
return isinstance(__value, LocalSerializableType)

class LocalGenericSerializableType(GenericSerializableType):
@classmethod
def _deserialize(self, value, types):
return LocalGenericSerializableType()

def _serialize(self, types) -> Any:
return {}

def __eq__(self, __value: object) -> bool:
return isinstance(__value, LocalGenericSerializableType)

@dataclass
class DataClassWithLocalType(DataClassDictMixin):
x: LocalType

obj = DataClassWithLocalType(LocalType())
assert obj.to_dict() == {"x": {}}
assert DataClassWithLocalType.from_dict({"x": {}}) == obj
x1: LocalDataclassType
x2: LocalNamedTupleType
x3: LocalPathLike
x4: LocalEnumType
x4_1: Literal[LocalEnumType.FOO]
x5: LocalSerializableType
x6: LocalGenericSerializableType

obj = DataClassWithLocalType(
x1=LocalDataclassType(foo=0),
x2=LocalNamedTupleType(foo=0),
x3=LocalPathLike("path/to/file"),
x4=LocalEnumType.FOO,
x4_1=LocalEnumType.FOO,
x5=LocalSerializableType(),
x6=LocalGenericSerializableType(),
)
assert obj.to_dict() == {
"x1": {"foo": 0},
"x2": [0],
"x3": "path/to/file",
"x4": "foo",
"x4_1": "foo",
"x5": {},
"x6": {},
}
assert (
DataClassWithLocalType.from_dict(
{
"x1": {"foo": 0},
"x2": [0],
"x3": "path/to/file",
"x4": "foo",
"x4_1": "foo",
"x5": {},
"x6": {},
}
)
== obj
)

0 comments on commit 4998a14

Please sign in to comment.