From 441849ab704a33c802ae1d1346e389241bb36e50 Mon Sep 17 00:00:00 2001 From: xdustinface Date: Sat, 5 Mar 2022 02:19:00 +0100 Subject: [PATCH 1/4] streamable: Cache convert functions for dict -> dataclass conversion --- chia/util/streamable.py | 127 ++++++++++++++++++++++++++++++---------- 1 file changed, 96 insertions(+), 31 deletions(-) diff --git a/chia/util/streamable.py b/chia/util/streamable.py index f48ab329f313..6b6e7b89d96a 100644 --- a/chia/util/streamable.py +++ b/chia/util/streamable.py @@ -74,12 +74,14 @@ class DefinitionError(StreamableError): ParseFunctionType = Callable[[BinaryIO], object] StreamFunctionType = Callable[[object, BinaryIO], None] +ConvertFunctionType = Callable[[object], object] # Caches to store the fields and (de)serialization methods for all available streamable classes. FIELDS_FOR_STREAMABLE_CLASS: Dict[Type[object], Dict[str, Type[object]]] = {} STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[StreamFunctionType]] = {} PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ParseFunctionType]] = {} +CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[Any], List[ConvertFunctionType]] = {} def is_type_List(f_type: object) -> bool: @@ -97,45 +99,105 @@ def is_type_Tuple(f_type: object) -> bool: return get_origin(f_type) == tuple or f_type == tuple -def dataclass_from_dict(klass: Type[Any], d: Any) -> Any: +def convert_optional(convert_func: ConvertFunctionType, item: Any) -> Any: + if item is None: + return None + return convert_func(item) + + +def convert_tuple(convert_funcs: List[ConvertFunctionType], items: Tuple[Any, ...]) -> Tuple[Any, ...]: + tuple_data = [] + for i in range(len(items)): + tuple_data.append(convert_funcs[i](items[i])) + return tuple(tuple_data) + + +def convert_list(convert_func: ConvertFunctionType, items: List[Any]) -> List[Any]: + list_data = [] + for item in items: + list_data.append(convert_func(item)) + return list_data + + +def convert_byte_type(f_type: Type[Any], item: Any) -> Any: + if type(item) == f_type: + return item + return f_type(hexstr_to_bytes(item)) + + +def convert_unhashable_type(f_type: Type[Any], item: Any) -> Any: + if type(item) == f_type: + return item + if hasattr(f_type, "from_bytes_unchecked"): + from_bytes_method = f_type.from_bytes_unchecked + else: + from_bytes_method = f_type.from_bytes + return from_bytes_method(hexstr_to_bytes(item)) + + +def convert_primitive(f_type: Type[Any], item: Any) -> Any: + if type(item) == f_type: + return item + return f_type(item) + + +def dataclass_from_dict(klass: Type[Any], item: Any) -> Any: + if type(item) == klass: + return item + obj = object.__new__(klass) + if klass not in CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS: + # For non-streamable dataclasses we can't populate the cache on startup, so we do it here for convert + # functions only. + convert_funcs = [] + hints = get_type_hints(klass) + fields = {field.name: hints.get(field.name, field.type) for field in dataclasses.fields(klass)} + + for _, f_type in fields.items(): + convert_funcs.append(function_to_convert_one_item(f_type)) + + FIELDS_FOR_STREAMABLE_CLASS[klass] = fields + CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[klass] = convert_funcs + else: + fields = FIELDS_FOR_STREAMABLE_CLASS[klass] + convert_funcs = CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[klass] + + for field, convert_func in zip(fields, convert_funcs): + object.__setattr__(obj, field, convert_func(item[field])) + return obj + + +def function_to_convert_one_item(f_type: Type[Any]) -> ConvertFunctionType: """ Converts a dictionary based on a dataclass, into an instance of that dataclass. Recursively goes through lists, optionals, and dictionaries. """ - if is_type_SpecificOptional(klass): - # Type is optional, data is either None, or Any - if d is None: - return None - return dataclass_from_dict(get_args(klass)[0], d) - elif is_type_Tuple(klass): - # Type is tuple, can have multiple different types inside - i = 0 - klass_properties = [] - for item in d: - klass_properties.append(dataclass_from_dict(klass.__args__[i], item)) - i = i + 1 - return tuple(klass_properties) - elif dataclasses.is_dataclass(klass): + if is_type_SpecificOptional(f_type): + convert_inner_func = function_to_convert_one_item(get_args(f_type)[0]) + return lambda item: convert_optional(convert_inner_func, item) + elif is_type_Tuple(f_type): + args = get_args(f_type) + convert_inner_tuple_funcs = [] + for arg in args: + convert_inner_tuple_funcs.append(function_to_convert_one_item(arg)) + # Ignoring for now as the proper solution isn't obvious + return lambda items: convert_tuple(convert_inner_tuple_funcs, items) # type: ignore[arg-type] + elif is_type_List(f_type): + inner_type = get_args(f_type)[0] + convert_inner_func = function_to_convert_one_item(inner_type) + # Ignoring for now as the proper solution isn't obvious + return lambda items: convert_list(convert_inner_func, items) # type: ignore[arg-type] + elif dataclasses.is_dataclass(f_type): # Type is a dataclass, data is a dictionary - hints = get_type_hints(klass) - fieldtypes = {f.name: hints.get(f.name, f.type) for f in dataclasses.fields(klass)} - return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d}) - elif is_type_List(klass): - # Type is a list, data is a list - return [dataclass_from_dict(get_args(klass)[0], item) for item in d] - elif issubclass(klass, bytes): - # Type is bytes, data is a hex string - return klass(hexstr_to_bytes(d)) - elif klass.__name__ in unhashable_types: + return lambda item: dataclass_from_dict(f_type, item) + elif issubclass(f_type, bytes): + # Type is bytes, data is a hex string or bytes + return lambda item: convert_byte_type(f_type, item) + elif f_type.__name__ in unhashable_types: # Type is unhashable (bls type), so cast from hex string - if hasattr(klass, "from_bytes_unchecked"): - from_bytes_method: Callable[[bytes], Any] = klass.from_bytes_unchecked - else: - from_bytes_method = klass.from_bytes - return from_bytes_method(hexstr_to_bytes(d)) + return lambda item: convert_unhashable_type(f_type, item) else: # Type is a primitive, cast with correct class - return klass(d) + return lambda item: convert_primitive(f_type, item) @overload @@ -339,6 +401,7 @@ class Example(Streamable): stream_functions = [] parse_functions = [] + convert_functions = [] try: hints = get_type_hints(cls) fields = {field.name: hints.get(field.name, field.type) for field in dataclasses.fields(cls)} @@ -350,9 +413,11 @@ class Example(Streamable): for _, f_type in fields.items(): stream_functions.append(cls.function_to_stream_one_item(f_type)) parse_functions.append(cls.function_to_parse_one_item(f_type)) + convert_functions.append(function_to_convert_one_item(f_type)) STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = stream_functions PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = parse_functions + CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = convert_functions return cls From 01d87129862f6e19f86a0eda0546ac59c80ed41e Mon Sep 17 00:00:00 2001 From: xdustinface Date: Wed, 9 Mar 2022 21:55:22 +0100 Subject: [PATCH 2/4] tests: Test `dataclass_from_dict` with non-streamable classes --- tests/core/util/test_streamable.py | 56 +++++++++++++++++++++++++++++- 1 file changed, 55 insertions(+), 1 deletion(-) diff --git a/tests/core/util/test_streamable.py b/tests/core/util/test_streamable.py index e750da68fb63..223fa26accdb 100644 --- a/tests/core/util/test_streamable.py +++ b/tests/core/util/test_streamable.py @@ -2,9 +2,10 @@ import io from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type import pytest +from blspy import G1Element from clvm_tools import binutils from typing_extensions import Literal @@ -18,6 +19,7 @@ from chia.util.streamable import ( DefinitionError, Streamable, + dataclass_from_dict, is_type_List, is_type_SpecificOptional, parse_bool, @@ -91,6 +93,58 @@ class TestClassPlain(Streamable): a: PlainClass +@dataclass +class TestDataclassFromDict1: + a: int + b: str + c: G1Element + + +@dataclass +class TestDataclassFromDict2: + a: TestDataclassFromDict1 + b: TestDataclassFromDict1 + c: float + + +def test_pure_dataclasses_in_dataclass_from_dict() -> None: + + d1_dict = {"a": 1, "b": "2", "c": str(G1Element())} + + d1: TestDataclassFromDict1 = dataclass_from_dict(TestDataclassFromDict1, d1_dict) + assert d1.a == 1 + assert d1.b == "2" + assert d1.c == G1Element() + + d2_dict = {"a": d1, "b": d1_dict, "c": 1.2345} + + d2: TestDataclassFromDict2 = dataclass_from_dict(TestDataclassFromDict2, d2_dict) + assert d2.a == d1 + assert d2.b == d1 + assert d2.c == 1.2345 + + +@pytest.mark.parametrize( + "test_class, input_dict, error", + [ + [TestDataclassFromDict1, {"a": "asdf", "b": "2", "c": G1Element()}, ValueError], + [TestDataclassFromDict1, {"a": 1, "b": "2"}, KeyError], + [TestDataclassFromDict1, {"a": 1, "b": "2", "c": "asd"}, ValueError], + [TestDataclassFromDict1, {"a": 1, "b": "2", "c": "00" * G1Element.SIZE}, ValueError], + [TestDataclassFromDict1, {"a": [], "b": "2", "c": G1Element()}, TypeError], + [TestDataclassFromDict1, {"a": {}, "b": "2", "c": G1Element()}, TypeError], + [TestDataclassFromDict2, {"a": "asdf", "b": 1.2345, "c": 1.2345}, TypeError], + [TestDataclassFromDict2, {"a": 1.2345, "b": {"a": 1, "b": "2"}, "c": 1.2345}, TypeError], + [TestDataclassFromDict2, {"a": {"a": 1, "b": "2", "c": G1Element()}, "b": {"a": 1, "b": "2"}}, KeyError], + [TestDataclassFromDict2, {"a": {"a": 1, "b": "2"}, "b": {"a": 1, "b": "2"}, "c": 1.2345}, KeyError], + ], +) +def test_dataclass_from_dict_failures(test_class: Type[Any], input_dict: Dict[str, Any], error: Any) -> None: + + with pytest.raises(error): + dataclass_from_dict(test_class, input_dict) + + def test_basic_list() -> None: a = [1, 2, 3] assert is_type_List(type(a)) From 2c723dc20ad65226c5ed9e4421f4ef9a12859c3e Mon Sep 17 00:00:00 2001 From: dustinface <35775977+xdustinface@users.noreply.github.com> Date: Fri, 20 May 2022 00:41:28 +0200 Subject: [PATCH 3/4] `Any` -> `object` Co-authored-by: Kyle Altendorf --- chia/util/streamable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chia/util/streamable.py b/chia/util/streamable.py index 6b6e7b89d96a..166c4135e810 100644 --- a/chia/util/streamable.py +++ b/chia/util/streamable.py @@ -81,7 +81,7 @@ class DefinitionError(StreamableError): FIELDS_FOR_STREAMABLE_CLASS: Dict[Type[object], Dict[str, Type[object]]] = {} STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[StreamFunctionType]] = {} PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ParseFunctionType]] = {} -CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[Any], List[ConvertFunctionType]] = {} +CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ConvertFunctionType]] = {} def is_type_List(f_type: object) -> bool: From 439ed9c0026d9fa8f1053c886fa2cfe6400e3652 Mon Sep 17 00:00:00 2001 From: xdustinface Date: Fri, 20 May 2022 00:42:57 +0200 Subject: [PATCH 4/4] Move comment into `dataclass_from_dict` --- chia/util/streamable.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/chia/util/streamable.py b/chia/util/streamable.py index 166c4135e810..dbe2517ec860 100644 --- a/chia/util/streamable.py +++ b/chia/util/streamable.py @@ -142,6 +142,10 @@ def convert_primitive(f_type: Type[Any], item: Any) -> Any: def dataclass_from_dict(klass: Type[Any], item: Any) -> Any: + """ + Converts a dictionary based on a dataclass, into an instance of that dataclass. + Recursively goes through lists, optionals, and dictionaries. + """ if type(item) == klass: return item obj = object.__new__(klass) @@ -167,10 +171,6 @@ def dataclass_from_dict(klass: Type[Any], item: Any) -> Any: def function_to_convert_one_item(f_type: Type[Any]) -> ConvertFunctionType: - """ - Converts a dictionary based on a dataclass, into an instance of that dataclass. - Recursively goes through lists, optionals, and dictionaries. - """ if is_type_SpecificOptional(f_type): convert_inner_func = function_to_convert_one_item(get_args(f_type)[0]) return lambda item: convert_optional(convert_inner_func, item)