Skip to content

Commit

Permalink
squash! Untify #[pyenum] and #[pyclass]
Browse files Browse the repository at this point in the history
  • Loading branch information
jovenlin0527 committed Nov 18, 2021
1 parent 55c2a72 commit fd73d79
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 25 deletions.
62 changes: 43 additions & 19 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{parse_quote, spanned::Spanned, Expr, Result, Token};

/// If the class is derived from a Rust `struct` or `enum`.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum PyClassKind {
Struct,
Enum,
}

/// The parsed arguments of the pyclass macro
pub struct PyClassArgs {
pub freelist: Option<syn::Expr>,
Expand All @@ -27,22 +34,28 @@ pub struct PyClassArgs {
pub has_extends: bool,
pub has_unsendable: bool,
pub module: Option<syn::LitStr>,
pub class_kind: PyClassKind,
}

impl Parse for PyClassArgs {
fn parse(input: ParseStream) -> Result<Self> {
let mut slf = PyClassArgs::default();

impl PyClassArgs {
fn parse(input: ParseStream, kind: PyClassKind) -> Result<Self> {
let mut slf = PyClassArgs::new(kind);
let vars = Punctuated::<Expr, Token![,]>::parse_terminated(input)?;
for expr in vars {
slf.add_expr(&expr)?;
}
Ok(slf)
}
}

impl Default for PyClassArgs {
fn default() -> Self {
pub fn parse_stuct_args(input: ParseStream) -> syn::Result<Self> {
Self::parse(input, PyClassKind::Struct)
}

pub fn parse_enum_args(input: ParseStream) -> syn::Result<Self> {
Self::parse(input, PyClassKind::Enum)
}

fn new(class_kind: PyClassKind) -> Self {
PyClassArgs {
freelist: None,
name: None,
Expand All @@ -54,11 +67,10 @@ impl Default for PyClassArgs {
is_basetype: false,
has_extends: false,
has_unsendable: false,
class_kind,
}
}
}

impl PyClassArgs {
/// Adda single expression from the comma separated list in the attribute, which is
/// either a single word or an assignment expression
fn add_expr(&mut self, expr: &Expr) -> Result<()> {
Expand Down Expand Up @@ -116,6 +128,11 @@ impl PyClassArgs {
},
"extends" => match unwrap_group(&**right) {
syn::Expr::Path(exp) => {
if self.class_kind == PyClassKind::Enum {
return Err(
err_spanned!( assign.span() => "enums cannot extend from other classes" ),
);
}
self.base = syn::TypePath {
path: exp.path.clone(),
qself: None,
Expand Down Expand Up @@ -150,6 +167,11 @@ impl PyClassArgs {
self.has_weaklist = true;
}
"subclass" => {
if self.class_kind == PyClassKind::Enum {
return Err(
err_spanned!(exp.span() => "enums can't be inherited by other classes"),
);
}
self.is_basetype = true;
}
"dict" => {
Expand Down Expand Up @@ -496,26 +518,27 @@ struct VariantPyO3<'a> {
}

pub fn build_py_enum(
_args: PyClassArgs,
enum_: &syn::ItemEnum,
args: PyClassArgs,
method_type: PyClassMethodsType,
) -> syn::Result<TokenStream> {
let variants: Vec<VariantPyO3> = enum_
.variants
.iter()
.map(|v| extract_variant_data(v))
.collect::<syn::Result<_>>()?;
impl_enum(enum_, variants, method_type)
impl_enum(&enum_, args, variants, method_type)
}

fn impl_enum(
enum_: &syn::ItemEnum,
attrs: PyClassArgs,
variants: Vec<VariantPyO3>,
methods_type: PyClassMethodsType,
) -> syn::Result<TokenStream> {
let enum_name = &enum_.ident;
let doc = utils::get_doc(&enum_.attrs, None);
let enum_cls = impl_enum_class(enum_name, doc, methods_type)?;
let enum_cls = impl_enum_class(enum_name, &attrs, doc, methods_type)?;

let variant_consts = variants
.iter()
Expand Down Expand Up @@ -548,16 +571,17 @@ fn impl_const(enum_: &syn::Ident, cls: &syn::Ident) -> syn::Result<TokenStream>
}

fn impl_enum_class(
typ: &syn::Ident,
cls: &syn::Ident,
_attr: &PyClassArgs,
doc: PythonDoc,
methods_type: PyClassMethodsType,
) -> syn::Result<TokenStream> {
let clsname = typ.to_string();
let extractext = impl_extractext(&typ);
let pyclassimpl_impl = PyClassImplBuilder::new(&typ, methods_type).doc(doc).build();
let clsname = cls.to_string();
let extractext = impl_extractext(cls);
let pyclassimpl_impl = PyClassImplBuilder::new(cls, methods_type).doc(doc).build();

Ok(quote! {
unsafe impl pyo3::type_object::PyTypeInfo for #typ {
unsafe impl pyo3::type_object::PyTypeInfo for #cls {
type AsRefTarget = pyo3::PyCell<Self>;
const NAME: &'static str = #clsname;
const MODULE: Option<&'static str> = None;
Expand All @@ -569,7 +593,7 @@ fn impl_enum_class(
}
}

impl pyo3::PyClass for #typ {
impl pyo3::PyClass for #cls {
type Dict = pyo3::pyclass_slots::PyClassDummySlot ;
type WeakRef = pyo3::pyclass_slots::PyClassDummySlot;
type BaseNativeType = pyo3::PyAny;
Expand Down Expand Up @@ -653,7 +677,7 @@ fn impl_descriptors(
}

/// Builds an implementation for `pyo3::class::impl_::PyClassImpl`
pub struct PyClassImplBuilder<'a> {
struct PyClassImplBuilder<'a> {
cls: &'a syn::Ident,
doc: Option<PythonDoc>,
is_gc: bool,
Expand Down
13 changes: 7 additions & 6 deletions pyo3-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,10 @@ pub fn pyproto(_: TokenStream, input: TokenStream) -> TokenStream {
#[proc_macro_attribute]
pub fn pyclass(attr: TokenStream, input: TokenStream) -> TokenStream {
use syn::Item;
let args = parse_macro_input!(attr as PyClassArgs);
let item = parse_macro_input!(input as Item);
match item {
Item::Struct(struct_) => pyclass_impl(args, struct_, methods_type()),
Item::Enum(enum_) => pyclass_enum_impl(args, enum_, methods_type()),
Item::Struct(struct_) => pyclass_impl(attr, struct_, methods_type()),
Item::Enum(enum_) => pyclass_enum_impl(attr, enum_, methods_type()),
unsupported => {
syn::Error::new_spanned(unsupported, "#[pyclass] only supports structs and enums.")
.to_compile_error()
Expand Down Expand Up @@ -209,10 +208,11 @@ pub fn derive_from_py_object(item: TokenStream) -> TokenStream {
}

fn pyclass_impl(
args: PyClassArgs,
attrs: TokenStream,
mut ast: syn::ItemStruct,
methods_type: PyClassMethodsType,
) -> TokenStream {
let args = parse_macro_input!(attrs with PyClassArgs::parse_stuct_args);
let expanded =
build_py_class(&mut ast, &args, methods_type).unwrap_or_else(|e| e.to_compile_error());

Expand All @@ -224,12 +224,13 @@ fn pyclass_impl(
}

fn pyclass_enum_impl(
args: PyClassArgs,
attr: TokenStream,
enum_: syn::ItemEnum,
methods_type: PyClassMethodsType,
) -> TokenStream {
let args = parse_macro_input!(attr with PyClassArgs::parse_enum_args);
let expanded =
build_py_enum(args, &enum_, methods_type).unwrap_or_else(|e| e.into_compile_error());
build_py_enum(&enum_, args, methods_type).unwrap_or_else(|e| e.into_compile_error());

quote!(
#enum_
Expand Down
1 change: 1 addition & 0 deletions tests/test_compile_error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ fn _test_compile_errors() {
t.compile_fail("tests/ui/invalid_need_module_arg_position.rs");
t.compile_fail("tests/ui/invalid_property_args.rs");
t.compile_fail("tests/ui/invalid_pyclass_args.rs");
t.compile_fail("tests/ui/invalid_pyclass_enum.rs");
t.compile_fail("tests/ui/invalid_pyclass_item.rs");
t.compile_fail("tests/ui/invalid_pyfunctions.rs");
t.compile_fail("tests/ui/invalid_pymethods.rs");
Expand Down
15 changes: 15 additions & 0 deletions tests/ui/invalid_pyclass_enum.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use pyo3::prelude::*;

#[pyclass(subclass)]
enum NotBaseClass {
x,
y,
}

#[pyclass(extends = PyList)]
enum NotDrivedClass {
x,
y,
}

fn main() {}
11 changes: 11 additions & 0 deletions tests/ui/invalid_pyclass_enum.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
error: enums can't be inherited by other classes
--> tests/ui/invalid_pyclass_enum.rs:3:11
|
3 | #[pyclass(subclass)]
| ^^^^^^^^

error: enums cannot extend from other classes
--> tests/ui/invalid_pyclass_enum.rs:9:11
|
9 | #[pyclass(extends = PyList)]
| ^^^^^^^

0 comments on commit fd73d79

Please sign in to comment.