diff --git a/CHANGELOG.md b/CHANGELOG.md index c7be6095cb6..1648e4c0122 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `PyAny::py` as a convenience for `PyNativeType::py`. [#1751](https://github.com/PyO3/pyo3/pull/1751) - Add implementation of `std::ops::Index` for `PyList`, `PyTuple` and `PySequence`. [#1825](https://github.com/PyO3/pyo3/pull/1825) - Add range indexing implementations of `std::ops::Index` for `PyList`, `PyTuple` and `PySequence`. [#1829](https://github.com/PyO3/pyo3/pull/1829) +- Add `PyMapping` type to represent the Python mapping protocol. [#1844](https://github.com/PyO3/pyo3/pull/1844) - Add commonly-used sequence methods to `PyList` and `PyTuple`. [#1849](https://github.com/PyO3/pyo3/pull/1849) - Add `as_sequence` methods to `PyList` and `PyTuple`. [#1860](https://github.com/PyO3/pyo3/pull/1860) diff --git a/guide/src/conversions/tables.md b/guide/src/conversions/tables.md index 7a8eaccfa97..6baf7dca677 100644 --- a/guide/src/conversions/tables.md +++ b/guide/src/conversions/tables.md @@ -35,6 +35,7 @@ The table below contains the Python type and the corresponding function argument | `datetime.timedelta` | - | `&PyDelta` | | `typing.Optional[T]` | `Option` | - | | `typing.Sequence[T]` | `Vec` | `&PySequence` | +| `typing.Mapping[K, V]` | `HashMap`, `BTreeMap`, `hashbrown::HashMap`[^2], `indexmap::IndexMap`[^3] | `&PyMapping` | | `typing.Iterator[Any]` | - | `&PyIterator` | | `typing.Union[...]` | See [`#[derive(FromPyObject)]`](traits.html#deriving-a-hrefhttpsdocsrspyo3latestpyo3conversiontraitfrompyobjecthtmlfrompyobjecta-for-enums) | - | diff --git a/src/ffi/abstract_.rs b/src/ffi/abstract_.rs index bb38f8ae367..2efe1d3c5d7 100644 --- a/src/ffi/abstract_.rs +++ b/src/ffi/abstract_.rs @@ -78,7 +78,9 @@ extern "C" { pub fn PyObject_GetItem(o: *mut PyObject, key: *mut PyObject) -> *mut PyObject; #[cfg_attr(PyPy, link_name = "PyPyObject_SetItem")] pub fn PyObject_SetItem(o: *mut PyObject, key: *mut PyObject, v: *mut PyObject) -> c_int; + #[cfg_attr(PyPy, link_name = "PyPyObject_DelItemString")] pub fn PyObject_DelItemString(o: *mut PyObject, key: *const c_char) -> c_int; + #[cfg_attr(PyPy, link_name = "PyPyObject_DelItem")] pub fn PyObject_DelItem(o: *mut PyObject, key: *mut PyObject) -> c_int; } @@ -300,6 +302,7 @@ pub unsafe fn PyMapping_DelItem(o: *mut PyObject, key: *mut PyObject) -> c_int { extern "C" { #[cfg_attr(PyPy, link_name = "PyPyMapping_HasKeyString")] pub fn PyMapping_HasKeyString(o: *mut PyObject, key: *const c_char) -> c_int; + #[cfg_attr(PyPy, link_name = "PyPyMapping_HasKey")] pub fn PyMapping_HasKey(o: *mut PyObject, key: *mut PyObject) -> c_int; #[cfg_attr(PyPy, link_name = "PyPyMapping_Keys")] pub fn PyMapping_Keys(o: *mut PyObject) -> *mut PyObject; diff --git a/src/types/dict.rs b/src/types/dict.rs index 89e29bb4d67..112630e0364 100644 --- a/src/types/dict.rs +++ b/src/types/dict.rs @@ -12,6 +12,8 @@ use std::collections::{BTreeMap, HashMap}; use std::ptr::NonNull; use std::{cmp, collections, hash}; +use super::PyMapping; + /// Represents a Python `dict`. #[repr(transparent)] pub struct PyDict(PyAny); @@ -178,6 +180,11 @@ impl PyDict { pos: 0, } } + + /// Returns `self` cast as a `PyMapping`. + pub fn as_mapping(&self) -> &PyMapping { + unsafe { PyMapping::try_from_unchecked(self) } + } } pub struct PyDictIterator<'py> { @@ -762,4 +769,25 @@ mod tests { assert_eq!(py_map.get_item("b").unwrap().extract::().unwrap(), 2); }); } + + #[test] + fn dict_as_mapping() { + Python::with_gil(|py| { + let mut map = HashMap::::new(); + map.insert(1, 1); + + let py_map = map.into_py_dict(py); + + assert_eq!(py_map.as_mapping().len().unwrap(), 1); + assert_eq!( + py_map + .as_mapping() + .get_item(1) + .unwrap() + .extract::() + .unwrap(), + 1 + ); + }); + } } diff --git a/src/types/mapping.rs b/src/types/mapping.rs new file mode 100644 index 00000000000..ca907f67a4c --- /dev/null +++ b/src/types/mapping.rs @@ -0,0 +1,259 @@ +// Copyright (c) 2017-present PyO3 Project and Contributors + +use crate::err::{PyDowncastError, PyErr, PyResult}; +use crate::types::{PyAny, PySequence}; +use crate::AsPyPointer; +use crate::{ffi, ToPyObject}; +use crate::{PyTryFrom, ToBorrowedObject}; + +/// Represents a reference to a Python object supporting the mapping protocol. +#[repr(transparent)] +pub struct PyMapping(PyAny); +pyobject_native_type_named!(PyMapping); +pyobject_native_type_extract!(PyMapping); + +impl PyMapping { + /// Returns the number of objects in the mapping. + /// + /// This is equivalent to the Python expression `len(self)`. + #[inline] + pub fn len(&self) -> PyResult { + let v = unsafe { ffi::PyMapping_Size(self.as_ptr()) }; + if v == -1 { + Err(PyErr::api_call_failed(self.py())) + } else { + Ok(v as usize) + } + } + + /// Returns whether the mapping is empty. + #[inline] + pub fn is_empty(&self) -> PyResult { + self.len().map(|l| l == 0) + } + + /// Gets the item in self with key `key`. + /// + /// Returns an `Err` if the item with specified key is not found, usually `KeyError`. + /// + /// This is equivalent to the Python expression `self[key]`. + #[inline] + pub fn get_item(&self, key: K) -> PyResult<&PyAny> + where + K: ToBorrowedObject, + { + PyAny::get_item(self, key) + } + + /// Sets the item in self with key `key`. + /// + /// This is equivalent to the Python expression `self[key] = value`. + #[inline] + pub fn set_item(&self, key: K, value: V) -> PyResult<()> + where + K: ToPyObject, + V: ToPyObject, + { + PyAny::set_item(self, key, value) + } + + /// Deletes the item with key `key`. + /// + /// This is equivalent to the Python statement `del self[key]`. + #[inline] + pub fn del_item(&self, key: K) -> PyResult<()> + where + K: ToBorrowedObject, + { + PyAny::del_item(self, key) + } + + /// Returns a sequence containing all keys in the mapping. + #[inline] + pub fn keys(&self) -> PyResult<&PySequence> { + unsafe { + self.py() + .from_owned_ptr_or_err(ffi::PyMapping_Keys(self.as_ptr())) + } + } + + /// Returns a sequence containing all values in the mapping. + #[inline] + pub fn values(&self) -> PyResult<&PySequence> { + unsafe { + self.py() + .from_owned_ptr_or_err(ffi::PyMapping_Values(self.as_ptr())) + } + } + + /// Returns a sequence of tuples of all (key, value) pairs in the mapping. + #[inline] + pub fn items(&self) -> PyResult<&PySequence> { + unsafe { + self.py() + .from_owned_ptr_or_err(ffi::PyMapping_Items(self.as_ptr())) + } + } +} + +impl<'v> PyTryFrom<'v> for PyMapping { + fn try_from>(value: V) -> Result<&'v PyMapping, PyDowncastError<'v>> { + let value = value.into(); + unsafe { + if ffi::PyMapping_Check(value.as_ptr()) != 0 { + Ok(::try_from_unchecked(value)) + } else { + Err(PyDowncastError::new(value, "Mapping")) + } + } + } + + #[inline] + fn try_from_exact>(value: V) -> Result<&'v PyMapping, PyDowncastError<'v>> { + ::try_from(value) + } + + #[inline] + unsafe fn try_from_unchecked>(value: V) -> &'v PyMapping { + let ptr = value.into() as *const _ as *const PyMapping; + &*ptr + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use crate::{exceptions::PyKeyError, types::PyTuple, Python}; + + use super::*; + + #[test] + fn test_len() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + assert_eq!(0, mapping.len().unwrap()); + assert!(mapping.is_empty().unwrap()); + + v.insert(7, 32); + let ob = v.to_object(py); + let mapping2 = ::try_from(ob.as_ref(py)).unwrap(); + assert_eq!(1, mapping2.len().unwrap()); + assert!(!mapping2.is_empty().unwrap()); + }); + } + + #[test] + fn test_get_item() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + assert_eq!( + 32, + mapping.get_item(7i32).unwrap().extract::().unwrap() + ); + assert!(mapping + .get_item(8i32) + .unwrap_err() + .is_instance::(py)); + }); + } + + #[test] + fn test_set_item() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + assert!(mapping.set_item(7i32, 42i32).is_ok()); // change + assert!(mapping.set_item(8i32, 123i32).is_ok()); // insert + assert_eq!( + 42i32, + mapping.get_item(7i32).unwrap().extract::().unwrap() + ); + assert_eq!( + 123i32, + mapping.get_item(8i32).unwrap().extract::().unwrap() + ); + }); + } + + #[test] + fn test_del_item() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + assert!(mapping.del_item(7i32).is_ok()); + assert_eq!(0, mapping.len().unwrap()); + assert!(mapping + .get_item(7i32) + .unwrap_err() + .is_instance::(py)); + }); + } + + #[test] + fn test_items() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + v.insert(8, 42); + v.insert(9, 123); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + // Can't just compare against a vector of tuples since we don't have a guaranteed ordering. + let mut key_sum = 0; + let mut value_sum = 0; + for el in mapping.items().unwrap().iter().unwrap() { + let tuple = el.unwrap().cast_as::().unwrap(); + key_sum += tuple.get_item(0).unwrap().extract::().unwrap(); + value_sum += tuple.get_item(1).unwrap().extract::().unwrap(); + } + assert_eq!(7 + 8 + 9, key_sum); + assert_eq!(32 + 42 + 123, value_sum); + }); + } + + #[test] + fn test_keys() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + v.insert(8, 42); + v.insert(9, 123); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + // Can't just compare against a vector of tuples since we don't have a guaranteed ordering. + let mut key_sum = 0; + for el in mapping.keys().unwrap().iter().unwrap() { + key_sum += el.unwrap().extract::().unwrap(); + } + assert_eq!(7 + 8 + 9, key_sum); + }); + } + + #[test] + fn test_values() { + Python::with_gil(|py| { + let mut v = HashMap::new(); + v.insert(7, 32); + v.insert(8, 42); + v.insert(9, 123); + let ob = v.to_object(py); + let mapping = ::try_from(ob.as_ref(py)).unwrap(); + // Can't just compare against a vector of tuples since we don't have a guaranteed ordering. + let mut values_sum = 0; + for el in mapping.values().unwrap().iter().unwrap() { + values_sum += el.unwrap().extract::().unwrap(); + } + assert_eq!(32 + 42 + 123, values_sum); + }); + } +} diff --git a/src/types/mod.rs b/src/types/mod.rs index 3c82b2e26f7..cb1d91fa9da 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -17,6 +17,7 @@ pub use self::floatob::PyFloat; pub use self::function::{PyCFunction, PyFunction}; pub use self::iterator::PyIterator; pub use self::list::PyList; +pub use self::mapping::PyMapping; pub use self::module::PyModule; pub use self::num::PyLong; pub use self::num::PyLong as PyInt; @@ -231,6 +232,7 @@ mod floatob; mod function; mod iterator; mod list; +mod mapping; mod module; mod num; mod sequence;