Skip to content

Commit

Permalink
Merge pull request #194 from Fatal1ty/numbers-encoding
Browse files Browse the repository at this point in the history
Improve union encoding performance
  • Loading branch information
Fatal1ty authored Mar 9, 2024
2 parents 8799fa4 + 7499e87 commit f7a70e7
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 19 deletions.
62 changes: 47 additions & 15 deletions mashumaro/core/meta/types/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import (
Any,
Callable,
Dict,
ForwardRef,
List,
Optional,
Expand All @@ -29,6 +30,7 @@
get_class_that_defines_method,
get_function_return_annotation,
get_literal_values,
get_type_origin,
is_final,
is_generic,
is_literal,
Expand Down Expand Up @@ -292,14 +294,50 @@ def pack_union(
lines.append(f"def {method_name}({method_args}, {default_kwargs}):")
else:
lines.append(f"def {method_name}({method_args}):")
packers: List[str] = []
packer_arg_types: Dict[str, List[Type]] = {}
for type_arg in args:
packer = PackerRegistry.get(
spec.copy(type=type_arg, expression="value")
)
if packer not in packers:
if packer == "value":
packers.insert(0, packer)
else:
packers.append(packer)
packer_arg_types.setdefault(packer, []).append(type_arg)

if len(packers) == 1 and packers[0] == "value":
return spec.expression

with lines.indent():
for packer in (
PackerRegistry.get(spec.copy(type=type_arg, expression="value"))
for type_arg in args
):
with lines.indent("try:"):
lines.append(f"return {packer}")
lines.append("except Exception: pass")
for packer in packers:
packer_arg_type_names = []
for packer_arg_type in packer_arg_types[packer]:
if is_generic(packer_arg_type):
packer_arg_type = get_type_origin(packer_arg_type)
packer_arg_type_name = clean_id(type_name(packer_arg_type))
spec.builder.ensure_object_imported(
packer_arg_type, packer_arg_type_name
)
if packer_arg_type_name not in packer_arg_type_names:
packer_arg_type_names.append(packer_arg_type_name)
if len(packer_arg_type_names) > 1:
packer_arg_type_check = (
f"in ({', '.join(packer_arg_type_names)})"
)
else:
packer_arg_type_check = f"is {packer_arg_type_names[0]}"
if packer == "value":
with lines.indent(
f"if value.__class__ {packer_arg_type_check}:"
):
lines.append(f"return {packer}")
else:
with lines.indent("try:"):
lines.append(f"return {packer}")
with lines.indent("except Exception:"):
lines.append("pass")
field_type = spec.builder.get_type_name_identifier(
typ=spec.type,
resolved_type_params=spec.builder.get_field_resolved_type_params(
Expand Down Expand Up @@ -484,14 +522,8 @@ def pack_special_typing_primitive(spec: ValueSpec) -> Optional[Expression]:


@register
def pack_number(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (int, float):
return f"{type_name(spec.origin_type)}({spec.expression})"


@register
def pack_bool_and_none(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (bool, NoneType, None):
def pack_number_and_bool_and_none(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (int, float, bool, NoneType, None):
return spec.expression


Expand Down
2 changes: 1 addition & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_debug_true_option(mocker):

@dataclass
class _(DataClassDictMixin):
union: Union[int, str]
union: Union[int, str, MyNamedTuple]
typed_dict: TypedDictRequiredKeys
named_tuple: MyNamedTupleWithDefaults
literal: Literal[1, 2, 3]
Expand Down
3 changes: 0 additions & 3 deletions tests/test_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,9 +1467,6 @@ class DataClass(DataClassDictMixin):
assert DataClass.from_dict({"x": 42}) == obj
assert obj.to_dict() == {"x": 42, "y": None, "z": 42}

with pytest.raises(TypeError):
DataClass(x=42, z=None).to_dict()


def test_dataclass_with_optional_list_with_optional_ints():
@dataclass
Expand Down
11 changes: 11 additions & 0 deletions tests/test_union.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from dataclasses import dataclass
from itertools import permutations
from typing import Any, Dict, List, Union

import pytest

from mashumaro import DataClassDictMixin
from mashumaro.codecs.basic import encode
from tests.utils import same_types


@dataclass
Expand Down Expand Up @@ -32,3 +35,11 @@ class DataClass(DataClassDictMixin):
instance = DataClass(x=test_case.loaded)
assert DataClass.from_dict({"x": test_case.dumped}) == instance
assert instance.to_dict() == {"x": test_case.dumped}


def test_union_encoding():
for variants in permutations((int, float, str, bool)):
for value in (1, 2.0, 3.1, "4", "5.0", True, False):
encoded = encode(value, Union[variants])
assert value == encoded
assert same_types(value, encoded)

0 comments on commit f7a70e7

Please sign in to comment.