Skip to content

Commit

Permalink
Migrate some conversions to extract_bound
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Jan 28, 2024
1 parent 595ca4b commit ffaa03e
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 75 deletions.
13 changes: 9 additions & 4 deletions src/conversions/num_complex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@
//! result = get_eigenvalues(m11,m12,m21,m22)
//! assert result == [complex(1,-1), complex(-2,0)]
//! ```
#[cfg(any(Py_LIMITED_API, PyPy))]
use crate::types::any::PyAnyMethods;
use crate::{
ffi, types::PyComplex, FromPyObject, PyAny, PyErr, PyObject, PyResult, Python, ToPyObject,
ffi, types::PyComplex, Bound, FromPyObject, PyAny, PyErr, PyObject, PyResult, Python,
ToPyObject,
};
use num_complex::Complex;
use std::os::raw::c_double;
Expand Down Expand Up @@ -131,8 +134,8 @@ macro_rules! complex_conversion {
}

#[cfg_attr(docsrs, doc(cfg(feature = "num-complex")))]
impl<'source> FromPyObject<'source> for Complex<$float> {
fn extract(obj: &'source PyAny) -> PyResult<Complex<$float>> {
impl FromPyObject<'_> for Complex<$float> {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Complex<$float>> {
#[cfg(not(any(Py_LIMITED_API, PyPy)))]
unsafe {
let val = ffi::PyComplex_AsCComplex(obj.as_ptr());
Expand All @@ -146,12 +149,14 @@ macro_rules! complex_conversion {

#[cfg(any(Py_LIMITED_API, PyPy))]
unsafe {
let complex;
let obj = if obj.is_instance_of::<PyComplex>() {
obj
} else if let Some(method) =
obj.lookup_special(crate::intern!(obj.py(), "__complex__"))?
{
method.call0()?
complex = method.call0()?;
&complex
} else {
// `obj` might still implement `__float__` or `__index__`, which will be
// handled by `PyComplex_{Real,Imag}AsDouble`, including propagating any
Expand Down
7 changes: 2 additions & 5 deletions src/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1425,11 +1425,8 @@ where
T: PyTypeInfo,
{
/// Extracts `Self` from the source `PyObject`.
fn extract(ob: &'a PyAny) -> PyResult<Self> {
ob.as_borrowed()
.downcast()
.map(Clone::clone)
.map_err(Into::into)
fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
ob.downcast().map(Clone::clone).map_err(Into::into)
}
}

Expand Down
123 changes: 67 additions & 56 deletions src/types/any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,51 +136,6 @@ impl PyAny {
.map(Bound::into_gil_ref)
}

/// Retrieve an attribute value, skipping the instance dictionary during the lookup but still
/// binding the object to the instance.
///
/// This is useful when trying to resolve Python's "magic" methods like `__getitem__`, which
/// are looked up starting from the type object. This returns an `Option` as it is not
/// typically a direct error for the special lookup to fail, as magic methods are optional in
/// many situations in which they might be called.
///
/// To avoid repeated temporary allocations of Python strings, the [`intern!`] macro can be used
/// to intern `attr_name`.
#[allow(dead_code)] // Currently only used with num-complex+abi3, so dead without that.
pub(crate) fn lookup_special<N>(&self, attr_name: N) -> PyResult<Option<&PyAny>>
where
N: IntoPy<Py<PyString>>,
{
let py = self.py();
let self_type = self.get_type();
let attr = if let Ok(attr) = self_type.getattr(attr_name) {
attr
} else {
return Ok(None);
};

// Manually resolve descriptor protocol.
if cfg!(Py_3_10)
|| unsafe { ffi::PyType_HasFeature(attr.get_type_ptr(), ffi::Py_TPFLAGS_HEAPTYPE) } != 0
{
// This is the preferred faster path, but does not work on static types (generally,
// types defined in extension modules) before Python 3.10.
unsafe {
let descr_get_ptr = ffi::PyType_GetSlot(attr.get_type_ptr(), ffi::Py_tp_descr_get);
if descr_get_ptr.is_null() {
return Ok(Some(attr));
}
let descr_get: ffi::descrgetfunc = std::mem::transmute(descr_get_ptr);
let ret = descr_get(attr.as_ptr(), self.as_ptr(), self_type.as_ptr());
py.from_owned_ptr_or_err(ret).map(Some)
}
} else if let Ok(descr_get) = attr.get_type().getattr(crate::intern!(py, "__get__")) {
descr_get.call1((attr, self, self_type)).map(Some)
} else {
Ok(Some(attr))
}
}

/// Sets an attribute value.
///
/// This is equivalent to the Python expression `self.attr_name = value`.
Expand Down Expand Up @@ -1666,9 +1621,9 @@ pub trait PyAnyMethods<'py> {
/// Extracts some type from the Python object.
///
/// This is a wrapper function around [`FromPyObject::extract()`].
fn extract<'a, D>(&'a self) -> PyResult<D>
fn extract<D>(&self) -> PyResult<D>
where
D: FromPyObject<'a>;
D: FromPyObject<'py>;

/// Returns the reference count for the Python object.
fn get_refcnt(&self) -> isize;
Expand Down Expand Up @@ -2202,11 +2157,11 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> {
std::mem::transmute(self)
}

fn extract<'a, D>(&'a self) -> PyResult<D>
fn extract<D>(&self) -> PyResult<D>
where
D: FromPyObject<'a>,
D: FromPyObject<'py>,
{
FromPyObject::extract(self.as_gil_ref())
FromPyObject::extract_bound(self)
}

fn get_refcnt(&self) -> isize {
Expand Down Expand Up @@ -2293,13 +2248,64 @@ impl<'py> PyAnyMethods<'py> for Bound<'py, PyAny> {
}
}

impl<'py> Bound<'py, PyAny> {
/// Retrieve an attribute value, skipping the instance dictionary during the lookup but still
/// binding the object to the instance.
///
/// This is useful when trying to resolve Python's "magic" methods like `__getitem__`, which
/// are looked up starting from the type object. This returns an `Option` as it is not
/// typically a direct error for the special lookup to fail, as magic methods are optional in
/// many situations in which they might be called.
///
/// To avoid repeated temporary allocations of Python strings, the [`intern!`] macro can be used
/// to intern `attr_name`.
#[allow(dead_code)] // Currently only used with num-complex+abi3, so dead without that.
pub(crate) fn lookup_special<N>(&self, attr_name: N) -> PyResult<Option<Bound<'py, PyAny>>>
where
N: IntoPy<Py<PyString>>,
{
let py = self.py();
let self_type = self.get_type().as_borrowed();
let attr = if let Ok(attr) = self_type.getattr(attr_name) {
attr
} else {
return Ok(None);
};

// Manually resolve descriptor protocol.
if cfg!(Py_3_10)
|| unsafe { ffi::PyType_HasFeature(attr.get_type_ptr(), ffi::Py_TPFLAGS_HEAPTYPE) } != 0
{
// This is the preferred faster path, but does not work on static types (generally,
// types defined in extension modules) before Python 3.10.
unsafe {
let descr_get_ptr = ffi::PyType_GetSlot(attr.get_type_ptr(), ffi::Py_tp_descr_get);
if descr_get_ptr.is_null() {
return Ok(Some(attr));
}
let descr_get: ffi::descrgetfunc = std::mem::transmute(descr_get_ptr);
let ret = descr_get(attr.as_ptr(), self.as_ptr(), self_type.as_ptr());
ret.assume_owned_or_err(py).map(Some)
}
} else if let Ok(descr_get) = attr
.get_type()
.as_borrowed()
.getattr(crate::intern!(py, "__get__"))
{
descr_get.call1((attr, self, self_type)).map(Some)
} else {
Ok(Some(attr))
}
}
}

#[cfg(test)]
#[cfg_attr(not(feature = "gil-refs"), allow(deprecated))]
mod tests {
use crate::{
basic::CompareOp,
types::{IntoPyDict, PyAny, PyBool, PyList, PyLong, PyModule},
PyTypeInfo, Python, ToPyObject,
types::{any::PyAnyMethods, IntoPyDict, PyAny, PyBool, PyList, PyLong, PyModule},
PyNativeType, PyTypeInfo, Python, ToPyObject,
};

#[test]
Expand Down Expand Up @@ -2344,8 +2350,13 @@ class NonHeapNonDescriptorInt:
.unwrap();

let int = crate::intern!(py, "__int__");
let eval_int =
|obj: &PyAny| obj.lookup_special(int)?.unwrap().call0()?.extract::<u32>();
let eval_int = |obj: &PyAny| {
obj.as_borrowed()
.lookup_special(int)?
.unwrap()
.call0()?
.extract::<u32>()
};

let simple = module.getattr("SimpleInt").unwrap().call0().unwrap();
assert_eq!(eval_int(simple).unwrap(), 1);
Expand All @@ -2354,7 +2365,7 @@ class NonHeapNonDescriptorInt:
let no_descriptor = module.getattr("NoDescriptorInt").unwrap().call0().unwrap();
assert_eq!(eval_int(no_descriptor).unwrap(), 1);
let missing = module.getattr("NoInt").unwrap().call0().unwrap();
assert!(missing.lookup_special(int).unwrap().is_none());
assert!(missing.as_borrowed().lookup_special(int).unwrap().is_none());
// Note the instance override should _not_ call the instance method that returns 2,
// because that's not how special lookups are meant to work.
let instance_override = module.getattr("instance_override").unwrap();
Expand All @@ -2364,7 +2375,7 @@ class NonHeapNonDescriptorInt:
.unwrap()
.call0()
.unwrap();
assert!(descriptor_error.lookup_special(int).is_err());
assert!(descriptor_error.as_borrowed().lookup_special(int).is_err());
let nonheap_nondescriptor = module
.getattr("NonHeapNonDescriptorInt")
.unwrap()
Expand Down
10 changes: 6 additions & 4 deletions src/types/boolobject.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ use crate::{
PyObject, PyResult, Python, ToPyObject,
};

use super::any::PyAnyMethods;

/// Represents a Python `bool`.
#[repr(transparent)]
pub struct PyBool(PyAny);
Expand Down Expand Up @@ -75,8 +77,8 @@ impl IntoPy<PyObject> for bool {
/// Converts a Python `bool` to a Rust `bool`.
///
/// Fails with `TypeError` if the input is not a Python `bool`.
impl<'source> FromPyObject<'source> for bool {
fn extract(obj: &'source PyAny) -> PyResult<Self> {
impl FromPyObject<'_> for bool {
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
let err = match obj.downcast::<PyBool>() {
Ok(obj) => return Ok(obj.is_true()),
Err(err) => err,
Expand All @@ -87,7 +89,7 @@ impl<'source> FromPyObject<'source> for bool {
.name()
.map_or(false, |name| name == "numpy.bool_")
{
let missing_conversion = |obj: &PyAny| {
let missing_conversion = |obj: &Bound<'_, PyAny>| {
PyTypeError::new_err(format!(
"object of type '{}' does not define a '__bool__' conversion",
obj.get_type()
Expand Down Expand Up @@ -117,7 +119,7 @@ impl<'source> FromPyObject<'source> for bool {
.lookup_special(crate::intern!(obj.py(), "__bool__"))?
.ok_or_else(|| missing_conversion(obj))?;

let obj = meth.call0()?.downcast::<PyBool>()?;
let obj = meth.call0()?.downcast_into::<PyBool>()?;
return Ok(obj.is_true());
}
}
Expand Down
12 changes: 6 additions & 6 deletions src/types/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -917,13 +917,13 @@ mod tests {

assert_eq!(iter.size_hint(), (3, Some(3)));

assert_eq!(1_i32, iter.next().unwrap().extract::<'_, i32>().unwrap());
assert_eq!(1, iter.next().unwrap().extract::<i32>().unwrap());
assert_eq!(iter.size_hint(), (2, Some(2)));

assert_eq!(2_i32, iter.next().unwrap().extract::<'_, i32>().unwrap());
assert_eq!(2, iter.next().unwrap().extract::<i32>().unwrap());
assert_eq!(iter.size_hint(), (1, Some(1)));

assert_eq!(3_i32, iter.next().unwrap().extract::<'_, i32>().unwrap());
assert_eq!(3, iter.next().unwrap().extract::<i32>().unwrap());
assert_eq!(iter.size_hint(), (0, Some(0)));

assert!(iter.next().is_none());
Expand All @@ -940,13 +940,13 @@ mod tests {

assert_eq!(iter.size_hint(), (3, Some(3)));

assert_eq!(3_i32, iter.next().unwrap().extract::<'_, i32>().unwrap());
assert_eq!(3, iter.next().unwrap().extract::<i32>().unwrap());
assert_eq!(iter.size_hint(), (2, Some(2)));

assert_eq!(2_i32, iter.next().unwrap().extract::<'_, i32>().unwrap());
assert_eq!(2, iter.next().unwrap().extract::<i32>().unwrap());
assert_eq!(iter.size_hint(), (1, Some(1)));

assert_eq!(1_i32, iter.next().unwrap().extract::<'_, i32>().unwrap());
assert_eq!(1, iter.next().unwrap().extract::<i32>().unwrap());
assert_eq!(iter.size_hint(), (0, Some(0)));

assert!(iter.next().is_none());
Expand Down

0 comments on commit ffaa03e

Please sign in to comment.