Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Efficient caching #37

Merged
merged 3 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flaxkv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .core import LevelDBDict, LMDBDict, RemoteDBDict

__version__ = "0.2.8"
__version__ = "0.2.9-alpha"

__all__ = [
"FlaxKV",
Expand Down
94 changes: 72 additions & 22 deletions flaxkv/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def _init(self):
self._static_view = self._db_manager.new_static_view()

self.buffer_dict = {}
self._stat_buffer_num = 0
self.delete_buffer_set = set()

self._buffered_count = 0
Expand Down Expand Up @@ -289,6 +290,8 @@ def get(self, key: Any, default=None):
"""
with self._buffer_lock:
if key in self.delete_buffer_set:
self.delete_buffer_set.discard(key)
self.buffer_dict[key] = default
return default

if key in self.buffer_dict:
Expand All @@ -297,13 +300,16 @@ def get(self, key: Any, default=None):
if self._cache_all_db:
return self._cache_dict.get(key, default)

key = self._encode_key(key)
value = self._static_view.get(key)
_encode_key = self._encode_key(key)
value = self._static_view.get(_encode_key)

if value is None:
self.buffer_dict[key] = default
return default

return value if self._raw else decode(value)
v = value if self._raw else decode(value)
self.buffer_dict[key] = v
return v

def get_db_value(self, key: str):
"""
Expand Down Expand Up @@ -358,6 +364,7 @@ def _set(self, key, value):
with self._buffer_lock:
self.buffer_dict[key] = value
self.delete_buffer_set.discard(key)
self._stat_buffer_num = len(self.buffer_dict)

self._buffered_count += 1
self._last_set_time = time.time()
Expand Down Expand Up @@ -401,7 +408,8 @@ def update(self, d: dict):
self.buffer_dict[key] = value
self.delete_buffer_set.discard(key)

self._buffered_count += 1
self._stat_buffer_num = len(self.buffer_dict)
self._buffered_count += len(d)

self._last_set_time = time.time()
# Trigger immediate write if buffer size exceeds MAX_BUFFER_SIZE
Expand Down Expand Up @@ -480,6 +488,7 @@ def _write_buffer_to_db(
f"write {self._db_manager.db_type.upper()} buffer to db successfully! "
f"current_num={current_write_num} latest_num={self._latest_write_num}"
)
self._stat_buffer_num = len(self.buffer_dict)

def __iter__(self):
"""
Expand Down Expand Up @@ -527,6 +536,9 @@ def __delitem__(self, key):
self._last_set_time = time.time()
if key in self.buffer_dict:
del self.buffer_dict[key]
# If it is in the buffer (possibly obtained through get), then _stat_buffer_num -= 1,
# and _stat_buffer_num can be negative
self._stat_buffer_num -= 1
return
else:
if self._cache_all_db:
Expand All @@ -552,6 +564,7 @@ def pop(self, key, default=None):
self._last_set_time = time.time()
if key in self.buffer_dict:
value = self.buffer_dict.pop(key)
self._stat_buffer_num -= 1
if self._raw:
return decode(value)
else:
Expand Down Expand Up @@ -727,6 +740,13 @@ def keys(self, decode_raw=True):
yield d_key
self._db_manager.close_static_view(view)

def to_dict(self, decode_raw=True):
"""
Retrieves all the key-value pairs in the database and buffer.
Returns: dict
"""
return self.db_dict(decode_raw=decode_raw)

def db_dict(self, decode_raw=True):
"""
Retrieves all the key-value pairs in the database and buffer.
Expand Down Expand Up @@ -898,18 +918,26 @@ def set_mapsize(self, map_size):
def stat(self):
if self._cache_all_db:
db_count = len(self._cache_dict)
count = db_count + self._stat_buffer_num
return {
'count': count,
'buffer': self._stat_buffer_num,
'db': db_count,
'marked_delete': len(self.delete_buffer_set),
"type": 'lmdb',
}
else:
env = self._db_manager.get_env()
stats = env.stat()
db_count = stats['entries']
buffer_count = len(self.buffer_dict.keys())
count = db_count + buffer_count
return {
'count': count,
'buffer': buffer_count,
'db': db_count,
'marked_delete': len(self.delete_buffer_set),
}
count = db_count + self._stat_buffer_num - len(self.delete_buffer_set)
return {
'count': count,
'buffer': self._stat_buffer_num,
'db': db_count,
'marked_delete': len(self.delete_buffer_set),
"type": 'lmdb',
}


class LevelDBDict(BaseDBDict):
Expand Down Expand Up @@ -959,10 +987,18 @@ def _iter_db_view(self, view, include_key=True, include_value=True):
yield key_or_value

def stat(self):
buffer_keys = set(self.buffer_dict.keys())

if self._cache_all_db:
db_keys = set(self._cache_dict.keys())
db_count = len(db_keys)
count = db_count + self._stat_buffer_num
return {
'count': count,
'buffer': self._stat_buffer_num,
'db': db_count,
'marked_delete': len(self.delete_buffer_set),
"type": 'leveldb',
}
else:
with self._buffer_lock:
view = self._db_manager.new_static_view()
Expand All @@ -971,12 +1007,21 @@ def stat(self):
view.close()

db_count = len(db_keys)
db_valid_keys = db_keys - self.delete_buffer_set
intersection_count = len(buffer_keys.intersection(db_valid_keys))
buffer_count = len(buffer_keys)
count = len(db_valid_keys) + buffer_count - intersection_count

return {'count': count, 'buffer': buffer_count, "db": db_count}
# db_valid_keys = db_keys - self.delete_buffer_set
# buffer_keys = set(self.buffer_dict.keys())
# intersection_count = len(buffer_keys.intersection(db_valid_keys))
# count = len(db_valid_keys) + self._stat_buffer_num - intersection_count
count = db_count + self._stat_buffer_num - len(self.delete_buffer_set)

# db_valid_keys = db_keys.union(buffer_keys) - self.delete_buffer_set
# count = len(db_valid_keys)
return {
'count': count,
'buffer': self._stat_buffer_num,
"db": db_count,
'marked_delete': len(self.delete_buffer_set),
'type': 'leveldb',
}


class RemoteDBDict(BaseDBDict):
Expand Down Expand Up @@ -1154,17 +1199,22 @@ def db_dict(self, decode_raw=True):
def stat(self):
if self._cache_all_db:
db_count = len(self._cache_dict)
buffer_num = self._stat_buffer_num
count = db_count + buffer_num
else:
# fixme:
env = self._db_manager.get_env()
stats = env.stat()
db_count = stats['count']
buffer_count = len(self.buffer_dict.keys())
count = db_count + buffer_count
buffer_num = self._stat_buffer_num
count = db_count + buffer_num - len(self.delete_buffer_set)

return {
'count': count,
'buffer': buffer_count,
'buffer': buffer_num,
'db': db_count,
'marked_delete': len(self.delete_buffer_set),
'type': 'remote',
}

def __repr__(self):
Expand Down
41 changes: 40 additions & 1 deletion flaxkv/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,18 @@
# limitations under the License.


from __future__ import annotations

import logging
import time
from functools import wraps
from typing import TYPE_CHECKING

from rich import print
from rich.text import Text

from .pack import encode
if TYPE_CHECKING:
from flaxkv import FlaxKV

ENABLED_MEASURE_TIME_DECORATOR = True

Expand Down Expand Up @@ -55,6 +59,8 @@ def wrapper(self, *args, **kwargs):


def msg_encoder(func):
from .pack import encode

@wraps(func)
async def wrapper(*args, **kwargs):
result = await func(*args, **kwargs)
Expand Down Expand Up @@ -99,3 +105,36 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


def cache(db: FlaxKV = None):
"""Keep a cache of previous function calls."""

if db is None:
db = {}

def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
key = (args, tuple(sorted(kwargs.items())))
if key in db:
return db[key]
result = func(*args, **kwargs)
db[key] = result
return result

return wrapper

return decorator


def singleton(cls):
instances = {}

@wraps(cls)
def get_instance(*args, **kwargs):
if cls not in instances:
instances[cls] = cls(*args, **kwargs)
return instances[cls]

return get_instance
8 changes: 5 additions & 3 deletions flaxkv/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ def check_pandas_type(obj):
return type(obj).__name__ == "DataFrame"


# def check_ext_type(obj):
# return isinstance(obj, (tuple, set))


def encode_hook(obj):
if isinstance(obj, np.ndarray):
return msgspec.msgpack.Ext(
Expand All @@ -54,8 +58,7 @@ def encode_hook(obj):
NPArray(dtype=obj.dtype.str, shape=obj.shape, data=obj.data)
),
)
elif check_pandas_type(obj):
# return msgspec.msgpack.Ext(2, pyarrow.serialize_pandas(obj).to_pybytes())
else:
return msgspec.msgpack.Ext(2, pickle.dumps(obj))
return obj

Expand All @@ -67,7 +70,6 @@ def ext_hook(type, data: memoryview):
serialized_array_rep.data, dtype=serialized_array_rep.dtype
).reshape(serialized_array_rep.shape)
elif type == 2:
# return pyarrow.deserialize_pandas(pyarrow.py_buffer(data.tobytes()))
return pickle.loads(data.tobytes())
return data

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ test = [
"litestar>=2.5.0",
"pytest",
"pytest-aiohttp",
"sparrow-python",
"uvicorn",
"httpx[http2]",
"pandas",
Expand Down
Loading
Loading