Skip to content

Commit

Permalink
streamable: Introduce Streamable.__post_init__ processing cache (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
xdustinface authored Jun 28, 2022
1 parent 0fd2dd6 commit 82a83b7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 58 deletions.
99 changes: 45 additions & 54 deletions chia/util/streamable.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class Field:
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[StreamFunctionType]] = {}
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ParseFunctionType]] = {}
CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ConvertFunctionType]] = {}
POST_INIT_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ConvertFunctionType]] = {}


def create_fields_cache(cls: Type[object]) -> Tuple[Field, ...]:
Expand Down Expand Up @@ -239,6 +240,41 @@ def function_to_convert_one_item(f_type: Type[Any]) -> ConvertFunctionType:
return lambda item: convert_primitive(f_type, item)


def post_init_process_item(f_type: Type[Any], item: Any) -> object:
if not isinstance(item, f_type):
try:
item = f_type(item)
except (TypeError, AttributeError, ValueError):
if hasattr(f_type, "from_bytes_unchecked"):
from_bytes_method: Callable[[bytes], Any] = f_type.from_bytes_unchecked
else:
from_bytes_method = f_type.from_bytes
try:
item = from_bytes_method(item)
except Exception:
item = from_bytes_method(bytes(item))
if not isinstance(item, f_type):
raise ValueError(f"Wrong type for {f_type}")
return item


def function_to_post_init_process_one_item(f_type: Type[object]) -> ConvertFunctionType:
if is_type_SpecificOptional(f_type):
process_inner_func = function_to_post_init_process_one_item(get_args(f_type)[0])
return lambda item: convert_optional(process_inner_func, item)
if is_type_Tuple(f_type):
args = get_args(f_type)
process_inner_tuple_funcs = []
for arg in args:
process_inner_tuple_funcs.append(function_to_post_init_process_one_item(arg))
return lambda items: convert_tuple(process_inner_tuple_funcs, items) # type: ignore[arg-type]
if is_type_List(f_type):
inner_type = get_args(f_type)[0]
process_inner_func = function_to_post_init_process_one_item(inner_type)
return lambda items: convert_list(process_inner_func, items) # type: ignore[arg-type]
return lambda item: post_init_process_item(f_type, item)


def recurse_jsonify(d: Any) -> Any:
"""
Makes bytes objects and unhashable types into strings with 0x, and makes large ints into
Expand Down Expand Up @@ -506,6 +542,7 @@ class Example(Streamable):
stream_functions = []
parse_functions = []
convert_functions = []
post_init_functions = []

fields = create_fields_cache(cls)
FIELDS_FOR_STREAMABLE_CLASS[cls] = fields
Expand All @@ -514,10 +551,12 @@ class Example(Streamable):
stream_functions.append(function_to_stream_one_item(field.type))
parse_functions.append(function_to_parse_one_item(field.type))
convert_functions.append(function_to_convert_one_item(field.type))
post_init_functions.append(function_to_post_init_process_one_item(field.type))

STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = stream_functions
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = parse_functions
CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = convert_functions
POST_INIT_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = post_init_functions
return cls


Expand Down Expand Up @@ -566,64 +605,16 @@ class Streamable:
Make sure to use the streamable decorator when inheriting from the Streamable class to prepare the streaming caches.
"""

def post_init_parse(self, item: Any, f_name: str, f_type: Type[Any]) -> Any:
if is_type_List(f_type):
collected_list: List[Any] = []
inner_type: Type[Any] = get_args(f_type)[0]
# wjb assert inner_type != get_args(List)[0] # type: ignore
if not is_type_List(type(item)):
raise ValueError(f"Wrong type for {f_name}, need a list.")
for el in item:
collected_list.append(self.post_init_parse(el, f_name, inner_type))
return collected_list
if is_type_SpecificOptional(f_type):
if item is None:
return None
else:
inner_type: Type = get_args(f_type)[0] # type: ignore
return self.post_init_parse(item, f_name, inner_type)
if is_type_Tuple(f_type):
collected_list = []
if not is_type_Tuple(type(item)) and not is_type_List(type(item)):
raise ValueError(f"Wrong type for {f_name}, need a tuple.")
if len(item) != len(get_args(f_type)):
raise ValueError(f"Wrong number of elements in tuple {f_name}.")
for i in range(len(item)):
inner_type = get_args(f_type)[i]
tuple_item = item[i]
collected_list.append(self.post_init_parse(tuple_item, f_name, inner_type))
return tuple(collected_list)
if not isinstance(item, f_type):
try:
item = f_type(item)
except (TypeError, AttributeError, ValueError):
if hasattr(f_type, "from_bytes_unchecked"):
from_bytes_method: Callable[[bytes], Any] = f_type.from_bytes_unchecked
else:
from_bytes_method = f_type.from_bytes
try:
item = from_bytes_method(item)
except Exception:
item = from_bytes_method(bytes(item))
if not isinstance(item, f_type):
raise ValueError(f"Wrong type for {f_name}")
return item

def __post_init__(self) -> None:
try:
fields = FIELDS_FOR_STREAMABLE_CLASS[type(self)]
except Exception:
fields = ()

fields = FIELDS_FOR_STREAMABLE_CLASS[type(self)]
process_funcs = POST_INIT_FUNCTIONS_FOR_STREAMABLE_CLASS[type(self)]

data = self.__dict__
for field in fields:
for field, process_func in zip(fields, process_funcs):
if field.name not in data:
raise ValueError(f"Field {field.name} not present")
try:
if not isinstance(data[field.name], field.type):
object.__setattr__(self, field.name, self.post_init_parse(data[field.name], field.name, field.type))
except TypeError:
# Throws a TypeError because we cannot call isinstance for subscripted generics like Optional[int]
object.__setattr__(self, field.name, self.post_init_parse(data[field.name], field.name, field.type))
object.__setattr__(self, field.name, process_func(data[field.name]))

@classmethod
def parse(cls: Type[_T_Streamable], f: BinaryIO) -> _T_Streamable:
Expand Down
8 changes: 4 additions & 4 deletions tests/core/util/test_streamable.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,13 +400,13 @@ def validate_item_type(type_in: Type[Any], item: object) -> bool:
(PostInitTestClassBasic, (1, "test", b"\00\01", b"\12" * 31, G1Element()), ValueError),
(PostInitTestClassBasic, (1, "test", b"\00\01", b"\12" * 32, b"\12" * 10), ValueError),
(PostInitTestClassBad, (1, 2), TypeError),
(PostInitTestClassList, ({"1": 1}, [[uint8(200), uint8(25)], [uint8(25)]]), ValueError),
(PostInitTestClassList, (("1", 1), [[uint8(200), uint8(25)], [uint8(25)]]), ValueError),
(PostInitTestClassList, ([1, 2, 3], [uint8(200), uint8(25)]), ValueError),
(PostInitTestClassList, ({"1": 1}, [[uint8(200), uint8(25)], [uint8(25)]]), TypeError),
(PostInitTestClassList, (("1", 1), [[uint8(200), uint8(25)], [uint8(25)]]), TypeError),
(PostInitTestClassList, ([1, 2, 3], [uint8(200), uint8(25)]), TypeError),
(PostInitTestClassTuple, ((1,), ((200, "test_2"), b"\xba" * 32)), ValueError),
(PostInitTestClassTuple, ((1, "test", 1), ((200, "test_2"), b"\xba" * 32)), ValueError),
(PostInitTestClassTuple, ((1, "test"), ({"a": 2}, b"\xba" * 32)), ValueError),
(PostInitTestClassTuple, ((1, "test"), (G1Element(), b"\xba" * 32)), ValueError),
(PostInitTestClassTuple, ((1, "test"), (G1Element(), b"\xba" * 32)), TypeError),
(PostInitTestClassOptional, ([], None, None, None), ValueError),
],
)
Expand Down

0 comments on commit 82a83b7

Please sign in to comment.