Skip to content

Commit

Permalink
Merge pull request #1864 from davidhewitt/pymethods-protos
Browse files Browse the repository at this point in the history
pymethods: add support for protocol methods
  • Loading branch information
davidhewitt authored Sep 24, 2021
2 parents 8744ee6 + 179b5d1 commit 9fa0abe
Show file tree
Hide file tree
Showing 10 changed files with 2,477 additions and 150 deletions.
1 change: 1 addition & 0 deletions guide/src/class.md
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ impl pyo3::class::impl_::PyClassImpl for MyClass {
visitor(collector.sequence_protocol_slots());
visitor(collector.async_protocol_slots());
visitor(collector.buffer_protocol_slots());
visitor(collector.methods_protocol_slots());
}

fn get_buffer() -> Option<&'static pyo3::class::impl_::PyBufferProcs> {
Expand Down
43 changes: 34 additions & 9 deletions pyo3-macros-backend/src/method.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,17 @@ pub enum FnType {
}

impl FnType {
pub fn self_conversion(&self, cls: Option<&syn::Type>) -> TokenStream {
pub fn self_conversion(
&self,
cls: Option<&syn::Type>,
error_mode: ExtractErrorMode,
) -> TokenStream {
match self {
FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) | FnType::FnCall(st) => {
st.receiver(cls.expect("no class given for Fn with a \"self\" receiver"))
}
FnType::Getter(st) | FnType::Setter(st) | FnType::Fn(st) | FnType::FnCall(st) => st
.receiver(
cls.expect("no class given for Fn with a \"self\" receiver"),
error_mode,
),
FnType::FnNew | FnType::FnStatic | FnType::ClassAttribute => {
quote!()
}
Expand Down Expand Up @@ -128,26 +134,45 @@ pub enum SelfType {
TryFromPyCell(Span),
}

#[derive(Clone, Copy)]
pub enum ExtractErrorMode {
NotImplemented,
Raise,
}

impl SelfType {
pub fn receiver(&self, cls: &syn::Type) -> TokenStream {
pub fn receiver(&self, cls: &syn::Type, error_mode: ExtractErrorMode) -> TokenStream {
let cell = match error_mode {
ExtractErrorMode::Raise => {
quote! { _py.from_borrowed_ptr::<::pyo3::PyAny>(_slf).downcast::<::pyo3::PyCell<#cls>>()? }
}
ExtractErrorMode::NotImplemented => {
quote! {
match _py.from_borrowed_ptr::<::pyo3::PyAny>(_slf).downcast::<::pyo3::PyCell<#cls>>() {
::std::result::Result::Ok(cell) => cell,
::std::result::Result::Err(_) => return ::pyo3::callback::convert(_py, _py.NotImplemented()),
}
}
}
};
match self {
SelfType::Receiver { mutable: false } => {
quote! {
let _cell = _py.from_borrowed_ptr::<::pyo3::PyCell<#cls>>(_slf);
let _cell = #cell;
let _ref = _cell.try_borrow()?;
let _slf = &_ref;
}
}
SelfType::Receiver { mutable: true } => {
quote! {
let _cell = _py.from_borrowed_ptr::<::pyo3::PyCell<#cls>>(_slf);
let _cell = #cell;
let mut _ref = _cell.try_borrow_mut()?;
let _slf = &mut _ref;
}
}
SelfType::TryFromPyCell(span) => {
quote_spanned! { *span =>
let _cell = _py.from_borrowed_ptr::<::pyo3::PyCell<#cls>>(_slf);
let _cell = #cell;
#[allow(clippy::useless_conversion)] // In case _slf is PyCell<Self>
let _slf = std::convert::TryFrom::try_from(_cell)?;
}
Expand Down Expand Up @@ -442,7 +467,7 @@ impl<'a> FnSpec<'a> {
cls: Option<&syn::Type>,
) -> Result<TokenStream> {
let deprecations = &self.deprecations;
let self_conversion = self.tp.self_conversion(cls);
let self_conversion = self.tp.self_conversion(cls, ExtractErrorMode::Raise);
let self_arg = self.tp.self_arg();
let arg_names = (0..self.args.len())
.map(|pos| syn::Ident::new(&format!("arg{}", pos), Span::call_site()))
Expand Down
8 changes: 8 additions & 0 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,13 @@ fn impl_class(
),
};

let methods_protos = match methods_type {
PyClassMethodsType::Specialization => {
Some(quote! { visitor(collector.methods_protocol_slots()); })
}
PyClassMethodsType::Inventory => None,
};

let base = &attr.base;
let base_nativetype = if attr.has_extends {
quote! { <Self::BaseType as ::pyo3::class::impl_::PyClassBaseType>::BaseNativeType }
Expand Down Expand Up @@ -591,6 +598,7 @@ fn impl_class(
visitor(collector.sequence_protocol_slots());
visitor(collector.async_protocol_slots());
visitor(collector.buffer_protocol_slots());
#methods_protos
}

fn get_buffer() -> ::std::option::Option<&'static ::pyo3::class::impl_::PyBufferProcs> {
Expand Down
83 changes: 75 additions & 8 deletions pyo3-macros-backend/src/pyimpl.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) 2017-present PyO3 Project and Contributors

use std::collections::HashSet;

use crate::{
konst::{ConstAttributes, ConstSpec},
pyfunction::PyFunctionOptions,
Expand Down Expand Up @@ -37,9 +39,12 @@ pub fn impl_methods(
impls: &mut Vec<syn::ImplItem>,
methods_type: PyClassMethodsType,
) -> syn::Result<TokenStream> {
let mut new_impls = Vec::new();
let mut call_impls = Vec::new();
let mut trait_impls = Vec::new();
let mut proto_impls = Vec::new();
let mut methods = Vec::new();

let mut implemented_proto_fragments = HashSet::new();

for iimpl in impls.iter_mut() {
match iimpl {
syn::ImplItem::Method(meth) => {
Expand All @@ -49,13 +54,18 @@ pub fn impl_methods(
let attrs = get_cfg_attributes(&meth.attrs);
methods.push(quote!(#(#attrs)* #token_stream));
}
GeneratedPyMethod::New(token_stream) => {
GeneratedPyMethod::TraitImpl(token_stream) => {
let attrs = get_cfg_attributes(&meth.attrs);
new_impls.push(quote!(#(#attrs)* #token_stream));
trait_impls.push(quote!(#(#attrs)* #token_stream));
}
GeneratedPyMethod::Call(token_stream) => {
GeneratedPyMethod::SlotTraitImpl(method_name, token_stream) => {
implemented_proto_fragments.insert(method_name);
let attrs = get_cfg_attributes(&meth.attrs);
call_impls.push(quote!(#(#attrs)* #token_stream));
trait_impls.push(quote!(#(#attrs)* #token_stream));
}
GeneratedPyMethod::Proto(token_stream) => {
let attrs = get_cfg_attributes(&meth.attrs);
proto_impls.push(quote!(#(#attrs)* #token_stream))
}
}
}
Expand All @@ -80,10 +90,25 @@ pub fn impl_methods(
PyClassMethodsType::Inventory => submit_methods_inventory(ty, methods),
};

let protos_registration = match methods_type {
PyClassMethodsType::Specialization => {
Some(impl_protos(ty, proto_impls, implemented_proto_fragments))
}
PyClassMethodsType::Inventory => {
if proto_impls.is_empty() {
None
} else {
panic!(
"cannot implement protos in #[pymethods] using `multiple-pymethods` feature"
);
}
}
};

Ok(quote! {
#(#new_impls)*
#(#trait_impls)*

#(#call_impls)*
#protos_registration

#methods_registration
})
Expand Down Expand Up @@ -122,6 +147,48 @@ fn impl_py_methods(ty: &syn::Type, methods: Vec<TokenStream>) -> TokenStream {
}
}

fn impl_protos(
ty: &syn::Type,
mut proto_impls: Vec<TokenStream>,
mut implemented_proto_fragments: HashSet<String>,
) -> TokenStream {
macro_rules! try_add_shared_slot {
($first:literal, $second:literal, $slot:ident) => {{
let first_implemented = implemented_proto_fragments.remove($first);
let second_implemented = implemented_proto_fragments.remove($second);
if first_implemented || second_implemented {
proto_impls.push(quote! { ::pyo3::$slot!(#ty) })
}
}};
}

try_add_shared_slot!("__setattr__", "__delattr__", generate_pyclass_setattr_slot);
try_add_shared_slot!("__set__", "__delete__", generate_pyclass_setdescr_slot);
try_add_shared_slot!("__setitem__", "__delitem__", generate_pyclass_setitem_slot);
try_add_shared_slot!("__add__", "__radd__", generate_pyclass_add_slot);
try_add_shared_slot!("__sub__", "__rsub__", generate_pyclass_sub_slot);
try_add_shared_slot!("__mul__", "__rmul__", generate_pyclass_mul_slot);
try_add_shared_slot!("__mod__", "__rmod__", generate_pyclass_mod_slot);
try_add_shared_slot!("__divmod__", "__rdivmod__", generate_pyclass_divmod_slot);
try_add_shared_slot!("__lshift__", "__rlshift__", generate_pyclass_lshift_slot);
try_add_shared_slot!("__rshift__", "__rrshift__", generate_pyclass_rshift_slot);
try_add_shared_slot!("__and__", "__rand__", generate_pyclass_and_slot);
try_add_shared_slot!("__or__", "__ror__", generate_pyclass_or_slot);
try_add_shared_slot!("__xor__", "__rxor__", generate_pyclass_xor_slot);
try_add_shared_slot!("__matmul__", "__rmatmul__", generate_pyclass_matmul_slot);
try_add_shared_slot!("__pow__", "__rpow__", generate_pyclass_pow_slot);

quote! {
impl ::pyo3::class::impl_::PyMethodsProtocolSlots<#ty>
for ::pyo3::class::impl_::PyClassImplCollector<#ty>
{
fn methods_protocol_slots(self) -> &'static [::pyo3::ffi::PyType_Slot] {
&[#(#proto_impls),*]
}
}
}
}

fn submit_methods_inventory(ty: &syn::Type, methods: Vec<TokenStream>) -> TokenStream {
if methods.is_empty() {
return TokenStream::default();
Expand Down
Loading

0 comments on commit 9fa0abe

Please sign in to comment.