Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

YMap Items View #65

Merged
merged 2 commits into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 175 additions & 18 deletions src/y_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ impl YMap {
entry.ok_or_else(|| PyKeyError::new_err(format!("{key}")))
}

/// Returns an iterator that can be used to traverse over all entries stored within this
/// Returns an item view that can be used to traverse over all entries stored within this
/// instance of `YMap`. Order of entry is not specified.
///
/// Example:
Expand All @@ -199,23 +199,23 @@ impl YMap {
/// for (key, value) in map.entries(txn)):
/// print(key, value)
/// ```
pub fn items(&self) -> YMapIterator {
match &self.0 {
SharedType::Integrated(val) => unsafe {
let this: *const Map = val;
let shared_iter = InnerYMapIterator::Integrated((*this).iter());
YMapIterator(ManuallyDrop::new(shared_iter))
},
SharedType::Prelim(val) => unsafe {
let this: *const HashMap<String, PyObject> = val;
let shared_iter = InnerYMapIterator::Prelim((*this).iter());
YMapIterator(ManuallyDrop::new(shared_iter))
},
}
pub fn items(&self) -> ItemView {
ItemView(&self.0)
}

pub fn __iter__(&self) -> YMapKeyIterator {
YMapKeyIterator(self.items())
pub fn keys(&self) -> KeyView {
let inner: *const _ = &self.0;
KeyView(inner)
}

pub fn __iter__(&self) -> KeyIterator {
let inner: *const _ = &self.0;
KeyIterator(YMapIterator::from(inner))
}

pub fn values(&self) -> ValueView {
let inner: *const _ = &self.0;
ValueView(inner)
}

pub fn observe(&mut self, f: PyObject) -> PyResult<ShallowSubscription> {
Expand Down Expand Up @@ -267,6 +267,131 @@ impl YMap {
}
}

#[pyclass(unsendable)]
pub struct ItemView(*const SharedType<Map, HashMap<String, PyObject>>);

#[pymethods]
impl ItemView {
fn __iter__(slf: PyRef<Self>) -> YMapIterator {
YMapIterator::from(slf.0)
}

fn __len__(&self) -> usize {
unsafe {
match &*self.0 {
SharedType::Integrated(map) => map.len() as usize,
SharedType::Prelim(map) => map.len(),
}
}
}

fn __str__(&self) -> String {
let vals: String = YMapIterator::from(self.0)
.map(|(key, val)| format!("({key}, {val})"))
.collect::<Vec<String>>()
.join(", ");
format!("{{{vals}}}")
}

fn __repr__(&self) -> String {
let data = self.__str__();
format!("ItemView({data})")
}

fn __contains__(&self, el: PyObject) -> bool {
let kv: Result<(String, PyObject), _> = Python::with_gil(|py| el.extract(py));
kv.ok()
.and_then(|(key, value)| unsafe {
match &*self.0 {
SharedType::Integrated(map) if map.contains(&key) => map.get(&key).map(|v| {
Python::with_gil(|py| v.into_py(py).as_ref(py).eq(value)).unwrap_or(false)
}),
SharedType::Prelim(map) if map.contains_key(&key) => map
.get(&key)
.map(|v| Python::with_gil(|py| v.as_ref(py).eq(value).unwrap_or(false))),
_ => None,
}
})
.unwrap_or(false)
}
}

#[pyclass(unsendable)]
pub struct KeyView(*const SharedType<Map, HashMap<String, PyObject>>);

#[pymethods]
impl KeyView {
fn __iter__(slf: PyRef<Self>) -> KeyIterator {
KeyIterator(YMapIterator::from(slf.0))
}

fn __len__(&self) -> usize {
unsafe {
match &*self.0 {
SharedType::Integrated(map) => map.len() as usize,
SharedType::Prelim(map) => map.len(),
}
}
}

fn __str__(&self) -> String {
let vals: String = YMapIterator::from(self.0)
.map(|(key, _)| key)
.collect::<Vec<String>>()
.join(", ");
format!("{{{vals}}}")
}

fn __repr__(&self) -> String {
let data = self.__str__();
format!("KeyView({data})")
}

fn __contains__(&self, el: PyObject) -> bool {
let key: Result<String, _> = Python::with_gil(|py| el.extract(py));
key.ok()
.map(|key| unsafe {
match &*self.0 {
SharedType::Integrated(map) => map.contains(&key),
SharedType::Prelim(map) => map.contains_key(&key),
}
})
.unwrap_or(false)
}
}

#[pyclass(unsendable)]
pub struct ValueView(*const SharedType<Map, HashMap<String, PyObject>>);

#[pymethods]
impl ValueView {
fn __iter__(slf: PyRef<Self>) -> ValueIterator {
ValueIterator(YMapIterator::from(slf.0))
}

fn __len__(&self) -> usize {
unsafe {
match &*self.0 {
SharedType::Integrated(map) => map.len() as usize,
SharedType::Prelim(map) => map.len(),
}
}
}

fn __str__(&self) -> String {
let vals: String = YMapIterator::from(self.0)
.map(|(_, v)| v.to_string())
.collect::<Vec<String>>()
.join(", ");
format!("{{{vals}}}")
}

fn __repr__(&self) -> String {
let data = self.__str__();
format!("ValueView({data})")
}
}

pub enum InnerYMapIterator {
Integrated(MapIter<'static>),
Prelim(std::collections::hash_map::Iter<'static, String, PyObject>),
Expand All @@ -281,6 +406,25 @@ impl Drop for YMapIterator {
}
}

impl From<*const SharedType<Map, HashMap<String, PyObject>>> for YMapIterator {
fn from(inner_map_ptr: *const SharedType<Map, HashMap<String, PyObject>>) -> Self {
unsafe {
match &*inner_map_ptr {
SharedType::Integrated(val) => {
let this: *const Map = val;
let shared_iter = InnerYMapIterator::Integrated((*this).iter());
YMapIterator(ManuallyDrop::new(shared_iter))
}
SharedType::Prelim(val) => {
let this: *const HashMap<String, PyObject> = val;
let shared_iter = InnerYMapIterator::Prelim((*this).iter());
YMapIterator(ManuallyDrop::new(shared_iter))
}
}
}
}
}

impl Iterator for YMapIterator {
type Item = (String, PyObject);

Expand All @@ -305,10 +449,10 @@ impl YMapIterator {
}

#[pyclass(unsendable)]
pub struct YMapKeyIterator(YMapIterator);
pub struct KeyIterator(YMapIterator);

#[pymethods]
impl YMapKeyIterator {
impl KeyIterator {
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
slf
}
Expand All @@ -317,6 +461,19 @@ impl YMapKeyIterator {
}
}

#[pyclass(unsendable)]
pub struct ValueIterator(YMapIterator);

#[pymethods]
impl ValueIterator {
fn __iter__(slf: PyRef<Self>) -> PyRef<Self> {
slf
}
fn __next__(mut slf: PyRefMut<Self>) -> Option<PyObject> {
slf.0.next().map(|(_, v)| v)
}
}

/// Event generated by `YMap.observe` method. Emitted during transaction commit phase.
#[pyclass(unsendable)]
pub struct YMapEvent {
Expand Down
75 changes: 60 additions & 15 deletions tests/test_y_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@


def test_get():
import y_py as Y

d = Y.YDoc()
m = d.get_map("map")

Expand Down Expand Up @@ -108,24 +110,67 @@ def test_pop():
assert value == "value2"


def test_iterator():
def test_items_view():
d = Y.YDoc()
x = d.get_map("test")
m = d.get_map("test")

with d.begin_transaction() as txn:
x.set(txn, "a", 1)
x.set(txn, "b", 2)
x.set(txn, "c", 3)
expected = {"a": 1, "b": 2, "c": 3}
for (key, val) in x.items():
v = expected[key]
assert val == v
del expected[key]

expected = {"a": 1, "b": 2, "c": 3}
for key in x:
assert key in expected
assert key in x
vals = {"a": 1, "b": 2, "c": 3}
m.update(txn, vals)
items = m.items()
# Ensure that the item view is a multi use iterator
for _ in range(2):
expected = vals.copy()
for (key, val) in items:
v = expected[key]
assert val == v
del expected[key]

assert len(items) == 3
assert ("b", 2) in items

# Ensure that the item view stays up to date with map state
m.set(txn, "d", 4)
assert ("d", 4) in items


def test_keys_values():
d = Y.YDoc()
m = d.get_map("test")
expected_keys = list("abc")
expected_values = list(range(1, 4))
with d.begin_transaction() as txn:
m.update(txn, zip(expected_keys, expected_values))

# Ensure basic iteration works
for key in m:
assert key in expected_keys
assert key in m

# Ensure keys can be iterated over multiple times
keys = m.keys()
for _ in range(2):
for key in keys:
assert key in expected_keys
assert key in keys

values = m.values()

for _ in range(2):
for val in values:
assert val in expected_values
assert val in values

# Ensure keys and values reflect updates to map
with d.begin_transaction() as txn:
m.set(txn, "d", 4)

assert "d" in keys
assert 4 in values

# Ensure key view operations
assert len(keys) == 4
assert len(values) == 4


def test_observer():
Expand Down
Loading