From 5c5ab111ad9900fec4e4f6c2de50b8036dbb19ad Mon Sep 17 00:00:00 2001 From: dustinface <35775977+xdustinface@users.noreply.github.com> Date: Fri, 3 Jun 2022 19:14:57 +0200 Subject: [PATCH] streamable: Fix default value assignments for `dataclass_from_dict` (#11732) * streamable: Use constructor in `dataclass_from_dict` This fixes default value assignments after #10561 but also leads to less perfomance due to `__post_init__` being called which at least gets mitigated by #11730. * tests: Test default values with `from_json_dict` * Convert to `str`, then compare. --- chia/util/streamable.py | 12 ++++++++---- tests/core/util/test_streamable.py | 30 ++++++++++++++++++++++++++---- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/chia/util/streamable.py b/chia/util/streamable.py index acb86e2be685..d65719addc25 100644 --- a/chia/util/streamable.py +++ b/chia/util/streamable.py @@ -132,7 +132,7 @@ def dataclass_from_dict(klass: Type[Any], item: Any) -> Any: """ if type(item) == klass: return item - obj = object.__new__(klass) + if klass not in CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS: # For non-streamable dataclasses we can't populate the cache on startup, so we do it here for convert # functions only. @@ -144,9 +144,13 @@ def dataclass_from_dict(klass: Type[Any], item: Any) -> Any: fields = FIELDS_FOR_STREAMABLE_CLASS[klass] convert_funcs = CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[klass] - for field, convert_func in zip(fields, convert_funcs): - object.__setattr__(obj, field.name, convert_func(item[field.name])) - return obj + return klass( + **{ + field.name: convert_func(item[field.name]) + for field, convert_func in zip(fields, convert_funcs) + if field.name in item + } + ) def function_to_convert_one_item(f_type: Type[Any]) -> ConvertFunctionType: diff --git a/tests/core/util/test_streamable.py b/tests/core/util/test_streamable.py index 223fa26accdb..24c0cbf70243 100644 --- a/tests/core/util/test_streamable.py +++ b/tests/core/util/test_streamable.py @@ -1,7 +1,7 @@ from __future__ import annotations import io -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple, Type import pytest @@ -128,15 +128,15 @@ def test_pure_dataclasses_in_dataclass_from_dict() -> None: "test_class, input_dict, error", [ [TestDataclassFromDict1, {"a": "asdf", "b": "2", "c": G1Element()}, ValueError], - [TestDataclassFromDict1, {"a": 1, "b": "2"}, KeyError], + [TestDataclassFromDict1, {"a": 1, "b": "2"}, TypeError], [TestDataclassFromDict1, {"a": 1, "b": "2", "c": "asd"}, ValueError], [TestDataclassFromDict1, {"a": 1, "b": "2", "c": "00" * G1Element.SIZE}, ValueError], [TestDataclassFromDict1, {"a": [], "b": "2", "c": G1Element()}, TypeError], [TestDataclassFromDict1, {"a": {}, "b": "2", "c": G1Element()}, TypeError], [TestDataclassFromDict2, {"a": "asdf", "b": 1.2345, "c": 1.2345}, TypeError], [TestDataclassFromDict2, {"a": 1.2345, "b": {"a": 1, "b": "2"}, "c": 1.2345}, TypeError], - [TestDataclassFromDict2, {"a": {"a": 1, "b": "2", "c": G1Element()}, "b": {"a": 1, "b": "2"}}, KeyError], - [TestDataclassFromDict2, {"a": {"a": 1, "b": "2"}, "b": {"a": 1, "b": "2"}, "c": 1.2345}, KeyError], + [TestDataclassFromDict2, {"a": {"a": 1, "b": "2", "c": G1Element()}, "b": {"a": 1, "b": "2"}}, TypeError], + [TestDataclassFromDict2, {"a": {"a": 1, "b": "2"}, "b": {"a": 1, "b": "2"}, "c": 1.2345}, TypeError], ], ) def test_dataclass_from_dict_failures(test_class: Type[Any], input_dict: Dict[str, Any], error: Any) -> None: @@ -145,6 +145,28 @@ def test_dataclass_from_dict_failures(test_class: Type[Any], input_dict: Dict[st dataclass_from_dict(test_class, input_dict) +@streamable +@dataclass(frozen=True) +class TestFromJsonDictDefaultValues(Streamable): + a: uint64 = uint64(1) + b: str = "default" + c: List[uint64] = field(default_factory=list) + + +@pytest.mark.parametrize( + "input_dict, output_dict", + [ + [{}, {"a": 1, "b": "default", "c": []}], + [{"a": 2}, {"a": 2, "b": "default", "c": []}], + [{"b": "not_default"}, {"a": 1, "b": "not_default", "c": []}], + [{"c": [1, 2]}, {"a": 1, "b": "default", "c": [1, 2]}], + [{"a": 2, "b": "not_default", "c": [1, 2]}, {"a": 2, "b": "not_default", "c": [1, 2]}], + ], +) +def test_from_json_dict_default_values(input_dict: Dict[str, object], output_dict: Dict[str, object]) -> None: + assert str(TestFromJsonDictDefaultValues.from_json_dict(input_dict).to_json_dict()) == str(output_dict) + + def test_basic_list() -> None: a = [1, 2, 3] assert is_type_List(type(a))