diff --git a/pyo3-macros-backend/src/pyclass.rs b/pyo3-macros-backend/src/pyclass.rs index 2a7395796dc..5d17834b12b 100644 --- a/pyo3-macros-backend/src/pyclass.rs +++ b/pyo3-macros-backend/src/pyclass.rs @@ -15,6 +15,13 @@ use syn::parse::{Parse, ParseStream}; use syn::punctuated::Punctuated; use syn::{parse_quote, spanned::Spanned, Expr, Result, Token}; +/// If the class is derived from a Rust `struct` or `enum`. +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum PyClassKind { + Struct, + Enum, +} + /// The parsed arguments of the pyclass macro pub struct PyClassArgs { pub freelist: Option, @@ -27,22 +34,28 @@ pub struct PyClassArgs { pub has_extends: bool, pub has_unsendable: bool, pub module: Option, + pub class_kind: PyClassKind, } -impl Parse for PyClassArgs { - fn parse(input: ParseStream) -> Result { - let mut slf = PyClassArgs::default(); - +impl PyClassArgs { + fn parse(input: ParseStream, kind: PyClassKind) -> Result { + let mut slf = PyClassArgs::new(kind); let vars = Punctuated::::parse_terminated(input)?; for expr in vars { slf.add_expr(&expr)?; } Ok(slf) } -} -impl Default for PyClassArgs { - fn default() -> Self { + pub fn parse_stuct_args(input: ParseStream) -> syn::Result { + Self::parse(input, PyClassKind::Struct) + } + + pub fn parse_enum_args(input: ParseStream) -> syn::Result { + Self::parse(input, PyClassKind::Enum) + } + + fn new(class_kind: PyClassKind) -> Self { PyClassArgs { freelist: None, name: None, @@ -54,11 +67,10 @@ impl Default for PyClassArgs { is_basetype: false, has_extends: false, has_unsendable: false, + class_kind, } } -} -impl PyClassArgs { /// Adda single expression from the comma separated list in the attribute, which is /// either a single word or an assignment expression fn add_expr(&mut self, expr: &Expr) -> Result<()> { @@ -116,6 +128,11 @@ impl PyClassArgs { }, "extends" => match unwrap_group(&**right) { syn::Expr::Path(exp) => { + if self.class_kind == PyClassKind::Enum { + return Err( + err_spanned!( assign.span() => "enums cannot extend from other classes" ), + ); + } self.base = syn::TypePath { path: exp.path.clone(), qself: None, @@ -150,6 +167,11 @@ impl PyClassArgs { self.has_weaklist = true; } "subclass" => { + if self.class_kind == PyClassKind::Enum { + return Err( + err_spanned!(exp.span() => "enums can't be inherited by other classes"), + ); + } self.is_basetype = true; } "dict" => { @@ -496,8 +518,8 @@ struct VariantPyO3<'a> { } pub fn build_py_enum( - _args: PyClassArgs, enum_: &syn::ItemEnum, + args: PyClassArgs, method_type: PyClassMethodsType, ) -> syn::Result { let variants: Vec = enum_ @@ -505,17 +527,18 @@ pub fn build_py_enum( .iter() .map(|v| extract_variant_data(v)) .collect::>()?; - impl_enum(enum_, variants, method_type) + impl_enum(&enum_, args, variants, method_type) } fn impl_enum( enum_: &syn::ItemEnum, + attrs: PyClassArgs, variants: Vec, methods_type: PyClassMethodsType, ) -> syn::Result { let enum_name = &enum_.ident; let doc = utils::get_doc(&enum_.attrs, None); - let enum_cls = impl_enum_class(enum_name, doc, methods_type)?; + let enum_cls = impl_enum_class(enum_name, &attrs, doc, methods_type)?; let variant_consts = variants .iter() @@ -548,16 +571,17 @@ fn impl_const(enum_: &syn::Ident, cls: &syn::Ident) -> syn::Result } fn impl_enum_class( - typ: &syn::Ident, + cls: &syn::Ident, + _attr: &PyClassArgs, doc: PythonDoc, methods_type: PyClassMethodsType, ) -> syn::Result { - let clsname = typ.to_string(); - let extractext = impl_extractext(&typ); - let pyclassimpl_impl = PyClassImplBuilder::new(&typ, methods_type).doc(doc).build(); + let clsname = cls.to_string(); + let extractext = impl_extractext(cls); + let pyclassimpl_impl = PyClassImplBuilder::new(cls, methods_type).doc(doc).build(); Ok(quote! { - unsafe impl pyo3::type_object::PyTypeInfo for #typ { + unsafe impl pyo3::type_object::PyTypeInfo for #cls { type AsRefTarget = pyo3::PyCell; const NAME: &'static str = #clsname; const MODULE: Option<&'static str> = None; @@ -569,7 +593,7 @@ fn impl_enum_class( } } - impl pyo3::PyClass for #typ { + impl pyo3::PyClass for #cls { type Dict = pyo3::pyclass_slots::PyClassDummySlot ; type WeakRef = pyo3::pyclass_slots::PyClassDummySlot; type BaseNativeType = pyo3::PyAny; @@ -653,7 +677,7 @@ fn impl_descriptors( } /// Builds an implementation for `pyo3::class::impl_::PyClassImpl` -pub struct PyClassImplBuilder<'a> { +struct PyClassImplBuilder<'a> { cls: &'a syn::Ident, doc: Option, is_gc: bool, diff --git a/pyo3-macros/src/lib.rs b/pyo3-macros/src/lib.rs index 143114cc6dd..9ff9e49a04d 100644 --- a/pyo3-macros/src/lib.rs +++ b/pyo3-macros/src/lib.rs @@ -116,11 +116,10 @@ pub fn pyproto(_: TokenStream, input: TokenStream) -> TokenStream { #[proc_macro_attribute] pub fn pyclass(attr: TokenStream, input: TokenStream) -> TokenStream { use syn::Item; - let args = parse_macro_input!(attr as PyClassArgs); let item = parse_macro_input!(input as Item); match item { - Item::Struct(struct_) => pyclass_impl(args, struct_, methods_type()), - Item::Enum(enum_) => pyclass_enum_impl(args, enum_, methods_type()), + Item::Struct(struct_) => pyclass_impl(attr, struct_, methods_type()), + Item::Enum(enum_) => pyclass_enum_impl(attr, enum_, methods_type()), unsupported => { syn::Error::new_spanned(unsupported, "#[pyclass] only supports structs and enums.") .to_compile_error() @@ -209,10 +208,11 @@ pub fn derive_from_py_object(item: TokenStream) -> TokenStream { } fn pyclass_impl( - args: PyClassArgs, + attrs: TokenStream, mut ast: syn::ItemStruct, methods_type: PyClassMethodsType, ) -> TokenStream { + let args = parse_macro_input!(attrs with PyClassArgs::parse_stuct_args); let expanded = build_py_class(&mut ast, &args, methods_type).unwrap_or_else(|e| e.to_compile_error()); @@ -224,12 +224,13 @@ fn pyclass_impl( } fn pyclass_enum_impl( - args: PyClassArgs, + attr: TokenStream, enum_: syn::ItemEnum, methods_type: PyClassMethodsType, ) -> TokenStream { + let args = parse_macro_input!(attr with PyClassArgs::parse_enum_args); let expanded = - build_py_enum(args, &enum_, methods_type).unwrap_or_else(|e| e.into_compile_error()); + build_py_enum(&enum_, args, methods_type).unwrap_or_else(|e| e.into_compile_error()); quote!( #enum_ diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index 607bcf20620..93a74d48bae 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -19,6 +19,7 @@ fn _test_compile_errors() { t.compile_fail("tests/ui/invalid_need_module_arg_position.rs"); t.compile_fail("tests/ui/invalid_property_args.rs"); t.compile_fail("tests/ui/invalid_pyclass_args.rs"); + t.compile_fail("tests/ui/invalid_pyclass_enum.rs"); t.compile_fail("tests/ui/invalid_pyclass_item.rs"); t.compile_fail("tests/ui/invalid_pyfunctions.rs"); t.compile_fail("tests/ui/invalid_pymethods.rs"); diff --git a/tests/ui/invalid_pyclass_enum.rs b/tests/ui/invalid_pyclass_enum.rs new file mode 100644 index 00000000000..62f2a3d6bd5 --- /dev/null +++ b/tests/ui/invalid_pyclass_enum.rs @@ -0,0 +1,15 @@ +use pyo3::prelude::*; + +#[pyclass(subclass)] +enum NotBaseClass { + x, + y, +} + +#[pyclass(extends = PyList)] +enum NotDrivedClass { + x, + y, +} + +fn main() {} diff --git a/tests/ui/invalid_pyclass_enum.stderr b/tests/ui/invalid_pyclass_enum.stderr new file mode 100644 index 00000000000..2dd0e737e4b --- /dev/null +++ b/tests/ui/invalid_pyclass_enum.stderr @@ -0,0 +1,11 @@ +error: enums can't be inherited by other classes + --> tests/ui/invalid_pyclass_enum.rs:3:11 + | +3 | #[pyclass(subclass)] + | ^^^^^^^^ + +error: enums cannot extend from other classes + --> tests/ui/invalid_pyclass_enum.rs:9:11 + | +9 | #[pyclass(extends = PyList)] + | ^^^^^^^