Skip to content

Commit

Permalink
Merge pull request #3311 from davidhewitt/frozen-receiver-error
Browse files Browse the repository at this point in the history
improve error span for mutable access to `#[pyclass(frozen)]`
  • Loading branch information
davidhewitt authored Jul 12, 2023
2 parents 398a33e + 56b7c38 commit bb05896
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 65 deletions.
76 changes: 47 additions & 29 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use quote::ToTokens;
use quote::{quote, quote_spanned};
use syn::ext::IdentExt;
use syn::spanned::Spanned;
use syn::Result;
use syn::{Result, Token};

#[derive(Clone, Debug)]
pub struct FnArg<'a> {
Expand Down Expand Up @@ -136,7 +136,7 @@ impl FnType {

#[derive(Clone, Debug)]
pub enum SelfType {
Receiver { mutable: bool },
Receiver { mutable: bool, span: Span },
TryFromPyCell(Span),
}

Expand All @@ -146,38 +146,55 @@ pub enum ExtractErrorMode {
Raise,
}

impl ExtractErrorMode {
pub fn handle_error(self, py: &syn::Ident, extract: TokenStream) -> TokenStream {
match self {
ExtractErrorMode::Raise => quote! { #extract? },
ExtractErrorMode::NotImplemented => quote! {
match #extract {
::std::result::Result::Ok(value) => value,
::std::result::Result::Err(_) => { return _pyo3::callback::convert(#py, #py.NotImplemented()); },
}
},
}
}
}

impl SelfType {
pub fn receiver(&self, cls: &syn::Type, error_mode: ExtractErrorMode) -> TokenStream {
let cell = match error_mode {
ExtractErrorMode::Raise => {
quote! { _py.from_borrowed_ptr::<_pyo3::PyAny>(_slf).downcast::<_pyo3::PyCell<#cls>>()? }
}
ExtractErrorMode::NotImplemented => {
quote! {
match _py.from_borrowed_ptr::<_pyo3::PyAny>(_slf).downcast::<_pyo3::PyCell<#cls>>() {
::std::result::Result::Ok(cell) => cell,
::std::result::Result::Err(_) => return _pyo3::callback::convert(_py, _py.NotImplemented()),
}
}
}
};
let py = syn::Ident::new("_py", Span::call_site());
let _slf = syn::Ident::new("_slf", Span::call_site());
match self {
SelfType::Receiver { mutable: false } => {
quote! {
let _cell = #cell;
let _ref = _cell.try_borrow()?;
let _slf: &#cls = &*_ref;
}
}
SelfType::Receiver { mutable: true } => {
quote! {
let _cell = #cell;
let mut _ref = _cell.try_borrow_mut()?;
let _slf: &mut #cls = &mut *_ref;
SelfType::Receiver { span, mutable } => {
let (method, mutability) = if *mutable {
(
quote_spanned! { *span => extract_pyclass_ref_mut },
Some(Token![mut](*span)),
)
} else {
(quote_spanned! { *span => extract_pyclass_ref }, None)
};
let extract = error_mode.handle_error(
&py,
quote_spanned! { *span =>
_pyo3::impl_::extract_argument::#method(
#py.from_borrowed_ptr::<_pyo3::PyAny>(#_slf),
&mut holder,
)
},
);
quote_spanned! { *span =>
let mut holder = _pyo3::impl_::extract_argument::FunctionArgumentHolder::INIT;
let #_slf: &#mutability #cls = #extract;
}
}
SelfType::TryFromPyCell(span) => {
let _slf = quote! { _slf };
let cell = error_mode.handle_error(
&py,
quote!{
_py.from_borrowed_ptr::<_pyo3::PyAny>(_slf).downcast::<_pyo3::PyCell<#cls>>()
}
);
quote_spanned! { *span =>
let _cell = #cell;
#[allow(clippy::useless_conversion)] // In case _slf is PyCell<Self>
Expand Down Expand Up @@ -247,8 +264,9 @@ pub fn parse_method_receiver(arg: &syn::FnArg) -> Result<SelfType> {
) => {
bail_spanned!(recv.span() => RECEIVER_BY_VALUE_ERR);
}
syn::FnArg::Receiver(syn::Receiver { mutability, .. }) => Ok(SelfType::Receiver {
syn::FnArg::Receiver(recv @ syn::Receiver { mutability, .. }) => Ok(SelfType::Receiver {
mutable: mutability.is_some(),
span: recv.span(),
}),
syn::FnArg::Typed(syn::PatType { ty, .. }) => {
if let syn::Type::ImplTrait(_) = &**ty {
Expand Down
7 changes: 5 additions & 2 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::borrow::Cow;

use crate::attributes::kw::frozen;
use crate::attributes::{
self, kw, take_pyo3_options, CrateAttribute, ExtendsAttribute, FreelistAttribute,
ModuleAttribute, NameAttribute, NameLitStr, TextSignatureAttribute,
Expand Down Expand Up @@ -355,7 +356,7 @@ fn impl_class(
cls,
args,
methods_type,
descriptors_to_items(cls, field_options)?,
descriptors_to_items(cls, args.options.frozen, field_options)?,
vec![],
)
.doc(doc)
Expand Down Expand Up @@ -674,6 +675,7 @@ fn extract_variant_data(variant: &mut syn::Variant) -> syn::Result<PyClassEnumVa

fn descriptors_to_items(
cls: &syn::Ident,
frozen: Option<frozen>,
field_options: Vec<(&syn::Field, FieldPyO3Options)>,
) -> syn::Result<Vec<MethodAndMethodDef>> {
let ty = syn::parse_quote!(#cls);
Expand All @@ -700,7 +702,8 @@ fn descriptors_to_items(
items.push(getter);
}

if options.set.is_some() {
if let Some(set) = options.set {
ensure_spanned!(frozen.is_none(), set.span() => "cannot use `#[pyo3(set)]` on a `frozen` class");
let setter = impl_py_setter_def(
&ty,
PropertyType::Descriptor {
Expand Down
37 changes: 11 additions & 26 deletions pyo3-macros-backend/src/pymethod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,9 +523,11 @@ pub fn impl_py_setter_def(
};

let slf = match property_type {
PropertyType::Descriptor { .. } => {
SelfType::Receiver { mutable: true }.receiver(cls, ExtractErrorMode::Raise)
PropertyType::Descriptor { .. } => SelfType::Receiver {
mutable: true,
span: Span::call_site(),
}
.receiver(cls, ExtractErrorMode::Raise),
PropertyType::Function { self_type, .. } => {
self_type.receiver(cls, ExtractErrorMode::Raise)
}
Expand Down Expand Up @@ -638,9 +640,11 @@ pub fn impl_py_getter_def(
};

let slf = match property_type {
PropertyType::Descriptor { .. } => {
SelfType::Receiver { mutable: false }.receiver(cls, ExtractErrorMode::Raise)
PropertyType::Descriptor { .. } => SelfType::Receiver {
mutable: false,
span: Span::call_site(),
}
.receiver(cls, ExtractErrorMode::Raise),
PropertyType::Function { self_type, .. } => {
self_type.receiver(cls, ExtractErrorMode::Raise)
}
Expand Down Expand Up @@ -949,8 +953,7 @@ impl Ty {
#ident.to_borrowed_any(#py)
},
),
Ty::CompareOp => handle_error(
extract_error_mode,
Ty::CompareOp => extract_error_mode.handle_error(
py,
quote! {
_pyo3::class::basic::CompareOp::from_raw(#ident)
Expand All @@ -959,8 +962,7 @@ impl Ty {
),
Ty::PySsizeT => {
let ty = arg.ty;
handle_error(
extract_error_mode,
extract_error_mode.handle_error(
py,
quote! {
::std::convert::TryInto::<#ty>::try_into(#ident).map_err(|e| _pyo3::exceptions::PyValueError::new_err(e.to_string()))
Expand All @@ -973,30 +975,13 @@ impl Ty {
}
}

fn handle_error(
extract_error_mode: ExtractErrorMode,
py: &syn::Ident,
extract: TokenStream,
) -> TokenStream {
match extract_error_mode {
ExtractErrorMode::Raise => quote! { #extract? },
ExtractErrorMode::NotImplemented => quote! {
match #extract {
::std::result::Result::Ok(value) => value,
::std::result::Result::Err(_) => { return _pyo3::callback::convert(#py, #py.NotImplemented()); },
}
},
}
}

fn extract_object(
extract_error_mode: ExtractErrorMode,
py: &syn::Ident,
name: &str,
source: TokenStream,
) -> TokenStream {
handle_error(
extract_error_mode,
extract_error_mode.handle_error(
py,
quote! {
_pyo3::impl_::extract_argument::extract_argument(
Expand Down
11 changes: 11 additions & 0 deletions tests/ui/invalid_frozen_pyclass_borrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ pub struct Foo {
field: u32,
}

#[pymethods]
impl Foo {
fn mut_method(&mut self) {}
}

fn borrow_mut_fails(foo: Py<Foo>, py: Python) {
let borrow = foo.as_ref(py).borrow_mut();
}
Expand All @@ -28,4 +33,10 @@ fn pyclass_get_of_mutable_class_fails(class: &PyCell<MutableBase>) {
class.get();
}

#[pyclass(frozen)]
pub struct SetOnFrozenClass {
#[pyo3(set)]
field: u32,
}

fn main() {}
34 changes: 26 additions & 8 deletions tests/ui/invalid_frozen_pyclass_borrow.stderr
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
error: cannot use `#[pyo3(set)]` on a `frozen` class
--> tests/ui/invalid_frozen_pyclass_borrow.rs:38:12
|
38 | #[pyo3(set)]
| ^^^

error[E0271]: type mismatch resolving `<Foo as PyClass>::Frozen == False`
--> tests/ui/invalid_frozen_pyclass_borrow.rs:11:19
|
11 | fn mut_method(&mut self) {}
| ^ expected `False`, found `True`
|
note: required by a bound in `extract_pyclass_ref_mut`
--> src/impl_/extract_argument.rs
|
| pub fn extract_pyclass_ref_mut<'a, 'py: 'a, T: PyClass<Frozen = False>>(
| ^^^^^^^^^^^^^^ required by this bound in `extract_pyclass_ref_mut`

error[E0271]: type mismatch resolving `<Foo as PyClass>::Frozen == False`
--> tests/ui/invalid_frozen_pyclass_borrow.rs:10:33
--> tests/ui/invalid_frozen_pyclass_borrow.rs:15:33
|
10 | let borrow = foo.as_ref(py).borrow_mut();
15 | let borrow = foo.as_ref(py).borrow_mut();
| ^^^^^^^^^^ expected `False`, found `True`
|
note: required by a bound in `pyo3::PyCell::<T>::borrow_mut`
Expand All @@ -11,9 +29,9 @@ note: required by a bound in `pyo3::PyCell::<T>::borrow_mut`
| ^^^^^^^^^^^^^^ required by this bound in `PyCell::<T>::borrow_mut`

error[E0271]: type mismatch resolving `<ImmutableChild as PyClass>::Frozen == False`
--> tests/ui/invalid_frozen_pyclass_borrow.rs:20:35
--> tests/ui/invalid_frozen_pyclass_borrow.rs:25:35
|
20 | let borrow = child.as_ref(py).borrow_mut();
25 | let borrow = child.as_ref(py).borrow_mut();
| ^^^^^^^^^^ expected `False`, found `True`
|
note: required by a bound in `pyo3::PyCell::<T>::borrow_mut`
Expand All @@ -23,9 +41,9 @@ note: required by a bound in `pyo3::PyCell::<T>::borrow_mut`
| ^^^^^^^^^^^^^^ required by this bound in `PyCell::<T>::borrow_mut`

error[E0271]: type mismatch resolving `<MutableBase as PyClass>::Frozen == True`
--> tests/ui/invalid_frozen_pyclass_borrow.rs:24:11
--> tests/ui/invalid_frozen_pyclass_borrow.rs:29:11
|
24 | class.get();
29 | class.get();
| ^^^ expected `True`, found `False`
|
note: required by a bound in `pyo3::Py::<T>::get`
Expand All @@ -35,9 +53,9 @@ note: required by a bound in `pyo3::Py::<T>::get`
| ^^^^^^^^^^^^^ required by this bound in `Py::<T>::get`

error[E0271]: type mismatch resolving `<MutableBase as PyClass>::Frozen == True`
--> tests/ui/invalid_frozen_pyclass_borrow.rs:28:11
--> tests/ui/invalid_frozen_pyclass_borrow.rs:33:11
|
28 | class.get();
33 | class.get();
| ^^^ expected `True`, found `False`
|
note: required by a bound in `pyo3::PyCell::<T>::get`
Expand Down

0 comments on commit bb05896

Please sign in to comment.