diff --git a/docs/usage.md b/docs/usage.md index c732926..0d9be81 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -156,13 +156,13 @@ Unregistering the callback is done with the same `unobserve` method. ### Document events Observing changes made to a document is mostly meant to send the changes to another document, usually over the wire to a remote machine. -Changes can be serialized to binary by calling `get_update()` on the event: +Changes can be serialized to binary by getting the event's `update`: ```py from pycrdt import TransactionEvent def handle_doc_changes(event: TransactionEvent): - update: bytes = event.get_update() + update: bytes = event.update # send binary update on the wire doc.observe(handle_doc_changes) diff --git a/pyproject.toml b/pyproject.toml index b66a612..d299ea1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", "Programming Language :: Rust", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", @@ -30,9 +31,10 @@ classifiers = [ [project.optional-dependencies] test = [ - "pytest >=7.4.2,<8", - "y-py >=0.7.0a1,<0.8", - "mypy", + "pytest >=7.4.2,<8", + "y-py >=0.7.0a1,<0.8", + "pydantic >=2.5.2,<3", + "mypy", ] docs = [ "mkdocs", "mkdocs-material" ] diff --git a/python/pycrdt/array.py b/python/pycrdt/array.py index 90fa99d..cd21361 100644 --- a/python/pycrdt/array.py +++ b/python/pycrdt/array.py @@ -175,6 +175,11 @@ def __str__(self) -> str: with self.doc.transaction() as txn: return self.integrated.to_json(txn._txn) + def to_py(self) -> list | None: + if self._integrated is None: + return self._prelim + return list(self) + def observe(self, callback: Callable[[Any], None]) -> str: _callback = partial(observe_callback, callback, self.doc) return f"o_{self.integrated.observe(_callback)}" @@ -193,7 +198,9 @@ def unobserve(self, subscription_id: str) -> None: def observe_callback(callback: Callable[[Any], None], doc: Doc, event: Any): _event = event_types[type(event)](event, doc) + doc._txn = ReadTransaction(doc=doc, _txn=event.transaction) callback(_event) + doc._txn = None def observe_deep_callback(callback: Callable[[Any], None], doc: Doc, events: list[Any]): diff --git a/python/pycrdt/base.py b/python/pycrdt/base.py index 0ee7f0f..9d95314 100644 --- a/python/pycrdt/base.py +++ b/python/pycrdt/base.py @@ -5,7 +5,7 @@ from ._pycrdt import Doc as _Doc from ._pycrdt import Transaction as _Transaction -from .transaction import ReadTransaction, Transaction +from .transaction import Transaction if TYPE_CHECKING: from .doc import Doc @@ -17,18 +17,26 @@ class BaseDoc: _doc: _Doc + _twin_doc: BaseDoc | None _txn: Transaction | None + _Model: Any + _dict: dict[str, BaseType] def __init__( self, *, client_id: int | None = None, doc: _Doc | None = None, + Model=None, + **data, ) -> None: + super().__init__(**data) if doc is None: doc = _Doc(client_id) self._doc = doc self._txn = None + self._Model = Model + self._dict = {} class BaseType(ABC): @@ -56,6 +64,10 @@ def __init__( self._prelim = init self._integrated = None + @abstractmethod + def to_py(self) -> Any: + ... + @abstractmethod def _get_or_insert(self, name: str, doc: Doc) -> Any: ... @@ -136,7 +148,7 @@ class BaseEvent: def __init__(self, event: Any, doc: Doc): slot: str for slot in self.__slots__: - processed = process_event(getattr(event, slot), doc, event.transaction) + processed = process_event(getattr(event, slot), doc) setattr(self, slot, processed) def __str__(self): @@ -148,13 +160,13 @@ def __str__(self): return "{" + ret + "}" -def process_event(value: Any, doc: Doc, txn) -> Any: +def process_event(value: Any, doc: Doc) -> Any: if isinstance(value, list): for idx, val in enumerate(value): - value[idx] = process_event(val, doc, txn) + value[idx] = process_event(val, doc) elif isinstance(value, dict): for key, val in value.items(): - value[key] = process_event(val, doc, txn) + value[key] = process_event(val, doc) else: val_type = type(value) if val_type in base_types: @@ -164,5 +176,4 @@ def process_event(value: Any, doc: Doc, txn) -> Any: else: base_type = cast(Type[BaseType], base_types[val_type]) value = base_type(_integrated=value, _doc=doc) - doc._txn = ReadTransaction(doc=doc, _txn=txn) return value diff --git a/python/pycrdt/doc.py b/python/pycrdt/doc.py index 07cb4dc..f5f2ad0 100644 --- a/python/pycrdt/doc.py +++ b/python/pycrdt/doc.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, cast from ._pycrdt import Doc as _Doc from ._pycrdt import SubdocsEvent, TransactionEvent @@ -9,16 +9,20 @@ class Doc(BaseDoc): + def __init__( self, init: dict[str, BaseType] = {}, *, client_id: int | None = None, doc: _Doc | None = None, + Model=None, ) -> None: - super().__init__(client_id=client_id, doc=doc) + super().__init__(client_id=client_id, doc=doc, Model=Model) for k, v in init.items(): self[k] = v + if Model is not None: + self._twin_doc = Doc(init) @property def guid(self) -> int: @@ -42,6 +46,15 @@ def get_update(self, state: bytes | None = None) -> bytes: return self._doc.get_update(state) def apply_update(self, update: bytes) -> None: + if self._Model is not None: + twin_doc = cast(Doc, self._twin_doc) + twin_doc.apply_update(update) + d = {k: twin_doc[k].to_py() for k in self._Model.model_fields} + try: + self._Model(**d) + except Exception as e: + self._twin_doc = Doc(self._dict) + raise e self._doc.apply_update(update) def __setitem__(self, key: str, value: BaseType) -> None: @@ -50,6 +63,10 @@ def __setitem__(self, key: str, value: BaseType) -> None: integrated = value._get_or_insert(key, self) prelim = value._integrate(self, integrated) value._init(prelim) + self._dict[key] = value + + def __getitem__(self, key: str) -> BaseType: + return self._dict[key] def observe(self, callback: Callable[[TransactionEvent], None]) -> int: return self._doc.observe(callback) diff --git a/python/pycrdt/map.py b/python/pycrdt/map.py index 7d2ebc6..f1a90cd 100644 --- a/python/pycrdt/map.py +++ b/python/pycrdt/map.py @@ -64,6 +64,11 @@ def __str__(self) -> str: with self.doc.transaction() as txn: return self.integrated.to_json(txn._txn) + def to_py(self) -> dict | None: + if self._integrated is None: + return self._prelim + return dict(self) + def __delitem__(self, key: str) -> None: if not isinstance(key, str): raise RuntimeError("Key must be of type string") @@ -147,7 +152,9 @@ def unobserve(self, subscription_id: str) -> None: def observe_callback(callback: Callable[[Any], None], doc: Doc, event: Any): _event = event_types[type(event)](event, doc) + doc._txn = ReadTransaction(doc=doc, _txn=event.transaction) callback(_event) + doc._txn = None def observe_deep_callback(callback: Callable[[Any], None], doc: Doc, events: list[Any]): diff --git a/python/pycrdt/text.py b/python/pycrdt/text.py index efac7a2..9b3f3e6 100644 --- a/python/pycrdt/text.py +++ b/python/pycrdt/text.py @@ -46,6 +46,11 @@ def __str__(self) -> str: with self.doc.transaction() as txn: return self.integrated.get_string(txn._txn) + def to_py(self) -> str | None: + if self._integrated is None: + return self._prelim + return str(self) + def __iadd__(self, value: str) -> Text: with self.doc.transaction() as txn: if isinstance(txn, ReadTransaction): @@ -89,7 +94,13 @@ def __setitem__(self, key: int | slice, value: str) -> None: "Read-only transaction cannot be used to modify document structure" ) if isinstance(key, int): - raise RuntimeError("Single item assignment not supported") + value_len = len(value) + if value_len != 1: + raise RuntimeError( + f"Single item assigned value must have a length of 1, not {value_len}" + ) + del self[key] + self.integrated.insert(txn._txn, key, value) elif isinstance(key, slice): if key.step is not None: raise RuntimeError("Step not supported") @@ -118,7 +129,9 @@ def unobserve(self, subscription_id: str) -> None: def observe_callback(callback: Callable[[Any], None], doc: Doc, event: Any): _event = event_types[type(event)](event, doc) + doc._txn = ReadTransaction(doc=doc, _txn=event.transaction) callback(_event) + doc._txn = None class TextEvent(BaseEvent): diff --git a/python/pycrdt/transaction.py b/python/pycrdt/transaction.py index 69ff48d..9cc249e 100644 --- a/python/pycrdt/transaction.py +++ b/python/pycrdt/transaction.py @@ -1,5 +1,6 @@ from __future__ import annotations +from types import TracebackType from typing import TYPE_CHECKING from ._pycrdt import Transaction as _Transaction @@ -25,7 +26,12 @@ def __enter__(self) -> Transaction: self._doc._txn = self return self - def __exit__(self, exc_type, exc_value, exc_tb) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: self._nb -= 1 # only drop the transaction when exiting root context manager # since nested transactions reuse the root transaction diff --git a/src/doc.rs b/src/doc.rs index 384e72c..210990c 100644 --- a/src/doc.rs +++ b/src/doc.rs @@ -133,52 +133,99 @@ impl Doc { #[pyclass(unsendable)] pub struct TransactionEvent { - before_state: PyObject, - after_state: PyObject, - delete_set: PyObject, - update: PyObject, + event: *const TransactionCleanupEvent, + txn: *const TransactionMut<'static>, + before_state: Option, + after_state: Option, + delete_set: Option, + update: Option, + transaction: Option, } impl TransactionEvent { fn new(event: &TransactionCleanupEvent, txn: &TransactionMut) -> Self { - // Convert all event data into Python objects eagerly, so that we don't have to hold - // on to the transaction. - let before_state = event.before_state.encode_v1(); - let before_state: PyObject = Python::with_gil(|py| PyBytes::new(py, &before_state).into()); - let after_state = event.after_state.encode_v1(); - let after_state: PyObject = Python::with_gil(|py| PyBytes::new(py, &after_state).into()); - let delete_set = event.delete_set.encode_v1(); - let delete_set: PyObject = Python::with_gil(|py| PyBytes::new(py, &delete_set).into()); - let update = txn.encode_update_v1(); - let update = Python::with_gil(|py| PyBytes::new(py, &update).into()); - TransactionEvent { - before_state, - after_state, - delete_set, - update, - } + let event = event as *const TransactionCleanupEvent; + let txn = unsafe { std::mem::transmute::<&TransactionMut, &TransactionMut<'static>>(txn) }; + let mut transaction_event = TransactionEvent { + event, + txn, + before_state: None, + after_state: None, + delete_set: None, + update: None, + transaction: None, + }; + transaction_event.update(); + transaction_event + } + + fn event(&self) -> &TransactionCleanupEvent { + unsafe { self.event.as_ref().unwrap() } + } + fn txn(&self) -> &TransactionMut { + unsafe { self.txn.as_ref().unwrap() } } } #[pymethods] impl TransactionEvent { + #[getter] + pub fn transaction(&mut self) -> PyObject { + if let Some(transaction) = self.transaction.as_ref() { + transaction.clone() + } else { + let transaction: PyObject = Python::with_gil(|py| Transaction::from(self.txn()).into_py(py)); + self.transaction = Some(transaction.clone()); + transaction + } + } + #[getter] pub fn before_state(&mut self) -> PyObject { - self.before_state.clone() + if let Some(before_state) = &self.before_state { + before_state.clone() + } else { + let before_state = self.event().before_state.encode_v1(); + let before_state: PyObject = Python::with_gil(|py| PyBytes::new(py, &before_state).into()); + self.before_state = Some(before_state.clone()); + before_state + } } #[getter] pub fn after_state(&mut self) -> PyObject { - self.after_state.clone() + if let Some(after_state) = &self.after_state { + after_state.clone() + } else { + let after_state = self.event().after_state.encode_v1(); + let after_state: PyObject = Python::with_gil(|py| PyBytes::new(py, &after_state).into()); + self.after_state = Some(after_state.clone()); + after_state + } } #[getter] pub fn delete_set(&mut self) -> PyObject { - self.delete_set.clone() + if let Some(delete_set) = &self.delete_set { + delete_set.clone() + } else { + let delete_set = self.event().delete_set.encode_v1(); + let delete_set: PyObject = Python::with_gil(|py| PyBytes::new(py, &delete_set).into()); + self.delete_set = Some(delete_set.clone()); + delete_set + } } - pub fn get_update(&self) -> PyObject { - self.update.clone() + #[getter] + pub fn update(&mut self) -> PyObject { + if let Some(update) = &self.update { + update.clone() + } else { + let update = self.txn().encode_update_v1(); + let update: PyObject = Python::with_gil(|py| PyBytes::new(py, &update).into()); + self.update = Some(update.clone()); + update + } } } diff --git a/tests/test_doc.py b/tests/test_doc.py index d2fae8e..91b3c6d 100644 --- a/tests/test_doc.py +++ b/tests/test_doc.py @@ -86,8 +86,7 @@ def test_transaction_event(): remote_doc = Doc() for event in events: - update = event.get_update() - remote_doc.apply_update(update) + remote_doc.apply_update(event.update) remote_text = Text() remote_doc["text"] = remote_text diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..94bc517 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,55 @@ +from datetime import datetime +from typing import Tuple + +import pytest +from pycrdt import Array, Doc, Text +from pydantic import BaseModel +from pydantic import ValidationError + + +def test_model(): + remote_doc = Doc( + { + "timestamp": Text("2020-01-02T03:04:05Z"), + "dimensions": Array(["10", "20"]), + } + ) + update = remote_doc.get_update() + + class Delivery(BaseModel): + timestamp: datetime + dimensions: Tuple[int, int] + + local_doc = Doc( + { + "timestamp": Text(), + "dimensions": Array(), + }, + Model=Delivery, + ) + local_doc.apply_update(update) + + remote_doc["dimensions"][1] = "a" # "a" is not an int + update = remote_doc.get_update() + with pytest.raises(ValidationError) as exc_info: + local_doc.apply_update(update) + assert str(exc_info.value).startswith("1 validation error for Delivery\ndimensions.1\n") + + remote_doc["timestamp"][6] = "0" # invalid "00" month + update = remote_doc.get_update() + with pytest.raises(ValidationError) as exc_info: + local_doc.apply_update(update) + assert str(exc_info.value).startswith("2 validation errors for Delivery\n") + + remote_doc["dimensions"][1] = "30" # revert invalid change, and make a change + update = remote_doc.get_update() + with pytest.raises(ValidationError) as exc_info: + local_doc.apply_update(update) + assert str(exc_info.value).startswith("1 validation error for Delivery\ntimestamp\n") + + remote_doc["timestamp"][6] = "2" # revert invalid change, and make a change + update = remote_doc.get_update() + local_doc.apply_update(update) + + assert str(local_doc["timestamp"]) == "2020-02-02T03:04:05Z" + assert list(local_doc["dimensions"]) == ["10", "30"]