diff --git a/chia/util/streamable.py b/chia/util/streamable.py index a073ddb9169a..5b948816e6e9 100644 --- a/chia/util/streamable.py +++ b/chia/util/streamable.py @@ -74,6 +74,7 @@ class Field: STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[StreamFunctionType]] = {} PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ParseFunctionType]] = {} CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ConvertFunctionType]] = {} +POST_INIT_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ConvertFunctionType]] = {} def create_fields_cache(cls: Type[object]) -> Tuple[Field, ...]: @@ -239,6 +240,41 @@ def function_to_convert_one_item(f_type: Type[Any]) -> ConvertFunctionType: return lambda item: convert_primitive(f_type, item) +def post_init_process_item(f_type: Type[Any], item: Any) -> object: + if not isinstance(item, f_type): + try: + item = f_type(item) + except (TypeError, AttributeError, ValueError): + if hasattr(f_type, "from_bytes_unchecked"): + from_bytes_method: Callable[[bytes], Any] = f_type.from_bytes_unchecked + else: + from_bytes_method = f_type.from_bytes + try: + item = from_bytes_method(item) + except Exception: + item = from_bytes_method(bytes(item)) + if not isinstance(item, f_type): + raise ValueError(f"Wrong type for {f_type}") + return item + + +def function_to_post_init_process_one_item(f_type: Type[object]) -> ConvertFunctionType: + if is_type_SpecificOptional(f_type): + process_inner_func = function_to_post_init_process_one_item(get_args(f_type)[0]) + return lambda item: convert_optional(process_inner_func, item) + if is_type_Tuple(f_type): + args = get_args(f_type) + process_inner_tuple_funcs = [] + for arg in args: + process_inner_tuple_funcs.append(function_to_post_init_process_one_item(arg)) + return lambda items: convert_tuple(process_inner_tuple_funcs, items) # type: ignore[arg-type] + if is_type_List(f_type): + inner_type = get_args(f_type)[0] + process_inner_func = function_to_post_init_process_one_item(inner_type) + return lambda items: convert_list(process_inner_func, items) # type: ignore[arg-type] + return lambda item: post_init_process_item(f_type, item) + + def recurse_jsonify(d: Any) -> Any: """ Makes bytes objects and unhashable types into strings with 0x, and makes large ints into @@ -506,6 +542,7 @@ class Example(Streamable): stream_functions = [] parse_functions = [] convert_functions = [] + post_init_functions = [] fields = create_fields_cache(cls) FIELDS_FOR_STREAMABLE_CLASS[cls] = fields @@ -514,10 +551,12 @@ class Example(Streamable): stream_functions.append(function_to_stream_one_item(field.type)) parse_functions.append(function_to_parse_one_item(field.type)) convert_functions.append(function_to_convert_one_item(field.type)) + post_init_functions.append(function_to_post_init_process_one_item(field.type)) STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = stream_functions PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = parse_functions CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = convert_functions + POST_INIT_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = post_init_functions return cls @@ -566,64 +605,16 @@ class Streamable: Make sure to use the streamable decorator when inheriting from the Streamable class to prepare the streaming caches. """ - def post_init_parse(self, item: Any, f_name: str, f_type: Type[Any]) -> Any: - if is_type_List(f_type): - collected_list: List[Any] = [] - inner_type: Type[Any] = get_args(f_type)[0] - # wjb assert inner_type != get_args(List)[0] # type: ignore - if not is_type_List(type(item)): - raise ValueError(f"Wrong type for {f_name}, need a list.") - for el in item: - collected_list.append(self.post_init_parse(el, f_name, inner_type)) - return collected_list - if is_type_SpecificOptional(f_type): - if item is None: - return None - else: - inner_type: Type = get_args(f_type)[0] # type: ignore - return self.post_init_parse(item, f_name, inner_type) - if is_type_Tuple(f_type): - collected_list = [] - if not is_type_Tuple(type(item)) and not is_type_List(type(item)): - raise ValueError(f"Wrong type for {f_name}, need a tuple.") - if len(item) != len(get_args(f_type)): - raise ValueError(f"Wrong number of elements in tuple {f_name}.") - for i in range(len(item)): - inner_type = get_args(f_type)[i] - tuple_item = item[i] - collected_list.append(self.post_init_parse(tuple_item, f_name, inner_type)) - return tuple(collected_list) - if not isinstance(item, f_type): - try: - item = f_type(item) - except (TypeError, AttributeError, ValueError): - if hasattr(f_type, "from_bytes_unchecked"): - from_bytes_method: Callable[[bytes], Any] = f_type.from_bytes_unchecked - else: - from_bytes_method = f_type.from_bytes - try: - item = from_bytes_method(item) - except Exception: - item = from_bytes_method(bytes(item)) - if not isinstance(item, f_type): - raise ValueError(f"Wrong type for {f_name}") - return item - def __post_init__(self) -> None: - try: - fields = FIELDS_FOR_STREAMABLE_CLASS[type(self)] - except Exception: - fields = () + + fields = FIELDS_FOR_STREAMABLE_CLASS[type(self)] + process_funcs = POST_INIT_FUNCTIONS_FOR_STREAMABLE_CLASS[type(self)] + data = self.__dict__ - for field in fields: + for field, process_func in zip(fields, process_funcs): if field.name not in data: raise ValueError(f"Field {field.name} not present") - try: - if not isinstance(data[field.name], field.type): - object.__setattr__(self, field.name, self.post_init_parse(data[field.name], field.name, field.type)) - except TypeError: - # Throws a TypeError because we cannot call isinstance for subscripted generics like Optional[int] - object.__setattr__(self, field.name, self.post_init_parse(data[field.name], field.name, field.type)) + object.__setattr__(self, field.name, process_func(data[field.name])) @classmethod def parse(cls: Type[_T_Streamable], f: BinaryIO) -> _T_Streamable: diff --git a/tests/core/util/test_streamable.py b/tests/core/util/test_streamable.py index 4fef08b2cd98..9413c1f6e4e5 100644 --- a/tests/core/util/test_streamable.py +++ b/tests/core/util/test_streamable.py @@ -400,13 +400,13 @@ def validate_item_type(type_in: Type[Any], item: object) -> bool: (PostInitTestClassBasic, (1, "test", b"\00\01", b"\12" * 31, G1Element()), ValueError), (PostInitTestClassBasic, (1, "test", b"\00\01", b"\12" * 32, b"\12" * 10), ValueError), (PostInitTestClassBad, (1, 2), TypeError), - (PostInitTestClassList, ({"1": 1}, [[uint8(200), uint8(25)], [uint8(25)]]), ValueError), - (PostInitTestClassList, (("1", 1), [[uint8(200), uint8(25)], [uint8(25)]]), ValueError), - (PostInitTestClassList, ([1, 2, 3], [uint8(200), uint8(25)]), ValueError), + (PostInitTestClassList, ({"1": 1}, [[uint8(200), uint8(25)], [uint8(25)]]), TypeError), + (PostInitTestClassList, (("1", 1), [[uint8(200), uint8(25)], [uint8(25)]]), TypeError), + (PostInitTestClassList, ([1, 2, 3], [uint8(200), uint8(25)]), TypeError), (PostInitTestClassTuple, ((1,), ((200, "test_2"), b"\xba" * 32)), ValueError), (PostInitTestClassTuple, ((1, "test", 1), ((200, "test_2"), b"\xba" * 32)), ValueError), (PostInitTestClassTuple, ((1, "test"), ({"a": 2}, b"\xba" * 32)), ValueError), - (PostInitTestClassTuple, ((1, "test"), (G1Element(), b"\xba" * 32)), ValueError), + (PostInitTestClassTuple, ((1, "test"), (G1Element(), b"\xba" * 32)), TypeError), (PostInitTestClassOptional, ([], None, None, None), ValueError), ], )