Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Improve service architecture #6

Merged
merged 1 commit into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flaxkv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .base import LevelDBDict, LMDBDict
from .core import LevelDBDict, LMDBDict
from .serve.client import RemoteDictDB

__version__ = "0.1.6"
Expand Down
158 changes: 88 additions & 70 deletions flaxkv/base.py → flaxkv/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,8 @@ def update(self, d: dict):
raise ValueError("Input must be a dictionary.")
with self._buffer_lock:
for key, value in d.items():
if self.raw:
key, value = encode(key), encode(value)
self.buffer_dict[key] = value
self.delete_buffer_set.discard(key)

Expand Down Expand Up @@ -425,7 +427,10 @@ def pop(self, key, default=None):
self._buffered_count += 1
if key in self.buffer_dict:
value = self.buffer_dict.pop(key)
return value
if self.raw:
return decode(value)
else:
return value
else:
if self.raw:
value = self._static_view.get(key)
Expand Down Expand Up @@ -514,26 +519,56 @@ def close(self, write=True):
self._db_manager.close()
self._logger.info(f"Closed ({self._db_manager.db_type.upper()}) successfully")

@abstractmethod
def keys(self):
def _get_state_buffer_info(
self, return_key=False, return_value=False, return_dict=False, decode_raw=True
):
with self._buffer_lock:
static_view = self._db_manager.new_static_view()
buffer_dict = self.buffer_dict.copy()
delete_buffer_set = self.delete_buffer_set.copy()

buffer_keys, buffer_values = None, None
if return_key:
if self.raw and decode_raw:
buffer_keys = set([decode_key(i) for i in buffer_dict.keys()])
else:
buffer_keys = set(buffer_dict.keys())
if return_value:
if self.raw and decode_raw:
buffer_values = list([decode(i) for i in buffer_dict.values()])
else:
buffer_values = list(self.buffer_dict.values())
if not return_dict:
buffer_dict = None
else:
if self.raw and decode_raw:
buffer_dict = {decode_key(k): decode(v) for k, v in buffer_dict.items()}

return buffer_dict, buffer_keys, buffer_values, delete_buffer_set, static_view

def values(self, decode_raw=True):
"""
Retrieves all the keys in the database and buffer.
Retrieves all the values in the database and buffer.

Returns:
list: A list of keys.
list: A list of values
"""
values_list = []
for key, value in self.items(decode_raw):
values_list.append(value)
return values_list

@abstractmethod
def values(self):
def keys(self, *args, **kwargs):
"""
Retrieves all the values in the database and buffer.
Retrieves all the keys in the database and buffer.

Returns:
list: A list of values
list: A list of keys.
"""

@abstractmethod
def items(self):
def items(self, *args, **kwargs):
"""
Retrieves all the key-value pairs in the database and buffer.

Expand All @@ -542,7 +577,7 @@ def items(self):
"""

@abstractmethod
def stat(self):
def stat(self, *args, **kwargs):
"""
Database statistics

Expand All @@ -564,12 +599,15 @@ def __init__(self, path: str, map_size=1024**3, rebuild=False, **kwargs):
"lmdb", path, max_dbs=1, map_size=map_size, rebuild=rebuild, **kwargs
)

def keys(self):
with self._buffer_lock:
session = self._db_manager.new_static_view()
cursor = session.cursor()
delete_buffer_set = self.delete_buffer_set.copy()
buffer_keys = set(self.buffer_dict.keys())
def keys(self, decode_raw=True):
(
buffer_dict,
buffer_keys,
buffer_values,
delete_buffer_set,
session,
) = self._get_state_buffer_info(return_key=True, decode_raw=decode_raw)
cursor = session.cursor()

lmdb_keys = set(
decode_key(key) for key in cursor.iternext(keys=True, values=False)
Expand All @@ -578,40 +616,27 @@ def keys(self):

return list(lmdb_keys.union(buffer_keys) - delete_buffer_set)

def values(self):

with self._buffer_lock:
session = self._db_manager.new_static_view()
cursor = session.cursor()
delete_buffer_set = self.delete_buffer_set.copy()
buffer_values = list(self.buffer_dict.values())

lmdb_values = []
for key, value in cursor.iternext(keys=True, values=True):
dk = decode_key(key)
if dk not in delete_buffer_set:
lmdb_values.append(decode(value))

session.abort()
return lmdb_values + buffer_values

def items(self):
with self._buffer_lock:
session = self._db_manager.new_static_view()
cursor = session.cursor()
buffer_dict = self.buffer_dict.copy()
delete_buffer_set = self.delete_buffer_set.copy()

def items(self, decode_raw=True):
(
buffer_dict,
buffer_keys,
buffer_values,
delete_buffer_set,
session,
) = self._get_state_buffer_info(return_dict=True, decode_raw=decode_raw)
cursor = session.cursor()
_db_dict = {}

for key, value in cursor.iternext(keys=True, values=True):
dk = decode_key(key)
dk = key if self.raw else decode_key(key)
if dk not in delete_buffer_set:
if self.raw:
dk = decode_key(dk)
_db_dict[dk] = decode(value)

_db_dict.update(buffer_dict)

session.abort()
self._db_manager.close_static_view(session)

return _db_dict.items()

Expand Down Expand Up @@ -649,47 +674,40 @@ class LevelDBDict(BaseDBDict):
def __init__(self, path: str, rebuild=False, **kwargs):
super().__init__("leveldb", path=path, rebuild=rebuild)

def keys(self):
with self._buffer_lock:
buffer_keys = set(self.buffer_dict.keys())
snapshot = self._db_manager.new_static_view()
def keys(self, decode_raw=True):
(
buffer_dict,
buffer_keys,
buffer_values,
delete_buffer_set,
snapshot,
) = self._get_state_buffer_info(return_key=True, decode_raw=decode_raw)

db_keys = set(decode_key(key) for key, _ in snapshot.iterator())
snapshot.close()

return list(db_keys.union(buffer_keys))
return list(db_keys.union(buffer_keys) - delete_buffer_set)

def values(self):
with self._buffer_lock:
snapshot = self._db_manager.new_static_view()
delete_buffer_set = self.delete_buffer_set.copy()
buffer_values = list(self.buffer_dict.values())

db_values = []
for key, value in snapshot.iterator():
dk = decode_key(key)
if dk not in delete_buffer_set:
db_values.append(decode(value))

snapshot.close()

return db_values + buffer_values

def items(self):
with self._buffer_lock:
snapshot = self._db_manager.new_static_view()
delete_buffer_set = self.delete_buffer_set.copy()
buffer_dict = self.buffer_dict.copy()
def items(self, decode_raw=True):
(
buffer_dict,
buffer_keys,
buffer_values,
delete_buffer_set,
snapshot,
) = self._get_state_buffer_info(return_dict=True, decode_raw=decode_raw)

_db_dict = {}
for key, value in snapshot.iterator():
dk = decode_key(key)
dk = key if self.raw else decode_key(key)
if dk not in delete_buffer_set:
if self.raw:
dk = decode_key(dk)
_db_dict[dk] = decode(value)

_db_dict.update(buffer_dict)

snapshot.close()
self._db_manager.close_static_view(snapshot)
return _db_dict.items()

def stat(self):
Expand Down
96 changes: 89 additions & 7 deletions flaxkv/serve/app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
import traceback

import msgspec
from litestar import Litestar, MediaType, Request, get, post

from ..pack import encode
from .interface import AttachRequest, DetachRequest, StructSetData
from ..pack import decode, decode_key, encode
from .interface import (
AttachRequest,
DetachRequest,
PopKeyRequest,
SetRequest,
StructSetData,
StructUpdateData,
)
from .manager import DBManager

db_manager = DBManager(root_path="./FLAXKV_DB")
db_manager = DBManager(root_path="./FLAXKV_DB", raw_mode=True)


@post(path="/attach")
Expand All @@ -26,19 +35,46 @@ async def detach(data: DetachRequest) -> dict:
return {"success": True}


@post(path="/set_value")
async def set_value(data: SetRequest) -> dict:
db = db_manager.get(data.db_name)
if db is None:
return {"success": False, "info": "db not found"}
print(data.key, data.value)
print(encode(data.key), encode(data.value))
db[encode(data.key)] = encode(data.value)
return {"success": True}


@post(path="/set_raw")
async def set_raw(db_name: str, request: Request) -> dict:
print(f"{db_name=}")
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
data = await request.body()
data = msgspec.msgpack.decode(data, type=StructSetData)
try:
data = msgspec.msgpack.decode(data, type=StructSetData)
except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}
db[data.key] = data.value
db.write_immediately(write=True, wait=True)
return {"success": True}


@post(path="/update_raw")
async def update_raw(db_name: str, request: Request) -> dict:
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
data = await request.body()
try:
db.update(decode(data))
return {"success": True}
except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}


@post("/get_raw", media_type=MediaType.TEXT)
async def get_raw(db_name: str, request: Request) -> bytes:
db = db_manager.get(db_name)
Expand All @@ -61,12 +97,53 @@ async def contains(db_name: str, request: Request) -> bytes:
return encode(is_contain)


@post("/pop")
async def pop(data: PopKeyRequest) -> dict:
db = db_manager.get(data.db_name)
if db is None:
return {"success": False, "info": "db not found"}
try:
return {"success": True, "value": db.pop(encode(data.key), None)}

except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}


@get("/keys")
async def get_keys(db_name: str) -> dict:
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
return {"keys": list(db.keys())}
try:
return {"keys": db.keys()}
except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}


@get("/values")
async def get_values(db_name: str) -> dict:
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
try:
return {"values": db.values()}
except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}


@get("/items")
async def get_items(db_name: str) -> dict:
db = db_manager.get(db_name)
if db is None:
return {"success": False, "info": "db not found"}
try:
return dict(db.items())
except Exception as e:
traceback.print_exc()
return {"success": False, "info": str(e)}


def on_shutdown():
Expand All @@ -78,8 +155,13 @@ def on_shutdown():
attach,
detach,
set_raw,
update_raw,
get_raw,
get_items,
get_values,
set_value,
contains,
pop,
get_keys,
],
on_startup=[lambda: print("on_startup")],
Expand Down
Loading
Loading