From 7168309464e4ff9eeea9839436121f90bacbbe63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Tue, 25 Aug 2020 00:00:12 +0200 Subject: [PATCH 1/7] Derive FromPyObject --- pyo3-derive-backend/src/frompy.rs | 428 ++++++++++++++++++++++++++++++ pyo3-derive-backend/src/lib.rs | 2 + pyo3cls/src/lib.rs | 14 +- src/prelude.rs | 2 +- tests/test_frompyobject.rs | 308 +++++++++++++++++++++ 5 files changed, 751 insertions(+), 3 deletions(-) create mode 100644 pyo3-derive-backend/src/frompy.rs create mode 100644 tests/test_frompyobject.rs diff --git a/pyo3-derive-backend/src/frompy.rs b/pyo3-derive-backend/src/frompy.rs new file mode 100644 index 00000000000..e60a0e4541e --- /dev/null +++ b/pyo3-derive-backend/src/frompy.rs @@ -0,0 +1,428 @@ +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::token::Paren; +use syn::{ + parse_quote, Attribute, DataEnum, DeriveInput, Expr, ExprCall, Fields, Ident, PatTuple, Result, + Variant, +}; + +/// Describes derivation input of an enum. +#[derive(Debug)] +struct Enum<'a> { + enum_ident: &'a Ident, + vars: Vec>, +} + +impl<'a> Enum<'a> { + /// Construct a new enum representation. + /// + /// `data_enum` is the `syn` representation of the input enum, `ident` is the + /// `Identifier` of the enum. + fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result { + if data_enum.variants.is_empty() { + return Err(syn::Error::new_spanned( + &data_enum.variants, + "Cannot derive FromPyObject for empty enum.", + )); + } + let vars = data_enum + .variants + .iter() + .map(Container::from_variant) + .collect::>>()?; + + Ok(Enum { + enum_ident: ident, + vars, + }) + } + + /// Build derivation body for enums. + fn derive_enum(&self) -> TokenStream { + let mut var_extracts = Vec::new(); + let mut error_names = String::new(); + for (i, var) in self.vars.iter().enumerate() { + let ext = match &var.style { + Style::Struct(tups) => self.build_struct_variant(tups, var.ident), + Style::StructNewtype(ident) => { + self.build_transparent_variant(var.ident, Some(ident)) + } + Style::Tuple(len) => self.build_tuple_variant(var.ident, *len), + Style::TupleNewtype => self.build_transparent_variant(var.ident, None), + }; + var_extracts.push(ext); + error_names.push_str(&var.err_name); + if i < self.vars.len() - 1 { + error_names.push_str(", "); + } + } + quote!( + #(#var_extracts)* + let type_name = obj.get_type().name(); + let from = obj + .repr() + .map(|s| format!("{} ({})", s.to_string_lossy(), type_name)) + .unwrap_or_else(|_| type_name.to_string()); + let err_msg = format!("Can't convert {} to {}", from, #error_names); + Err(::pyo3::exceptions::PyTypeError::py_err(err_msg)) + ) + } + + /// Build match for tuple struct variant. + fn build_tuple_variant(&self, var_ident: &Ident, len: usize) -> TokenStream { + let enum_ident = self.enum_ident; + let mut ext: Punctuated = Punctuated::new(); + let mut fields: Punctuated = Punctuated::new(); + let mut field_pats = PatTuple { + attrs: vec![], + paren_token: Paren::default(), + elems: Default::default(), + }; + for i in 0..len { + ext.push(parse_quote!(slice[#i].extract())); + let ident = Ident::new(&format!("_field{}", i), Span::call_site()); + field_pats.elems.push(parse_quote!(Ok(#ident))); + fields.push(ident); + } + + quote!( + match <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj) { + Ok(s) => { + if s.len() == #len { + let slice = s.as_slice(); + if let (#field_pats) = (#ext) { + return Ok(#enum_ident::#var_ident(#fields)) + } + } + }, + Err(_) => {} + } + ) + } + + /// Build match for transparent enum variants. + fn build_transparent_variant( + &self, + var_ident: &Ident, + field_ident: Option<&Ident>, + ) -> TokenStream { + let enum_ident = self.enum_ident; + if let Some(ident) = field_ident { + quote!( + if let Ok(#ident) = obj.extract() { + return Ok(#enum_ident::#var_ident{#ident}) + } + ) + } else { + quote!( + if let Ok(inner) = obj.extract() { + return Ok(#enum_ident::#var_ident(inner)) + } + ) + } + } + + /// Build match for struct variant with named fields. + fn build_struct_variant( + &self, + tups: &[(&'a Ident, ExprCall)], + var_ident: &Ident, + ) -> TokenStream { + let enum_ident = self.enum_ident; + let mut field_pats = PatTuple { + attrs: vec![], + paren_token: Paren::default(), + elems: Default::default(), + }; + let mut fields: Punctuated = Punctuated::new(); + let mut ext: Punctuated = Punctuated::new(); + for (ident, ext_fn) in tups { + field_pats.elems.push(parse_quote!(Ok(#ident))); + fields.push(parse_quote!(#ident)); + ext.push(parse_quote!(obj.#ext_fn.and_then(|o| o.extract()))); + } + quote!(if let #field_pats = #ext { + return Ok(#enum_ident::#var_ident{#fields}); + }) + } +} + +/// Container Style +/// +/// Covers Structs, Tuplestructs and corresponding Newtypes. +#[derive(Clone, Debug)] +enum Style<'a> { + /// Struct Container, e.g. `struct Foo { a: String }` + /// + /// Variant contains the list of field identifiers and the corresponding extraction call. + Struct(Vec<(&'a Ident, ExprCall)>), + /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }` + /// + /// The field specified by the identifier is extracted directly from the object. + StructNewtype(&'a Ident), + /// Tuple struct, e.g. `struct Foo(String)`. + /// + /// Fields are extracted from a tuple. + Tuple(usize), + /// Tuple newtype, e.g. `#[transparent] struct Foo(String)` + /// + /// The wrapped field is directly extracted from the object. + TupleNewtype, +} + +/// Data container +/// +/// Either describes a struct or an enum variant. +#[derive(Debug)] +struct Container<'a> { + ident: &'a Ident, + style: Style<'a>, + err_name: String, +} + +impl<'a> Container<'a> { + /// Construct a container from an enum Variant. + /// + /// Fails if the variant has no fields or incompatible attributes. + fn from_variant(var: &'a Variant) -> Result { + Self::new(&var.fields, &var.ident, &var.attrs) + } + + /// Construct a container based on fields, identifier and attributes. + /// + /// Fails if the variant has no fields or incompatible attributes. + fn new(fields: &'a Fields, ident: &'a Ident, attrs: &'a [Attribute]) -> Result { + let transparent = attrs.iter().any(|a| a.path.is_ident("transparent")); + if transparent { + Self::check_transparent_len(fields)?; + } + let style = match fields { + Fields::Unnamed(unnamed) => { + if transparent { + Style::TupleNewtype + } else { + Style::Tuple(unnamed.unnamed.len()) + } + } + Fields::Named(named) => { + if transparent { + let field = named + .named + .iter() + .next() + .expect("Check for len 1 is done above"); + let ident = field + .ident + .as_ref() + .expect("Named fields should have identifiers"); + Style::StructNewtype(ident) + } else { + let mut fields = Vec::new(); + for field in named.named.iter() { + let ident = field + .ident + .as_ref() + .expect("Named fields should have identifiers"); + fields.push((ident, ext_fn(&field.attrs, ident)?)) + } + Style::Struct(fields) + } + } + Fields::Unit => { + return Err(syn::Error::new_spanned( + &fields, + "Cannot derive FromPyObject for Unit structs and variants", + )) + } + }; + let err_name = maybe_renamed_err(&attrs)? + .map(|s| s.value()) + .unwrap_or_else(|| ident.to_string()); + + let v = Container { + ident: &ident, + style, + err_name, + }; + Ok(v) + } + + /// Build derivation body for a struct. + fn derive_struct(&self) -> TokenStream { + match &self.style { + Style::StructNewtype(ident) => self.build_newtype_struct(Some(&ident)), + Style::TupleNewtype => self.build_newtype_struct(None), + Style::Tuple(len) => self.build_tuple_struct(*len), + Style::Struct(tups) => self.build_struct(tups), + } + } + + fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream { + if let Some(ident) = field_ident { + quote!( + Ok(Self{#ident: obj.extract()?}) + ) + } else { + quote!(Ok(Self(obj.extract()?))) + } + } + + fn build_tuple_struct(&self, len: usize) -> TokenStream { + let mut fields: Punctuated = Punctuated::new(); + for i in 0..len { + fields.push(quote!(slice[#i].extract()?)); + } + quote!( + let s = <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj)?; + let seq_len = s.len(); + if seq_len != #len { + let msg = format!( + "Expected tuple of length {}, but got length {}.", + #len, + seq_len + ); + return Err(::pyo3::exceptions::PyValueError::py_err(msg)) + } + let slice = s.as_slice(); + Ok(Self(#fields)) + ) + } + + fn build_struct(&self, tups: &[(&Ident, syn::ExprCall)]) -> TokenStream { + let mut fields: Punctuated = Punctuated::new(); + for (ident, ext_fn) in tups { + fields.push(quote!(#ident: obj.#ext_fn?.extract()?)); + } + quote!(Ok(Self{#fields})) + } + + fn check_transparent_len(fields: &Fields) -> Result<()> { + if fields.len() != 1 { + return Err(syn::Error::new_spanned( + fields, + "Transparent structs and variants can only have 1 field", + )); + } + Ok(()) + } +} + +/// Get the extraction function that's called on the input object. +/// +/// Valid arguments are `get_item`, `get_attr` which are called with the +/// stringified field identifier or a function call on `PyAny`, e.g. `get_attr("attr")` +fn ext_fn(attrs: &[Attribute], field_ident: &Ident) -> Result { + let attr = if let Some(attr) = attrs.iter().find(|a| a.path.is_ident("extract")) { + attr + } else { + return Ok(parse_quote!(getattr(stringify!(#field_ident)))); + }; + if let Ok(ident) = attr.parse_args::() { + if ident != "getattr" && ident != "get_item" { + Err(syn::Error::new_spanned( + ident, + "Only get_item and getattr are valid for extraction.", + )) + } else { + let arg = field_ident.to_string(); + Ok(parse_quote!(#ident(#arg))) + } + } else if let Ok(call) = attr.parse_args() { + Ok(call) + } else { + Err(syn::Error::new_spanned( + attr, + "Only get_item and getattr are valid for extraction,\ + both can be passed with or without an argument, e.g. \ + #[extract(getattr(\"attr\")] and #[extract(getattr)]", + )) + } +} + +/// Returns the name of the variant for the error message if no variants match. +fn maybe_renamed_err(attrs: &[syn::Attribute]) -> Result> { + for attr in attrs { + if !attr.path.is_ident("rename_err") { + continue; + } + let attr = attr.parse_meta()?; + if let syn::Meta::NameValue(nv) = &attr { + match &nv.lit { + syn::Lit::Str(s) => { + return Ok(Some(s.clone())); + } + _ => { + return Err(syn::Error::new_spanned( + attr, + "rename_err attribute must be string literal: #[rename_err=\"Name\"]", + )) + } + } + } + } + Ok(None) +} + +fn verify_and_get_lifetime(generics: &syn::Generics) -> Result> { + let lifetimes = generics.lifetimes().collect::>(); + if lifetimes.len() > 1 { + return Err(syn::Error::new_spanned( + &generics, + "Only a single lifetime parameter can be specified.", + )); + } + Ok(lifetimes.into_iter().next()) +} + +/// Derive FromPyObject for enums and structs. +/// +/// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier +/// * At least one field, in case of `#[transparent]`, exactly one field +/// * At least one variant for enums. +/// * Fields of input structs and enums must implement `FromPyObject` +/// * Derivation for structs with generic fields like `struct Foo(T)` +/// adds `T: FromPyObject` on the derived implementation. +pub fn build_derive_from_pyobject(tokens: &mut DeriveInput) -> Result { + let mut trait_generics = tokens.generics.clone(); + let generics = &tokens.generics; + let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? { + lt.clone() + } else { + trait_generics.params.push(parse_quote!('source)); + parse_quote!('source) + }; + let mut where_clause: syn::WhereClause = parse_quote!(where); + for param in generics.type_params() { + let gen_ident = ¶m.ident; + where_clause + .predicates + .push(parse_quote!(#gen_ident: FromPyObject<#lt_param>)) + } + let derives = match &tokens.data { + syn::Data::Enum(en) => { + let en = Enum::new(en, &tokens.ident)?; + en.derive_enum() + } + syn::Data::Struct(st) => { + let st = Container::new(&st.fields, &tokens.ident, &tokens.attrs)?; + st.derive_struct() + } + _ => { + return Err(syn::Error::new_spanned( + tokens, + "FromPyObject can only be derived for structs and enums.", + )) + } + }; + + let ident = &tokens.ident; + Ok(quote!( + #[automatically_derived] + impl#trait_generics ::pyo3::FromPyObject<#lt_param> for #ident#generics #where_clause { + fn extract(obj: &#lt_param ::pyo3::PyAny) -> ::pyo3::PyResult { + #derives + } + } + )) +} diff --git a/pyo3-derive-backend/src/lib.rs b/pyo3-derive-backend/src/lib.rs index 5695adf42e8..78db37368fb 100644 --- a/pyo3-derive-backend/src/lib.rs +++ b/pyo3-derive-backend/src/lib.rs @@ -4,6 +4,7 @@ #![recursion_limit = "1024"] mod defs; +mod frompy; mod konst; mod method; mod module; @@ -15,6 +16,7 @@ mod pymethod; mod pyproto; mod utils; +pub use frompy::build_derive_from_pyobject; pub use module::{add_fn_to_module, process_functions_in_module, py_init}; pub use pyclass::{build_py_class, PyClassArgs}; pub use pyfunction::{build_py_function, PyFunctionAttr}; diff --git a/pyo3cls/src/lib.rs b/pyo3cls/src/lib.rs index 795423cbff3..2377ec03bbe 100644 --- a/pyo3cls/src/lib.rs +++ b/pyo3cls/src/lib.rs @@ -5,8 +5,8 @@ extern crate proc_macro; use proc_macro::TokenStream; use pyo3_derive_backend::{ - build_py_class, build_py_function, build_py_methods, build_py_proto, get_doc, - process_functions_in_module, py_init, PyClassArgs, PyFunctionAttr, + build_derive_from_pyobject, build_py_class, build_py_function, build_py_methods, + build_py_proto, get_doc, process_functions_in_module, py_init, PyClassArgs, PyFunctionAttr, }; use quote::quote; use syn::parse_macro_input; @@ -91,3 +91,13 @@ pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream { ) .into() } + +#[proc_macro_derive(FromPyObject, attributes(transparent, extract, rename_err))] +pub fn derive_from_py_object(item: TokenStream) -> TokenStream { + let mut ast = parse_macro_input!(item as syn::DeriveInput); + let expanded = build_derive_from_pyobject(&mut ast).unwrap_or_else(|e| e.to_compile_error()); + quote!( + #expanded + ) + .into() +} diff --git a/src/prelude.rs b/src/prelude.rs index e3bc5f0b8fe..8046096f7f2 100644 --- a/src/prelude.rs +++ b/src/prelude.rs @@ -20,4 +20,4 @@ pub use crate::{FromPyObject, IntoPy, IntoPyPointer, PyTryFrom, PyTryInto, ToPyO // PyModule is only part of the prelude because we need it for the pymodule function pub use crate::types::{PyAny, PyModule}; #[cfg(feature = "macros")] -pub use pyo3cls::{pyclass, pyfunction, pymethods, pymodule, pyproto}; +pub use pyo3cls::{pyclass, pyfunction, pymethods, pymodule, pyproto, FromPyObject}; diff --git a/tests/test_frompyobject.rs b/tests/test_frompyobject.rs new file mode 100644 index 00000000000..b1bf0cddb47 --- /dev/null +++ b/tests/test_frompyobject.rs @@ -0,0 +1,308 @@ +use pyo3::exceptions::PyValueError; +use pyo3::prelude::*; +use pyo3::types::{PyDict, PyString, PyTuple}; +use pyo3::{PyErrValue, PyMappingProtocol}; + +#[macro_use] +mod common; + +#[derive(Debug, FromPyObject)] +pub struct A<'a> { + #[extract(getattr)] + s: String, + #[extract(get_item)] + t: &'a PyString, + #[extract(getattr("foo"))] + p: &'a PyAny, +} + +#[pyclass] +pub struct PyA { + #[pyo3(get)] + s: String, + #[pyo3(get)] + foo: Option, +} + +#[pyproto] +impl PyMappingProtocol for PyA { + fn __getitem__(&self, key: String) -> pyo3::PyResult { + if key == "t" { + Ok("bar".into()) + } else { + Err(PyValueError::py_err("Failed")) + } + } +} + +#[test] +fn test_named_fields_struct() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let pya = PyA { + s: "foo".into(), + foo: None, + }; + let py_c = Py::new(py, pya).unwrap(); + let a: A = FromPyObject::extract(py_c.as_ref(py)).expect("Failed to extract A from PyA"); + assert_eq!(a.s, "foo"); + assert_eq!(a.t.to_string_lossy(), "bar"); + assert!(a.p.is_none()); +} + +#[derive(Debug, FromPyObject)] +#[transparent] +pub struct B { + test: String, +} + +#[test] +fn test_transparent_named_field_struct() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let test = "test".into_py(py); + let b: B = FromPyObject::extract(test.as_ref(py)).expect("Failed to extract B from String"); + assert_eq!(b.test, "test"); + let test: PyObject = 1.into_py(py); + let b = B::extract(test.as_ref(py)); + assert!(b.is_err()) +} + +#[derive(Debug, FromPyObject)] +#[transparent] +pub struct D { + test: T, +} + +#[test] +fn test_generic_transparent_named_field_struct() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let test = "test".into_py(py); + let d: D = + D::extract(test.as_ref(py)).expect("Failed to extract D from String"); + assert_eq!(d.test, "test"); + let test = 1usize.into_py(py); + let d: D = D::extract(test.as_ref(py)).expect("Failed to extract D from String"); + assert_eq!(d.test, 1); +} + +#[derive(Debug, FromPyObject)] +pub struct E { + test: T, + test2: T2, +} + +#[pyclass] +pub struct PyE { + #[pyo3(get)] + test: String, + #[pyo3(get)] + test2: usize, +} + +#[test] +fn test_generic_named_fields_struct() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let pye = PyE { + test: "test".into(), + test2: 2, + } + .into_py(py); + + let e: E = + E::extract(pye.as_ref(py)).expect("Failed to extract E from PyE"); + assert_eq!(e.test, "test"); + assert_eq!(e.test2, 2); + let e = E::::extract(pye.as_ref(py)); + assert!(e.is_err()); +} + +#[derive(Debug, FromPyObject)] +pub struct C { + #[extract(getattr("test"))] + test: String, +} + +#[test] +fn test_named_field_with_ext_fn() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let pyc = PyE { + test: "foo".into(), + test2: 0, + } + .into_py(py); + let c = C::extract(pyc.as_ref(py)).expect("Failed to extract C from PyE"); + assert_eq!(c.test, "foo"); +} + +#[derive(FromPyObject)] +pub struct Tuple(String, usize); + +#[test] +fn test_tuple_struct() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let tup = PyTuple::new(py, &[1.into_py(py), "test".into_py(py)]); + let tup = Tuple::extract(tup.as_ref()); + assert!(tup.is_err()); + let tup = PyTuple::new(py, &["test".into_py(py), 1.into_py(py)]); + let tup = Tuple::extract(tup.as_ref()).expect("Failed to extract Tuple from PyTuple"); + assert_eq!(tup.0, "test"); + assert_eq!(tup.1, 1); +} + +#[derive(FromPyObject)] +pub struct TransparentTuple(String); + +#[test] +fn test_transparent_tuple_struct() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let tup = PyTuple::new(py, &[1.into_py(py)]); + let tup = TransparentTuple::extract(tup.as_ref()); + assert!(tup.is_err()); + let tup = PyTuple::new(py, &["test".into_py(py)]); + let tup = TransparentTuple::extract(tup.as_ref()) + .expect("Failed to extract TransparentTuple from PyTuple"); + assert_eq!(tup.0, "test"); +} + +#[derive(Debug, FromPyObject)] +pub enum Foo<'a> { + TupleVar(usize, String), + StructVar { + test: &'a PyString, + }, + #[transparent] + TransparentTuple(usize), + #[transparent] + TransparentStructVar { + a: Option, + }, + StructVarGetAttrArg { + #[extract(getattr("bla"))] + a: bool, + }, + StructWithGetItem { + #[extract(get_item)] + a: String, + }, + StructWithGetItemArg { + #[extract(get_item("foo"))] + a: String, + }, + #[transparent] + CatchAll(&'a PyAny), +} + +#[pyclass] +pub struct PyBool { + #[pyo3(get)] + bla: bool, +} + +#[test] +fn test_enum() { + let gil = Python::acquire_gil(); + let py = gil.python(); + + let tup = PyTuple::new(py, &[1.into_py(py), "test".into_py(py)]); + let f = Foo::extract(tup.as_ref()).expect("Failed to extract Foo from tuple"); + match f { + Foo::TupleVar(test, test2) => { + assert_eq!(test, 1); + assert_eq!(test2, "test"); + } + _ => panic!("Expected extracting Foo::TupleVar, got {:?}", f), + } + + let pye = PyE { + test: "foo".into(), + test2: 0, + } + .into_py(py); + let f = Foo::extract(pye.as_ref(py)).expect("Failed to extract Foo from PyE"); + match f { + Foo::StructVar { test } => assert_eq!(test.to_string_lossy(), "foo"), + _ => panic!("Expected extracting Foo::StructVar, got {:?}", f), + } + + let int: PyObject = 1.into_py(py); + let f = Foo::extract(int.as_ref(py)).expect("Failed to extract Foo from int"); + match f { + Foo::TransparentTuple(test) => assert_eq!(test, 1), + _ => panic!("Expected extracting Foo::TransparentTuple, got {:?}", f), + } + let none = py.None(); + let f = Foo::extract(none.as_ref(py)).expect("Failed to extract Foo from int"); + match f { + Foo::TransparentStructVar { a } => assert!(a.is_none()), + _ => panic!("Expected extracting Foo::TransparentStructVar, got {:?}", f), + } + + let pybool = PyBool { bla: true }.into_py(py); + let f = Foo::extract(pybool.as_ref(py)).expect("Failed to extract Foo from PyBool"); + match f { + Foo::StructVarGetAttrArg { a } => assert!(a), + _ => panic!("Expected extracting Foo::StructVarGetAttrArg, got {:?}", f), + } + + let dict = PyDict::new(py); + dict.set_item("a", "test").expect("Failed to set item"); + let f = Foo::extract(dict.as_ref()).expect("Failed to extract Foo from dict"); + match f { + Foo::StructWithGetItem { a } => assert_eq!(a, "test"), + _ => panic!("Expected extracting Foo::StructWithGetItem, got {:?}", f), + } + + let dict = PyDict::new(py); + dict.set_item("foo", "test").expect("Failed to set item"); + let f = Foo::extract(dict.as_ref()).expect("Failed to extract Foo from dict"); + match f { + Foo::StructWithGetItemArg { a } => assert_eq!(a, "test"), + _ => panic!("Expected extracting Foo::StructWithGetItemArg, got {:?}", f), + } + + let dict = PyDict::new(py); + let f = Foo::extract(dict.as_ref()).expect("Failed to extract Foo from dict"); + match f { + Foo::CatchAll(any) => { + let d = <&PyDict>::extract(any).expect("Expected pydict"); + assert!(d.is_empty()); + } + _ => panic!("Expected extracting Foo::CatchAll, got {:?}", f), + } +} + +#[derive(FromPyObject)] +pub enum Bar { + #[rename_err = "str"] + A(String), + #[rename_err = "uint"] + B(usize), + #[rename_err = "int"] + C(isize), +} + +#[test] +fn test_err_rename() { + let gil = Python::acquire_gil(); + let py = gil.python(); + let dict = PyDict::new(py); + let f = Bar::extract(dict.as_ref()); + assert!(f.is_err()); + match f { + Ok(_) => {} + Err(e) => match e.pvalue { + PyErrValue::ToObject(to) => { + let o = to.to_object(py); + let s = String::extract(o.as_ref(py)).expect("Err val is not a string"); + assert_eq!(s, "Can't convert {} (dict) to str, uint, int") + } + _ => panic!("Expected PyErrValue::ToObject"), + }, + } +} From 60fe4925f531c673e189aedfe8d9886a55bd104a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Wed, 26 Aug 2020 22:13:14 +0200 Subject: [PATCH 2/7] '#[derive(FromPyObject)]` changes suggested by @davidwhewitt. --- CHANGELOG.md | 1 + pyo3-derive-backend/src/from_pyobject.rs | 496 +++++++++++++++++++++++ pyo3-derive-backend/src/frompy.rs | 428 ------------------- pyo3-derive-backend/src/lib.rs | 4 +- pyo3cls/src/lib.rs | 6 +- tests/test_frompyobject.rs | 39 +- 6 files changed, 522 insertions(+), 452 deletions(-) create mode 100644 pyo3-derive-backend/src/from_pyobject.rs delete mode 100644 pyo3-derive-backend/src/frompy.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index e8e2dd1aa2f..3b3e0592681 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. - Add `Py::as_ref` and `Py::into_ref`. [#1098](https://github.com/PyO3/pyo3/pull/1098) - Add optional implementations of `ToPyObject`, `IntoPy`, and `FromPyObject` for [hashbrown](https://crates.io/crates/hashbrown)'s `HashMap` and `HashSet` types. The `hashbrown` feature must be enabled for these implementations to be built. [#1114](https://github.com/PyO3/pyo3/pull/1114/) - Allow other `Result` types when using `#[pyfunction]`. [#1106](https://github.com/PyO3/pyo3/issues/1106). +- Add `#[derive(FromPyObject)]` macro for enums and structs. [#1065](https://github.com/PyO3/pyo3/pull/1065) ### Changed - Exception types have been renamed from e.g. `RuntimeError` to `PyRuntimeError`, and are now only accessible by `&T` or `Py` similar to other Python-native types. The old names continue to exist but are deprecated. [#1024](https://github.com/PyO3/pyo3/pull/1024) diff --git a/pyo3-derive-backend/src/from_pyobject.rs b/pyo3-derive-backend/src/from_pyobject.rs new file mode 100644 index 00000000000..1f222c4c5a2 --- /dev/null +++ b/pyo3-derive-backend/src/from_pyobject.rs @@ -0,0 +1,496 @@ +use proc_macro2::{Span, TokenStream}; +use quote::quote; +use syn::punctuated::Punctuated; +use syn::{parse_quote, Attribute, DataEnum, DeriveInput, ExprCall, Fields, Ident, Result}; + +/// Describes derivation input of an enum. +#[derive(Debug)] +struct Enum<'a> { + enum_ident: &'a Ident, + variants: Vec>, +} + +impl<'a> Enum<'a> { + /// Construct a new enum representation. + /// + /// `data_enum` is the `syn` representation of the input enum, `ident` is the + /// `Identifier` of the enum. + fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result { + if data_enum.variants.is_empty() { + return Err(syn::Error::new_spanned( + &data_enum.variants, + "Cannot derive FromPyObject for empty enum.", + )); + } + let vars = data_enum + .variants + .iter() + .map(|variant| { + let attrs = ContainerAttribute::parse_attrs(&variant.attrs)?; + let var_ident = &variant.ident; + Container::new( + &variant.fields, + parse_quote!(#ident::#var_ident), + attrs, + true, + ) + }) + .collect::>>()?; + + Ok(Enum { + enum_ident: ident, + variants: vars, + }) + } + + /// Build derivation body for enums. + fn build(&self) -> TokenStream { + let mut var_extracts = Vec::new(); + let mut error_names = String::new(); + for (i, var) in self.variants.iter().enumerate() { + let struct_derive = var.build(); + let ext = quote!( + let maybe_ret = || -> ::pyo3::PyResult { + #struct_derive + }(); + if maybe_ret.is_ok() { + return maybe_ret + } + ); + + var_extracts.push(ext); + error_names.push_str(&var.err_name); + if i < self.variants.len() - 1 { + error_names.push_str(", "); + } + } + quote!( + #(#var_extracts)* + let type_name = obj.get_type().name(); + let from = obj + .repr() + .map(|s| format!("{} ({})", s.to_string_lossy(), type_name)) + .unwrap_or_else(|_| type_name.to_string()); + let err_msg = format!("Can't convert {} to {}", from, #error_names); + Err(::pyo3::exceptions::PyTypeError::py_err(err_msg)) + ) + } +} + +/// Container Style +/// +/// Covers Structs, Tuplestructs and corresponding Newtypes. +#[derive(Debug)] +enum ContainerType<'a> { + /// Struct Container, e.g. `struct Foo { a: String }` + /// + /// Variant contains the list of field identifiers and the corresponding extraction call. + Struct(Vec<(&'a Ident, FieldAttribute)>), + /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }` + /// + /// The field specified by the identifier is extracted directly from the object. + StructNewtype(&'a Ident), + /// Tuple struct, e.g. `struct Foo(String)`. + /// + /// Fields are extracted from a tuple. + Tuple(usize), + /// Tuple newtype, e.g. `#[transparent] struct Foo(String)` + /// + /// The wrapped field is directly extracted from the object. + TupleNewtype, +} + +/// Data container +/// +/// Either describes a struct or an enum variant. +#[derive(Debug)] +struct Container<'a> { + path: syn::Path, + ty: ContainerType<'a>, + err_name: String, + is_enum_variant: bool, +} + +impl<'a> Container<'a> { + /// Construct a container based on fields, identifier and attributes. + /// + /// Fails if the variant has no fields or incompatible attributes. + fn new( + fields: &'a Fields, + path: syn::Path, + attrs: Vec, + is_enum_variant: bool, + ) -> Result { + let transparent = attrs.iter().any(ContainerAttribute::transparent); + if transparent { + Self::check_transparent_len(fields)?; + } + let style = match (fields, transparent) { + (Fields::Unnamed(_), true) => ContainerType::TupleNewtype, + (Fields::Unnamed(unnamed), false) => ContainerType::Tuple(unnamed.unnamed.len()), + (Fields::Named(named), true) => { + let field = named + .named + .iter() + .next() + .expect("Check for len 1 is done above"); + let ident = field + .ident + .as_ref() + .expect("Named fields should have identifiers"); + ContainerType::StructNewtype(ident) + } + (Fields::Named(named), false) => { + let mut fields = Vec::new(); + for field in named.named.iter() { + let ident = field + .ident + .as_ref() + .expect("Named fields should have identifiers"); + let attr = FieldAttribute::parse_attrs(&field.attrs)? + .unwrap_or_else(|| FieldAttribute::Ident(parse_quote!(getattr))); + fields.push((ident, attr)) + } + ContainerType::Struct(fields) + } + (Fields::Unit, _) => { + return Err(syn::Error::new_spanned( + &fields, + "Cannot derive FromPyObject for Unit structs and variants", + )) + } + }; + let err_name = attrs + .iter() + .find_map(|a| a.annotation()) + .cloned() + .unwrap_or_else(|| path.segments.last().unwrap().ident.to_string()); + + let v = Container { + path, + ty: style, + err_name, + is_enum_variant, + }; + Ok(v) + } + + fn verify_struct_container_attrs(attrs: &'a [ContainerAttribute]) -> Result<()> { + for attr in attrs { + match attr { + ContainerAttribute::Transparent => continue, + ContainerAttribute::ErrorAnnotation(_) => { + return Err(syn::Error::new( + Span::call_site(), + "Annotating error messages for structs is \ + not supported. Remove the annotation attribute.", + )) + } + } + } + Ok(()) + } + + /// Build derivation body for a struct. + fn build(&self) -> TokenStream { + match &self.ty { + ContainerType::StructNewtype(ident) => self.build_newtype_struct(Some(&ident)), + ContainerType::TupleNewtype => self.build_newtype_struct(None), + ContainerType::Tuple(len) => self.build_tuple_struct(*len), + ContainerType::Struct(tups) => self.build_struct(tups), + } + } + + fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream { + let self_ty = &self.path; + if let Some(ident) = field_ident { + quote!( + Ok(#self_ty{#ident: obj.extract()?}) + ) + } else { + quote!(Ok(#self_ty(obj.extract()?))) + } + } + + fn build_tuple_struct(&self, len: usize) -> TokenStream { + let self_ty = &self.path; + let mut fields: Punctuated = Punctuated::new(); + for i in 0..len { + fields.push(quote!(slice[#i].extract()?)); + } + let msg = if self.is_enum_variant { + quote!(format!( + "Expected tuple of length {}, but got length {}.", + #len, + s.len() + )) + } else { + quote!("") + }; + quote!( + let s = <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj)?; + if s.len() != #len { + return Err(::pyo3::exceptions::PyValueError::py_err(#msg)) + } + let slice = s.as_slice(); + Ok(#self_ty(#fields)) + ) + } + + fn build_struct(&self, tups: &[(&Ident, FieldAttribute)]) -> TokenStream { + let self_ty = &self.path; + let mut fields: Punctuated = Punctuated::new(); + for (ident, attr) in tups { + let ext_fn = match attr { + FieldAttribute::IdentWithArg(expr) => quote!(#expr), + FieldAttribute::Ident(meth) => { + let arg = ident.to_string(); + quote!(#meth(#arg)) + } + }; + fields.push(quote!(#ident: obj.#ext_fn?.extract()?)); + } + quote!(Ok(#self_ty{#fields})) + } + + fn check_transparent_len(fields: &Fields) -> Result<()> { + if fields.len() != 1 { + return Err(syn::Error::new_spanned( + fields, + "Transparent structs and variants can only have 1 field", + )); + } + Ok(()) + } +} + +/// Attributes for deriving FromPyObject scoped on containers. +#[derive(Clone, Debug, PartialEq)] +enum ContainerAttribute { + /// Treat the Container as a Wrapper, directly extract its fields from the input object. + Transparent, + /// Change the name of an enum variant in the generated error message. + ErrorAnnotation(String), +} + +impl ContainerAttribute { + /// Return whether this attribute is `Transparent` + fn transparent(&self) -> bool { + match self { + ContainerAttribute::Transparent => true, + _ => false, + } + } + + /// Convenience method to access `ErrorAnnotation`. + fn annotation(&self) -> Option<&String> { + match self { + ContainerAttribute::ErrorAnnotation(s) => Some(s), + _ => None, + } + } + + /// Parse valid container arguments + /// + /// Fails if any are invalid. + fn parse_attrs(value: &[Attribute]) -> Result> { + let mut attrs = Vec::new(); + let list = get_pyo3_meta_list(value)?; + for meta in list.nested { + if let syn::NestedMeta::Meta(metaitem) = &meta { + match metaitem { + syn::Meta::Path(p) if p.is_ident("transparent") => { + attrs.push(ContainerAttribute::Transparent) + } + syn::Meta::NameValue(nv) if nv.path.is_ident("annotation") => { + if let syn::Lit::Str(s) = &nv.lit { + attrs.push(ContainerAttribute::ErrorAnnotation(s.value())) + } else { + return Err(syn::Error::new_spanned( + &nv.lit, + "Expected string literal.", + )); + } + } + _ => (), + } + } else { + return Err(syn::Error::new_spanned( + meta, + "Unknown container attribute, expected `transparent` or \ + `annotation(\"err_name\")`", + )); + } + } + Ok(attrs) + } +} + +/// Attributes for deriving FromPyObject scoped on fields. +#[derive(Clone, Debug)] +enum FieldAttribute { + /// How a specific field should be extracted. + Ident(Ident), + IdentWithArg(ExprCall), +} + +impl FieldAttribute { + /// Extract the field attribute. + /// + /// Currently fails if more than 1 attribute is passed in `pyo3` + fn parse_attrs(attrs: &[Attribute]) -> Result> { + let list = get_pyo3_meta_list(attrs)?; + if list.nested.len() > 1 { + return Err(syn::Error::new_spanned( + list, + "Only one of `item`, `attribute` can be provided, possibly as \ + a key-value pair: `attribute = \"name\"`.", + )); + } + let meta = if let Some(attr) = list.nested.first() { + attr + } else { + return Ok(None); + }; + if let syn::NestedMeta::Meta(metaitem) = meta { + let path = metaitem.path(); + let ident = Self::check_valid_ident(path)?; + match metaitem { + syn::Meta::NameValue(nv) => Self::get_ident_with_arg(ident, &nv.lit).map(Some), + syn::Meta::Path(_) => Ok(Some(FieldAttribute::Ident(parse_quote!(#ident)))), + _ => Err(syn::Error::new_spanned( + metaitem, + "`item` or `attribute` need to be passed alone or as key-value \ + pairs, e.g. `attribute = \"name\"`.", + )), + } + } else { + Err(syn::Error::new_spanned(meta, "Unexpected literal.")) + } + } + + /// Verify the attribute path and return it if it is valid. + fn check_valid_ident(path: &syn::Path) -> Result { + if path.is_ident("item") { + Ok(parse_quote!(get_item)) + } else if path.is_ident("attribute") { + Ok(parse_quote!(getattr)) + } else { + Err(syn::Error::new_spanned( + path, + "Expected `item` or `attribute`", + )) + } + } + + /// Try to build `IdentWithArg` based on identifier and literal. + fn get_ident_with_arg(ident: Ident, lit: &syn::Lit) -> Result { + if ident == "getattr" { + if let syn::Lit::Str(s) = lit { + return Ok(FieldAttribute::IdentWithArg(parse_quote!(#ident(#s)))); + } else { + return Err(syn::Error::new_spanned(lit, "Expected string literal.")); + } + } + if ident == "get_item" { + return Ok(FieldAttribute::IdentWithArg(parse_quote!(#ident(#lit)))); + } + + // path is already checked in the `parse_attrs` loop, returning the error here anyways. + Err(syn::Error::new_spanned( + ident, + "Expected `item` or `attribute`.", + )) + } +} + +/// Extract pyo3 metalist, flattens multiple lists into a single one. +fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result { + let mut list: Punctuated = Punctuated::new(); + for value in attrs { + match value.parse_meta()? { + syn::Meta::List(ml) if value.path.is_ident("pyo3") => { + for meta in ml.nested { + list.push(meta); + } + } + _ => { + return Err(syn::Error::new_spanned( + value, + "Expected `pyo3()` attribute.", + )) + } + } + } + Ok(syn::MetaList { + path: parse_quote!(pyo3), + paren_token: syn::token::Paren::default(), + nested: list, + }) +} + +fn verify_and_get_lifetime(generics: &syn::Generics) -> Result> { + let lifetimes = generics.lifetimes().collect::>(); + if lifetimes.len() > 1 { + return Err(syn::Error::new_spanned( + &generics, + "FromPyObject can only be derived with at most one lifetime parameter.", + )); + } + Ok(lifetimes.into_iter().next()) +} + +/// Derive FromPyObject for enums and structs. +/// +/// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier +/// * At least one field, in case of `#[transparent]`, exactly one field +/// * At least one variant for enums. +/// * Fields of input structs and enums must implement `FromPyObject` +/// * Derivation for structs with generic fields like `struct Foo(T)` +/// adds `T: FromPyObject` on the derived implementation. +pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result { + let mut trait_generics = tokens.generics.clone(); + let generics = &tokens.generics; + let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? { + lt.clone() + } else { + trait_generics.params.push(parse_quote!('source)); + parse_quote!('source) + }; + let mut where_clause: syn::WhereClause = parse_quote!(where); + for param in generics.type_params() { + let gen_ident = ¶m.ident; + where_clause + .predicates + .push(parse_quote!(#gen_ident: FromPyObject<#lt_param>)) + } + let derives = match &tokens.data { + syn::Data::Enum(en) => { + let en = Enum::new(en, &tokens.ident)?; + en.build() + } + syn::Data::Struct(st) => { + let attrs = ContainerAttribute::parse_attrs(&tokens.attrs)?; + Container::verify_struct_container_attrs(&attrs)?; + let ident = &tokens.ident; + let st = Container::new(&st.fields, parse_quote!(#ident), attrs, false)?; + st.build() + } + _ => { + return Err(syn::Error::new_spanned( + tokens, + "FromPyObject can only be derived for structs and enums.", + )) + } + }; + + let ident = &tokens.ident; + Ok(quote!( + #[automatically_derived] + impl#trait_generics ::pyo3::FromPyObject<#lt_param> for #ident#generics #where_clause { + fn extract(obj: &#lt_param ::pyo3::PyAny) -> ::pyo3::PyResult { + #derives + } + } + )) +} diff --git a/pyo3-derive-backend/src/frompy.rs b/pyo3-derive-backend/src/frompy.rs deleted file mode 100644 index e60a0e4541e..00000000000 --- a/pyo3-derive-backend/src/frompy.rs +++ /dev/null @@ -1,428 +0,0 @@ -use proc_macro2::{Span, TokenStream}; -use quote::quote; -use syn::punctuated::Punctuated; -use syn::token::Paren; -use syn::{ - parse_quote, Attribute, DataEnum, DeriveInput, Expr, ExprCall, Fields, Ident, PatTuple, Result, - Variant, -}; - -/// Describes derivation input of an enum. -#[derive(Debug)] -struct Enum<'a> { - enum_ident: &'a Ident, - vars: Vec>, -} - -impl<'a> Enum<'a> { - /// Construct a new enum representation. - /// - /// `data_enum` is the `syn` representation of the input enum, `ident` is the - /// `Identifier` of the enum. - fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result { - if data_enum.variants.is_empty() { - return Err(syn::Error::new_spanned( - &data_enum.variants, - "Cannot derive FromPyObject for empty enum.", - )); - } - let vars = data_enum - .variants - .iter() - .map(Container::from_variant) - .collect::>>()?; - - Ok(Enum { - enum_ident: ident, - vars, - }) - } - - /// Build derivation body for enums. - fn derive_enum(&self) -> TokenStream { - let mut var_extracts = Vec::new(); - let mut error_names = String::new(); - for (i, var) in self.vars.iter().enumerate() { - let ext = match &var.style { - Style::Struct(tups) => self.build_struct_variant(tups, var.ident), - Style::StructNewtype(ident) => { - self.build_transparent_variant(var.ident, Some(ident)) - } - Style::Tuple(len) => self.build_tuple_variant(var.ident, *len), - Style::TupleNewtype => self.build_transparent_variant(var.ident, None), - }; - var_extracts.push(ext); - error_names.push_str(&var.err_name); - if i < self.vars.len() - 1 { - error_names.push_str(", "); - } - } - quote!( - #(#var_extracts)* - let type_name = obj.get_type().name(); - let from = obj - .repr() - .map(|s| format!("{} ({})", s.to_string_lossy(), type_name)) - .unwrap_or_else(|_| type_name.to_string()); - let err_msg = format!("Can't convert {} to {}", from, #error_names); - Err(::pyo3::exceptions::PyTypeError::py_err(err_msg)) - ) - } - - /// Build match for tuple struct variant. - fn build_tuple_variant(&self, var_ident: &Ident, len: usize) -> TokenStream { - let enum_ident = self.enum_ident; - let mut ext: Punctuated = Punctuated::new(); - let mut fields: Punctuated = Punctuated::new(); - let mut field_pats = PatTuple { - attrs: vec![], - paren_token: Paren::default(), - elems: Default::default(), - }; - for i in 0..len { - ext.push(parse_quote!(slice[#i].extract())); - let ident = Ident::new(&format!("_field{}", i), Span::call_site()); - field_pats.elems.push(parse_quote!(Ok(#ident))); - fields.push(ident); - } - - quote!( - match <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj) { - Ok(s) => { - if s.len() == #len { - let slice = s.as_slice(); - if let (#field_pats) = (#ext) { - return Ok(#enum_ident::#var_ident(#fields)) - } - } - }, - Err(_) => {} - } - ) - } - - /// Build match for transparent enum variants. - fn build_transparent_variant( - &self, - var_ident: &Ident, - field_ident: Option<&Ident>, - ) -> TokenStream { - let enum_ident = self.enum_ident; - if let Some(ident) = field_ident { - quote!( - if let Ok(#ident) = obj.extract() { - return Ok(#enum_ident::#var_ident{#ident}) - } - ) - } else { - quote!( - if let Ok(inner) = obj.extract() { - return Ok(#enum_ident::#var_ident(inner)) - } - ) - } - } - - /// Build match for struct variant with named fields. - fn build_struct_variant( - &self, - tups: &[(&'a Ident, ExprCall)], - var_ident: &Ident, - ) -> TokenStream { - let enum_ident = self.enum_ident; - let mut field_pats = PatTuple { - attrs: vec![], - paren_token: Paren::default(), - elems: Default::default(), - }; - let mut fields: Punctuated = Punctuated::new(); - let mut ext: Punctuated = Punctuated::new(); - for (ident, ext_fn) in tups { - field_pats.elems.push(parse_quote!(Ok(#ident))); - fields.push(parse_quote!(#ident)); - ext.push(parse_quote!(obj.#ext_fn.and_then(|o| o.extract()))); - } - quote!(if let #field_pats = #ext { - return Ok(#enum_ident::#var_ident{#fields}); - }) - } -} - -/// Container Style -/// -/// Covers Structs, Tuplestructs and corresponding Newtypes. -#[derive(Clone, Debug)] -enum Style<'a> { - /// Struct Container, e.g. `struct Foo { a: String }` - /// - /// Variant contains the list of field identifiers and the corresponding extraction call. - Struct(Vec<(&'a Ident, ExprCall)>), - /// Newtype struct container, e.g. `#[transparent] struct Foo { a: String }` - /// - /// The field specified by the identifier is extracted directly from the object. - StructNewtype(&'a Ident), - /// Tuple struct, e.g. `struct Foo(String)`. - /// - /// Fields are extracted from a tuple. - Tuple(usize), - /// Tuple newtype, e.g. `#[transparent] struct Foo(String)` - /// - /// The wrapped field is directly extracted from the object. - TupleNewtype, -} - -/// Data container -/// -/// Either describes a struct or an enum variant. -#[derive(Debug)] -struct Container<'a> { - ident: &'a Ident, - style: Style<'a>, - err_name: String, -} - -impl<'a> Container<'a> { - /// Construct a container from an enum Variant. - /// - /// Fails if the variant has no fields or incompatible attributes. - fn from_variant(var: &'a Variant) -> Result { - Self::new(&var.fields, &var.ident, &var.attrs) - } - - /// Construct a container based on fields, identifier and attributes. - /// - /// Fails if the variant has no fields or incompatible attributes. - fn new(fields: &'a Fields, ident: &'a Ident, attrs: &'a [Attribute]) -> Result { - let transparent = attrs.iter().any(|a| a.path.is_ident("transparent")); - if transparent { - Self::check_transparent_len(fields)?; - } - let style = match fields { - Fields::Unnamed(unnamed) => { - if transparent { - Style::TupleNewtype - } else { - Style::Tuple(unnamed.unnamed.len()) - } - } - Fields::Named(named) => { - if transparent { - let field = named - .named - .iter() - .next() - .expect("Check for len 1 is done above"); - let ident = field - .ident - .as_ref() - .expect("Named fields should have identifiers"); - Style::StructNewtype(ident) - } else { - let mut fields = Vec::new(); - for field in named.named.iter() { - let ident = field - .ident - .as_ref() - .expect("Named fields should have identifiers"); - fields.push((ident, ext_fn(&field.attrs, ident)?)) - } - Style::Struct(fields) - } - } - Fields::Unit => { - return Err(syn::Error::new_spanned( - &fields, - "Cannot derive FromPyObject for Unit structs and variants", - )) - } - }; - let err_name = maybe_renamed_err(&attrs)? - .map(|s| s.value()) - .unwrap_or_else(|| ident.to_string()); - - let v = Container { - ident: &ident, - style, - err_name, - }; - Ok(v) - } - - /// Build derivation body for a struct. - fn derive_struct(&self) -> TokenStream { - match &self.style { - Style::StructNewtype(ident) => self.build_newtype_struct(Some(&ident)), - Style::TupleNewtype => self.build_newtype_struct(None), - Style::Tuple(len) => self.build_tuple_struct(*len), - Style::Struct(tups) => self.build_struct(tups), - } - } - - fn build_newtype_struct(&self, field_ident: Option<&Ident>) -> TokenStream { - if let Some(ident) = field_ident { - quote!( - Ok(Self{#ident: obj.extract()?}) - ) - } else { - quote!(Ok(Self(obj.extract()?))) - } - } - - fn build_tuple_struct(&self, len: usize) -> TokenStream { - let mut fields: Punctuated = Punctuated::new(); - for i in 0..len { - fields.push(quote!(slice[#i].extract()?)); - } - quote!( - let s = <::pyo3::types::PyTuple as ::pyo3::conversion::PyTryFrom>::try_from(obj)?; - let seq_len = s.len(); - if seq_len != #len { - let msg = format!( - "Expected tuple of length {}, but got length {}.", - #len, - seq_len - ); - return Err(::pyo3::exceptions::PyValueError::py_err(msg)) - } - let slice = s.as_slice(); - Ok(Self(#fields)) - ) - } - - fn build_struct(&self, tups: &[(&Ident, syn::ExprCall)]) -> TokenStream { - let mut fields: Punctuated = Punctuated::new(); - for (ident, ext_fn) in tups { - fields.push(quote!(#ident: obj.#ext_fn?.extract()?)); - } - quote!(Ok(Self{#fields})) - } - - fn check_transparent_len(fields: &Fields) -> Result<()> { - if fields.len() != 1 { - return Err(syn::Error::new_spanned( - fields, - "Transparent structs and variants can only have 1 field", - )); - } - Ok(()) - } -} - -/// Get the extraction function that's called on the input object. -/// -/// Valid arguments are `get_item`, `get_attr` which are called with the -/// stringified field identifier or a function call on `PyAny`, e.g. `get_attr("attr")` -fn ext_fn(attrs: &[Attribute], field_ident: &Ident) -> Result { - let attr = if let Some(attr) = attrs.iter().find(|a| a.path.is_ident("extract")) { - attr - } else { - return Ok(parse_quote!(getattr(stringify!(#field_ident)))); - }; - if let Ok(ident) = attr.parse_args::() { - if ident != "getattr" && ident != "get_item" { - Err(syn::Error::new_spanned( - ident, - "Only get_item and getattr are valid for extraction.", - )) - } else { - let arg = field_ident.to_string(); - Ok(parse_quote!(#ident(#arg))) - } - } else if let Ok(call) = attr.parse_args() { - Ok(call) - } else { - Err(syn::Error::new_spanned( - attr, - "Only get_item and getattr are valid for extraction,\ - both can be passed with or without an argument, e.g. \ - #[extract(getattr(\"attr\")] and #[extract(getattr)]", - )) - } -} - -/// Returns the name of the variant for the error message if no variants match. -fn maybe_renamed_err(attrs: &[syn::Attribute]) -> Result> { - for attr in attrs { - if !attr.path.is_ident("rename_err") { - continue; - } - let attr = attr.parse_meta()?; - if let syn::Meta::NameValue(nv) = &attr { - match &nv.lit { - syn::Lit::Str(s) => { - return Ok(Some(s.clone())); - } - _ => { - return Err(syn::Error::new_spanned( - attr, - "rename_err attribute must be string literal: #[rename_err=\"Name\"]", - )) - } - } - } - } - Ok(None) -} - -fn verify_and_get_lifetime(generics: &syn::Generics) -> Result> { - let lifetimes = generics.lifetimes().collect::>(); - if lifetimes.len() > 1 { - return Err(syn::Error::new_spanned( - &generics, - "Only a single lifetime parameter can be specified.", - )); - } - Ok(lifetimes.into_iter().next()) -} - -/// Derive FromPyObject for enums and structs. -/// -/// * Max 1 lifetime specifier, will be tied to `FromPyObject`'s specifier -/// * At least one field, in case of `#[transparent]`, exactly one field -/// * At least one variant for enums. -/// * Fields of input structs and enums must implement `FromPyObject` -/// * Derivation for structs with generic fields like `struct Foo(T)` -/// adds `T: FromPyObject` on the derived implementation. -pub fn build_derive_from_pyobject(tokens: &mut DeriveInput) -> Result { - let mut trait_generics = tokens.generics.clone(); - let generics = &tokens.generics; - let lt_param = if let Some(lt) = verify_and_get_lifetime(generics)? { - lt.clone() - } else { - trait_generics.params.push(parse_quote!('source)); - parse_quote!('source) - }; - let mut where_clause: syn::WhereClause = parse_quote!(where); - for param in generics.type_params() { - let gen_ident = ¶m.ident; - where_clause - .predicates - .push(parse_quote!(#gen_ident: FromPyObject<#lt_param>)) - } - let derives = match &tokens.data { - syn::Data::Enum(en) => { - let en = Enum::new(en, &tokens.ident)?; - en.derive_enum() - } - syn::Data::Struct(st) => { - let st = Container::new(&st.fields, &tokens.ident, &tokens.attrs)?; - st.derive_struct() - } - _ => { - return Err(syn::Error::new_spanned( - tokens, - "FromPyObject can only be derived for structs and enums.", - )) - } - }; - - let ident = &tokens.ident; - Ok(quote!( - #[automatically_derived] - impl#trait_generics ::pyo3::FromPyObject<#lt_param> for #ident#generics #where_clause { - fn extract(obj: &#lt_param ::pyo3::PyAny) -> ::pyo3::PyResult { - #derives - } - } - )) -} diff --git a/pyo3-derive-backend/src/lib.rs b/pyo3-derive-backend/src/lib.rs index 78db37368fb..2a943850588 100644 --- a/pyo3-derive-backend/src/lib.rs +++ b/pyo3-derive-backend/src/lib.rs @@ -4,7 +4,7 @@ #![recursion_limit = "1024"] mod defs; -mod frompy; +mod from_pyobject; mod konst; mod method; mod module; @@ -16,7 +16,7 @@ mod pymethod; mod pyproto; mod utils; -pub use frompy::build_derive_from_pyobject; +pub use from_pyobject::build_derive_from_pyobject; pub use module::{add_fn_to_module, process_functions_in_module, py_init}; pub use pyclass::{build_py_class, PyClassArgs}; pub use pyfunction::{build_py_function, PyFunctionAttr}; diff --git a/pyo3cls/src/lib.rs b/pyo3cls/src/lib.rs index 2377ec03bbe..bade299b750 100644 --- a/pyo3cls/src/lib.rs +++ b/pyo3cls/src/lib.rs @@ -92,10 +92,10 @@ pub fn pyfunction(attr: TokenStream, input: TokenStream) -> TokenStream { .into() } -#[proc_macro_derive(FromPyObject, attributes(transparent, extract, rename_err))] +#[proc_macro_derive(FromPyObject, attributes(pyo3, extract))] pub fn derive_from_py_object(item: TokenStream) -> TokenStream { - let mut ast = parse_macro_input!(item as syn::DeriveInput); - let expanded = build_derive_from_pyobject(&mut ast).unwrap_or_else(|e| e.to_compile_error()); + let ast = parse_macro_input!(item as syn::DeriveInput); + let expanded = build_derive_from_pyobject(&ast).unwrap_or_else(|e| e.to_compile_error()); quote!( #expanded ) diff --git a/tests/test_frompyobject.rs b/tests/test_frompyobject.rs index b1bf0cddb47..23b42a4ccd0 100644 --- a/tests/test_frompyobject.rs +++ b/tests/test_frompyobject.rs @@ -8,11 +8,11 @@ mod common; #[derive(Debug, FromPyObject)] pub struct A<'a> { - #[extract(getattr)] + #[pyo3(attribute)] s: String, - #[extract(get_item)] + #[pyo3(item)] t: &'a PyString, - #[extract(getattr("foo"))] + #[pyo3(attribute = "foo")] p: &'a PyAny, } @@ -51,7 +51,7 @@ fn test_named_fields_struct() { } #[derive(Debug, FromPyObject)] -#[transparent] +#[pyo3(transparent)] pub struct B { test: String, } @@ -69,7 +69,7 @@ fn test_transparent_named_field_struct() { } #[derive(Debug, FromPyObject)] -#[transparent] +#[pyo3(transparent)] pub struct D { test: T, } @@ -121,7 +121,7 @@ fn test_generic_named_fields_struct() { #[derive(Debug, FromPyObject)] pub struct C { - #[extract(getattr("test"))] + #[pyo3(attribute = "test")] test: String, } @@ -155,17 +155,18 @@ fn test_tuple_struct() { } #[derive(FromPyObject)] +#[pyo3(transparent)] pub struct TransparentTuple(String); #[test] fn test_transparent_tuple_struct() { let gil = Python::acquire_gil(); let py = gil.python(); - let tup = PyTuple::new(py, &[1.into_py(py)]); - let tup = TransparentTuple::extract(tup.as_ref()); + let tup: PyObject = 1.into_py(py); + let tup = TransparentTuple::extract(tup.as_ref(py)); assert!(tup.is_err()); - let tup = PyTuple::new(py, &["test".into_py(py)]); - let tup = TransparentTuple::extract(tup.as_ref()) + let test = "test".into_py(py); + let tup = TransparentTuple::extract(test.as_ref(py)) .expect("Failed to extract TransparentTuple from PyTuple"); assert_eq!(tup.0, "test"); } @@ -176,25 +177,25 @@ pub enum Foo<'a> { StructVar { test: &'a PyString, }, - #[transparent] + #[pyo3(transparent)] TransparentTuple(usize), - #[transparent] + #[pyo3(transparent)] TransparentStructVar { a: Option, }, StructVarGetAttrArg { - #[extract(getattr("bla"))] + #[pyo3(attribute = "bla")] a: bool, }, StructWithGetItem { - #[extract(get_item)] + #[pyo3(item)] a: String, }, StructWithGetItemArg { - #[extract(get_item("foo"))] + #[pyo3(item = "foo")] a: String, }, - #[transparent] + #[pyo3(transparent)] CatchAll(&'a PyAny), } @@ -279,11 +280,11 @@ fn test_enum() { #[derive(FromPyObject)] pub enum Bar { - #[rename_err = "str"] + #[pyo3(annotation = "str")] A(String), - #[rename_err = "uint"] + #[pyo3(annotation = "uint")] B(usize), - #[rename_err = "int"] + #[pyo3(annotation = "int", transparent)] C(isize), } From 7781bb78de1f599de72f1315664ff5bd28fde911 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Sat, 29 Aug 2020 10:40:05 +0200 Subject: [PATCH 3/7] Specify item key and attr name as arguments. --- pyo3-derive-backend/src/from_pyobject.rs | 135 ++++++++++++++--------- tests/test_frompyobject.rs | 8 +- 2 files changed, 86 insertions(+), 57 deletions(-) diff --git a/pyo3-derive-backend/src/from_pyobject.rs b/pyo3-derive-backend/src/from_pyobject.rs index 1f222c4c5a2..b24331d20dd 100644 --- a/pyo3-derive-backend/src/from_pyobject.rs +++ b/pyo3-derive-backend/src/from_pyobject.rs @@ -1,7 +1,7 @@ use proc_macro2::{Span, TokenStream}; use quote::quote; use syn::punctuated::Punctuated; -use syn::{parse_quote, Attribute, DataEnum, DeriveInput, ExprCall, Fields, Ident, Result}; +use syn::{parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Meta, Result}; /// Describes derivation input of an enum. #[derive(Debug)] @@ -148,7 +148,7 @@ impl<'a> Container<'a> { .as_ref() .expect("Named fields should have identifiers"); let attr = FieldAttribute::parse_attrs(&field.attrs)? - .unwrap_or_else(|| FieldAttribute::Ident(parse_quote!(getattr))); + .unwrap_or_else(|| FieldAttribute::GetAttr(None)); fields.push((ident, attr)) } ContainerType::Struct(fields) @@ -242,10 +242,19 @@ impl<'a> Container<'a> { let mut fields: Punctuated = Punctuated::new(); for (ident, attr) in tups { let ext_fn = match attr { - FieldAttribute::IdentWithArg(expr) => quote!(#expr), - FieldAttribute::Ident(meth) => { - let arg = ident.to_string(); - quote!(#meth(#arg)) + FieldAttribute::GetAttr(name) => { + if let Some(name) = name.as_ref() { + quote!(getattr(#name)) + } else { + quote!(getattr(stringify!(#ident))) + } + } + FieldAttribute::GetItem(key) => { + if let Some(key) = key.as_ref() { + quote!(get_item(#key)) + } else { + quote!(get_item(stringify!(#ident))) + } } }; fields.push(quote!(#ident: obj.#ext_fn?.extract()?)); @@ -329,9 +338,8 @@ impl ContainerAttribute { /// Attributes for deriving FromPyObject scoped on fields. #[derive(Clone, Debug)] enum FieldAttribute { - /// How a specific field should be extracted. - Ident(Ident), - IdentWithArg(ExprCall), + GetItem(Option), + GetAttr(Option), } impl FieldAttribute { @@ -340,68 +348,89 @@ impl FieldAttribute { /// Currently fails if more than 1 attribute is passed in `pyo3` fn parse_attrs(attrs: &[Attribute]) -> Result> { let list = get_pyo3_meta_list(attrs)?; + if list.nested.is_empty() { + return Ok(None); + } if list.nested.len() > 1 { return Err(syn::Error::new_spanned( list, - "Only one of `item`, `attribute` can be provided, possibly as \ - a key-value pair: `attribute = \"name\"`.", + "Only one of `item`, `attribute` can be provided, possibly with an \ + additional argument: `item(\"key\")` or `attribute(\"name\").", )); } - let meta = if let Some(attr) = list.nested.first() { - attr - } else { - return Ok(None); - }; - if let syn::NestedMeta::Meta(metaitem) = meta { - let path = metaitem.path(); - let ident = Self::check_valid_ident(path)?; - match metaitem { - syn::Meta::NameValue(nv) => Self::get_ident_with_arg(ident, &nv.lit).map(Some), - syn::Meta::Path(_) => Ok(Some(FieldAttribute::Ident(parse_quote!(#ident)))), - _ => Err(syn::Error::new_spanned( - metaitem, - "`item` or `attribute` need to be passed alone or as key-value \ - pairs, e.g. `attribute = \"name\"`.", - )), + let metaitem = list.nested.into_iter().next().unwrap(); + let meta = match metaitem { + syn::NestedMeta::Meta(meta) => meta, + syn::NestedMeta::Lit(lit) => { + return Err(syn::Error::new_spanned( + lit, + "Expected `attribute` or `item`, not a literal.", + )) } - } else { - Err(syn::Error::new_spanned(meta, "Unexpected literal.")) - } - } - - /// Verify the attribute path and return it if it is valid. - fn check_valid_ident(path: &syn::Path) -> Result { - if path.is_ident("item") { - Ok(parse_quote!(get_item)) - } else if path.is_ident("attribute") { - Ok(parse_quote!(getattr)) + }; + let path = meta.path(); + if path.is_ident("attribute") { + Ok(Some(FieldAttribute::GetAttr(Self::attribute_arg(meta)?))) + } else if path.is_ident("item") { + Ok(Some(FieldAttribute::GetItem(Self::item_arg(meta)?))) } else { Err(syn::Error::new_spanned( - path, - "Expected `item` or `attribute`", + meta, + "Expected `attribute` or `item`.", )) } } - /// Try to build `IdentWithArg` based on identifier and literal. - fn get_ident_with_arg(ident: Ident, lit: &syn::Lit) -> Result { - if ident == "getattr" { - if let syn::Lit::Str(s) = lit { - return Ok(FieldAttribute::IdentWithArg(parse_quote!(#ident(#s)))); - } else { - return Err(syn::Error::new_spanned(lit, "Expected string literal.")); + fn attribute_arg(meta: syn::Meta) -> syn::Result> { + let arg_list = match meta { + syn::Meta::List(list) => list, + syn::Meta::Path(_) => return Ok(None), + Meta::NameValue(nv) => { + let err_msg = "Expected a string literal or no argument: `pyo3(attribute(\"name\") or `pyo3(attribute)`"; + return Err(syn::Error::new_spanned(nv, err_msg)); } + }; + if arg_list.nested.len() != 1 { + return Err(syn::Error::new_spanned( + arg_list, + "Expected a single string literal.", + )); } - if ident == "get_item" { - return Ok(FieldAttribute::IdentWithArg(parse_quote!(#ident(#lit)))); + let first = arg_list.nested.first().unwrap(); + if let syn::NestedMeta::Lit(lit) = first { + if let syn::Lit::Str(litstr) = lit { + return Ok(Some(parse_quote!(#litstr))); + } } - - // path is already checked in the `parse_attrs` loop, returning the error here anyways. Err(syn::Error::new_spanned( - ident, - "Expected `item` or `attribute`.", + first, + "Expected a single string literal.", )) } + + fn item_arg(meta: syn::Meta) -> syn::Result> { + let arg_list = match meta { + syn::Meta::List(list) => list, + syn::Meta::Path(_) => return Ok(None), + Meta::NameValue(nv) => { + return Err(syn::Error::new_spanned( + nv, + "Expected a literal or no argument: `pyo3(item(\"key\") or `pyo3(item)`", + )) + } + }; + if arg_list.nested.len() != 1 { + return Err(syn::Error::new_spanned( + arg_list, + "Expected a single literal.", + )); + } + let first = arg_list.nested.first().unwrap(); + if let syn::NestedMeta::Lit(lit) = first { + return Ok(Some(parse_quote!(#lit))); + } + Err(syn::Error::new_spanned(first, "Expected a literal.")) + } } /// Extract pyo3 metalist, flattens multiple lists into a single one. diff --git a/tests/test_frompyobject.rs b/tests/test_frompyobject.rs index 23b42a4ccd0..ca9104cf987 100644 --- a/tests/test_frompyobject.rs +++ b/tests/test_frompyobject.rs @@ -12,7 +12,7 @@ pub struct A<'a> { s: String, #[pyo3(item)] t: &'a PyString, - #[pyo3(attribute = "foo")] + #[pyo3(attribute("foo"))] p: &'a PyAny, } @@ -121,7 +121,7 @@ fn test_generic_named_fields_struct() { #[derive(Debug, FromPyObject)] pub struct C { - #[pyo3(attribute = "test")] + #[pyo3(attribute("test"))] test: String, } @@ -184,7 +184,7 @@ pub enum Foo<'a> { a: Option, }, StructVarGetAttrArg { - #[pyo3(attribute = "bla")] + #[pyo3(attribute("bla"))] a: bool, }, StructWithGetItem { @@ -192,7 +192,7 @@ pub enum Foo<'a> { a: String, }, StructWithGetItemArg { - #[pyo3(item = "foo")] + #[pyo3(item("foo"))] a: String, }, #[pyo3(transparent)] From a8c5379eff27c407b27d34325cfe8df26805616e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Sun, 30 Aug 2020 12:54:13 +0200 Subject: [PATCH 4/7] Add compile fail tests for FromPyObject derives + some fixes. Fix some error messages and accidental passes. --- pyo3-derive-backend/src/from_pyobject.rs | 122 +++++++++------- tests/test_compile_error.rs | 1 + tests/ui/invalid_frompy_derive.rs | 156 ++++++++++++++++++++ tests/ui/invalid_frompy_derive.stderr | 173 +++++++++++++++++++++++ 4 files changed, 401 insertions(+), 51 deletions(-) create mode 100644 tests/ui/invalid_frompy_derive.rs create mode 100644 tests/ui/invalid_frompy_derive.stderr diff --git a/pyo3-derive-backend/src/from_pyobject.rs b/pyo3-derive-backend/src/from_pyobject.rs index b24331d20dd..b28a82579e1 100644 --- a/pyo3-derive-backend/src/from_pyobject.rs +++ b/pyo3-derive-backend/src/from_pyobject.rs @@ -1,6 +1,7 @@ use proc_macro2::{Span, TokenStream}; -use quote::quote; +use quote::{quote, ToTokens}; use syn::punctuated::Punctuated; +use syn::spanned::Spanned; use syn::{parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Meta, Result}; /// Describes derivation input of an enum. @@ -17,8 +18,8 @@ impl<'a> Enum<'a> { /// `Identifier` of the enum. fn new(data_enum: &'a DataEnum, ident: &'a Ident) -> Result { if data_enum.variants.is_empty() { - return Err(syn::Error::new_spanned( - &data_enum.variants, + return Err(spanned_err( + &ident, "Cannot derive FromPyObject for empty enum.", )); } @@ -121,6 +122,12 @@ impl<'a> Container<'a> { attrs: Vec, is_enum_variant: bool, ) -> Result { + if fields.is_empty() { + return Err(spanned_err( + fields, + "Cannot derive FromPyObject for empty structs and variants.", + )); + } let transparent = attrs.iter().any(ContainerAttribute::transparent); if transparent { Self::check_transparent_len(fields)?; @@ -154,10 +161,11 @@ impl<'a> Container<'a> { ContainerType::Struct(fields) } (Fields::Unit, _) => { - return Err(syn::Error::new_spanned( + // covered by length check above + return Err(spanned_err( &fields, "Cannot derive FromPyObject for Unit structs and variants", - )) + )); } }; let err_name = attrs @@ -175,16 +183,30 @@ impl<'a> Container<'a> { Ok(v) } - fn verify_struct_container_attrs(attrs: &'a [ContainerAttribute]) -> Result<()> { + fn verify_struct_container_attrs( + attrs: &'a [ContainerAttribute], + original: &[Attribute], + ) -> Result<()> { for attr in attrs { match attr { ContainerAttribute::Transparent => continue, ContainerAttribute::ErrorAnnotation(_) => { + let span = original + .iter() + .map(|a| a.span()) + .fold(None, |mut acc: Option, span| { + if let Some(all) = acc.as_mut() { + all.join(span) + } else { + Some(span) + } + }) + .unwrap_or_else(Span::call_site); return Err(syn::Error::new( - Span::call_site(), + span, "Annotating error messages for structs is \ not supported. Remove the annotation attribute.", - )) + )); } } } @@ -264,7 +286,7 @@ impl<'a> Container<'a> { fn check_transparent_len(fields: &Fields) -> Result<()> { if fields.len() != 1 { - return Err(syn::Error::new_spanned( + return Err(spanned_err( fields, "Transparent structs and variants can only have 1 field", )); @@ -315,16 +337,18 @@ impl ContainerAttribute { if let syn::Lit::Str(s) = &nv.lit { attrs.push(ContainerAttribute::ErrorAnnotation(s.value())) } else { - return Err(syn::Error::new_spanned( - &nv.lit, - "Expected string literal.", - )); + return Err(spanned_err(&nv.lit, "Expected string literal.")); } } - _ => (), + other => { + return Err(spanned_err( + other, + "Expected `transparent` or `annotation = \"name\"`", + )) + } } } else { - return Err(syn::Error::new_spanned( + return Err(spanned_err( meta, "Unknown container attribute, expected `transparent` or \ `annotation(\"err_name\")`", @@ -352,8 +376,8 @@ impl FieldAttribute { return Ok(None); } if list.nested.len() > 1 { - return Err(syn::Error::new_spanned( - list, + return Err(spanned_err( + list.nested, "Only one of `item`, `attribute` can be provided, possibly with an \ additional argument: `item(\"key\")` or `attribute(\"name\").", )); @@ -362,7 +386,7 @@ impl FieldAttribute { let meta = match metaitem { syn::NestedMeta::Meta(meta) => meta, syn::NestedMeta::Lit(lit) => { - return Err(syn::Error::new_spanned( + return Err(spanned_err( lit, "Expected `attribute` or `item`, not a literal.", )) @@ -374,10 +398,7 @@ impl FieldAttribute { } else if path.is_ident("item") { Ok(Some(FieldAttribute::GetItem(Self::item_arg(meta)?))) } else { - Err(syn::Error::new_spanned( - meta, - "Expected `attribute` or `item`.", - )) + Err(spanned_err(meta, "Expected `attribute` or `item`.")) } } @@ -387,25 +408,25 @@ impl FieldAttribute { syn::Meta::Path(_) => return Ok(None), Meta::NameValue(nv) => { let err_msg = "Expected a string literal or no argument: `pyo3(attribute(\"name\") or `pyo3(attribute)`"; - return Err(syn::Error::new_spanned(nv, err_msg)); + return Err(spanned_err(nv, err_msg)); } }; - if arg_list.nested.len() != 1 { - return Err(syn::Error::new_spanned( - arg_list, - "Expected a single string literal.", - )); + let arg_msg = "Expected a single string literal argument."; + if arg_list.nested.is_empty() { + return Err(spanned_err(arg_list, arg_msg)); + } else if arg_list.nested.len() > 1 { + return Err(spanned_err(arg_list.nested, arg_msg)); } let first = arg_list.nested.first().unwrap(); if let syn::NestedMeta::Lit(lit) = first { if let syn::Lit::Str(litstr) = lit { + if litstr.value().is_empty() { + return Err(spanned_err(litstr, "Attribute name cannot be empty.")); + } return Ok(Some(parse_quote!(#litstr))); } } - Err(syn::Error::new_spanned( - first, - "Expected a single string literal.", - )) + Err(spanned_err(first, arg_msg)) } fn item_arg(meta: syn::Meta) -> syn::Result> { @@ -413,26 +434,30 @@ impl FieldAttribute { syn::Meta::List(list) => list, syn::Meta::Path(_) => return Ok(None), Meta::NameValue(nv) => { - return Err(syn::Error::new_spanned( + return Err(spanned_err( nv, "Expected a literal or no argument: `pyo3(item(\"key\") or `pyo3(item)`", )) } }; - if arg_list.nested.len() != 1 { - return Err(syn::Error::new_spanned( - arg_list, - "Expected a single literal.", - )); + let arg_msg = "Expected a single literal argument."; + if arg_list.nested.is_empty() { + return Err(spanned_err(arg_list, arg_msg)); + } else if arg_list.nested.len() > 1 { + return Err(spanned_err(arg_list.nested, arg_msg)); } let first = arg_list.nested.first().unwrap(); if let syn::NestedMeta::Lit(lit) = first { return Ok(Some(parse_quote!(#lit))); } - Err(syn::Error::new_spanned(first, "Expected a literal.")) + Err(spanned_err(first, arg_msg)) } } +fn spanned_err(tokens: T, msg: &str) -> syn::Error { + syn::Error::new_spanned(tokens, msg) +} + /// Extract pyo3 metalist, flattens multiple lists into a single one. fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result { let mut list: Punctuated = Punctuated::new(); @@ -443,12 +468,7 @@ fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result { list.push(meta); } } - _ => { - return Err(syn::Error::new_spanned( - value, - "Expected `pyo3()` attribute.", - )) - } + _ => continue, } } Ok(syn::MetaList { @@ -461,9 +481,9 @@ fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result { fn verify_and_get_lifetime(generics: &syn::Generics) -> Result> { let lifetimes = generics.lifetimes().collect::>(); if lifetimes.len() > 1 { - return Err(syn::Error::new_spanned( + return Err(spanned_err( &generics, - "FromPyObject can only be derived with at most one lifetime parameter.", + "FromPyObject can be derived with at most one lifetime parameter.", )); } Ok(lifetimes.into_iter().next()) @@ -500,15 +520,15 @@ pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result { } syn::Data::Struct(st) => { let attrs = ContainerAttribute::parse_attrs(&tokens.attrs)?; - Container::verify_struct_container_attrs(&attrs)?; + Container::verify_struct_container_attrs(&attrs, &tokens.attrs)?; let ident = &tokens.ident; let st = Container::new(&st.fields, parse_quote!(#ident), attrs, false)?; st.build() } - _ => { - return Err(syn::Error::new_spanned( + syn::Data::Union(_) => { + return Err(spanned_err( tokens, - "FromPyObject can only be derived for structs and enums.", + "FromPyObject can not be derived for unions.", )) } }; diff --git a/tests/test_compile_error.rs b/tests/test_compile_error.rs index f61d72ecf2a..653a80d56b5 100644 --- a/tests/test_compile_error.rs +++ b/tests/test_compile_error.rs @@ -2,6 +2,7 @@ #[test] fn test_compile_errors() { let t = trybuild::TestCases::new(); + t.compile_fail("tests/ui/invalid_frompy_derive.rs"); t.compile_fail("tests/ui/invalid_macro_args.rs"); t.compile_fail("tests/ui/invalid_property_args.rs"); t.compile_fail("tests/ui/invalid_pyclass_args.rs"); diff --git a/tests/ui/invalid_frompy_derive.rs b/tests/ui/invalid_frompy_derive.rs new file mode 100644 index 00000000000..0a54aeca0fb --- /dev/null +++ b/tests/ui/invalid_frompy_derive.rs @@ -0,0 +1,156 @@ +use pyo3::prelude::FromPyObject; + +#[derive(FromPyObject)] +struct Foo(); + +#[derive(FromPyObject)] +struct Foo2 {} + +#[derive(FromPyObject)] +enum EmptyEnum {} + +#[derive(FromPyObject)] +enum EnumWithEmptyTupleVar { + EmptyTuple(), + Valid(String), +} + +#[derive(FromPyObject)] +enum EnumWithEmptyStructVar { + EmptyStruct {}, + Valid(String), +} + +#[derive(FromPyObject)] +#[pyo3(transparent)] +struct EmptyTransparentTup(); + +#[derive(FromPyObject)] +#[pyo3(transparent)] +struct EmptyTransparentStruct {} + +#[derive(FromPyObject)] +enum EnumWithTransparentEmptyTupleVar { + #[pyo3(transparent)] + EmptyTuple(), + Valid(String), +} + +#[derive(FromPyObject)] +enum EnumWithTransparentEmptyStructVar { + #[pyo3(transparent)] + EmptyStruct {}, + Valid(String), +} + +#[derive(FromPyObject)] +#[pyo3(transparent)] +struct TransparentTupTooManyFields(String, String); + +#[derive(FromPyObject)] +#[pyo3(transparent)] +struct TransparentStructTooManyFields { + foo: String, + bar: String, +} + +#[derive(FromPyObject)] +enum EnumWithTransparentTupleTooMany { + #[pyo3(transparent)] + EmptyTuple(String, String), + Valid(String), +} + +#[derive(FromPyObject)] +enum EnumWithTransparentStructTooMany { + #[pyo3(transparent)] + EmptyStruct { + foo: String, + bar: String, + }, + Valid(String), +} + +#[derive(FromPyObject)] +struct UnknownAttribute { + #[pyo3(attr)] + a: String, +} + +#[derive(FromPyObject)] +struct InvalidAttributeArg { + #[pyo3(attribute(1))] + a: String, +} + +#[derive(FromPyObject)] +struct TooManyAttributeArgs { + #[pyo3(attribute("a", "b"))] + a: String, +} + +#[derive(FromPyObject)] +struct EmptyAttributeArg { + #[pyo3(attribute(""))] + a: String, +} + +#[derive(FromPyObject)] +struct NoAttributeArg { + #[pyo3(attribute())] + a: String, +} + +#[derive(FromPyObject)] +struct TooManyitemArgs { + #[pyo3(item("a", "b"))] + a: String, +} + +#[derive(FromPyObject)] +struct NoItemArg { + #[pyo3(item())] + a: String, +} + +#[derive(FromPyObject)] +struct ItemAndAttribute { + #[pyo3(item, attribute)] + a: String, +} + +#[derive(FromPyObject)] +#[pyo3(unknown = "should not work")] +struct UnknownContainerAttr { + a: String, +} + +#[derive(FromPyObject)] +#[pyo3(annotation = "should not work")] +struct AnnotationOnStruct { + a: String, +} + +#[derive(FromPyObject)] +enum InvalidAnnotatedEnum { + #[pyo3(annotation = 1)] + Foo(String), +} + +#[derive(FromPyObject)] +enum TooManyLifetimes<'a, 'b> { + Foo(&'a str), + Bar(&'b str), +} + +#[derive(FromPyObject)] +union Union { + a: usize, +} + +#[derive(FromPyObject)] +enum UnitEnum { + Unit, +} + +fn main() {} diff --git a/tests/ui/invalid_frompy_derive.stderr b/tests/ui/invalid_frompy_derive.stderr new file mode 100644 index 00000000000..9c08ec65fba --- /dev/null +++ b/tests/ui/invalid_frompy_derive.stderr @@ -0,0 +1,173 @@ +error: Cannot derive FromPyObject for empty structs and variants. + --> $DIR/invalid_frompy_derive.rs:4:11 + | +4 | struct Foo(); + | ^^ + +error: Cannot derive FromPyObject for empty structs and variants. + --> $DIR/invalid_frompy_derive.rs:7:13 + | +7 | struct Foo2 {} + | ^^ + +error: Cannot derive FromPyObject for empty enum. + --> $DIR/invalid_frompy_derive.rs:10:6 + | +10 | enum EmptyEnum {} + | ^^^^^^^^^ + +error: Cannot derive FromPyObject for empty structs and variants. + --> $DIR/invalid_frompy_derive.rs:14:15 + | +14 | EmptyTuple(), + | ^^ + +error: Cannot derive FromPyObject for empty structs and variants. + --> $DIR/invalid_frompy_derive.rs:20:17 + | +20 | EmptyStruct {}, + | ^^ + +error: Cannot derive FromPyObject for empty structs and variants. + --> $DIR/invalid_frompy_derive.rs:26:27 + | +26 | struct EmptyTransparentTup(); + | ^^ + +error: Cannot derive FromPyObject for empty structs and variants. + --> $DIR/invalid_frompy_derive.rs:30:31 + | +30 | struct EmptyTransparentStruct {} + | ^^ + +error: Cannot derive FromPyObject for empty structs and variants. + --> $DIR/invalid_frompy_derive.rs:35:15 + | +35 | EmptyTuple(), + | ^^ + +error: Cannot derive FromPyObject for empty structs and variants. + --> $DIR/invalid_frompy_derive.rs:42:17 + | +42 | EmptyStruct {}, + | ^^ + +error: Transparent structs and variants can only have 1 field + --> $DIR/invalid_frompy_derive.rs:48:35 + | +48 | struct TransparentTupTooManyFields(String, String); + | ^^^^^^^^^^^^^^^^ + +error: Transparent structs and variants can only have 1 field + --> $DIR/invalid_frompy_derive.rs:52:39 + | +52 | struct TransparentStructTooManyFields { + | _______________________________________^ +53 | | foo: String, +54 | | bar: String, +55 | | } + | |_^ + +error: Transparent structs and variants can only have 1 field + --> $DIR/invalid_frompy_derive.rs:60:15 + | +60 | EmptyTuple(String, String), + | ^^^^^^^^^^^^^^^^ + +error: Transparent structs and variants can only have 1 field + --> $DIR/invalid_frompy_derive.rs:67:17 + | +67 | EmptyStruct { + | _________________^ +68 | | foo: String, +69 | | bar: String, +70 | | }, + | |_____^ + +error: Expected `attribute` or `item`. + --> $DIR/invalid_frompy_derive.rs:76:12 + | +76 | #[pyo3(attr)] + | ^^^^ + +error: Expected a single string literal argument. + --> $DIR/invalid_frompy_derive.rs:82:22 + | +82 | #[pyo3(attribute(1))] + | ^ + +error: Expected a single string literal argument. + --> $DIR/invalid_frompy_derive.rs:88:22 + | +88 | #[pyo3(attribute("a", "b"))] + | ^^^^^^^^ + +error: Attribute name cannot be empty. + --> $DIR/invalid_frompy_derive.rs:94:22 + | +94 | #[pyo3(attribute(""))] + | ^^ + +error: Expected a single string literal argument. + --> $DIR/invalid_frompy_derive.rs:100:12 + | +100 | #[pyo3(attribute())] + | ^^^^^^^^^^^ + +error: Expected a single literal argument. + --> $DIR/invalid_frompy_derive.rs:106:17 + | +106 | #[pyo3(item("a", "b"))] + | ^^^^^^^^ + +error: Expected a single literal argument. + --> $DIR/invalid_frompy_derive.rs:112:12 + | +112 | #[pyo3(item())] + | ^^^^^^ + +error: Only one of `item`, `attribute` can be provided, possibly with an additional argument: `item("key")` or `attribute("name"). + --> $DIR/invalid_frompy_derive.rs:118:12 + | +118 | #[pyo3(item, attribute)] + | ^^^^^^^^^^^^^^^ + +error: Expected `transparent` or `annotation = "name"` + --> $DIR/invalid_frompy_derive.rs:123:8 + | +123 | #[pyo3(unknown = "should not work")] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +error: Annotating error messages for structs is not supported. Remove the annotation attribute. + --> $DIR/invalid_frompy_derive.rs:129:1 + | +129 | #[pyo3(annotation = "should not work")] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +error: Expected string literal. + --> $DIR/invalid_frompy_derive.rs:136:25 + | +136 | #[pyo3(annotation = 1)] + | ^ + +error: FromPyObject can be derived with at most one lifetime parameter. + --> $DIR/invalid_frompy_derive.rs:141:22 + | +141 | enum TooManyLifetimes<'a, 'b> { + | ^^^^^^^^ + +error: FromPyObject can not be derived for unions. + --> $DIR/invalid_frompy_derive.rs:147:1 + | +147 | / union Union { +148 | | a: usize, +149 | | } + | |_^ + +error: Cannot derive FromPyObject for empty structs and variants. + --> $DIR/invalid_frompy_derive.rs:151:10 + | +151 | #[derive(FromPyObject)] + | ^^^^^^^^^^^^ + | + = note: this error originates in a derive macro (in Nightly builds, run with -Z macro-backtrace for more info) From 53a858c5c1a17494923610622bbacbdf4f2c1460 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Sun, 30 Aug 2020 12:55:15 +0200 Subject: [PATCH 5/7] Add documentation for FromPyObject derivation. --- guide/src/conversions.md | 141 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 141 insertions(+) diff --git a/guide/src/conversions.md b/guide/src/conversions.md index 41ddb6454b8..b861b6ca0d0 100644 --- a/guide/src/conversions.md +++ b/guide/src/conversions.md @@ -119,6 +119,147 @@ mutable references, you have to extract the PyO3 reference wrappers [`PyRef`] and [`PyRefMut`]. They work like the reference wrappers of `std::cell::RefCell` and ensure (at runtime) that Rust borrows are allowed. +#### Deriving [`FromPyObject`] + +[`FromPyObject`] can be automatically derived for many kinds of structs and enums +if the member types themselves implement `FromPyObject`. This even includes members +with a generic type `T: FromPyObject`. Derivation for empty enums, enum variants and +structs is not supported. + +#### Deriving [`FromPyObject`] for structs + +``` +use pyo3::prelude::*; + +#[derive(FromPyObject)] +struct RustyStruct { + my_string: String, +} +``` + +The derivation generates code that will per default access the attribute `my_string` on +the Python object, i.e. `obj.getattr("my_string")`, and call `extract()` on the attribute. +It is also possible to access the value on the Python object through `obj.get_item("my_string")` +by setting the attribute `pyo3(item)` on the field: +``` +use pyo3::prelude::*; + +#[derive(FromPyObject)] +struct RustyStruct { + #[pyo3(item)] + my_string: String, +} +``` + +The argument passed to `getattr` and `get_item` can also be configured: + +``` +use pyo3::prelude::*; + +#[derive(FromPyObject)] +struct RustyStruct { + #[pyo3(item("key"))] + string_in_mapping: String, + #[pyo3(attribute("name"))] + string_attr: String, +} +``` + +This tries to extract `string_attr` from the attribute `name` and `string_in_mapping` +from a mapping with the key `"key"`. The arguments for `attribute` are restricted to +non-empty string literals while `item` can take any valid literal that implements +`ToBorrowedObject`. + +#### Deriving [`FromPyObject`] for tuple structs + +Tuple structs are also supported but do not allow customizing the extraction. The input is +always assumed to be a Python tuple with the same length as the Rust type, the `n`th field +is extracted from the `n`th item in the Python tuple. + +``` +use pyo3::prelude::*; + +#[derive(FromPyObject)] +struct RustyTuple(String, String); +``` + +#### Deriving [`FromPyObject`] for wrapper types + +The `pyo3(transparent)` attribute can be used on structs with exactly one field. This results +in extracting directly from the input object, i.e. `obj.extract()`, rather than trying to access +an item or attribute. +``` +use pyo3::prelude::*; + +#[derive(FromPyObject)] +#[pyo3(transparent)] +struct RustyTransparentTuple(String); + +#[derive(FromPyObject)] +#[pyo3(transparent)] +struct RustyTransparentStruct { + inner: String, +} +``` + +#### Deriving [`FromPyObject`] for enums + +The `FromPyObject` derivation for enums generates code that tries to extract the variants in the +order of the fields. As soon as a variant can be extracted succesfully, that variant is returned. + +The same customizations and restrictions described for struct derivations apply to enum variants, +i.e. a tuple variant assumes that the input is a Python tuple, and a struct variant defaults to +extracting fields as attributes but can be configured in the same manner. The `transparent` +attribute can be applied to single-field-variants. + +``` +use pyo3::prelude::*; + +#[derive(FromPyObject)] +enum RustyEnum<'a> { + #[pyo3(transparent)] + Int(usize), // input is a positive int + #[pyo3(transparent)] + String(String), // input is a string + IntTuple(usize, usize), // input is a 2-tuple with positive ints + StringIntTuple(String, usize), // innput is a 2-tuple with String and int + Coordinates3d { // needs to be in front of 2d + x: usize, + y: usize, + z: usize, + }, + Coordinates2d { // only gets checked if the input did not have `z` + #[pyo3(attribute("x"))] + a: usize, + #[pyo3(attribute("y"))] + b: usize, + }, + #[pyo3(transparent)] + CatchAll(&'a PyAny), // This extraction never fails +} +``` + +If none of the enum variants match, a `PyValueError` containing the names of the +tested variants is returned. The names reported in the error message can be customized +through the `pyo3(annotation = "name")` attribute, e.g. to use conventional Python type +names: + +``` +use pyo3::prelude::*; + +#[derive(FromPyObject)] +enum RustyEnum { + #[pyo3(transparent, annotation = "str")] + String(String), + #[pyo3(transparent, annotation = "int")] + Int(isize), +} +``` + +If the input is neither a string nor an integer, the error message will be: +`"Can't convert to str, int"`, where `` is replaced by the type name and +`repr()` of the input object. + ### `IntoPy` This trait defines the to-python conversion for a Rust type. It is usually implemented as From 7a9f4a163311844f6d72caf006aa1c398c545a29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Sun, 30 Aug 2020 15:33:50 +0200 Subject: [PATCH 6/7] FromPyObject derive suggestions by @kngwyu --- pyo3-derive-backend/src/from_pyobject.rs | 84 ++++++++++-------------- tests/ui/invalid_frompy_derive.stderr | 4 +- 2 files changed, 37 insertions(+), 51 deletions(-) diff --git a/pyo3-derive-backend/src/from_pyobject.rs b/pyo3-derive-backend/src/from_pyobject.rs index b28a82579e1..b32408e75d7 100644 --- a/pyo3-derive-backend/src/from_pyobject.rs +++ b/pyo3-derive-backend/src/from_pyobject.rs @@ -2,7 +2,7 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, ToTokens}; use syn::punctuated::Punctuated; use syn::spanned::Spanned; -use syn::{parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Meta, Result}; +use syn::{parse_quote, Attribute, DataEnum, DeriveInput, Fields, Ident, Meta, MetaList, Result}; /// Describes derivation input of an enum. #[derive(Debug)] @@ -264,20 +264,10 @@ impl<'a> Container<'a> { let mut fields: Punctuated = Punctuated::new(); for (ident, attr) in tups { let ext_fn = match attr { - FieldAttribute::GetAttr(name) => { - if let Some(name) = name.as_ref() { - quote!(getattr(#name)) - } else { - quote!(getattr(stringify!(#ident))) - } - } - FieldAttribute::GetItem(key) => { - if let Some(key) = key.as_ref() { - quote!(get_item(#key)) - } else { - quote!(get_item(stringify!(#ident))) - } - } + FieldAttribute::GetAttr(Some(name)) => quote!(getattr(#name)), + FieldAttribute::GetAttr(None) => quote!(getattr(stringify!(#ident))), + FieldAttribute::GetItem(Some(key)) => quote!(get_item(#key)), + FieldAttribute::GetItem(None) => quote!(get_item(stringify!(#ident))), }; fields.push(quote!(#ident: obj.#ext_fn?.extract()?)); } @@ -330,10 +320,10 @@ impl ContainerAttribute { for meta in list.nested { if let syn::NestedMeta::Meta(metaitem) = &meta { match metaitem { - syn::Meta::Path(p) if p.is_ident("transparent") => { + Meta::Path(p) if p.is_ident("transparent") => { attrs.push(ContainerAttribute::Transparent) } - syn::Meta::NameValue(nv) if nv.path.is_ident("annotation") => { + Meta::NameValue(nv) if nv.path.is_ident("annotation") => { if let syn::Lit::Str(s) = &nv.lit { attrs.push(ContainerAttribute::ErrorAnnotation(s.value())) } else { @@ -372,17 +362,17 @@ impl FieldAttribute { /// Currently fails if more than 1 attribute is passed in `pyo3` fn parse_attrs(attrs: &[Attribute]) -> Result> { let list = get_pyo3_meta_list(attrs)?; - if list.nested.is_empty() { - return Ok(None); - } - if list.nested.len() > 1 { - return Err(spanned_err( - list.nested, - "Only one of `item`, `attribute` can be provided, possibly with an \ - additional argument: `item(\"key\")` or `attribute(\"name\").", - )); - } - let metaitem = list.nested.into_iter().next().unwrap(); + let metaitem = match list.nested.len() { + 0 => return Ok(None), + 1 => list.nested.into_iter().next().unwrap(), + _ => { + return Err(spanned_err( + list.nested, + "Only one of `item`, `attribute` can be provided, possibly with an \ + additional argument: `item(\"key\")` or `attribute(\"name\").", + )) + } + }; let meta = match metaitem { syn::NestedMeta::Meta(meta) => meta, syn::NestedMeta::Lit(lit) => { @@ -402,37 +392,33 @@ impl FieldAttribute { } } - fn attribute_arg(meta: syn::Meta) -> syn::Result> { + fn attribute_arg(meta: Meta) -> syn::Result> { let arg_list = match meta { - syn::Meta::List(list) => list, - syn::Meta::Path(_) => return Ok(None), + Meta::List(list) => list, + Meta::Path(_) => return Ok(None), Meta::NameValue(nv) => { let err_msg = "Expected a string literal or no argument: `pyo3(attribute(\"name\") or `pyo3(attribute)`"; return Err(spanned_err(nv, err_msg)); } }; let arg_msg = "Expected a single string literal argument."; - if arg_list.nested.is_empty() { - return Err(spanned_err(arg_list, arg_msg)); - } else if arg_list.nested.len() > 1 { - return Err(spanned_err(arg_list.nested, arg_msg)); - } - let first = arg_list.nested.first().unwrap(); - if let syn::NestedMeta::Lit(lit) = first { - if let syn::Lit::Str(litstr) = lit { - if litstr.value().is_empty() { - return Err(spanned_err(litstr, "Attribute name cannot be empty.")); - } - return Ok(Some(parse_quote!(#litstr))); + let first = match arg_list.nested.len() { + 1 => arg_list.nested.first().unwrap(), + _ => return Err(spanned_err(arg_list, arg_msg)), + }; + if let syn::NestedMeta::Lit(syn::Lit::Str(litstr)) = first { + if litstr.value().is_empty() { + return Err(spanned_err(litstr, "Attribute name cannot be empty.")); } + return Ok(Some(parse_quote!(#litstr))); } Err(spanned_err(first, arg_msg)) } - fn item_arg(meta: syn::Meta) -> syn::Result> { + fn item_arg(meta: Meta) -> syn::Result> { let arg_list = match meta { - syn::Meta::List(list) => list, - syn::Meta::Path(_) => return Ok(None), + Meta::List(list) => list, + Meta::Path(_) => return Ok(None), Meta::NameValue(nv) => { return Err(spanned_err( nv, @@ -459,11 +445,11 @@ fn spanned_err(tokens: T, msg: &str) -> syn::Error { } /// Extract pyo3 metalist, flattens multiple lists into a single one. -fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result { +fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result { let mut list: Punctuated = Punctuated::new(); for value in attrs { match value.parse_meta()? { - syn::Meta::List(ml) if value.path.is_ident("pyo3") => { + Meta::List(ml) if value.path.is_ident("pyo3") => { for meta in ml.nested { list.push(meta); } @@ -471,7 +457,7 @@ fn get_pyo3_meta_list(attrs: &[Attribute]) -> Result { _ => continue, } } - Ok(syn::MetaList { + Ok(MetaList { path: parse_quote!(pyo3), paren_token: syn::token::Paren::default(), nested: list, diff --git a/tests/ui/invalid_frompy_derive.stderr b/tests/ui/invalid_frompy_derive.stderr index 9c08ec65fba..a0d1a60b176 100644 --- a/tests/ui/invalid_frompy_derive.stderr +++ b/tests/ui/invalid_frompy_derive.stderr @@ -97,10 +97,10 @@ error: Expected a single string literal argument. | ^ error: Expected a single string literal argument. - --> $DIR/invalid_frompy_derive.rs:88:22 + --> $DIR/invalid_frompy_derive.rs:88:12 | 88 | #[pyo3(attribute("a", "b"))] - | ^^^^^^^^ + | ^^^^^^^^^^^^^^^^^^^ error: Attribute name cannot be empty. --> $DIR/invalid_frompy_derive.rs:94:22 From 0f32f886b8756f6f89c6033f0370c3c2f598b44f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Sebastian=20P=C3=BCtz?= Date: Sun, 30 Aug 2020 19:04:42 +0200 Subject: [PATCH 7/7] More FromPyObject derive suggestions by @davidhewitt --- guide/src/conversions.md | 42 ++++++++++++++++++++---- pyo3-derive-backend/src/from_pyobject.rs | 34 ++++++++++--------- tests/test_frompyobject.rs | 3 +- tests/ui/invalid_frompy_derive.stderr | 4 +-- 4 files changed, 57 insertions(+), 26 deletions(-) diff --git a/guide/src/conversions.md b/guide/src/conversions.md index b861b6ca0d0..f46da697f68 100644 --- a/guide/src/conversions.md +++ b/guide/src/conversions.md @@ -183,17 +183,28 @@ use pyo3::prelude::*; struct RustyTuple(String, String); ``` +Tuple structs with a single field are treated as wrapper types which are described in the +following section. To override this behaviour and ensure that the input is in fact a tuple, +specify the struct as +``` +use pyo3::prelude::*; + +#[derive(FromPyObject)] +struct RustyTuple((String,)); +``` + #### Deriving [`FromPyObject`] for wrapper types The `pyo3(transparent)` attribute can be used on structs with exactly one field. This results in extracting directly from the input object, i.e. `obj.extract()`, rather than trying to access -an item or attribute. +an item or attribute. This behaviour is enabled per default for newtype structs and tuple-variants +with a single field. + ``` use pyo3::prelude::*; #[derive(FromPyObject)] -#[pyo3(transparent)] -struct RustyTransparentTuple(String); +struct RustyTransparentTupleStruct(String); #[derive(FromPyObject)] #[pyo3(transparent)] @@ -217,12 +228,10 @@ use pyo3::prelude::*; #[derive(FromPyObject)] enum RustyEnum<'a> { - #[pyo3(transparent)] Int(usize), // input is a positive int - #[pyo3(transparent)] String(String), // input is a string IntTuple(usize, usize), // input is a 2-tuple with positive ints - StringIntTuple(String, usize), // innput is a 2-tuple with String and int + StringIntTuple(String, usize), // input is a 2-tuple with String and int Coordinates3d { // needs to be in front of 2d x: usize, y: usize, @@ -257,9 +266,28 @@ enum RustyEnum { ``` If the input is neither a string nor an integer, the error message will be: -`"Can't convert to str, int"`, where `` is replaced by the type name and +`"Can't convert to Union[str, int]"`, where `` is replaced by the type name and `repr()` of the input object. +#### `#[derive(FromPyObject)]` Container Attributes +- `pyo3(transparent)` + - extract the field directly from the object as `obj.extract()` instead of `get_item()` or + `getattr()` + - Newtype structs and tuple-variants are treated as transparent per default. + - only supported for single-field structs and enum variants +- `pyo3(annotation = "name")` + - changes the name of the failed variant in the generated error message in case of failure. + - e.g. `pyo3("int")` reports the variant's type as `int`. + - only supported for enum variants + +#### `#[derive(FromPyObject)]` Field Attributes +- `pyo3(attribute)`, `pyo3(attribute("name"))` + - retrieve the field from an attribute, possibly with a custom name specified as an argument + - argument must be a string-literal. +- `pyo3(item)`, `pyo3(item("key"))` + - retrieve the field from a mapping, possibly with the custom key specified as an argument. + - can be any literal that implements `ToBorrowedObject` + ### `IntoPy` This trait defines the to-python conversion for a Rust type. It is usually implemented as diff --git a/pyo3-derive-backend/src/from_pyobject.rs b/pyo3-derive-backend/src/from_pyobject.rs index b32408e75d7..f7688e2b3e4 100644 --- a/pyo3-derive-backend/src/from_pyobject.rs +++ b/pyo3-derive-backend/src/from_pyobject.rs @@ -65,6 +65,11 @@ impl<'a> Enum<'a> { error_names.push_str(", "); } } + let error_names = if self.variants.len() > 1 { + format!("Union[{}]", error_names) + } else { + error_names + }; quote!( #(#var_extracts)* let type_name = obj.get_type().name(); @@ -134,7 +139,13 @@ impl<'a> Container<'a> { } let style = match (fields, transparent) { (Fields::Unnamed(_), true) => ContainerType::TupleNewtype, - (Fields::Unnamed(unnamed), false) => ContainerType::Tuple(unnamed.unnamed.len()), + (Fields::Unnamed(unnamed), false) => { + if unnamed.unnamed.len() == 1 { + ContainerType::TupleNewtype + } else { + ContainerType::Tuple(unnamed.unnamed.len()) + } + } (Fields::Named(named), true) => { let field = named .named @@ -321,7 +332,8 @@ impl ContainerAttribute { if let syn::NestedMeta::Meta(metaitem) = &meta { match metaitem { Meta::Path(p) if p.is_ident("transparent") => { - attrs.push(ContainerAttribute::Transparent) + attrs.push(ContainerAttribute::Transparent); + continue; } Meta::NameValue(nv) if nv.path.is_ident("annotation") => { if let syn::Lit::Str(s) = &nv.lit { @@ -329,21 +341,13 @@ impl ContainerAttribute { } else { return Err(spanned_err(&nv.lit, "Expected string literal.")); } + continue; } - other => { - return Err(spanned_err( - other, - "Expected `transparent` or `annotation = \"name\"`", - )) - } + _ => {} // return Err below } - } else { - return Err(spanned_err( - meta, - "Unknown container attribute, expected `transparent` or \ - `annotation(\"err_name\")`", - )); } + + return Err(spanned_err(meta, "Unrecognized `pyo3` container attribute")); } Ok(attrs) } @@ -514,7 +518,7 @@ pub fn build_derive_from_pyobject(tokens: &DeriveInput) -> Result { syn::Data::Union(_) => { return Err(spanned_err( tokens, - "FromPyObject can not be derived for unions.", + "#[derive(FromPyObject)] is not supported for unions.", )) } }; diff --git a/tests/test_frompyobject.rs b/tests/test_frompyobject.rs index ca9104cf987..0fa937cc853 100644 --- a/tests/test_frompyobject.rs +++ b/tests/test_frompyobject.rs @@ -155,7 +155,6 @@ fn test_tuple_struct() { } #[derive(FromPyObject)] -#[pyo3(transparent)] pub struct TransparentTuple(String); #[test] @@ -301,7 +300,7 @@ fn test_err_rename() { PyErrValue::ToObject(to) => { let o = to.to_object(py); let s = String::extract(o.as_ref(py)).expect("Err val is not a string"); - assert_eq!(s, "Can't convert {} (dict) to str, uint, int") + assert_eq!(s, "Can't convert {} (dict) to Union[str, uint, int]") } _ => panic!("Expected PyErrValue::ToObject"), }, diff --git a/tests/ui/invalid_frompy_derive.stderr b/tests/ui/invalid_frompy_derive.stderr index a0d1a60b176..4eaf1218586 100644 --- a/tests/ui/invalid_frompy_derive.stderr +++ b/tests/ui/invalid_frompy_derive.stderr @@ -132,7 +132,7 @@ error: Only one of `item`, `attribute` can be provided, possibly with an additio 118 | #[pyo3(item, attribute)] | ^^^^^^^^^^^^^^^ -error: Expected `transparent` or `annotation = "name"` +error: Unrecognized `pyo3` container attribute --> $DIR/invalid_frompy_derive.rs:123:8 | 123 | #[pyo3(unknown = "should not work")] @@ -156,7 +156,7 @@ error: FromPyObject can be derived with at most one lifetime parameter. 141 | enum TooManyLifetimes<'a, 'b> { | ^^^^^^^^ -error: FromPyObject can not be derived for unions. +error: #[derive(FromPyObject)] is not supported for unions. --> $DIR/invalid_frompy_derive.rs:147:1 | 147 | / union Union {