Skip to content

Commit

Permalink
streamable: Fix default value assignments for dataclass_from_dict (#…
Browse files Browse the repository at this point in the history
…11732)

* streamable: Use constructor in `dataclass_from_dict`

This fixes default value assignments after #10561 but also leads to less perfomance due to `__post_init__` being called which at least gets mitigated by #11730.

* tests: Test default values with `from_json_dict`

* Convert to `str`, then compare.
  • Loading branch information
xdustinface authored Jun 3, 2022
1 parent 1dccb68 commit 5c5ab11
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
12 changes: 8 additions & 4 deletions chia/util/streamable.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def dataclass_from_dict(klass: Type[Any], item: Any) -> Any:
"""
if type(item) == klass:
return item
obj = object.__new__(klass)

if klass not in CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS:
# For non-streamable dataclasses we can't populate the cache on startup, so we do it here for convert
# functions only.
Expand All @@ -144,9 +144,13 @@ def dataclass_from_dict(klass: Type[Any], item: Any) -> Any:
fields = FIELDS_FOR_STREAMABLE_CLASS[klass]
convert_funcs = CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[klass]

for field, convert_func in zip(fields, convert_funcs):
object.__setattr__(obj, field.name, convert_func(item[field.name]))
return obj
return klass(
**{
field.name: convert_func(item[field.name])
for field, convert_func in zip(fields, convert_funcs)
if field.name in item
}
)


def function_to_convert_one_item(f_type: Type[Any]) -> ConvertFunctionType:
Expand Down
30 changes: 26 additions & 4 deletions tests/core/util/test_streamable.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import io
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple, Type

import pytest
Expand Down Expand Up @@ -128,15 +128,15 @@ def test_pure_dataclasses_in_dataclass_from_dict() -> None:
"test_class, input_dict, error",
[
[TestDataclassFromDict1, {"a": "asdf", "b": "2", "c": G1Element()}, ValueError],
[TestDataclassFromDict1, {"a": 1, "b": "2"}, KeyError],
[TestDataclassFromDict1, {"a": 1, "b": "2"}, TypeError],
[TestDataclassFromDict1, {"a": 1, "b": "2", "c": "asd"}, ValueError],
[TestDataclassFromDict1, {"a": 1, "b": "2", "c": "00" * G1Element.SIZE}, ValueError],
[TestDataclassFromDict1, {"a": [], "b": "2", "c": G1Element()}, TypeError],
[TestDataclassFromDict1, {"a": {}, "b": "2", "c": G1Element()}, TypeError],
[TestDataclassFromDict2, {"a": "asdf", "b": 1.2345, "c": 1.2345}, TypeError],
[TestDataclassFromDict2, {"a": 1.2345, "b": {"a": 1, "b": "2"}, "c": 1.2345}, TypeError],
[TestDataclassFromDict2, {"a": {"a": 1, "b": "2", "c": G1Element()}, "b": {"a": 1, "b": "2"}}, KeyError],
[TestDataclassFromDict2, {"a": {"a": 1, "b": "2"}, "b": {"a": 1, "b": "2"}, "c": 1.2345}, KeyError],
[TestDataclassFromDict2, {"a": {"a": 1, "b": "2", "c": G1Element()}, "b": {"a": 1, "b": "2"}}, TypeError],
[TestDataclassFromDict2, {"a": {"a": 1, "b": "2"}, "b": {"a": 1, "b": "2"}, "c": 1.2345}, TypeError],
],
)
def test_dataclass_from_dict_failures(test_class: Type[Any], input_dict: Dict[str, Any], error: Any) -> None:
Expand All @@ -145,6 +145,28 @@ def test_dataclass_from_dict_failures(test_class: Type[Any], input_dict: Dict[st
dataclass_from_dict(test_class, input_dict)


@streamable
@dataclass(frozen=True)
class TestFromJsonDictDefaultValues(Streamable):
a: uint64 = uint64(1)
b: str = "default"
c: List[uint64] = field(default_factory=list)


@pytest.mark.parametrize(
"input_dict, output_dict",
[
[{}, {"a": 1, "b": "default", "c": []}],
[{"a": 2}, {"a": 2, "b": "default", "c": []}],
[{"b": "not_default"}, {"a": 1, "b": "not_default", "c": []}],
[{"c": [1, 2]}, {"a": 1, "b": "default", "c": [1, 2]}],
[{"a": 2, "b": "not_default", "c": [1, 2]}, {"a": 2, "b": "not_default", "c": [1, 2]}],
],
)
def test_from_json_dict_default_values(input_dict: Dict[str, object], output_dict: Dict[str, object]) -> None:
assert str(TestFromJsonDictDefaultValues.from_json_dict(input_dict).to_json_dict()) == str(output_dict)


def test_basic_list() -> None:
a = [1, 2, 3]
assert is_type_List(type(a))
Expand Down

0 comments on commit 5c5ab11

Please sign in to comment.