diff --git a/newsfragments/4288.fixed.md b/newsfragments/4288.fixed.md new file mode 100644 index 00000000000..105bb042276 --- /dev/null +++ b/newsfragments/4288.fixed.md @@ -0,0 +1 @@ +allow `#[pyo3::prelude::pymodule]` with nested declarative modules diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index 39240aba7e8..a015dbdf7c0 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -153,7 +153,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { } Item::Fn(item_fn) => { ensure_spanned!( - !has_attribute(&item_fn.attrs, "pymodule_export"), + !has_attribute(&item_fn.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); let is_pymodule_init = @@ -161,22 +161,22 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { let ident = &item_fn.sig.ident; if is_pymodule_init { ensure_spanned!( - !has_attribute(&item_fn.attrs, "pyfunction"), + !has_attribute(&item_fn.attrs, &["pyfunction"]), item_fn.span() => "`#[pyfunction]` cannot be used alongside `#[pymodule_init]`" ); ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified"); pymodule_init = Some(quote! { #ident(module)?; }); - } else if has_attribute(&item_fn.attrs, "pyfunction") { + } else if has_attribute(&item_fn.attrs, &["pyfunction"]) { module_items.push(ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs)); } } Item::Struct(item_struct) => { ensure_spanned!( - !has_attribute(&item_struct.attrs, "pymodule_export"), + !has_attribute(&item_struct.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); - if has_attribute(&item_struct.attrs, "pyclass") { + if has_attribute(&item_struct.attrs, &["pyclass"]) { module_items.push(item_struct.ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs)); if !has_pyo3_module_declared::( @@ -190,10 +190,10 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { } Item::Enum(item_enum) => { ensure_spanned!( - !has_attribute(&item_enum.attrs, "pymodule_export"), + !has_attribute(&item_enum.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); - if has_attribute(&item_enum.attrs, "pyclass") { + if has_attribute(&item_enum.attrs, &["pyclass"]) { module_items.push(item_enum.ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs)); if !has_pyo3_module_declared::( @@ -207,10 +207,13 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { } Item::Mod(item_mod) => { ensure_spanned!( - !has_attribute(&item_mod.attrs, "pymodule_export"), + !has_attribute(&item_mod.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); - if has_attribute(&item_mod.attrs, "pymodule") { + if has_attribute(&item_mod.attrs, &["pymodule"]) + || has_attribute(&item_mod.attrs, &[ctx.pyo3_path, "pymodule"]) + || has_attribute(&item_mod.attrs, &[ctx.pyo3_path, "prelude", "pymodule"]) + { module_items.push(item_mod.ident.clone()); module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs)); if !has_pyo3_module_declared::( @@ -224,61 +227,61 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { } Item::ForeignMod(item) => { ensure_spanned!( - !has_attribute(&item.attrs, "pymodule_export"), + !has_attribute(&item.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); } Item::Trait(item) => { ensure_spanned!( - !has_attribute(&item.attrs, "pymodule_export"), + !has_attribute(&item.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); } Item::Const(item) => { ensure_spanned!( - !has_attribute(&item.attrs, "pymodule_export"), + !has_attribute(&item.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); } Item::Static(item) => { ensure_spanned!( - !has_attribute(&item.attrs, "pymodule_export"), + !has_attribute(&item.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); } Item::Macro(item) => { ensure_spanned!( - !has_attribute(&item.attrs, "pymodule_export"), + !has_attribute(&item.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); } Item::ExternCrate(item) => { ensure_spanned!( - !has_attribute(&item.attrs, "pymodule_export"), + !has_attribute(&item.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); } Item::Impl(item) => { ensure_spanned!( - !has_attribute(&item.attrs, "pymodule_export"), + !has_attribute(&item.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); } Item::TraitAlias(item) => { ensure_spanned!( - !has_attribute(&item.attrs, "pymodule_export"), + !has_attribute(&item.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); } Item::Type(item) => { ensure_spanned!( - !has_attribute(&item.attrs, "pymodule_export"), + !has_attribute(&item.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); } Item::Union(item) => { ensure_spanned!( - !has_attribute(&item.attrs, "pymodule_export"), + !has_attribute(&item.attrs, &["pymodule_export"]), item.span() => "`#[pymodule_export]` may only be used on `use` statements" ); } @@ -533,8 +536,10 @@ fn find_and_remove_attribute(attrs: &mut Vec, ident: &str) -> bo found } -fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool { - attrs.iter().any(|attr| attr.path().is_ident(ident)) +fn has_attribute(attrs: &[syn::Attribute], ident: &[&str]) -> bool { + attrs + .iter() + .any(|attr| attr.path().segments.iter().map(|v| &v.ident).eq(ident)) } fn set_module_attribute(attrs: &mut Vec, module_name: &str) { diff --git a/tests/test_declarative_module.rs b/tests/test_declarative_module.rs index 061d0337285..8935090244a 100644 --- a/tests/test_declarative_module.rs +++ b/tests/test_declarative_module.rs @@ -117,6 +117,9 @@ mod declarative_module { struct Struct; } + #[pyo3::prelude::pymodule] + mod full_path_inner {} + #[pymodule_init] fn init(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add("double2", m.getattr("double")?) @@ -239,3 +242,11 @@ fn test_module_names() { ); }) } + +#[test] +fn test_inner_module_full_path() { + Python::with_gil(|py| { + let m = declarative_module(py); + py_assert!(py, m, "m.full_path_inner"); + }) +}