Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make recursive calls in object_to_json less verbose #3265

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 19 additions & 76 deletions ax/storage/json_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import enum
from collections import OrderedDict
from collections.abc import Callable
from functools import partial
from inspect import isclass
from typing import Any

Expand All @@ -28,19 +29,15 @@
from ax.utils.common.typeutils_torch import torch_type_to_str


# pyre-fixme[3]: Return annotation cannot be `Any`.
def object_to_json( # noqa C901
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
# pyre-ignore[3]: Missing return annotation
def object_to_json(
# pyre-ignore[2]: Missing parameter annotation
obj: Any,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
# pyre-ignore[2, 24]: Missing parameter annotation, Invalid type parameters
encoder_registry: dict[
type, Callable[[Any], dict[str, Any]]
] = CORE_ENCODER_REGISTRY,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
# pyre-ignore[2, 24]: Missing parameter annotation, Invalid type parameters
class_encoder_registry: dict[
type, Callable[[Any], dict[str, Any]]
] = CORE_CLASS_ENCODER_REGISTRY,
Expand All @@ -59,6 +56,11 @@ def object_to_json( # noqa C901
We then pass each item of the dictionary back into this function to
recursively convert the entire object.
"""
_object_to_json = partial(
object_to_json,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
obj = numpy_type_to_python_type(obj)
_type = type(obj)

Expand All @@ -67,14 +69,7 @@ def object_to_json( # noqa C901
for class_type in class_encoder_registry:
if issubclass(obj, class_type):
obj_dict = class_encoder_registry[class_type](obj)
return {
k: object_to_json(
v,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
for k, v in obj_dict.items()
}
return {k: _object_to_json(v) for k, v in obj_dict.items()}

raise ValueError(
f"{obj} is a class. Add it to the CLASS_ENCODER_REGISTRY "
Expand All @@ -83,87 +78,36 @@ def object_to_json( # noqa C901

if _type in encoder_registry:
obj_dict = encoder_registry[_type](obj)
return {
k: object_to_json(
v,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
for k, v in obj_dict.items()
}

return {k: _object_to_json(v) for k, v in obj_dict.items()}
# Python built-in types + `typing` module types
if _type in (str, int, float, bool, type(None)):
return obj
elif _type is list:
return [
object_to_json(
x,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
for x in obj
]
return [_object_to_json(x) for x in obj]
elif _type is tuple:
return tuple(
object_to_json(
x,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
for x in obj
)
return tuple(_object_to_json(x) for x in obj)
elif _type is dict:
return {
k: object_to_json(
v,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
for k, v in obj.items()
}
return {k: _object_to_json(v) for k, v in obj.items()}
elif _is_named_tuple(obj):
return {
"__type": _type.__name__,
**{
k: object_to_json(
v,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
for k, v in obj._asdict().items()
},
**{k: _object_to_json(v) for k, v in obj._asdict().items()},
}
elif dataclasses.is_dataclass(obj):
field_names = [f.name for f in dataclasses.fields(obj)]
return {
"__type": _type.__name__,
**{
k: object_to_json(
v,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
k: _object_to_json(v)
for k, v in obj.__dict__.items()
if k in field_names
},
}

# Types from libraries, commonly used in Ax (e.g., numpy, pandas, torch)
elif _type is OrderedDict:
return {
"__type": _type.__name__,
"value": [
(
k,
object_to_json(
v,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
),
)
for k, v in obj.items()
],
"value": [(k, _object_to_json(v)) for k, v in obj.items()],
}
elif _type is datetime.datetime:
return {
Expand All @@ -183,7 +127,6 @@ def object_to_json( # noqa C901
elif _type.__module__ == "torch":
# Torch does not support saving to string, so save to buffer first
return {"__type": f"torch_{_type.__name__}", "value": torch_type_to_str(obj)}

err = (
f"Object {obj} passed to `object_to_json` (of type {_type}, module: "
f"{_type.__module__}) is not registered with a corresponding encoder "
Expand Down
Loading