Skip to content

Commit

Permalink
Merge pull request #6 from KenyonY/serve/add_methods
Browse files Browse the repository at this point in the history
✨ Improve service architecture
  • Loading branch information
KenyonY authored Dec 5, 2023
2 parents 74cde66 + c32487e commit 6d250bb
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 81 deletions.
2 changes: 1 addition & 1 deletion flaxkv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import LevelDBDict, LMDBDict
from .core import LevelDBDict, LMDBDict
from .serve.client import RemoteDictDB

__version__ = "0.1.6"
Expand Down
158 changes: 88 additions & 70 deletions flaxkv/base.py → flaxkv/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ def update(self, d: dict):
raise ValueError("Input must be a dictionary.")
with self._buffer_lock:
for key, value in d.items():
if self.raw:
key, value = encode(key), encode(value)
self.buffer_dict[key] = value
self.delete_buffer_set.discard(key)

Expand Down Expand Up @@ -425,7 +427,10 @@ def pop(self, key, default=None):
self._buffered_count += 1
if key in self.buffer_dict:
value = self.buffer_dict.pop(key)
return value
if self.raw:
return decode(value)
else:
return value
else:
if self.raw:
value = self._static_view.get(key)
Expand Down Expand Up @@ -514,26 +519,56 @@ def close(self, write=True):
self._db_manager.close()
self._logger.info(f"Closed ({self._db_manager.db_type.upper()}) successfully")

@abstractmethod
def keys(self):
def _get_state_buffer_info(
self, return_key=False, return_value=False, return_dict=False, decode_raw=True
):
with self._buffer_lock:
static_view = self._db_manager.new_static_view()
buffer_dict = self.buffer_dict.copy()
delete_buffer_set = self.delete_buffer_set.copy()

buffer_keys, buffer_values = None, None
if return_key:
if self.raw and decode_raw:
buffer_keys = set([decode_key(i) for i in buffer_dict.keys()])
else:
buffer_keys = set(buffer_dict.keys())
if return_value:
if self.raw and decode_raw:
buffer_values = list([decode(i) for i in buffer_dict.values()])
else:
buffer_values = list(self.buffer_dict.values())
if not return_dict:
buffer_dict = None
else:
if self.raw and decode_raw:
buffer_dict = {decode_key(k): decode(v) for k, v in buffer_dict.items()}

return buffer_dict, buffer_keys, buffer_values, delete_buffer_set, static_view

def values(self, decode_raw=True):
"""
Retrieves all the keys in the database and buffer.
Retrieves all the values in the database and buffer.
Returns:
list: A list of keys.
list: A list of values
"""
values_list = []
for key, value in self.items(decode_raw):
values_list.append(value)
return values_list

@abstractmethod
def values(self):
def keys(self, *args, **kwargs):
"""
Retrieves all the values in the database and buffer.
Retrieves all the keys in the database and buffer.
Returns:
list: A list of values
list: A list of keys.
"""

@abstractmethod
def items(self):
def items(self, *args, **kwargs):
"""
Retrieves all the key-value pairs in the database and buffer.
Expand All @@ -542,7 +577,7 @@ def items(self):
"""

@abstractmethod
def stat(self):
def stat(self, *args, **kwargs):
"""
Database statistics
Expand All @@ -564,12 +599,15 @@ def __init__(self, path: str, map_size=1024**3, rebuild=False, **kwargs):
"lmdb", path, max_dbs=1, map_size=map_size, rebuild=rebuild, **kwargs
)

def keys(self):
with self._buffer_lock:
session = self._db_manager.new_static_view()
cursor = session.cursor()
delete_buffer_set = self.delete_buffer_set.copy()
buffer_keys = set(self.buffer_dict.keys())
def keys(self, decode_raw=True):
(
buffer_dict,
buffer_keys,
buffer_values,
delete_buffer_set,
session,
) = self._get_state_buffer_info(return_key=True, decode_raw=decode_raw)
cursor = session.cursor()

lmdb_keys = set(
decode_key(key) for key in cursor.iternext(keys=True, values=False)
Expand All @@ -578,40 +616,27 @@ def keys(self):

return list(lmdb_keys.union(buffer_keys) - delete_buffer_set)

def values(self):

with self._buffer_lock:
session = self._db_manager.new_static_view()
cursor = session.cursor()
delete_buffer_set = self.delete_buffer_set.copy()
buffer_values = list(self.buffer_dict.values())

lmdb_values = []
for key, value in cursor.iternext(keys=True, values=True):
dk = decode_key(key)
if dk not in delete_buffer_set:
lmdb_values.append(decode(value))

session.abort()
return lmdb_values + buffer_values

def items(self):
with self._buffer_lock:
session = self._db_manager.new_static_view()
cursor = session.cursor()
buffer_dict = self.buffer_dict.copy()
delete_buffer_set = self.delete_buffer_set.copy()

def items(self, decode_raw=True):
(
buffer_dict,
buffer_keys,
buffer_values,
delete_buffer_set,
session,
) = self._get_state_buffer_info(return_dict=True, decode_raw=decode_raw)
cursor = session.cursor()
_db_dict = {}

for key, value in cursor.iternext(keys=True, values=True):
dk = decode_key(key)
dk = key if self.raw else decode_key(key)
if dk not in delete_buffer_set:
if self.raw:
dk = decode_key(dk)
_db_dict[dk] = decode(value)

_db_dict.update(buffer_dict)

session.abort()
self._db_manager.close_static_view(session)

return _db_dict.items()

Expand Down Expand Up @@ -649,47 +674,40 @@ class LevelDBDict(BaseDBDict):
def __init__(self, path: str, rebuild=False, **kwargs):
super().__init__("leveldb", path=path, rebuild=rebuild)

def keys(self):
with self._buffer_lock:
buffer_keys = set(self.buffer_dict.keys())
snapshot = self._db_manager.new_static_view()
def keys(self, decode_raw=True):
(
buffer_dict,
buffer_keys,
buffer_values,
delete_buffer_set,
snapshot,
) = self._get_state_buffer_info(return_key=True, decode_raw=decode_raw)

db_keys = set(decode_key(key) for key, _ in snapshot.iterator())
snapshot.close()

return list(db_keys.union(buffer_keys))
return list(db_keys.union(buffer_keys) - delete_buffer_set)

def values(self):
with self._buffer_lock:
snapshot = self._db_manager.new_static_view()
delete_buffer_set = self.delete_buffer_set.copy()
buffer_values = list(self.buffer_dict.values())

db_values = []
for key, value in snapshot.iterator():
dk = decode_key(key)
if dk not in delete_buffer_set:
db_values.append(decode(value))

snapshot.close()

return db_values + buffer_values

def items(self):
with self._buffer_lock:
snapshot = self._db_manager.new_static_view()
delete_buffer_set = self.delete_buffer_set.copy()
buffer_dict = self.buffer_dict.copy()
def items(self, decode_raw=True):
(
buffer_dict,
buffer_keys,
buffer_values,
delete_buffer_set,
snapshot,
) = self._get_state_buffer_info(return_dict=True, decode_raw=decode_raw)

_db_dict = {}
for key, value in snapshot.iterator():
dk = decode_key(key)
dk = key if self.raw else decode_key(key)
if dk not in delete_buffer_set:
if self.raw:
dk = decode_key(dk)
_db_dict[dk] = decode(value)

_db_dict.update(buffer_dict)

snapshot.close()
self._db_manager.close_static_view(snapshot)
return _db_dict.items()

def stat(self):
Expand Down
96 changes: 89 additions & 7 deletions flaxkv/serve/app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import traceback

import msgspec
from litestar import Litestar, MediaType, Request, get, post

from ..pack import encode
from .interface import AttachRequest, DetachRequest, StructSetData
from ..pack import decode, decode_key, encode
from .interface import (
AttachRequest,
DetachRequest,
PopKeyRequest,
SetRequest,
StructSetData,
StructUpdateData,
)
from .manager import DBManager

db_manager = DBManager(root_path="./FLAXKV_DB")
db_manager = DBManager(root_path="./FLAXKV_DB", raw_mode=True)


@post(path="/attach")
Expand All @@ -26,19 +35,46 @@ async def detach(data: DetachRequest) -> dict:
return {"success": True}


@post(path="/set_value")
async def set_value(data: SetRequest) -> dict:
db = db_manager.get(data.db_name)
if db is None:
return {"success": False, "info": "db not found"}
print(data.key, data.value)
print(encode(data.key), encode(data.value))
db[encode(data.key)] = encode(data.value)
return {"success": True}


@post(path="/set_raw")
async def set_raw(db_name: str, request: Request) -> dict:
print(f"{db_name=}")
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
data = await request.body()
data = msgspec.msgpack.decode(data, type=StructSetData)
try:
data = msgspec.msgpack.decode(data, type=StructSetData)
except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}
db[data.key] = data.value
db.write_immediately(write=True, wait=True)
return {"success": True}


@post(path="/update_raw")
async def update_raw(db_name: str, request: Request) -> dict:
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
data = await request.body()
try:
db.update(decode(data))
return {"success": True}
except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}


@post("/get_raw", media_type=MediaType.TEXT)
async def get_raw(db_name: str, request: Request) -> bytes:
db = db_manager.get(db_name)
Expand All @@ -61,12 +97,53 @@ async def contains(db_name: str, request: Request) -> bytes:
return encode(is_contain)


@post("/pop")
async def pop(data: PopKeyRequest) -> dict:
db = db_manager.get(data.db_name)
if db is None:
return {"success": False, "info": "db not found"}
try:
return {"success": True, "value": db.pop(encode(data.key), None)}

except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}


@get("/keys")
async def get_keys(db_name: str) -> dict:
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
return {"keys": list(db.keys())}
try:
return {"keys": db.keys()}
except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}


@get("/values")
async def get_values(db_name: str) -> dict:
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
try:
return {"values": db.values()}
except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}


@get("/items")
async def get_items(db_name: str) -> dict:
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
try:
return dict(db.items())
except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}


def on_shutdown():
Expand All @@ -78,8 +155,13 @@ def on_shutdown():
attach,
detach,
set_raw,
update_raw,
get_raw,
get_items,
get_values,
set_value,
contains,
pop,
get_keys,
],
on_startup=[lambda: print("on_startup")],
Expand Down
Loading

0 comments on commit 6d250bb

Please sign in to comment.