Skip to content

Commit

Permalink
refactor code generation
Browse files Browse the repository at this point in the history
  • Loading branch information
newcomertv committed May 1, 2024
1 parent ac88ea8 commit 023d3b7
Showing 1 changed file with 83 additions and 40 deletions.
123 changes: 83 additions & 40 deletions pyo3-macros-backend/src/pyclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1044,19 +1044,20 @@ fn impl_complex_enum_struct_variant_cls(
Ok((cls_impl, field_getters, Vec::new()))
}

fn impl_complex_enum_tuple_variant_match_arms(
fn impl_complex_enum_tuple_variant_field_getters(
ctx: &Ctx,
variant: &PyClassEnumTupleVariant<'_>,
enum_name: &syn::Ident,
variant_cls_type: &syn::Type,
variant_ident: &&Ident,
field_names: &mut Vec<Ident>,
fields_with_types: &mut Vec<TokenStream>,
field_getters: &mut Vec<MethodAndMethodDef>,
field_getter_impls: &mut Vec<TokenStream>,
) -> Result<()> {
) -> Result<(Vec<MethodAndMethodDef>, Vec<TokenStream>)> {
let Ctx { pyo3_path } = ctx;

let mut field_getters = vec![];
let mut field_getter_impls = vec![];

for (index, field) in variant.fields.iter().enumerate() {
let field_name = format_ident!("_{}", index);
let field_type = field.ty;
Expand Down Expand Up @@ -1091,60 +1092,48 @@ fn impl_complex_enum_tuple_variant_match_arms(
field_getter_impls.push(field_getter_impl);
}

Ok(())
Ok((field_getters, field_getter_impls))
}

fn impl_complex_enum_tuple_variant_cls(
enum_name: &syn::Ident,
variant: &PyClassEnumTupleVariant<'_>,
fn impl_complex_enum_tuple_variant_len(
ctx: &Ctx,
) -> Result<(TokenStream, Vec<MethodAndMethodDef>, Vec<MethodAndSlotDef>)> {
_variant: &PyClassEnumTupleVariant<'_>,
_enum_name: &syn::Ident,
variant_cls_type: &syn::Type,
_variant_ident: &&Ident,
num_fields: usize,
) -> Result<(MethodAndSlotDef, TokenStream)> {
let Ctx { pyo3_path } = ctx;
let variant_ident = &variant.ident;
let variant_cls = gen_complex_enum_variant_class_ident(enum_name, variant.ident);
let variant_cls_type = parse_quote!(#variant_cls);

let mut slots = vec![];

// represents the index of the field
let mut field_names: Vec<Ident> = vec![];
let mut fields_with_types: Vec<TokenStream> = vec![];
let mut field_getters: Vec<MethodAndMethodDef> = vec![];
let mut field_getter_impls: Vec<TokenStream> = vec![];

impl_complex_enum_tuple_variant_match_arms(
ctx,
variant,
enum_name,
&variant_cls_type,
&variant_ident,
&mut field_names,
&mut fields_with_types,
&mut field_getters,
&mut field_getter_impls,
)?;

let num_fields = variant.fields.len();

let mut len_signature: syn::Signature =
syn::parse_quote!(fn __len__(slf: PyRef<Self>) -> PyResult<usize>);
let variant_len = crate::pymethod::impl_py_len_def(&variant_cls_type, ctx, &mut len_signature)?;

slots.push(variant_len);

let len_method_impl = quote! {
fn __len__(slf: #pyo3_path::PyRef<Self>) -> #pyo3_path::PyResult<usize> {
Ok(#num_fields)
}
};

Ok((variant_len, len_method_impl))
}

fn impl_complex_enum_tuple_variant_getitem(
ctx: &Ctx,
_variant: &PyClassEnumTupleVariant<'_>,
_enum_name: &syn::Ident,
variant_cls: &syn::Ident,
variant_cls_type: &syn::Type,
_variant_ident: &&Ident,
num_fields: usize,
) -> Result<(MethodAndSlotDef, TokenStream)> {
let Ctx { pyo3_path } = ctx;

let mut get_item_signature: syn::Signature =
syn::parse_quote!(fn __getitem__(slf: PyRef<Self>, idx: usize) -> PyResult<PyObject>);
let variant_len =
let variant_getitem =
crate::pymethod::impl_py_getitem_def(&variant_cls_type, ctx, &mut get_item_signature)?;

slots.push(variant_len);

let match_arms: Vec<_> = (0..num_fields)
.map(|i| {
let field_access = format_ident!("_{}", i);
Expand Down Expand Up @@ -1178,6 +1167,60 @@ fn impl_complex_enum_tuple_variant_cls(
}
};

Ok((variant_getitem, get_item_method_impl))
}

fn impl_complex_enum_tuple_variant_cls(
enum_name: &syn::Ident,
variant: &PyClassEnumTupleVariant<'_>,
ctx: &Ctx,
) -> Result<(TokenStream, Vec<MethodAndMethodDef>, Vec<MethodAndSlotDef>)> {
let Ctx { pyo3_path } = ctx;
let variant_ident = &variant.ident;
let variant_cls = gen_complex_enum_variant_class_ident(enum_name, variant.ident);
let variant_cls_type = parse_quote!(#variant_cls);

let mut slots = vec![];

// represents the index of the field
let mut field_names: Vec<Ident> = vec![];
let mut fields_with_types: Vec<TokenStream> = vec![];

let (field_getters, field_getter_impls) = impl_complex_enum_tuple_variant_field_getters(
ctx,
variant,
enum_name,
&variant_cls_type,
&variant_ident,
&mut field_names,
&mut fields_with_types,
)?;

let num_fields = variant.fields.len();

let (variant_len, len_method_impl) = impl_complex_enum_tuple_variant_len(
ctx,
variant,
enum_name,
&variant_cls_type,
&variant_ident,
num_fields,
)?;

slots.push(variant_len);

let (variant_getitem, getitem_method_impl) = impl_complex_enum_tuple_variant_getitem(
ctx,
variant,
enum_name,
&variant_cls,
&variant_cls_type,
&variant_ident,
num_fields,
)?;

slots.push(variant_getitem);

let cls_impl = quote! {
#[doc(hidden)]
#[allow(non_snake_case)]
Expand All @@ -1189,7 +1232,7 @@ fn impl_complex_enum_tuple_variant_cls(

#len_method_impl

#get_item_method_impl
#getitem_method_impl

#(#field_getter_impls)*
}
Expand Down

0 comments on commit 023d3b7

Please sign in to comment.