Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

streamable: Cache convert functions from dataclass_from_dict #10561

Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 96 additions & 31 deletions chia/util/streamable.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,14 @@ class DefinitionError(StreamableError):

ParseFunctionType = Callable[[BinaryIO], object]
StreamFunctionType = Callable[[object, BinaryIO], None]
ConvertFunctionType = Callable[[object], object]


# Caches to store the fields and (de)serialization methods for all available streamable classes.
FIELDS_FOR_STREAMABLE_CLASS: Dict[Type[object], Dict[str, Type[object]]] = {}
STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[StreamFunctionType]] = {}
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[object], List[ParseFunctionType]] = {}
CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS: Dict[Type[Any], List[ConvertFunctionType]] = {}
xdustinface marked this conversation as resolved.
Show resolved Hide resolved


def is_type_List(f_type: object) -> bool:
Expand All @@ -97,45 +99,105 @@ def is_type_Tuple(f_type: object) -> bool:
return get_origin(f_type) == tuple or f_type == tuple


def dataclass_from_dict(klass: Type[Any], d: Any) -> Any:
def convert_optional(convert_func: ConvertFunctionType, item: Any) -> Any:
if item is None:
return None
return convert_func(item)


def convert_tuple(convert_funcs: List[ConvertFunctionType], items: Tuple[Any, ...]) -> Tuple[Any, ...]:
tuple_data = []
for i in range(len(items)):
tuple_data.append(convert_funcs[i](items[i]))
return tuple(tuple_data)


def convert_list(convert_func: ConvertFunctionType, items: List[Any]) -> List[Any]:
list_data = []
for item in items:
list_data.append(convert_func(item))
return list_data


def convert_byte_type(f_type: Type[Any], item: Any) -> Any:
if type(item) == f_type:
return item
return f_type(hexstr_to_bytes(item))


def convert_unhashable_type(f_type: Type[Any], item: Any) -> Any:
if type(item) == f_type:
return item
if hasattr(f_type, "from_bytes_unchecked"):
from_bytes_method = f_type.from_bytes_unchecked
else:
from_bytes_method = f_type.from_bytes
return from_bytes_method(hexstr_to_bytes(item))


def convert_primitive(f_type: Type[Any], item: Any) -> Any:
if type(item) == f_type:
return item
return f_type(item)


def dataclass_from_dict(klass: Type[Any], item: Any) -> Any:
if type(item) == klass:
return item
obj = object.__new__(klass)
xdustinface marked this conversation as resolved.
Show resolved Hide resolved
if klass not in CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS:
# For non-streamable dataclasses we can't populate the cache on startup, so we do it here for convert
# functions only.
convert_funcs = []
hints = get_type_hints(klass)
fields = {field.name: hints.get(field.name, field.type) for field in dataclasses.fields(klass)}

for _, f_type in fields.items():
convert_funcs.append(function_to_convert_one_item(f_type))

FIELDS_FOR_STREAMABLE_CLASS[klass] = fields
CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[klass] = convert_funcs
else:
fields = FIELDS_FOR_STREAMABLE_CLASS[klass]
convert_funcs = CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[klass]

for field, convert_func in zip(fields, convert_funcs):
object.__setattr__(obj, field, convert_func(item[field]))
return obj


def function_to_convert_one_item(f_type: Type[Any]) -> ConvertFunctionType:
"""
Converts a dictionary based on a dataclass, into an instance of that dataclass.
Recursively goes through lists, optionals, and dictionaries.
"""
altendky marked this conversation as resolved.
Show resolved Hide resolved
if is_type_SpecificOptional(klass):
# Type is optional, data is either None, or Any
if d is None:
return None
return dataclass_from_dict(get_args(klass)[0], d)
elif is_type_Tuple(klass):
# Type is tuple, can have multiple different types inside
i = 0
klass_properties = []
for item in d:
klass_properties.append(dataclass_from_dict(klass.__args__[i], item))
i = i + 1
return tuple(klass_properties)
elif dataclasses.is_dataclass(klass):
if is_type_SpecificOptional(f_type):
convert_inner_func = function_to_convert_one_item(get_args(f_type)[0])
return lambda item: convert_optional(convert_inner_func, item)
elif is_type_Tuple(f_type):
args = get_args(f_type)
convert_inner_tuple_funcs = []
for arg in args:
convert_inner_tuple_funcs.append(function_to_convert_one_item(arg))
# Ignoring for now as the proper solution isn't obvious
return lambda items: convert_tuple(convert_inner_tuple_funcs, items) # type: ignore[arg-type]
elif is_type_List(f_type):
inner_type = get_args(f_type)[0]
convert_inner_func = function_to_convert_one_item(inner_type)
# Ignoring for now as the proper solution isn't obvious
return lambda items: convert_list(convert_inner_func, items) # type: ignore[arg-type]
elif dataclasses.is_dataclass(f_type):
# Type is a dataclass, data is a dictionary
hints = get_type_hints(klass)
fieldtypes = {f.name: hints.get(f.name, f.type) for f in dataclasses.fields(klass)}
return klass(**{f: dataclass_from_dict(fieldtypes[f], d[f]) for f in d})
elif is_type_List(klass):
# Type is a list, data is a list
return [dataclass_from_dict(get_args(klass)[0], item) for item in d]
elif issubclass(klass, bytes):
# Type is bytes, data is a hex string
return klass(hexstr_to_bytes(d))
elif klass.__name__ in unhashable_types:
return lambda item: dataclass_from_dict(f_type, item)
elif issubclass(f_type, bytes):
# Type is bytes, data is a hex string or bytes
return lambda item: convert_byte_type(f_type, item)
elif f_type.__name__ in unhashable_types:
# Type is unhashable (bls type), so cast from hex string
if hasattr(klass, "from_bytes_unchecked"):
from_bytes_method: Callable[[bytes], Any] = klass.from_bytes_unchecked
else:
from_bytes_method = klass.from_bytes
return from_bytes_method(hexstr_to_bytes(d))
return lambda item: convert_unhashable_type(f_type, item)
else:
# Type is a primitive, cast with correct class
return klass(d)
return lambda item: convert_primitive(f_type, item)


@overload
Expand Down Expand Up @@ -339,6 +401,7 @@ class Example(Streamable):

stream_functions = []
parse_functions = []
convert_functions = []
try:
hints = get_type_hints(cls)
fields = {field.name: hints.get(field.name, field.type) for field in dataclasses.fields(cls)}
Expand All @@ -350,9 +413,11 @@ class Example(Streamable):
for _, f_type in fields.items():
stream_functions.append(cls.function_to_stream_one_item(f_type))
parse_functions.append(cls.function_to_parse_one_item(f_type))
convert_functions.append(function_to_convert_one_item(f_type))

STREAM_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = stream_functions
PARSE_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = parse_functions
CONVERT_FUNCTIONS_FOR_STREAMABLE_CLASS[cls] = convert_functions
return cls


Expand Down
56 changes: 55 additions & 1 deletion tests/core/util/test_streamable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

import io
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Type

import pytest
from blspy import G1Element
from clvm_tools import binutils
from typing_extensions import Literal

Expand All @@ -18,6 +19,7 @@
from chia.util.streamable import (
DefinitionError,
Streamable,
dataclass_from_dict,
is_type_List,
is_type_SpecificOptional,
parse_bool,
Expand Down Expand Up @@ -91,6 +93,58 @@ class TestClassPlain(Streamable):
a: PlainClass


@dataclass
class TestDataclassFromDict1:
a: int
b: str
c: G1Element


@dataclass
class TestDataclassFromDict2:
a: TestDataclassFromDict1
b: TestDataclassFromDict1
c: float


def test_pure_dataclasses_in_dataclass_from_dict() -> None:

d1_dict = {"a": 1, "b": "2", "c": str(G1Element())}

d1: TestDataclassFromDict1 = dataclass_from_dict(TestDataclassFromDict1, d1_dict)
assert d1.a == 1
assert d1.b == "2"
assert d1.c == G1Element()

d2_dict = {"a": d1, "b": d1_dict, "c": 1.2345}

d2: TestDataclassFromDict2 = dataclass_from_dict(TestDataclassFromDict2, d2_dict)
assert d2.a == d1
assert d2.b == d1
assert d2.c == 1.2345


@pytest.mark.parametrize(
"test_class, input_dict, error",
[
[TestDataclassFromDict1, {"a": "asdf", "b": "2", "c": G1Element()}, ValueError],
[TestDataclassFromDict1, {"a": 1, "b": "2"}, KeyError],
[TestDataclassFromDict1, {"a": 1, "b": "2", "c": "asd"}, ValueError],
[TestDataclassFromDict1, {"a": 1, "b": "2", "c": "00" * G1Element.SIZE}, ValueError],
[TestDataclassFromDict1, {"a": [], "b": "2", "c": G1Element()}, TypeError],
[TestDataclassFromDict1, {"a": {}, "b": "2", "c": G1Element()}, TypeError],
[TestDataclassFromDict2, {"a": "asdf", "b": 1.2345, "c": 1.2345}, TypeError],
[TestDataclassFromDict2, {"a": 1.2345, "b": {"a": 1, "b": "2"}, "c": 1.2345}, TypeError],
[TestDataclassFromDict2, {"a": {"a": 1, "b": "2", "c": G1Element()}, "b": {"a": 1, "b": "2"}}, KeyError],
[TestDataclassFromDict2, {"a": {"a": 1, "b": "2"}, "b": {"a": 1, "b": "2"}, "c": 1.2345}, KeyError],
],
)
def test_dataclass_from_dict_failures(test_class: Type[Any], input_dict: Dict[str, Any], error: Any) -> None:

with pytest.raises(error):
dataclass_from_dict(test_class, input_dict)


def test_basic_list() -> None:
a = [1, 2, 3]
assert is_type_List(type(a))
Expand Down