Skip to content

Commit

Permalink
feat: Use streaming when retrieving from remote db
Browse files Browse the repository at this point in the history
  • Loading branch information
KenyonY committed Dec 27, 2023
1 parent 02d8a5b commit fd002a7
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 9 deletions.
2 changes: 1 addition & 1 deletion flaxkv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .core import LevelDBDict, LMDBDict, RemoteDBDict

__version__ = "0.2.2"
__version__ = "0.2.3"

__all__ = [
"FlaxKV",
Expand Down
14 changes: 8 additions & 6 deletions flaxkv/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -982,12 +982,14 @@ def db_dict(self, decode_raw=True):
_db_dict = self._cache_dict.copy()
else:
_db_dict = {}
response: Response = view.client.get(f"/dict?db_name={self._db_name}")
if not response.is_success:
raise ValueError(
f"Failed to get items from remote db: {decode(response.read())}"
)
remote_db_dict = decode(response.read())
with view.client.stream(
"GET", f"/dict_stream?db_name={self._db_name}"
) as r:
data_stream = b""
for data in r.iter_bytes():
data_stream += data

remote_db_dict = decode(data_stream)
for dk, dv in remote_db_dict.items():
if dk not in delete_buffer_set:
_db_dict[dk] = dv
Expand Down
24 changes: 22 additions & 2 deletions flaxkv/serve/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@

from __future__ import annotations

import io
import traceback
from typing import AsyncGenerator

import msgspec
from litestar import Litestar, MediaType, Request, get, post, status_codes
from litestar.exceptions import HTTPException
from litestar.openapi import OpenAPIConfig
from litestar.response import Stream

from .. import __version__
from ..pack import encode
Expand Down Expand Up @@ -161,7 +164,7 @@ async def _keys(db_name: str) -> bytes:


@get("/dict", media_type=MediaType.TEXT)
async def _items(db_name: str) -> bytes:
async def _dict(db_name: str) -> bytes:
db = get_db(db_name)
try:
return encode(db.db_dict())
Expand All @@ -170,6 +173,22 @@ async def _items(db_name: str) -> bytes:
raise HTTPException(status_code=500, detail=str(e))


@get("/dict_stream", media_type=MediaType.TEXT)
async def _dict_stream(db_name: str) -> Stream:
async def my_generator(data: bytes, chunk_size=4096) -> AsyncGenerator[bytes, None]:
with io.BytesIO(data) as data_io:
while chunk := data_io.read(chunk_size):
yield chunk

db = get_db(db_name)
try:
result_bin = encode(db.db_dict())
return Stream(my_generator(result_bin))
except Exception as e:
traceback.print_exc()
raise HTTPException(status_code=500, detail=str(e))


@get("/stat", media_type=MediaType.TEXT)
async def _stat(db_name: str) -> bytes:
db = get_db(db_name)
Expand Down Expand Up @@ -199,7 +218,8 @@ def on_shutdown():
_delete,
_delete_batch,
_keys,
_items,
_dict,
_dict_stream,
_stat,
],
on_startup=[on_startup],
Expand Down

0 comments on commit fd002a7

Please sign in to comment.