From c32487ecbdf59139e8af737771c82e77329ef9b6 Mon Sep 17 00:00:00 2001 From: "K.Y" Date: Tue, 5 Dec 2023 22:37:02 +0800 Subject: [PATCH] :sparkles: Improve service architecture --- flaxkv/__init__.py | 2 +- flaxkv/{base.py => core.py} | 158 ++++++++++++++++++++---------------- flaxkv/serve/app.py | 96 ++++++++++++++++++++-- flaxkv/serve/client.py | 36 +++++++- flaxkv/serve/interface.py | 17 ++++ flaxkv/serve/manager.py | 5 +- 6 files changed, 233 insertions(+), 81 deletions(-) rename flaxkv/{base.py => core.py} (84%) diff --git a/flaxkv/__init__.py b/flaxkv/__init__.py index 0af9e32..638ca54 100644 --- a/flaxkv/__init__.py +++ b/flaxkv/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .base import LevelDBDict, LMDBDict +from .core import LevelDBDict, LMDBDict from .serve.client import RemoteDictDB __version__ = "0.1.6" diff --git a/flaxkv/base.py b/flaxkv/core.py similarity index 84% rename from flaxkv/base.py rename to flaxkv/core.py index 4c25c1b..1bf6fab 100644 --- a/flaxkv/base.py +++ b/flaxkv/core.py @@ -300,6 +300,8 @@ def update(self, d: dict): raise ValueError("Input must be a dictionary.") with self._buffer_lock: for key, value in d.items(): + if self.raw: + key, value = encode(key), encode(value) self.buffer_dict[key] = value self.delete_buffer_set.discard(key) @@ -425,7 +427,10 @@ def pop(self, key, default=None): self._buffered_count += 1 if key in self.buffer_dict: value = self.buffer_dict.pop(key) - return value + if self.raw: + return decode(value) + else: + return value else: if self.raw: value = self._static_view.get(key) @@ -514,26 +519,56 @@ def close(self, write=True): self._db_manager.close() self._logger.info(f"Closed ({self._db_manager.db_type.upper()}) successfully") - @abstractmethod - def keys(self): + def _get_state_buffer_info( + self, return_key=False, return_value=False, return_dict=False, decode_raw=True + ): + with self._buffer_lock: + static_view = self._db_manager.new_static_view() + buffer_dict = self.buffer_dict.copy() + delete_buffer_set = self.delete_buffer_set.copy() + + buffer_keys, buffer_values = None, None + if return_key: + if self.raw and decode_raw: + buffer_keys = set([decode_key(i) for i in buffer_dict.keys()]) + else: + buffer_keys = set(buffer_dict.keys()) + if return_value: + if self.raw and decode_raw: + buffer_values = list([decode(i) for i in buffer_dict.values()]) + else: + buffer_values = list(self.buffer_dict.values()) + if not return_dict: + buffer_dict = None + else: + if self.raw and decode_raw: + buffer_dict = {decode_key(k): decode(v) for k, v in buffer_dict.items()} + + return buffer_dict, buffer_keys, buffer_values, delete_buffer_set, static_view + + def values(self, decode_raw=True): """ - Retrieves all the keys in the database and buffer. + Retrieves all the values in the database and buffer. Returns: - list: A list of keys. + list: A list of values """ + values_list = [] + for key, value in self.items(decode_raw): + values_list.append(value) + return values_list @abstractmethod - def values(self): + def keys(self, *args, **kwargs): """ - Retrieves all the values in the database and buffer. + Retrieves all the keys in the database and buffer. Returns: - list: A list of values + list: A list of keys. """ @abstractmethod - def items(self): + def items(self, *args, **kwargs): """ Retrieves all the key-value pairs in the database and buffer. @@ -542,7 +577,7 @@ def items(self): """ @abstractmethod - def stat(self): + def stat(self, *args, **kwargs): """ Database statistics @@ -564,12 +599,15 @@ def __init__(self, path: str, map_size=1024**3, rebuild=False, **kwargs): "lmdb", path, max_dbs=1, map_size=map_size, rebuild=rebuild, **kwargs ) - def keys(self): - with self._buffer_lock: - session = self._db_manager.new_static_view() - cursor = session.cursor() - delete_buffer_set = self.delete_buffer_set.copy() - buffer_keys = set(self.buffer_dict.keys()) + def keys(self, decode_raw=True): + ( + buffer_dict, + buffer_keys, + buffer_values, + delete_buffer_set, + session, + ) = self._get_state_buffer_info(return_key=True, decode_raw=decode_raw) + cursor = session.cursor() lmdb_keys = set( decode_key(key) for key in cursor.iternext(keys=True, values=False) @@ -578,40 +616,27 @@ def keys(self): return list(lmdb_keys.union(buffer_keys) - delete_buffer_set) - def values(self): - - with self._buffer_lock: - session = self._db_manager.new_static_view() - cursor = session.cursor() - delete_buffer_set = self.delete_buffer_set.copy() - buffer_values = list(self.buffer_dict.values()) - - lmdb_values = [] - for key, value in cursor.iternext(keys=True, values=True): - dk = decode_key(key) - if dk not in delete_buffer_set: - lmdb_values.append(decode(value)) - - session.abort() - return lmdb_values + buffer_values - - def items(self): - with self._buffer_lock: - session = self._db_manager.new_static_view() - cursor = session.cursor() - buffer_dict = self.buffer_dict.copy() - delete_buffer_set = self.delete_buffer_set.copy() - + def items(self, decode_raw=True): + ( + buffer_dict, + buffer_keys, + buffer_values, + delete_buffer_set, + session, + ) = self._get_state_buffer_info(return_dict=True, decode_raw=decode_raw) + cursor = session.cursor() _db_dict = {} for key, value in cursor.iternext(keys=True, values=True): - dk = decode_key(key) + dk = key if self.raw else decode_key(key) if dk not in delete_buffer_set: + if self.raw: + dk = decode_key(dk) _db_dict[dk] = decode(value) _db_dict.update(buffer_dict) - session.abort() + self._db_manager.close_static_view(session) return _db_dict.items() @@ -649,47 +674,40 @@ class LevelDBDict(BaseDBDict): def __init__(self, path: str, rebuild=False, **kwargs): super().__init__("leveldb", path=path, rebuild=rebuild) - def keys(self): - with self._buffer_lock: - buffer_keys = set(self.buffer_dict.keys()) - snapshot = self._db_manager.new_static_view() + def keys(self, decode_raw=True): + ( + buffer_dict, + buffer_keys, + buffer_values, + delete_buffer_set, + snapshot, + ) = self._get_state_buffer_info(return_key=True, decode_raw=decode_raw) db_keys = set(decode_key(key) for key, _ in snapshot.iterator()) snapshot.close() - return list(db_keys.union(buffer_keys)) + return list(db_keys.union(buffer_keys) - delete_buffer_set) - def values(self): - with self._buffer_lock: - snapshot = self._db_manager.new_static_view() - delete_buffer_set = self.delete_buffer_set.copy() - buffer_values = list(self.buffer_dict.values()) - - db_values = [] - for key, value in snapshot.iterator(): - dk = decode_key(key) - if dk not in delete_buffer_set: - db_values.append(decode(value)) - - snapshot.close() - - return db_values + buffer_values - - def items(self): - with self._buffer_lock: - snapshot = self._db_manager.new_static_view() - delete_buffer_set = self.delete_buffer_set.copy() - buffer_dict = self.buffer_dict.copy() + def items(self, decode_raw=True): + ( + buffer_dict, + buffer_keys, + buffer_values, + delete_buffer_set, + snapshot, + ) = self._get_state_buffer_info(return_dict=True, decode_raw=decode_raw) _db_dict = {} for key, value in snapshot.iterator(): - dk = decode_key(key) + dk = key if self.raw else decode_key(key) if dk not in delete_buffer_set: + if self.raw: + dk = decode_key(dk) _db_dict[dk] = decode(value) _db_dict.update(buffer_dict) - snapshot.close() + self._db_manager.close_static_view(snapshot) return _db_dict.items() def stat(self): diff --git a/flaxkv/serve/app.py b/flaxkv/serve/app.py index debc533..4c9c6bd 100644 --- a/flaxkv/serve/app.py +++ b/flaxkv/serve/app.py @@ -1,11 +1,20 @@ +import traceback + import msgspec from litestar import Litestar, MediaType, Request, get, post -from ..pack import encode -from .interface import AttachRequest, DetachRequest, StructSetData +from ..pack import decode, decode_key, encode +from .interface import ( + AttachRequest, + DetachRequest, + PopKeyRequest, + SetRequest, + StructSetData, + StructUpdateData, +) from .manager import DBManager -db_manager = DBManager(root_path="./FLAXKV_DB") +db_manager = DBManager(root_path="./FLAXKV_DB", raw_mode=True) @post(path="/attach") @@ -26,19 +35,46 @@ async def detach(data: DetachRequest) -> dict: return {"success": True} +@post(path="/set_value") +async def set_value(data: SetRequest) -> dict: + db = db_manager.get(data.db_name) + if db is None: + return {"success": False, "info": "db not found"} + print(data.key, data.value) + print(encode(data.key), encode(data.value)) + db[encode(data.key)] = encode(data.value) + return {"success": True} + + @post(path="/set_raw") async def set_raw(db_name: str, request: Request) -> dict: - print(f"{db_name=}") db = db_manager.get(db_name) if db is None: return {"success": False, "info": "db not found"} data = await request.body() - data = msgspec.msgpack.decode(data, type=StructSetData) + try: + data = msgspec.msgpack.decode(data, type=StructSetData) + except Exception as e: + traceback.print_exc() + return {"success": False, "info": str(e)} db[data.key] = data.value - db.write_immediately(write=True, wait=True) return {"success": True} +@post(path="/update_raw") +async def update_raw(db_name: str, request: Request) -> dict: + db = db_manager.get(db_name) + if db is None: + return {"success": False, "info": "db not found"} + data = await request.body() + try: + db.update(decode(data)) + return {"success": True} + except Exception as e: + traceback.print_exc() + return {"success": False, "info": str(e)} + + @post("/get_raw", media_type=MediaType.TEXT) async def get_raw(db_name: str, request: Request) -> bytes: db = db_manager.get(db_name) @@ -61,12 +97,53 @@ async def contains(db_name: str, request: Request) -> bytes: return encode(is_contain) +@post("/pop") +async def pop(data: PopKeyRequest) -> dict: + db = db_manager.get(data.db_name) + if db is None: + return {"success": False, "info": "db not found"} + try: + return {"success": True, "value": db.pop(encode(data.key), None)} + + except Exception as e: + traceback.print_exc() + return {"success": False, "info": str(e)} + + @get("/keys") async def get_keys(db_name: str) -> dict: db = db_manager.get(db_name) if db is None: return {"success": False, "info": "db not found"} - return {"keys": list(db.keys())} + try: + return {"keys": db.keys()} + except Exception as e: + traceback.print_exc() + return {"success": False, "info": str(e)} + + +@get("/values") +async def get_values(db_name: str) -> dict: + db = db_manager.get(db_name) + if db is None: + return {"success": False, "info": "db not found"} + try: + return {"values": db.values()} + except Exception as e: + traceback.print_exc() + return {"success": False, "info": str(e)} + + +@get("/items") +async def get_items(db_name: str) -> dict: + db = db_manager.get(db_name) + if db is None: + return {"success": False, "info": "db not found"} + try: + return dict(db.items()) + except Exception as e: + traceback.print_exc() + return {"success": False, "info": str(e)} def on_shutdown(): @@ -78,8 +155,13 @@ def on_shutdown(): attach, detach, set_raw, + update_raw, get_raw, + get_items, + get_values, + set_value, contains, + pop, get_keys, ], on_startup=[lambda: print("on_startup")], diff --git a/flaxkv/serve/client.py b/flaxkv/serve/client.py index 567dc44..8de9d6d 100644 --- a/flaxkv/serve/client.py +++ b/flaxkv/serve/client.py @@ -36,13 +36,47 @@ def set(self, key, value): url = f"{self._url}/set_raw?db_name={self._db_name}" data = {"key": encode(key), "value": encode(value)} response = self._client.post(url, data=encode(data)) - return response.read() + return response.json() + + def update(self, d: dict): + url = f"{self._url}/update_raw?db_name={self._db_name}" + response = self._client.post(url, data=encode(d)) + return response.json() + + def pop(self, key, default=None): + url = f"{self._url}/pop" + data = {"key": key, "db_name": self._db_name} + response = self._client.post(url, json=data) + result = response.json() + if result["success"]: + value = result["value"] + if value is None: + return default + return value + else: + raise + + def _items_dict(self): + url = f"{self._url}/items?db_name={self._db_name}" + response = self._client.get(url) + return response.json() + + def items(self): + return self._items_dict().items() + + def __repr__(self): + return str(self._items_dict()) def keys(self): url = f"{self._url}/keys?db_name={self._db_name}" response = self._client.get(url) return response.json()["keys"] + def values(self): + url = f"{self._url}/values?db_name={self._db_name}" + response = self._client.get(url) + return response.json()["values"] + def __contains__(self, key): url = f"{self._url}/contains?db_name={self._db_name}" response = self._client.post(url, data=encode(key)) diff --git a/flaxkv/serve/interface.py b/flaxkv/serve/interface.py index a2a9cb9..aa66e14 100644 --- a/flaxkv/serve/interface.py +++ b/flaxkv/serve/interface.py @@ -18,6 +18,23 @@ class DetachRequest: db_name: str +@dataclass +class SetRequest: + db_name: str + key: Any + value: Any + + +@dataclass +class PopKeyRequest: + db_name: str + key: Any + + class StructSetData(msgspec.Struct): key: bytes value: bytes + + +class StructUpdateData(msgspec.Struct): + dict: bytes diff --git a/flaxkv/serve/manager.py b/flaxkv/serve/manager.py index c9a33dc..59a5ce2 100644 --- a/flaxkv/serve/manager.py +++ b/flaxkv/serve/manager.py @@ -4,8 +4,9 @@ class DBManager: - def __init__(self, root_path="./FLAXKV_DB"): + def __init__(self, root_path="./FLAXKV_DB", raw_mode=True): self._db_dict = {} + self._raw_mode = raw_mode self._root_path = Path(root_path) if not self._root_path.exists(): self._root_path.mkdir(parents=True) @@ -19,7 +20,7 @@ def set_db(self, db_name: str, backend: str, rebuild: bool): path_or_url=db_path.__str__(), backend=backend, rebuild=rebuild, - raw=True, + raw=self._raw_mode, log=True, )