Skip to content

Commit

Permalink
implementation complete
Browse files Browse the repository at this point in the history
  • Loading branch information
tssweeney committed Sep 17, 2024
1 parent 49d1547 commit d65dc58
Showing 1 changed file with 75 additions and 3 deletions.
78 changes: 75 additions & 3 deletions weave/trace/serialize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import typing
from typing import Any
import weakref
from typing import Any, Optional

from weave.trace import custom_objs
from weave.trace.object_record import ObjectRecord
Expand Down Expand Up @@ -97,13 +99,71 @@ def from_json(obj: Any, project_id: str, server: TraceServerInterface) -> Any:
# _[to/from]_json_custom_weave_type are used to serialize and deserialize
# objects that have custom serialization logic. These are NOT weave.Objects,
# but rather things like PIL Images. These methods are inverses of each other.
# Importantly we can actally cache the results of both directions so that we
# Importantly we can actually cache the results of both directions so that we
# don't have to do the work more than once.


class CustomWeaveTypeSerializationCache:
"""Cache for custom weave type serialization.
Specifically, a dev can:
- store a serialization tuple of (deserialized object, serialized dict)
- retrieve the serialized dict for a deserialized object
- retrieve the deserialized object for a serialized dict
When keying by object:
In addition to weak references, the cache will also attempt to call the object's `__hash__` method if it
has one, and include that hash in the cache key. This will allow the cache
to be effectively invalidated if the object is updated.
When keying by dict:
We will stringify the dict (deterministically) to create a key. This lets
us cache the results of deserializing the same dict with different objects.
"""

_obj_to_dict: weakref.WeakValueDictionary[str, dict]
_dict_to_obj: weakref.WeakValueDictionary[str, Any]

def __init__(self) -> None:
self._obj_to_dict = weakref.WeakValueDictionary()
self._dict_to_obj = weakref.WeakValueDictionary()

def store(self, obj: Any, serialized_dict: dict) -> None:
obj_key = self._get_obj_key(obj)
dict_key = self._get_dict_key(serialized_dict)
self._obj_to_dict[obj_key] = serialized_dict
self._dict_to_obj[dict_key] = obj

def get_serialized_dict(self, obj: Any) -> Optional[dict]:
obj_key = self._get_obj_key(obj)
return self._obj_to_dict.get(obj_key)

def get_deserialized_obj(self, serialized_dict: dict) -> Optional[Any]:
dict_key = self._get_dict_key(serialized_dict)
return self._dict_to_obj.get(dict_key)

def _get_obj_key(self, obj: Any) -> Any:
try:
return (id(obj), hash(obj))
except TypeError:
return id(obj)

def _get_dict_key(self, d: dict) -> str:
return json.dumps(d, sort_keys=True)


# Initialize the global cache
_custom_weave_type_cache = CustomWeaveTypeSerializationCache()


def _to_json_custom_weave_type(
obj: Any, project_id: str, server: TraceServerInterface
) -> dict:
# Check if the object is already in the cache
cached_result = _custom_weave_type_cache.get_serialized_dict(obj)
if cached_result is not None:
return cached_result

encoded = custom_objs.encode_custom_obj(obj)
if encoded is None:
raise ValueError(f"No encoder for object: {obj}")
Expand All @@ -121,11 +181,23 @@ def _to_json_custom_weave_type(
load_op_uri = encoded.get("load_op")
if load_op_uri:
result["load_op"] = load_op_uri

# Store the result in the cache
_custom_weave_type_cache.store(obj, result)
return result


def _from_json_custom_weave_type(
obj: dict, project_id: str, server: TraceServerInterface
) -> Any:
# Check if the serialized dict is already in the cache
cached_result = _custom_weave_type_cache.get_deserialized_obj(obj)
if cached_result is not None:
return cached_result

files = _load_custom_obj_files(project_id, server, obj["files"])
return custom_objs.decode_custom_obj(obj["weave_type"], files, obj.get("load_op"))
result = custom_objs.decode_custom_obj(obj["weave_type"], files, obj.get("load_op"))

# Store the result in the cache
_custom_weave_type_cache.store(result, obj)
return result

0 comments on commit d65dc58

Please sign in to comment.