From ccd04475a30409f226a1e044239c6b7e35260e98 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Tue, 2 Jul 2024 07:24:47 -0400 Subject: [PATCH] refs #4286 -- allow setting submodule on declarative pymodules (#4301) --- guide/src/module.md | 4 +- newsfragments/4301.added.md | 1 + pyo3-macros-backend/src/attributes.rs | 2 + pyo3-macros-backend/src/module.rs | 55 ++++++++++++++++++++------- pyo3-macros/src/lib.rs | 24 ++++++++++-- tests/test_declarative_module.rs | 10 ++++- 6 files changed, 77 insertions(+), 19 deletions(-) create mode 100644 newsfragments/4301.added.md diff --git a/guide/src/module.md b/guide/src/module.md index 8c6049270cb..2c4039a6e76 100644 --- a/guide/src/module.md +++ b/guide/src/module.md @@ -154,6 +154,8 @@ The `#[pymodule]` macro automatically sets the `module` attribute of the `#[pycl For nested modules, the name of the parent module is automatically added. In the following example, the `Unit` class will have for `module` `my_extension.submodule` because it is properly nested but the `Ext` class will have for `module` the default `builtins` because it not nested. + +You can provide the `submodule` argument to `pymodule()` for modules that are not top-level modules. ```rust # mod declarative_module_module_attr_test { use pyo3::prelude::*; @@ -168,7 +170,7 @@ mod my_extension { #[pymodule_export] use super::Ext; - #[pymodule] + #[pymodule(submodule)] mod submodule { use super::*; // This is a submodule diff --git a/newsfragments/4301.added.md b/newsfragments/4301.added.md new file mode 100644 index 00000000000..2ee759c28b5 --- /dev/null +++ b/newsfragments/4301.added.md @@ -0,0 +1 @@ +allow setting `submodule` on declarative `#[pymodule]`s diff --git a/pyo3-macros-backend/src/attributes.rs b/pyo3-macros-backend/src/attributes.rs index 02af17b618b..6a45ee875e3 100644 --- a/pyo3-macros-backend/src/attributes.rs +++ b/pyo3-macros-backend/src/attributes.rs @@ -37,6 +37,7 @@ pub mod kw { syn::custom_keyword!(set_all); syn::custom_keyword!(signature); syn::custom_keyword!(subclass); + syn::custom_keyword!(submodule); syn::custom_keyword!(text_signature); syn::custom_keyword!(transparent); syn::custom_keyword!(unsendable); @@ -178,6 +179,7 @@ pub type ModuleAttribute = KeywordAttribute; pub type NameAttribute = KeywordAttribute; pub type RenameAllAttribute = KeywordAttribute; pub type TextSignatureAttribute = KeywordAttribute; +pub type SubmoduleAttribute = kw::submodule; impl Parse for KeywordAttribute { fn parse(input: ParseStream<'_>) -> Result { diff --git a/pyo3-macros-backend/src/module.rs b/pyo3-macros-backend/src/module.rs index 4ce3023cbb4..faa7032de80 100644 --- a/pyo3-macros-backend/src/module.rs +++ b/pyo3-macros-backend/src/module.rs @@ -3,6 +3,7 @@ use crate::{ attributes::{ self, take_attributes, take_pyo3_options, CrateAttribute, ModuleAttribute, NameAttribute, + SubmoduleAttribute, }, get_doc, pyclass::PyClassPyO3Option, @@ -27,6 +28,7 @@ pub struct PyModuleOptions { krate: Option, name: Option, module: Option, + is_submodule: bool, } impl PyModuleOptions { @@ -38,6 +40,7 @@ impl PyModuleOptions { PyModulePyO3Option::Name(name) => options.set_name(name.value.0)?, PyModulePyO3Option::Crate(path) => options.set_crate(path)?, PyModulePyO3Option::Module(module) => options.set_module(module)?, + PyModulePyO3Option::Submodule(submod) => options.set_submodule(submod)?, } } @@ -73,9 +76,22 @@ impl PyModuleOptions { self.module = Some(name); Ok(()) } + + fn set_submodule(&mut self, submod: SubmoduleAttribute) -> Result<()> { + ensure_spanned!( + !self.is_submodule, + submod.span() => "`submodule` may only be specified once" + ); + + self.is_submodule = true; + Ok(()) + } } -pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { +pub fn pymodule_module_impl( + mut module: syn::ItemMod, + mut is_submodule: bool, +) -> Result { let syn::ItemMod { attrs, vis, @@ -100,6 +116,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { } else { name.to_string() }; + is_submodule = is_submodule || options.is_submodule; let mut module_items = Vec::new(); let mut module_items_cfg_attrs = Vec::new(); @@ -297,7 +314,7 @@ pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result { ) } }}; - let initialization = module_initialization(&name, ctx, module_def); + let initialization = module_initialization(&name, ctx, module_def, is_submodule); Ok(quote!( #(#attrs)* #vis mod #ident { @@ -331,7 +348,7 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result let vis = &function.vis; let doc = get_doc(&function.attrs, None, ctx); - let initialization = module_initialization(&name, ctx, quote! { MakeDef::make_def() }); + let initialization = module_initialization(&name, ctx, quote! { MakeDef::make_def() }, false); // Module function called with optional Python<'_> marker as first arg, followed by the module. let mut module_args = Vec::new(); @@ -396,28 +413,37 @@ pub fn pymodule_function_impl(mut function: syn::ItemFn) -> Result }) } -fn module_initialization(name: &syn::Ident, ctx: &Ctx, module_def: TokenStream) -> TokenStream { +fn module_initialization( + name: &syn::Ident, + ctx: &Ctx, + module_def: TokenStream, + is_submodule: bool, +) -> TokenStream { let Ctx { pyo3_path, .. } = ctx; let pyinit_symbol = format!("PyInit_{}", name); let name = name.to_string(); let pyo3_name = LitCStr::new(CString::new(name).unwrap(), Span::call_site(), ctx); - quote! { + let mut result = quote! { #[doc(hidden)] pub const __PYO3_NAME: &'static ::std::ffi::CStr = #pyo3_name; pub(super) struct MakeDef; #[doc(hidden)] pub static _PYO3_DEF: #pyo3_path::impl_::pymodule::ModuleDef = #module_def; - - /// This autogenerated function is called by the python interpreter when importing - /// the module. - #[doc(hidden)] - #[export_name = #pyinit_symbol] - pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject { - #pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py)) - } + }; + if !is_submodule { + result.extend(quote! { + /// This autogenerated function is called by the python interpreter when importing + /// the module. + #[doc(hidden)] + #[export_name = #pyinit_symbol] + pub unsafe extern "C" fn __pyo3_init() -> *mut #pyo3_path::ffi::PyObject { + #pyo3_path::impl_::trampoline::module_init(|py| _PYO3_DEF.make_module(py)) + } + }); } + result } /// Finds and takes care of the #[pyfn(...)] in `#[pymodule]` @@ -557,6 +583,7 @@ fn has_pyo3_module_declared( } enum PyModulePyO3Option { + Submodule(SubmoduleAttribute), Crate(CrateAttribute), Name(NameAttribute), Module(ModuleAttribute), @@ -571,6 +598,8 @@ impl Parse for PyModulePyO3Option { input.parse().map(PyModulePyO3Option::Crate) } else if lookahead.peek(attributes::kw::module) { input.parse().map(PyModulePyO3Option::Module) + } else if lookahead.peek(attributes::kw::submodule) { + input.parse().map(PyModulePyO3Option::Submodule) } else { Err(lookahead.error()) } diff --git a/pyo3-macros/src/lib.rs b/pyo3-macros/src/lib.rs index 8dbf2782d5b..95e983079f1 100644 --- a/pyo3-macros/src/lib.rs +++ b/pyo3-macros/src/lib.rs @@ -3,7 +3,7 @@ #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))] use proc_macro::TokenStream; -use proc_macro2::TokenStream as TokenStream2; +use proc_macro2::{Span, TokenStream as TokenStream2}; use pyo3_macros_backend::{ build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods, pymodule_function_impl, pymodule_module_impl, PyClassArgs, PyClassMethodsType, @@ -35,10 +35,26 @@ use syn::{parse::Nothing, parse_macro_input, Item}; /// [1]: https://pyo3.rs/latest/module.html #[proc_macro_attribute] pub fn pymodule(args: TokenStream, input: TokenStream) -> TokenStream { - parse_macro_input!(args as Nothing); match parse_macro_input!(input as Item) { - Item::Mod(module) => pymodule_module_impl(module), - Item::Fn(function) => pymodule_function_impl(function), + Item::Mod(module) => { + let is_submodule = match parse_macro_input!(args as Option) { + Some(i) if i == "submodule" => true, + Some(_) => { + return syn::Error::new( + Span::call_site(), + "#[pymodule] only accepts submodule as an argument", + ) + .into_compile_error() + .into(); + } + None => false, + }; + pymodule_module_impl(module, is_submodule) + } + Item::Fn(function) => { + parse_macro_input!(args as Nothing); + pymodule_function_impl(function) + } unsupported => Err(syn::Error::new_spanned( unsupported, "#[pymodule] only supports modules and functions.", diff --git a/tests/test_declarative_module.rs b/tests/test_declarative_module.rs index 061d0337285..0bf426a52cc 100644 --- a/tests/test_declarative_module.rs +++ b/tests/test_declarative_module.rs @@ -49,6 +49,10 @@ create_exception!( "Some description." ); +#[pymodule] +#[pyo3(submodule)] +mod external_submodule {} + /// A module written using declarative syntax. #[pymodule] mod declarative_module { @@ -70,6 +74,9 @@ mod declarative_module { #[pymodule_export] use super::some_module::SomeException; + #[pymodule_export] + use super::external_submodule; + #[pymodule] mod inner { use super::*; @@ -108,7 +115,7 @@ mod declarative_module { } } - #[pymodule] + #[pymodule(submodule)] #[pyo3(module = "custom_root")] mod inner_custom_root { use super::*; @@ -174,6 +181,7 @@ fn test_declarative_module() { py_assert!(py, m, "hasattr(m, 'LocatedClass')"); py_assert!(py, m, "isinstance(m.inner.Struct(), m.inner.Struct)"); py_assert!(py, m, "isinstance(m.inner.Enum.A, m.inner.Enum)"); + py_assert!(py, m, "hasattr(m, 'external_submodule')") }) }