Skip to content

Commit

Permalink
WIP declarative module
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed May 11, 2022
1 parent 39ab95a commit c08703d
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 52 deletions.
4 changes: 3 additions & 1 deletion pyo3-macros-backend/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ mod pymethod;
mod pyproto;

pub use frompyobject::build_derive_from_pyobject;
pub use module::{process_functions_in_module, pymodule_impl, PyModuleOptions};
pub use module::{
process_functions_in_module, pymodule_function_impl, pymodule_module_impl, PyModuleOptions,
};
pub use pyclass::{build_py_class, build_py_enum, PyClassArgs};
pub use pyfunction::{build_py_function, PyFunctionOptions};
pub use pyimpl::{build_py_methods, PyClassMethodsType};
Expand Down
97 changes: 96 additions & 1 deletion pyo3-macros-backend/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::{
attributes::{
self, is_attribute_ident, take_attributes, take_pyo3_options, CrateAttribute, NameAttribute,
},
get_doc,
pyfunction::{impl_wrap_pyfunction, PyFunctionOptions},
utils::{get_pyo3_crate, PythonDoc},
};
Expand Down Expand Up @@ -59,9 +60,103 @@ impl PyModuleOptions {
}
}

pub fn pymodule_module_impl(mut module: syn::ItemMod) -> Result<TokenStream> {
let syn::ItemMod {
attrs,
vis,
ident,
mod_token,
content,
semi: _,
} = &mut module;
let items = match content {
Some((_, items)) => items,
None => bail_spanned!(module.span() => "`#[pymodule]` can only be used on inline modules"),
};
let options = PyModuleOptions::from_attrs(attrs)?;
let doc = get_doc(attrs, None);

let name = options.name.unwrap_or_else(|| ident.unraw());
let krate = get_pyo3_crate(&options.krate);
let pyinit_symbol = format!("PyInit_{}", name);

let mut module_items = Vec::new();

fn extract_use_items(source: &syn::UseTree, target: &mut Vec<Ident>) -> Result<()> {
match source {
syn::UseTree::Name(name) => target.push(name.ident.clone()),
syn::UseTree::Path(path) => extract_use_items(&path.tree, target)?,
syn::UseTree::Group(group) => {
for tree in &group.items {
extract_use_items(tree, target)?
}
}
syn::UseTree::Glob(glob) => bail_spanned!(glob.span() => "#[pyo3] cannot import glob statements"),
syn::UseTree::Rename(rename) => target.push(rename.ident.clone()),
}
Ok(())
}

let mut pymodule_init = None;

for item in items.iter_mut() {
match item {
syn::Item::Use(item_use) => {
let mut is_pyo3 = false;
item_use.attrs.retain(|attr| {
let found = attr.path.is_ident("pyo3");
is_pyo3 |= found;
!found
});
if is_pyo3 {
extract_use_items(&item_use.tree, &mut module_items)?;
}
}
syn::Item::Fn(item_fn) => {
let mut is_module_init = false;
item_fn.attrs.retain(|attr| {
let found = attr.path.is_ident("pymodule_init");
is_module_init |= found;
!found
});
if is_module_init {
ensure_spanned!(pymodule_init.is_none(), item_fn.span() => "only one pymodule_init may be specified");
let ident = &item_fn.sig.ident;
pymodule_init = Some(quote! { #ident(module)?; });
}
}
_ => {}
}
}

Ok(quote! {
#vis #mod_token #ident {
#(#items)*

pub static DEF: #krate::impl_::pymodule::ModuleDef = unsafe {
use #krate::impl_::pymodule as impl_;
impl_::ModuleDef::new(concat!(stringify!(#name), "\0"), #doc, impl_::ModuleInitializer(__pyo3_pymodule))
};

pub fn __pyo3_pymodule(_py: #krate::Python, module: &#krate::types::PyModule) -> #krate::PyResult<()> {
#(#module_items::DEF.add_to_module(module)?;)*
#pymodule_init
Ok(())
}

/// This autogenerated function is called by the python interpreter when importing
/// the module.
#[export_name = #pyinit_symbol]
pub unsafe extern "C" fn __pyo3_init() -> *mut #krate::ffi::PyObject {
DEF.module_init()
}
}
})
}

/// Generates the function that is called by the python interpreter to initialize the native
/// module
pub fn pymodule_impl(
pub fn pymodule_function_impl(
fnname: &Ident,
options: PyModuleOptions,
doc: PythonDoc,
Expand Down
42 changes: 24 additions & 18 deletions pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@ use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use pyo3_macros_backend::{
build_derive_from_pyobject, build_py_class, build_py_enum, build_py_function, build_py_methods,
get_doc, process_functions_in_module, pymodule_impl, PyClassArgs, PyClassMethodsType,
PyFunctionOptions, PyModuleOptions,
get_doc, process_functions_in_module, pymodule_function_impl, pymodule_module_impl,
PyClassArgs, PyClassMethodsType, PyFunctionOptions,
PyModuleOptions
};
use quote::quote;
use syn::{parse::Nothing, parse_macro_input};
Expand All @@ -35,25 +36,30 @@ use syn::{parse::Nothing, parse_macro_input};
pub fn pymodule(args: TokenStream, input: TokenStream) -> TokenStream {
parse_macro_input!(args as Nothing);

let mut ast = parse_macro_input!(input as syn::ItemFn);
let options = match PyModuleOptions::from_attrs(&mut ast.attrs) {
Ok(options) => options,
Err(e) => return e.into_compile_error().into(),
};

if let Err(err) = process_functions_in_module(&options, &mut ast) {
return err.into_compile_error().into();
}
if let Ok(module) = syn::parse(input.clone()) {
pymodule_module_impl(module)
.unwrap_or_compile_error()
.into()
} else {
let mut ast = parse_macro_input!(input as syn::ItemFn);
let options = match PyModuleOptions::from_attrs(&mut ast.attrs) {
Ok(options) => options,
Err(e) => return e.into_compile_error().into(),
};

let doc = get_doc(&ast.attrs, None);
if let Err(err) = process_functions_in_module(&options, &mut ast) {
return err.into_compile_error().into();
}

let expanded = pymodule_impl(&ast.sig.ident, options, doc, &ast.vis);
let doc = get_doc(&ast.attrs, None);

quote!(
#ast
#expanded
)
.into()
let expanded = pymodule_function_impl(&ast.sig.ident, options, doc, &ast.vis);
quote!(
#ast
#expanded
)
.into()
}
}

/// A proc macro used to implement Python's [dunder methods][1].
Expand Down
65 changes: 34 additions & 31 deletions pytests/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::wrap_pymodule;

pub mod buf_and_str;
pub mod datetime;
Expand All @@ -14,35 +12,40 @@ pub mod pyfunctions;
pub mod subclassing;

#[pymodule]
fn pyo3_pytests(py: Python<'_>, m: &PyModule) -> PyResult<()> {
#[cfg(not(Py_LIMITED_API))]
m.add_wrapped(wrap_pymodule!(buf_and_str::buf_and_str))?;
#[cfg(not(Py_LIMITED_API))]
m.add_wrapped(wrap_pymodule!(datetime::datetime))?;
m.add_wrapped(wrap_pymodule!(dict_iter::dict_iter))?;
m.add_wrapped(wrap_pymodule!(misc::misc))?;
m.add_wrapped(wrap_pymodule!(objstore::objstore))?;
m.add_wrapped(wrap_pymodule!(othermod::othermod))?;
m.add_wrapped(wrap_pymodule!(path::path))?;
m.add_wrapped(wrap_pymodule!(pyclasses::pyclasses))?;
m.add_wrapped(wrap_pymodule!(pyfunctions::pyfunctions))?;
m.add_wrapped(wrap_pymodule!(subclassing::subclassing))?;
mod pyo3_pytests {
use pyo3::types::{PyDict, PyModule};
use pyo3::PyResult;

// Inserting to sys.modules allows importing submodules nicely from Python
// e.g. import pyo3_pytests.buf_and_str as bas
#[pyo3]
use {
// #[cfg(not(Py_LIMITED_API))]
crate::buf_and_str::buf_and_str,
// #[cfg(not(Py_LIMITED_API))]
crate::datetime::datetime,
crate::dict_iter::dict_iter,
crate::misc::misc,
crate::objstore::objstore,
crate::othermod::othermod,
crate::path::path,
crate::pyclasses::pyclasses,
crate::pyfunctions::pyfunctions,
crate::subclassing::subclassing,
};

let sys = PyModule::import(py, "sys")?;
let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?;
sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?;
sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?;
sys_modules.set_item("pyo3_pytests.dict_iter", m.getattr("dict_iter")?)?;
sys_modules.set_item("pyo3_pytests.misc", m.getattr("misc")?)?;
sys_modules.set_item("pyo3_pytests.objstore", m.getattr("objstore")?)?;
sys_modules.set_item("pyo3_pytests.othermod", m.getattr("othermod")?)?;
sys_modules.set_item("pyo3_pytests.path", m.getattr("path")?)?;
sys_modules.set_item("pyo3_pytests.pyclasses", m.getattr("pyclasses")?)?;
sys_modules.set_item("pyo3_pytests.pyfunctions", m.getattr("pyfunctions")?)?;
sys_modules.set_item("pyo3_pytests.subclassing", m.getattr("subclassing")?)?;

Ok(())
#[pymodule_init]
fn init(m: &PyModule) -> PyResult<()> {
let sys = PyModule::import(m.py(), "sys")?;
let sys_modules: &PyDict = sys.getattr("modules")?.downcast()?;
sys_modules.set_item("pyo3_pytests.buf_and_str", m.getattr("buf_and_str")?)?;
sys_modules.set_item("pyo3_pytests.datetime", m.getattr("datetime")?)?;
sys_modules.set_item("pyo3_pytests.dict_iter", m.getattr("dict_iter")?)?;
sys_modules.set_item("pyo3_pytests.misc", m.getattr("misc")?)?;
sys_modules.set_item("pyo3_pytests.objstore", m.getattr("objstore")?)?;
sys_modules.set_item("pyo3_pytests.othermod", m.getattr("othermod")?)?;
sys_modules.set_item("pyo3_pytests.path", m.getattr("path")?)?;
sys_modules.set_item("pyo3_pytests.pyclasses", m.getattr("pyclasses")?)?;
sys_modules.set_item("pyo3_pytests.pyfunctions", m.getattr("pyfunctions")?)?;
sys_modules.set_item("pyo3_pytests.subclassing", m.getattr("subclassing")?)?;
Ok(())
}
}
4 changes: 4 additions & 0 deletions src/impl_/pymodule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ impl ModuleDef {
}),
)
}

pub fn add_to_module(&'static self, module: &PyModule) -> PyResult<()> {
module.add_object(self.make_module(module.py())?)
}
}

#[cfg(test)]
Expand Down
12 changes: 11 additions & 1 deletion src/types/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// based on Daniel Grunwald's https://github.com/dgrunwald/rust-cpython

use crate::callback::IntoPyCallbackOutput;
use crate::err::{PyErr, PyResult};
use crate::err::{self, PyErr, PyResult};
use crate::exceptions;
use crate::ffi;
use crate::pyclass::PyClass;
Expand Down Expand Up @@ -252,6 +252,16 @@ impl PyModule {
self.setattr(name, value.into_py(self.py()))
}

pub(crate) fn add_object(&self, value: PyObject) -> PyResult<()> {
let py = self.py();
let attr_name = value.getattr(py, "__name__")?;

unsafe {
let ret = ffi::PyObject_SetAttr(self.as_ptr(), attr_name.as_ptr(), value.as_ptr());
err::error_on_minusone(py, ret)
}
}

/// Adds a new class to the module.
///
/// Notice that this method does not take an argument.
Expand Down
32 changes: 32 additions & 0 deletions tests/test_module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,3 +443,35 @@ fn test_module_doc_hidden() {
py_assert!(py, m, "m.__doc__ == ''");
})
}

/// A module written using declarative syntax.
#[pymodule]
mod declarative_module {

#[pyo3]
use super::module_with_functions;
}

#[test]
fn test_declarative_module() {
Python::with_gil(|py| {
let m = pyo3::wrap_pymodule!(declarative_module)(py).into_ref(py);
py_assert!(
py,
m,
"m.__doc__ == 'A module written using declarative syntax.'"
);

let submodule = m.getattr("module_with_functions").unwrap();
assert_eq!(
submodule
.getattr("no_parameters")
.unwrap()
.call0()
.unwrap()
.extract::<i32>()
.unwrap(),
42
);
})
}

0 comments on commit c08703d

Please sign in to comment.