Skip to content

Commit

Permalink
Make recursive calls in object_to_json less verbose (facebook#3265)
Browse files Browse the repository at this point in the history
Summary:

As titled

Differential Revision: D56764721
  • Loading branch information
ItsMrLin authored and facebook-github-bot committed Jan 24, 2025
1 parent 26170b9 commit c306d35
Showing 1 changed file with 19 additions and 76 deletions.
95 changes: 19 additions & 76 deletions ax/storage/json_store/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import enum
from collections import OrderedDict
from collections.abc import Callable
from functools import partial
from inspect import isclass
from typing import Any

Expand All @@ -28,19 +29,15 @@
from ax.utils.common.typeutils_torch import torch_type_to_str


# pyre-fixme[3]: Return annotation cannot be `Any`.
def object_to_json( # noqa C901
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
# pyre-ignore[3]: Missing return annotation
def object_to_json(
# pyre-ignore[2]: Missing parameter annotation
obj: Any,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
# pyre-ignore[2, 24]: Missing parameter annotation, Invalid type parameters
encoder_registry: dict[
type, Callable[[Any], dict[str, Any]]
] = CORE_ENCODER_REGISTRY,
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
# pyre-fixme[24]: Generic type `type` expects 1 type parameter, use
# `typing.Type` to avoid runtime subscripting errors.
# pyre-ignore[2, 24]: Missing parameter annotation, Invalid type parameters
class_encoder_registry: dict[
type, Callable[[Any], dict[str, Any]]
] = CORE_CLASS_ENCODER_REGISTRY,
Expand All @@ -59,6 +56,11 @@ def object_to_json( # noqa C901
We then pass each item of the dictionary back into this function to
recursively convert the entire object.
"""
_object_to_json = partial(
object_to_json,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
obj = numpy_type_to_python_type(obj)
_type = type(obj)

Expand All @@ -67,14 +69,7 @@ def object_to_json( # noqa C901
for class_type in class_encoder_registry:
if issubclass(obj, class_type):
obj_dict = class_encoder_registry[class_type](obj)
return {
k: object_to_json(
v,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
for k, v in obj_dict.items()
}
return {k: _object_to_json(v) for k, v in obj_dict.items()}

raise ValueError(
f"{obj} is a class. Add it to the CLASS_ENCODER_REGISTRY "
Expand All @@ -83,87 +78,36 @@ def object_to_json( # noqa C901

if _type in encoder_registry:
obj_dict = encoder_registry[_type](obj)
return {
k: object_to_json(
v,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
for k, v in obj_dict.items()
}

return {k: _object_to_json(v) for k, v in obj_dict.items()}
# Python built-in types + `typing` module types
if _type in (str, int, float, bool, type(None)):
return obj
elif _type is list:
return [
object_to_json(
x,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
for x in obj
]
return [_object_to_json(x) for x in obj]
elif _type is tuple:
return tuple(
object_to_json(
x,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
for x in obj
)
return tuple(_object_to_json(x) for x in obj)
elif _type is dict:
return {
k: object_to_json(
v,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
for k, v in obj.items()
}
return {k: _object_to_json(v) for k, v in obj.items()}
elif _is_named_tuple(obj):
return {
"__type": _type.__name__,
**{
k: object_to_json(
v,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
for k, v in obj._asdict().items()
},
**{k: _object_to_json(v) for k, v in obj._asdict().items()},
}
elif dataclasses.is_dataclass(obj):
field_names = [f.name for f in dataclasses.fields(obj)]
return {
"__type": _type.__name__,
**{
k: object_to_json(
v,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
)
k: _object_to_json(v)
for k, v in obj.__dict__.items()
if k in field_names
},
}

# Types from libraries, commonly used in Ax (e.g., numpy, pandas, torch)
elif _type is OrderedDict:
return {
"__type": _type.__name__,
"value": [
(
k,
object_to_json(
v,
encoder_registry=encoder_registry,
class_encoder_registry=class_encoder_registry,
),
)
for k, v in obj.items()
],
"value": [(k, _object_to_json(v)) for k, v in obj.items()],
}
elif _type is datetime.datetime:
return {
Expand All @@ -183,7 +127,6 @@ def object_to_json( # noqa C901
elif _type.__module__ == "torch":
# Torch does not support saving to string, so save to buffer first
return {"__type": f"torch_{_type.__name__}", "value": torch_type_to_str(obj)}

err = (
f"Object {obj} passed to `object_to_json` (of type {_type}, module: "
f"{_type.__module__}) is not registered with a corresponding encoder "
Expand Down

0 comments on commit c306d35

Please sign in to comment.