Skip to content

Commit

Permalink
fixes #4285 -- allow full-path to pymodule with nested declarative mo…
Browse files Browse the repository at this point in the history
…dules (#4288)
  • Loading branch information
alex authored Jul 5, 2024
1 parent 5860c4f commit 9afc38a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 6 deletions.
1 change: 1 addition & 0 deletions newsfragments/4288.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
allow `#[pyo3::prelude::pymodule]` with nested declarative modules
88 changes: 82 additions & 6 deletions pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
get_doc,
pyclass::PyClassPyO3Option,
pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
utils::{Ctx, LitCStr},
utils::{Ctx, LitCStr, PyO3CratePath},
};
use proc_macro2::{Span, TokenStream};
use quote::quote;
Expand Down Expand Up @@ -183,7 +183,18 @@ pub fn pymodule_module_impl(
);
ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one `#[pymodule_init]` may be specified");
pymodule_init = Some(quote! { #ident(module)?; });
} else if has_attribute(&item_fn.attrs, "pyfunction") {
} else if has_attribute(&item_fn.attrs, "pyfunction")
|| has_attribute_with_namespace(
&item_fn.attrs,
Some(pyo3_path),
&["pyfunction"],
)
|| has_attribute_with_namespace(
&item_fn.attrs,
Some(pyo3_path),
&["prelude", "pyfunction"],
)
{
module_items.push(ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_fn.attrs));
}
Expand All @@ -193,7 +204,18 @@ pub fn pymodule_module_impl(
!has_attribute(&item_struct.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
if has_attribute(&item_struct.attrs, "pyclass") {
if has_attribute(&item_struct.attrs, "pyclass")
|| has_attribute_with_namespace(
&item_struct.attrs,
Some(pyo3_path),
&["pyclass"],
)
|| has_attribute_with_namespace(
&item_struct.attrs,
Some(pyo3_path),
&["prelude", "pyclass"],
)
{
module_items.push(item_struct.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_struct.attrs));
if !has_pyo3_module_declared::<PyClassPyO3Option>(
Expand All @@ -210,7 +232,14 @@ pub fn pymodule_module_impl(
!has_attribute(&item_enum.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
if has_attribute(&item_enum.attrs, "pyclass") {
if has_attribute(&item_enum.attrs, "pyclass")
|| has_attribute_with_namespace(&item_enum.attrs, Some(pyo3_path), &["pyclass"])
|| has_attribute_with_namespace(
&item_enum.attrs,
Some(pyo3_path),
&["prelude", "pyclass"],
)
{
module_items.push(item_enum.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_enum.attrs));
if !has_pyo3_module_declared::<PyClassPyO3Option>(
Expand All @@ -227,7 +256,14 @@ pub fn pymodule_module_impl(
!has_attribute(&item_mod.attrs, "pymodule_export"),
item.span() => "`#[pymodule_export]` may only be used on `use` statements"
);
if has_attribute(&item_mod.attrs, "pymodule") {
if has_attribute(&item_mod.attrs, "pymodule")
|| has_attribute_with_namespace(&item_mod.attrs, Some(pyo3_path), &["pymodule"])
|| has_attribute_with_namespace(
&item_mod.attrs,
Some(pyo3_path),
&["prelude", "pymodule"],
)
{
module_items.push(item_mod.ident.clone());
module_items_cfg_attrs.push(get_cfg_attributes(&item_mod.attrs));
if !has_pyo3_module_declared::<PyModulePyO3Option>(
Expand Down Expand Up @@ -555,8 +591,48 @@ fn find_and_remove_attribute(attrs: &mut Vec<syn::Attribute>, ident: &str) -> bo
found
}

enum IdentOrStr<'a> {
Str(&'a str),
Ident(syn::Ident),
}

impl<'a> PartialEq<syn::Ident> for IdentOrStr<'a> {
fn eq(&self, other: &syn::Ident) -> bool {
match self {
IdentOrStr::Str(s) => other == s,
IdentOrStr::Ident(i) => other == i,
}
}
}
fn has_attribute(attrs: &[syn::Attribute], ident: &str) -> bool {
attrs.iter().any(|attr| attr.path().is_ident(ident))
has_attribute_with_namespace(attrs, None, &[ident])
}

fn has_attribute_with_namespace(
attrs: &[syn::Attribute],
crate_path: Option<&PyO3CratePath>,
idents: &[&str],
) -> bool {
let mut segments = vec![];
if let Some(c) = crate_path {
match c {
PyO3CratePath::Given(paths) => {
for p in &paths.segments {
segments.push(IdentOrStr::Ident(p.ident.clone()));
}
}
PyO3CratePath::Default => segments.push(IdentOrStr::Str("pyo3")),
}
};
for i in idents {
segments.push(IdentOrStr::Str(i));
}

attrs.iter().any(|attr| {
segments
.iter()
.eq(attr.path().segments.iter().map(|v| &v.ident))
})
}

fn set_module_attribute(attrs: &mut Vec<syn::Attribute>, module_name: &str) {
Expand Down
11 changes: 11 additions & 0 deletions tests/test_declarative_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ mod declarative_module {
struct Struct;
}

#[pyo3::prelude::pymodule]
mod full_path_inner {}

#[pymodule_init]
fn init(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add("double2", m.getattr("double")?)
Expand Down Expand Up @@ -247,3 +250,11 @@ fn test_module_names() {
);
})
}

#[test]
fn test_inner_module_full_path() {
Python::with_gil(|py| {
let m = declarative_module(py);
py_assert!(py, m, "m.full_path_inner");
})
}

0 comments on commit 9afc38a

Please sign in to comment.