diff --git a/guide/src/function.md b/guide/src/function.md index f2a2d1b41ee..d3355a54fa6 100644 --- a/guide/src/function.md +++ b/guide/src/function.md @@ -255,13 +255,14 @@ in Python code. ### Accessing the module of a function -It is possible to access the module of a `#[pyfunction]` in the function body by passing the `pass_module` argument to the attribute: +It is possible to access the module of a `#[pyfunction]` in the function body by using `#[pyo3(pass_module)]` option: ```rust use pyo3::wrap_pyfunction; use pyo3::prelude::*; -#[pyfunction(pass_module)] +#[pyfunction] +#[pyo3(pass_module)] fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> { module.name() } diff --git a/pyo3-macros-backend/src/pyfunction.rs b/pyo3-macros-backend/src/pyfunction.rs index d125d66193f..3e3efe3c311 100644 --- a/pyo3-macros-backend/src/pyfunction.rs +++ b/pyo3-macros-backend/src/pyfunction.rs @@ -209,7 +209,7 @@ impl PyFunctionSignature { #[derive(Default)] pub struct PyFunctionOptions { - pub pass_module: bool, + pub pass_module: Option, pub name: Option, pub signature: Option, pub text_signature: Option, @@ -219,7 +219,7 @@ pub struct PyFunctionOptions { impl Parse for PyFunctionOptions { fn parse(input: ParseStream) -> Result { let mut options = PyFunctionOptions { - pass_module: false, + pass_module: None, name: None, signature: None, text_signature: None, @@ -314,10 +314,10 @@ impl PyFunctionOptions { PyFunctionOption::Name(name) => self.set_name(name)?, PyFunctionOption::PassModule(kw) => { ensure_spanned!( - !self.pass_module, + self.pass_module.is_none(), kw.span() => "`pass_module` may only be specified once" ); - self.pass_module = true; + self.pass_module = Some(kw); } PyFunctionOption::Signature(signature) => { ensure_spanned!( @@ -385,7 +385,7 @@ pub fn impl_wrap_pyfunction( .map(FnArg::parse) .collect::>>()?; - if options.pass_module { + if options.pass_module.is_some() { const PASS_MODULE_ERR: &str = "expected &PyModule as first argument with `pass_module`"; ensure_spanned!( !arguments.is_empty(), @@ -426,7 +426,7 @@ pub fn impl_wrap_pyfunction( let name = &func.sig.ident; let wrapper_ident = format_ident!("__pyo3_raw_{}", name); - let wrapper = function_c_wrapper(name, &wrapper_ident, &spec, options.pass_module)?; + let wrapper = function_c_wrapper(name, &wrapper_ident, &spec, options.pass_module.is_some())?; let (methoddef_meth, cfunc_variant) = if spec.args.is_empty() { (quote!(noargs), quote!(PyCFunction)) } else if spec.can_use_fastcall() { diff --git a/pyo3-macros-backend/src/pymethod.rs b/pyo3-macros-backend/src/pymethod.rs index 02da1c8d142..850cb2b4ca5 100644 --- a/pyo3-macros-backend/src/pymethod.rs +++ b/pyo3-macros-backend/src/pymethod.rs @@ -27,6 +27,7 @@ pub fn gen_py_method( ) -> Result { check_generic(sig)?; ensure_not_async_fn(sig)?; + ensure_function_options_valid(&options)?; let spec = FnSpec::parse(sig, &mut *meth_attrs, options)?; Ok(match &spec.tp { @@ -71,6 +72,13 @@ pub(crate) fn check_generic(sig: &syn::Signature) -> syn::Result<()> { Ok(()) } +fn ensure_function_options_valid(options: &PyFunctionOptions) -> syn::Result<()> { + if let Some(pass_module) = &options.pass_module { + bail_spanned!(pass_module.span() => "`pass_module` cannot be used on Python methods") + } + Ok(()) +} + pub fn gen_py_const(cls: &syn::Type, spec: &ConstSpec) -> TokenStream { let member = &spec.rust_ident; let deprecations = &spec.attributes.deprecations; diff --git a/pyo3-macros/src/lib.rs b/pyo3-macros/src/lib.rs index 963e1814551..1a21001f372 100644 --- a/pyo3-macros/src/lib.rs +++ b/pyo3-macros/src/lib.rs @@ -186,7 +186,7 @@ pub fn pymethods(_: TokenStream, input: TokenStream) -> TokenStream { /// | [`#[args]`][10] | Define a method's default arguments and allows the function to receive `*args` and `**kwargs`. | /// /// Methods within a `#[pymethods]` block can also be annotated with any of the `#[pyo3]` options which can -/// be used with [`#[pyfunction]`][attr.pyfunction.html]. +/// be used with [`#[pyfunction]`][attr.pyfunction.html], except for `pass_module`. /// /// For more on creating class methods see the [class section of the guide][1]. /// @@ -218,6 +218,7 @@ pub fn pymethods_with_inventory(_: TokenStream, input: TokenStream) -> TokenStre /// | :- | :- | /// | `#[pyo3(name = "...")]` | Defines the name of the function in Python. | /// | `#[pyo3(text_signature = "...")]` | Defines the `__text_signature__` attribute of the function in Python. | +/// | `#[pyo3(pass_module)]` | Passes the module containing the function as a `&PyModule` first argument to the function. | /// /// For more on exposing functions see the [function section of the guide][1]. /// diff --git a/tests/test_module.rs b/tests/test_module.rs index b2349db4b87..55ed0459132 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -48,7 +48,8 @@ fn module_with_functions(_py: Python, m: &PyModule) -> PyResult<()> { 42 } - #[pyfn(m, pass_module)] + #[pyfn(m)] + #[pyo3(pass_module)] fn with_module(module: &PyModule) -> PyResult<&str> { module.name() } @@ -340,12 +341,14 @@ fn test_module_with_constant() { }); } -#[pyfunction(pass_module)] +#[pyfunction] +#[pyo3(pass_module)] fn pyfunction_with_module(module: &PyModule) -> PyResult<&str> { module.name() } -#[pyfunction(pass_module)] +#[pyfunction] +#[pyo3(pass_module)] fn pyfunction_with_module_and_py<'a>( module: &'a PyModule, _python: Python<'a>, @@ -353,12 +356,14 @@ fn pyfunction_with_module_and_py<'a>( module.name() } -#[pyfunction(pass_module)] +#[pyfunction] +#[pyo3(pass_module)] fn pyfunction_with_module_and_arg(module: &PyModule, string: String) -> PyResult<(&str, String)> { module.name().map(|s| (s, string)) } -#[pyfunction(pass_module, string = "\"foo\"")] +#[pyfunction(string = "\"foo\"")] +#[pyo3(pass_module)] fn pyfunction_with_module_and_default_arg<'a>( module: &'a PyModule, string: &str, @@ -366,7 +371,8 @@ fn pyfunction_with_module_and_default_arg<'a>( module.name().map(|s| (s, string.into())) } -#[pyfunction(pass_module, args = "*", kwargs = "**")] +#[pyfunction(args = "*", kwargs = "**")] +#[pyo3(pass_module)] fn pyfunction_with_module_and_args_kwargs<'a>( module: &'a PyModule, args: &PyTuple, @@ -377,13 +383,24 @@ fn pyfunction_with_module_and_args_kwargs<'a>( .map(|s| (s, args.len(), kwargs.map(|d| d.len()))) } +#[pyfunction] +#[pyo3(pass_module)] +fn pyfunction_with_pass_module_in_attribute(module: &PyModule) -> PyResult<&str> { + module.name() +} + #[pymodule] fn module_with_functions_with_module(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(pyfunction_with_module, m)?)?; m.add_function(wrap_pyfunction!(pyfunction_with_module_and_py, m)?)?; m.add_function(wrap_pyfunction!(pyfunction_with_module_and_arg, m)?)?; m.add_function(wrap_pyfunction!(pyfunction_with_module_and_default_arg, m)?)?; - m.add_function(wrap_pyfunction!(pyfunction_with_module_and_args_kwargs, m)?) + m.add_function(wrap_pyfunction!(pyfunction_with_module_and_args_kwargs, m)?)?; + m.add_function(wrap_pyfunction!( + pyfunction_with_pass_module_in_attribute, + m + )?)?; + Ok(()) } #[test] @@ -413,4 +430,9 @@ fn test_module_functions_with_module() { "m.pyfunction_with_module_and_args_kwargs(1, x=1, y=2) \ == ('module_with_functions_with_module', 1, 2)" ); + py_assert!( + py, + m, + "m.pyfunction_with_pass_module_in_attribute() == 'module_with_functions_with_module'" + ); } diff --git a/tests/ui/invalid_pymethods.rs b/tests/ui/invalid_pymethods.rs index ebb2fb87463..528af3197c1 100644 --- a/tests/ui/invalid_pymethods.rs +++ b/tests/ui/invalid_pymethods.rs @@ -108,4 +108,10 @@ impl MyClass { async fn async_method(&self) {} } +#[pymethods] +impl MyClass { + #[pyo3(pass_module)] + fn method_cannot_pass_module(&self, m: &PyModule) {} +} + fn main() {} diff --git a/tests/ui/invalid_pymethods.stderr b/tests/ui/invalid_pymethods.stderr index 8091348e017..56f8f99d185 100644 --- a/tests/ui/invalid_pymethods.stderr +++ b/tests/ui/invalid_pymethods.stderr @@ -95,3 +95,9 @@ Additional crates such as `pyo3-asyncio` can be used to integrate async Rust and | 108 | async fn async_method(&self) {} | ^^^^^ + +error: `pass_module` cannot be used on Python methods + --> $DIR/invalid_pymethods.rs:113:12 + | +113 | #[pyo3(pass_module)] + | ^^^^^^^^^^^