Skip to content

Commit

Permalink
Add tests for FromPyObject implementation for HashSet/BTreeSet
Browse files Browse the repository at this point in the history
  • Loading branch information
kngwyu committed Mar 30, 2020
1 parent 85de698 commit 810d3b6
Showing 1 changed file with 37 additions and 19 deletions.
56 changes: 37 additions & 19 deletions src/types/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
use crate::err::{self, PyErr, PyResult};
use crate::internal_tricks::Unsendable;
use crate::{
ffi, AsPyPointer, FromPyObject, PyAny, PyNativeType, PyObject, PyTryFrom, Python,
ToBorrowedObject, ToPyObject,
ffi, AsPyPointer, FromPyObject, PyAny, PyNativeType, PyObject, Python, ToBorrowedObject,
ToPyObject,
};
use std::cmp;
use std::collections::{BTreeSet, HashSet};
Expand Down Expand Up @@ -189,27 +189,19 @@ where
K: FromPyObject<'source> + cmp::Eq + hash::Hash,
S: hash::BuildHasher + Default,
{
fn extract(ob: &'source PyAny) -> Result<Self, PyErr> {
let set = <PySet as PyTryFrom>::try_from(ob)?;
let mut ret = HashSet::default();
for k in set.iter() {
ret.insert(K::extract(k)?);
}
Ok(ret)
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let set: &PySet = ob.downcast()?;
set.iter().map(K::extract).collect()
}
}

impl<'source, K> FromPyObject<'source> for BTreeSet<K>
where
K: FromPyObject<'source> + cmp::Ord,
{
fn extract(ob: &'source PyAny) -> Result<Self, PyErr> {
let set = <PySet as PyTryFrom>::try_from(ob)?;
let mut ret = BTreeSet::default();
for k in set.iter() {
ret.insert(K::extract(k)?);
}
Ok(ret)
fn extract(ob: &'source PyAny) -> PyResult<Self> {
let set: &PySet = ob.downcast()?;
set.iter().map(K::extract).collect()
}
}

Expand Down Expand Up @@ -279,9 +271,9 @@ impl<'a> std::iter::IntoIterator for &'a PyFrozenSet {
#[cfg(test)]
mod test {
use super::{PyFrozenSet, PySet};
use crate::instance::AsPyRef;
use crate::{ObjectProtocol, PyTryFrom, Python, ToPyObject};
use std::collections::HashSet;
use crate::{AsPyRef, ObjectProtocol, PyTryFrom, Python, ToPyObject};
use std::collections::{BTreeSet, HashSet};
use std::iter::FromIterator;

#[test]
fn test_set_new() {
Expand Down Expand Up @@ -433,4 +425,30 @@ mod test {
assert_eq!(1i32, el.extract::<i32>().unwrap());
}
}

#[test]
fn test_extract_hashset() {
let gil = Python::acquire_gil();
let py = gil.python();

let set = PySet::new(py, &[1, 2, 3, 4, 5]).unwrap();
let hash_set: HashSet<usize> = set.extract().unwrap();
assert_eq!(
hash_set,
HashSet::from_iter([1, 2, 3, 4, 5].iter().copied())
);
}

#[test]
fn test_extract_btreeset() {
let gil = Python::acquire_gil();
let py = gil.python();

let set = PySet::new(py, &[1, 2, 3, 4, 5]).unwrap();
let hash_set: BTreeSet<usize> = set.extract().unwrap();
assert_eq!(
hash_set,
BTreeSet::from_iter([1, 2, 3, 4, 5].iter().copied())
);
}
}

0 comments on commit 810d3b6

Please sign in to comment.