diff --git a/chia/util/streamable.py b/chia/util/streamable.py index f48ab329f313..dbe2517ec860 100644 --- a/chia/util/streamable.py +++ b/chia/util/streamable.py @@ -74,12 +74,14 @@ class DefinitionError(StreamableError): ParseFunctionType = Callable[[BinaryIO], object] StreamFunctionType = Callable[[object, BinaryIO], None] +ConvertFunctionType = Callable[[object], object] # Caches to store the fields and (de)serialization methods for all available streamable classes. FIELDS_FOR_STREAMABLE_CLASS: Dict[Type[object], Dict[str, Type[object]]] = {} STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[StreamFunctionType]] = {} PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ParseFunctionType]] = {} +CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ConvertFunctionType]] = {} def is_type_List(f_type: object) -> bool: @@ -97,45 +99,105 @@ def is_type_Tuple(f_type: object) -> bool: return get_origin(f_type) == tuple or f_type == tuple -def dataclass_from_dict(klass: Type[Any], d: Any) -> Any: +def convert_optional(convert_func: ConvertFunctionType, item: Any) -> Any: + if item is None: + return None + return convert_func(item) + + +def convert_tuple(convert_funcs: List[ConvertFunctionType], items: Tuple[Any, ...]) -> Tuple[Any, ...]: + tuple_data = [] + for i in range(len(items)): + tuple_data.append(convert_funcs[i](items[i])) + return tuple(tuple_data) + + +def convert_list(convert_func: ConvertFunctionType, items: List[Any]) -> List[Any]: + list_data = [] + for item in items: + list_data.append(convert_func(item)) + return list_data + + +def convert_byte_type(f_type: Type[Any], item: Any) -> Any: + if type(item) == f_type: + return item + return f_type(hexstr_to_bytes(item)) + + +def convert_unhashable_type(f_type: Type[Any], item: Any) -> Any: + if type(item) == f_type: + return item + if hasattr(f_type, "from_bytes_unchecked"): + from_bytes_method = f_type.from_bytes_unchecked + else: + from_bytes_method = f_type.from_bytes + return from_bytes_method(hexstr_to_bytes(item)) + + +def convert_primitive(f_type: Type[Any], item: Any) -> Any: + if type(item) == f_type: + return item + return f_type(item) + + +def dataclass_from_dict(klass: Type[Any], item: Any) -> Any: """ Converts a dictionary based on a dataclass, into an instance of that dataclass. Recursively goes through lists, optionals, and dictionaries. """ - if is_type_SpecificOptional(klass): - # Type is optional, data is either None, or Any - if d is None: - return None - return dataclass_from_dict(get_args(klass)[0], d) - elif is_type_Tuple(klass): - # Type is tuple, can have multiple different types inside - i = 0 - klass_properties = [] - for item in d: - klass_properties.append(dataclass_from_dict(klass.__args__[i], item)) - i = i + 1 - return tuple(klass_properties) - elif dataclasses.is_dataclass(klass): - # Type is a dataclass, data is a dictionary + if type(item) == klass: + return item + obj = object.__new__(klass) + if klass not in CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS: + # For non-streamable dataclasses we can't populate the cache on startup, so we do it here for convert + # functions only. + convert_funcs = [] hints = get_type_hints(klass) - fieldtypes = {f.name: hints.get(f.name, f.type) for f in dataclasses.fields(klass)} - return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d}) - elif is_type_List(klass): - # Type is a list, data is a list - return [dataclass_from_dict(get_args(klass)[0], item) for item in d] - elif issubclass(klass, bytes): - # Type is bytes, data is a hex string - return klass(hexstr_to_bytes(d)) - elif klass.__name__ in unhashable_types: + fields = {field.name: hints.get(field.name, field.type) for field in dataclasses.fields(klass)} + + for _, f_type in fields.items(): + convert_funcs.append(function_to_convert_one_item(f_type)) + + FIELDS_FOR_STREAMABLE_CLASS[klass] = fields + CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[klass] = convert_funcs + else: + fields = FIELDS_FOR_STREAMABLE_CLASS[klass] + convert_funcs = CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[klass] + + for field, convert_func in zip(fields, convert_funcs): + object.__setattr__(obj, field, convert_func(item[field])) + return obj + + +def function_to_convert_one_item(f_type: Type[Any]) -> ConvertFunctionType: + if is_type_SpecificOptional(f_type): + convert_inner_func = function_to_convert_one_item(get_args(f_type)[0]) + return lambda item: convert_optional(convert_inner_func, item) + elif is_type_Tuple(f_type): + args = get_args(f_type) + convert_inner_tuple_funcs = [] + for arg in args: + convert_inner_tuple_funcs.append(function_to_convert_one_item(arg)) + # Ignoring for now as the proper solution isn't obvious + return lambda items: convert_tuple(convert_inner_tuple_funcs, items) # type: ignore[arg-type] + elif is_type_List(f_type): + inner_type = get_args(f_type)[0] + convert_inner_func = function_to_convert_one_item(inner_type) + # Ignoring for now as the proper solution isn't obvious + return lambda items: convert_list(convert_inner_func, items) # type: ignore[arg-type] + elif dataclasses.is_dataclass(f_type): + # Type is a dataclass, data is a dictionary + return lambda item: dataclass_from_dict(f_type, item) + elif issubclass(f_type, bytes): + # Type is bytes, data is a hex string or bytes + return lambda item: convert_byte_type(f_type, item) + elif f_type.__name__ in unhashable_types: # Type is unhashable (bls type), so cast from hex string - if hasattr(klass, "from_bytes_unchecked"): - from_bytes_method: Callable[[bytes], Any] = klass.from_bytes_unchecked - else: - from_bytes_method = klass.from_bytes - return from_bytes_method(hexstr_to_bytes(d)) + return lambda item: convert_unhashable_type(f_type, item) else: # Type is a primitive, cast with correct class - return klass(d) + return lambda item: convert_primitive(f_type, item) @overload @@ -339,6 +401,7 @@ class Example(Streamable): stream_functions = [] parse_functions = [] + convert_functions = [] try: hints = get_type_hints(cls) fields = {field.name: hints.get(field.name, field.type) for field in dataclasses.fields(cls)} @@ -350,9 +413,11 @@ class Example(Streamable): for _, f_type in fields.items(): stream_functions.append(cls.function_to_stream_one_item(f_type)) parse_functions.append(cls.function_to_parse_one_item(f_type)) + convert_functions.append(function_to_convert_one_item(f_type)) STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = stream_functions PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = parse_functions + CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = convert_functions return cls diff --git a/tests/core/util/test_streamable.py b/tests/core/util/test_streamable.py index e750da68fb63..223fa26accdb 100644 --- a/tests/core/util/test_streamable.py +++ b/tests/core/util/test_streamable.py @@ -2,9 +2,10 @@ import io from dataclasses import dataclass -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Type import pytest +from blspy import G1Element from clvm_tools import binutils from typing_extensions import Literal @@ -18,6 +19,7 @@ from chia.util.streamable import ( DefinitionError, Streamable, + dataclass_from_dict, is_type_List, is_type_SpecificOptional, parse_bool, @@ -91,6 +93,58 @@ class TestClassPlain(Streamable): a: PlainClass +@dataclass +class TestDataclassFromDict1: + a: int + b: str + c: G1Element + + +@dataclass +class TestDataclassFromDict2: + a: TestDataclassFromDict1 + b: TestDataclassFromDict1 + c: float + + +def test_pure_dataclasses_in_dataclass_from_dict() -> None: + + d1_dict = {"a": 1, "b": "2", "c": str(G1Element())} + + d1: TestDataclassFromDict1 = dataclass_from_dict(TestDataclassFromDict1, d1_dict) + assert d1.a == 1 + assert d1.b == "2" + assert d1.c == G1Element() + + d2_dict = {"a": d1, "b": d1_dict, "c": 1.2345} + + d2: TestDataclassFromDict2 = dataclass_from_dict(TestDataclassFromDict2, d2_dict) + assert d2.a == d1 + assert d2.b == d1 + assert d2.c == 1.2345 + + +@pytest.mark.parametrize( + "test_class, input_dict, error", + [ + [TestDataclassFromDict1, {"a": "asdf", "b": "2", "c": G1Element()}, ValueError], + [TestDataclassFromDict1, {"a": 1, "b": "2"}, KeyError], + [TestDataclassFromDict1, {"a": 1, "b": "2", "c": "asd"}, ValueError], + [TestDataclassFromDict1, {"a": 1, "b": "2", "c": "00" * G1Element.SIZE}, ValueError], + [TestDataclassFromDict1, {"a": [], "b": "2", "c": G1Element()}, TypeError], + [TestDataclassFromDict1, {"a": {}, "b": "2", "c": G1Element()}, TypeError], + [TestDataclassFromDict2, {"a": "asdf", "b": 1.2345, "c": 1.2345}, TypeError], + [TestDataclassFromDict2, {"a": 1.2345, "b": {"a": 1, "b": "2"}, "c": 1.2345}, TypeError], + [TestDataclassFromDict2, {"a": {"a": 1, "b": "2", "c": G1Element()}, "b": {"a": 1, "b": "2"}}, KeyError], + [TestDataclassFromDict2, {"a": {"a": 1, "b": "2"}, "b": {"a": 1, "b": "2"}, "c": 1.2345}, KeyError], + ], +) +def test_dataclass_from_dict_failures(test_class: Type[Any], input_dict: Dict[str, Any], error: Any) -> None: + + with pytest.raises(error): + dataclass_from_dict(test_class, input_dict) + + def test_basic_list() -> None: a = [1, 2, 3] assert is_type_List(type(a))