Skip to content

Commit

Permalink
Implement YStore versioning
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Nov 21, 2022
1 parent ebf67f3 commit 8ea0486
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 21 deletions.
21 changes: 18 additions & 3 deletions tests/test_ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from ypy_websocket.ystore import SQLiteYStore, TempFileYStore
from ypy_websocket.ystore import SQLiteYStore, TempFileYStore, YDocVersionMismatch


class MetadataCallback:
Expand All @@ -22,13 +22,16 @@ class MyTempFileYStore(TempFileYStore):
prefix_dir = "test_temp_"


MY_SQLITE_YSTORE_DB_PATH = str(Path(tempfile.mkdtemp(prefix="test_sql_")) / "ystore.db")


class MySQLiteYStore(SQLiteYStore):
db_path = str(Path(tempfile.mkdtemp(prefix="test_sql_")) / "ystore.db")
db_path = MY_SQLITE_YSTORE_DB_PATH


@pytest.mark.asyncio
@pytest.mark.parametrize("YStore", (MyTempFileYStore, MySQLiteYStore))
async def test_file_ystore(YStore):
async def test_ystore(YStore):
store_name = "my_store"
ystore = YStore(store_name, metadata_callback=MetadataCallback())
data = [b"foo", b"bar", b"baz"]
Expand All @@ -44,3 +47,15 @@ async def test_file_ystore(YStore):
assert d == data[i] # data
assert m == bytes(i) # metadata
i += 1


@pytest.mark.asyncio
@pytest.mark.parametrize("YStore", (MyTempFileYStore, MySQLiteYStore))
async def test_version(YStore):
store_name = "my_store"
prev_version = YStore.version
YStore.version = -1
ystore = YStore(store_name, metadata_callback=MetadataCallback())
with pytest.raises(YDocVersionMismatch):
await ystore.write(b"foo")
YStore.version = prev_version
71 changes: 53 additions & 18 deletions ypy_websocket/ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,25 @@
from typing import AsyncIterator, Callable, Optional, Tuple

import aiofiles # type: ignore
import aiofiles.os # type: ignore
import aiosqlite # type: ignore
import y_py as Y

from .yutils import Decoder, write_var_uint


class YDocVersionMismatch(Exception):
pass


class YDocNotFound(Exception):
pass


class BaseYStore(ABC):

metadata_callback: Optional[Callable] = None
version = 1

@abstractmethod
def __init__(self, path: str, metadata_callback=None):
Expand Down Expand Up @@ -57,14 +63,28 @@ def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
self.metadata_callback = metadata_callback
self.lock = asyncio.Lock()

async def check_version(self, read_file: bool = False) -> Optional[bytes]:
async with aiofiles.open(self.path, "rb") as f:
header = await f.read(8)
if header == b"VERSION:":
version = int(await f.readline())
if version != self.version:
raise YDocVersionMismatch(
f"YStore version mismatch: got '{version}' in file, but supported is '{self.version}'"
)
else:
raise YDocVersionMismatch("YStore has no version")
if read_file:
return await f.read()
return None

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
async with self.lock:
try:
async with aiofiles.open(self.path, "rb") as f:
data = await f.read()
except BaseException:
if not await aiofiles.os.path.exists(self.path):
raise YDocNotFound
data = await self.check_version(read_file=True)
is_data = True
assert data is not None
for d in Decoder(data).read_messages():
if is_data:
update = d
Expand All @@ -75,13 +95,17 @@ async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore

async def write(self, data: bytes) -> None:
parent = Path(self.path).parent
if not parent.exists():
parent.mkdir(parents=True)
mode = "wb"
else:
await aiofiles.os.makedirs(parent, exist_ok=True)
if await aiofiles.os.path.exists(self.path):
await self.check_version()
mode = "ab"
else:
mode = "wb"
async with self.lock:
async with aiofiles.open(self.path, mode) as f:
if mode == "wb":
version = f"VERSION:{self.version}\n".encode()
await f.write(version)
data_len = write_var_uint(len(data))
await f.write(data_len + data)
metadata = await self.get_metadata()
Expand Down Expand Up @@ -127,24 +151,35 @@ class MySQLiteYStore(SQLiteYStore):

db_path: str = "ystore.db"
path: str
db_created: asyncio.Event
db_initialized: asyncio.Task

def __init__(self, path: str, metadata_callback: Optional[Callable] = None):
self.path = path
self.metadata_callback = metadata_callback
self.db_created = asyncio.Event()
asyncio.create_task(self.create_db())
self.db_initialized = asyncio.create_task(self.init_db())

async def create_db(self):
async def init_db(self):
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
"CREATE TABLE IF NOT EXISTS yupdates (path TEXT, yupdate BLOB, metadata BLOB, timestamp TEXT)"
cursor = await db.execute(
"SELECT count(name) FROM sqlite_master WHERE type='table' and name='yupdates'"
)
await db.commit()
self.db_created.set()
table_exists = (await cursor.fetchone())[0]
if table_exists:
cursor = await db.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
raise YDocVersionMismatch(
f"YStore version mismatch: got '{version}' in DB, but supported is '{self.version}'"
)
else:
await db.execute(
"CREATE TABLE yupdates (path TEXT, yupdate BLOB, metadata BLOB, timestamp TEXT)"
)
await db.execute(f"PRAGMA user_version = {self.version}")
await db.commit()

async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
await self.db_created.wait()
await self.db_initialized
try:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
Expand All @@ -160,7 +195,7 @@ async def read(self) -> AsyncIterator[Tuple[bytes, bytes]]: # type: ignore
raise YDocNotFound

async def write(self, data: bytes) -> None:
await self.db_created.wait()
await self.db_initialized
metadata = await self.get_metadata()
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
Expand Down

0 comments on commit 8ea0486

Please sign in to comment.