Skip to content

Commit

Permalink
fixes PyO3#4285 -- allow full-path to pymodule with nested declarativ…
Browse files Browse the repository at this point in the history
…e modules
  • Loading branch information
alex committed Jun 26, 2024
1 parent 2e2d440 commit ed72155
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
1 change: 1 addition & 0 deletions newsfragments/4288.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
allow `#[pyo3::prelude::pymodule]` with nested declarative modules
47 changes: 26 additions & 21 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,30 +153,30 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
}
Item::Fn(item_fn) => {
ensure_spanned!(
!has_attribute(&item_fn.attrs, "pymodule_export"),
!has_attribute(&item_fn.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
let is_pymodule_init =
find_and_remove_attribute(&mut item_fn.attrs, "pymodule_init");
let ident = &item_fn.sig.ident;
if is_pymodule_init {
ensure_spanned!(
!has_attribute(&item_fn.attrs, "pyfunction"),
!has_attribute(&item_fn.attrs, &["pyfunction"]),
item_fn.span() => "`#[pyfunction]` cannot be used alongside `#[pymodule_init]`"
);
ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified");
pymodule_init = Some(quote! { #ident(module)?; });
} else if has_attribute(&item_fn.attrs, "pyfunction") {
} else if has_attribute(&item_fn.attrs, &["pyfunction"]) {
module_items.push(ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs));
}
}
Item::Struct(item_struct) => {
ensure_spanned!(
!has_attribute(&item_struct.attrs, "pymodule_export"),
!has_attribute(&item_struct.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
if has_attribute(&item_struct.attrs, "pyclass") {
if has_attribute(&item_struct.attrs, &["pyclass"]) {
module_items.push(item_struct.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs));
if !has_pyo3_module_declared::<PyClassPyO3Option>(
Expand All @@ -190,10 +190,10 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
}
Item::Enum(item_enum) => {
ensure_spanned!(
!has_attribute(&item_enum.attrs, "pymodule_export"),
!has_attribute(&item_enum.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
if has_attribute(&item_enum.attrs, "pyclass") {
if has_attribute(&item_enum.attrs, &["pyclass"]) {
module_items.push(item_enum.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs));
if !has_pyo3_module_declared::<PyClassPyO3Option>(
Expand All @@ -207,10 +207,13 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
}
Item::Mod(item_mod) => {
ensure_spanned!(
!has_attribute(&item_mod.attrs, "pymodule_export"),
!has_attribute(&item_mod.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
if has_attribute(&item_mod.attrs, "pymodule") {
if has_attribute(&item_mod.attrs, &["pymodule"])
|| has_attribute(&item_mod.attrs, &[ctx.pyo3_path, "pymodule"])
|| has_attribute(&item_mod.attrs, &[ctx.pyo3_path, "prelude", "pymodule"])
{
module_items.push(item_mod.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs));
if !has_pyo3_module_declared::<PyModulePyO3Option>(
Expand All @@ -224,61 +227,61 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
}
Item::ForeignMod(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
!has_attribute(&item.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Trait(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
!has_attribute(&item.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Const(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
!has_attribute(&item.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Static(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
!has_attribute(&item.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Macro(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
!has_attribute(&item.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::ExternCrate(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
!has_attribute(&item.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Impl(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
!has_attribute(&item.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::TraitAlias(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
!has_attribute(&item.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Type(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
!has_attribute(&item.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Item::Union(item) => {
ensure_spanned!(
!has_attribute(&item.attrs, "pymodule_export"),
!has_attribute(&item.attrs, &["pymodule_export"]),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
}
Expand Down Expand Up @@ -533,8 +536,10 @@ fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bo
found
}

fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool {
attrs.iter().any(|attr| attr.path().is_ident(ident))
fn has_attribute(attrs: &[syn::Attribute], ident: &[&str]) -> bool {
attrs
.iter()
.any(|attr| attr.path().segments.iter().map(|v| &v.ident).eq(ident))
}

fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) {
Expand Down
11 changes: 11 additions & 0 deletions tests/test_declarative_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ mod declarative_module {
struct Struct;
}

#[pyo3::prelude::pymodule]
mod full_path_inner {}

#[pymodule_init]
fn init(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add("double2", m.getattr("double")?)
Expand Down Expand Up @@ -239,3 +242,11 @@ fn test_module_names() {
);
})
}

#[test]
fn test_inner_module_full_path() {
Python::with_gil(|py| {
let m = declarative_module(py);
py_assert!(py, m, "m.full_path_inner");
})
}

0 comments on commit ed72155

Please sign in to comment.