-
-
Notifications
You must be signed in to change notification settings - Fork 501
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(ast): split
ast
macro into multiple files (#5791)
Pure refactor. Split implementation of `#[ast]` macro into multiple files. This means each file works with a single version of `TokenStream`, rather than having `proc_macro::TokenStream` and `proc_macro2::TokenStream` both in use in a single file.
- Loading branch information
1 parent
52c6409
commit dc10eaf
Showing
2 changed files
with
101 additions
and
99 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
use proc_macro2::TokenStream; | ||
use quote::quote; | ||
|
||
pub fn ast(input: &syn::Item) -> TokenStream { | ||
let (head, tail) = match input { | ||
syn::Item::Enum(enum_) => (enum_repr(enum_), assert_generated_derives(&enum_.attrs)), | ||
syn::Item::Struct(struct_) => { | ||
(quote!(#[repr(C)]), assert_generated_derives(&struct_.attrs)) | ||
} | ||
_ => unreachable!(), | ||
}; | ||
|
||
quote! { | ||
#[derive(::oxc_ast_macros::Ast)] | ||
#head | ||
#input | ||
#tail | ||
} | ||
} | ||
|
||
/// If `enum_` has any non-unit variant, returns `#[repr(C, u8)]`, otherwise returns `#[repr(u8)]`. | ||
fn enum_repr(enum_: &syn::ItemEnum) -> TokenStream { | ||
if enum_.variants.iter().any(|var| !matches!(var.fields, syn::Fields::Unit)) { | ||
quote!(#[repr(C, u8)]) | ||
} else { | ||
quote!(#[repr(u8)]) | ||
} | ||
} | ||
|
||
/// Generate assertions that traits used in `#[generate_derive]` are in scope. | ||
/// | ||
/// e.g. for `#[generate_derive(GetSpan)]`, it generates: | ||
/// | ||
/// ```rs | ||
/// const _: () = { | ||
/// { | ||
/// trait AssertionTrait: ::oxc_span::GetSpan {} | ||
/// impl<T: GetSpan> AssertionTrait for T {} | ||
/// } | ||
/// }; | ||
/// ``` | ||
/// | ||
/// If `GetSpan` is not in scope, or it is not the correct `oxc_span::GetSpan`, | ||
/// this will raise a compilation error. | ||
fn assert_generated_derives(attrs: &[syn::Attribute]) -> TokenStream { | ||
#[inline] | ||
fn parse(attr: &syn::Attribute) -> impl Iterator<Item = syn::Ident> { | ||
attr.parse_args_with( | ||
syn::punctuated::Punctuated::<syn::Ident, syn::token::Comma>::parse_terminated, | ||
) | ||
.expect("`generate_derive` only accepts traits as single segment paths, Found an invalid argument") | ||
.into_iter() | ||
} | ||
|
||
// TODO: benchmark this to see if a lazy static cell containing `HashMap` would perform better. | ||
#[inline] | ||
fn abs_trait( | ||
ident: &syn::Ident, | ||
) -> (/* absolute type path */ TokenStream, /* possible generics */ TokenStream) { | ||
#[cold] | ||
fn invalid_derive(ident: &syn::Ident) -> ! { | ||
panic!( | ||
"Invalid derive trait(generate_derive): {ident}.\n\ | ||
Help: If you are trying to implement a new `generate_derive` trait, \ | ||
Make sure to add it to the list below." | ||
) | ||
} | ||
|
||
if ident == "CloneIn" { | ||
(quote!(::oxc_allocator::CloneIn), quote!(<'static>)) | ||
} else if ident == "GetSpan" { | ||
(quote!(::oxc_span::GetSpan), TokenStream::default()) | ||
} else if ident == "GetSpanMut" { | ||
(quote!(::oxc_span::GetSpanMut), TokenStream::default()) | ||
} else if ident == "ContentEq" { | ||
(quote!(::oxc_span::cmp::ContentEq), TokenStream::default()) | ||
} else if ident == "ContentHash" { | ||
(quote!(::oxc_span::hash::ContentHash), TokenStream::default()) | ||
} else { | ||
invalid_derive(ident) | ||
} | ||
} | ||
|
||
// NOTE: At this level we don't care if a trait is derived multiple times, It is the | ||
// responsibility of the `ast_tools` to raise errors for those. | ||
let assertion = | ||
attrs.iter().filter(|attr| attr.path().is_ident("generate_derive")).flat_map(parse).map( | ||
|derive| { | ||
let (abs_derive, generics) = abs_trait(&derive); | ||
quote! {{ | ||
// NOTE: these are wrapped in a scope to avoid the need for unique identifiers. | ||
trait AssertionTrait: #abs_derive #generics {} | ||
impl<T: #derive #generics> AssertionTrait for T {} | ||
}} | ||
}, | ||
); | ||
quote!(const _: () = { #(#assertion)* };) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters