diff --git a/Cargo.toml b/Cargo.toml index c9e2255..3b9ec2d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,8 +10,10 @@ repository = "https://github.com/rust-osdev/volatile" edition = "2021" [dependencies] +volatile-macro = { version = "0.5.2", optional = true, path = "volatile-macro" } [features] +derive = ["dep:volatile-macro"] # Enable unstable features; requires Rust nightly; might break on compiler updates unstable = [] # Enable unstable and experimental features; requires Rust nightly; might break on compiler updates @@ -28,3 +30,6 @@ pre-release-commit-message = "Release version {{version}}" [package.metadata.docs.rs] features = ["unstable"] + +[workspace] +members = ["volatile-macro"] diff --git a/src/lib.rs b/src/lib.rs index eb4ee4f..d5f8a31 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -42,6 +42,64 @@ #![warn(missing_docs)] #![deny(unsafe_op_in_unsafe_fn)] +/// A derive macro for method-based accesses to volatile structures. +/// +/// This macro allows you to access the fields of a volatile structure via methods that enforce access limitations. +/// It is also more easily chainable than [`map_field`]. +/// +/// # Examples +/// +/// ``` +/// use volatile::access::ReadOnly; +/// use volatile::{Volatile, VolatilePtr, VolatileRef}; +/// +/// #[repr(C)] +/// #[derive(Volatile, Default)] +/// pub struct DeviceConfig { +/// feature_select: u32, +/// #[access(ReadOnly)] +/// feature: u32, +/// } +/// +/// let mut device_config = DeviceConfig::default(); +/// let mut volatile_ref = VolatileRef::from_mut_ref(&mut device_config); +/// let mut volatile_ptr = volatile_ref.as_mut_ptr(); +/// +/// volatile_ptr.feature_select().write(42); +/// assert_eq!(volatile_ptr.feature_select().read(), 42); +/// +/// // This does not compile, because we specified `#[access(ReadOnly)]` for this field. +/// // volatile_ptr.feature().write(42); +/// +/// // A real device might have changed the value, though. +/// assert_eq!(volatile_ptr.feature().read(), 0); +/// ``` +/// +/// # Details +/// +/// This macro generates a new trait (`{T}Volatile`) and implements it for `VolatilePtr<'a, T, ReadWrite>`. +/// The example above results in (roughly) the following code: +/// +/// ``` +/// pub trait DeviceConfigVolatile<'a> { +/// fn feature_select(self) -> VolatilePtr<'a, u32, ReadWrite>; +/// +/// fn feature(self) -> VolatilePtr<'a, u32, ReadOnly>; +/// } +/// +/// impl<'a> DeviceConfigVolatile<'a> for VolatilePtr<'a, DeviceConfig, ReadWrite> { +/// fn feature_select(self) -> VolatilePtr<'a, u32, ReadWrite> { +/// map_field!(self.feature_select).restrict() +/// } +/// +/// fn feature(self) -> VolatilePtr<'a, u32, ReadOnly> { +/// map_field!(self.feature).restrict() +/// } +/// } +/// ``` +#[cfg(feature = "derive")] +pub use volatile_macro::Volatile; + pub use volatile_ptr::VolatilePtr; pub use volatile_ref::VolatileRef; diff --git a/volatile-macro/Cargo.toml b/volatile-macro/Cargo.toml new file mode 100644 index 0000000..24c4ba2 --- /dev/null +++ b/volatile-macro/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "volatile-macro" +version = "0.5.2" +authors = ["Martin Kröning "] +edition = "2021" +description = "Procedural macros for the volatile crate." +repository = "https://github.com/rust-osdev/volatile" +license = "MIT OR Apache-2.0" +keywords = ["volatile"] +categories = ["no-std", "no-std::no-alloc"] + +[lib] +proc-macro = true + +[dependencies] +proc-macro2 = "1" +quote = "1" +syn = { version = "2", features = ["full"] } diff --git a/volatile-macro/src/lib.rs b/volatile-macro/src/lib.rs new file mode 100644 index 0000000..efa08d7 --- /dev/null +++ b/volatile-macro/src/lib.rs @@ -0,0 +1,26 @@ +use proc_macro::TokenStream; +use proc_macro2::TokenStream as TokenStream2; +use quote::ToTokens; +use syn::parse_macro_input; + +macro_rules! bail { + ($span:expr, $($tt:tt)*) => { + return Err(syn::Error::new_spanned($span, format!($($tt)*))) + }; +} + +mod volatile; + +#[proc_macro_derive(Volatile, attributes(access))] +pub fn derive_volatile(item: TokenStream) -> TokenStream { + match volatile::derive_volatile(parse_macro_input!(item)) { + Ok(items) => { + let mut tokens = TokenStream2::new(); + for item in &items { + item.to_tokens(&mut tokens); + } + tokens.into() + } + Err(e) => e.to_compile_error().into(), + } +} diff --git a/volatile-macro/src/volatile.rs b/volatile-macro/src/volatile.rs new file mode 100644 index 0000000..4d89a12 --- /dev/null +++ b/volatile-macro/src/volatile.rs @@ -0,0 +1,182 @@ +use quote::format_ident; +use syn::{ + parse_quote, Attribute, Fields, Ident, Item, ItemImpl, ItemStruct, ItemTrait, Path, Result, + Signature, Visibility, +}; + +fn validate_input(input: &ItemStruct) -> Result<()> { + if !matches!(&input.fields, Fields::Named(_)) { + bail!( + &input.fields, + "#[derive(Volatile)] can only be used on structs with named fields" + ); + } + + if !input.generics.params.is_empty() { + bail!( + &input.generics, + "#[derive(Volatile)] cannot be used with generic structs" + ); + } + + let mut valid_repr = false; + for attr in &input.attrs { + if attr.path().is_ident("repr") { + let ident = attr.parse_args::()?; + if ident == "C" || ident == "transparent" { + valid_repr = true; + } + } + } + if !valid_repr { + bail!( + &input.ident, + "#[derive(Volatile)] structs must be `#[repr(C)]` or `#[repr(transparent)]`" + ); + } + + Ok(()) +} + +fn parse_attrs(fields: &Fields) -> Result>> { + let mut attrss = vec![]; + + for field in fields.iter() { + let mut attrs = vec![]; + for attr in &field.attrs { + if attr.path().is_ident("doc") { + attrs.push(attr.clone()); + } + } + attrss.push(attrs); + } + + Ok(attrss) +} + +fn parse_sigs(fields: &Fields) -> Result> { + let mut sigs = vec![]; + + for field in fields.iter() { + let ident = field.ident.as_ref().unwrap(); + let ty = &field.ty; + + let mut access: Path = parse_quote! { ::volatile::access::ReadWrite }; + for attr in &field.attrs { + if attr.path().is_ident("access") { + access = attr.parse_args()?; + } + } + + let sig = parse_quote! { + fn #ident(self) -> ::volatile::VolatilePtr<'a, #ty, #access> + }; + sigs.push(sig); + } + + Ok(sigs) +} + +fn emit_trait( + vis: &Visibility, + ident: &Ident, + attrs: &[Vec], + sigs: &[Signature], +) -> Result { + let item_trait = parse_quote! { + #[allow(non_camel_case_types)] + #vis trait #ident <'a> { + #( + #(#attrs)* + #sigs; + )* + } + }; + + Ok(item_trait) +} + +fn emit_impl(trait_ident: &Ident, struct_ident: &Ident, sigs: &[Signature]) -> Result { + let fields = sigs.iter().map(|sig| &sig.ident); + + let item_impl = parse_quote! { + #[automatically_derived] + impl<'a> #trait_ident<'a> for ::volatile::VolatilePtr<'a, #struct_ident, ::volatile::access::ReadWrite> { + #( + #sigs { + ::volatile::map_field!(self.#fields).restrict() + } + )* + } + }; + + Ok(item_impl) +} + +pub fn derive_volatile(input: ItemStruct) -> Result> { + validate_input(&input)?; + let attrs = parse_attrs(&input.fields)?; + let sigs = parse_sigs(&input.fields)?; + let trait_ident = format_ident!("{}Volatile", input.ident); + + let item_trait = emit_trait(&input.vis, &trait_ident, &attrs, &sigs)?; + let item_impl = emit_impl(&item_trait.ident, &input.ident, &sigs)?; + Ok(vec![Item::Trait(item_trait), Item::Impl(item_impl)]) +} + +#[cfg(test)] +mod tests { + use quote::{quote, ToTokens}; + + use super::*; + + #[test] + fn test_derive() -> Result<()> { + let input = parse_quote! { + #[repr(C)] + #[derive(Volatile, Default)] + pub struct DeviceConfig { + feature_select: u32, + /// This is a good field. + #[access(ReadOnly)] + feature: u32, + } + }; + + let result = derive_volatile(input)?; + + let expected_trait = quote! { + #[allow(non_camel_case_types)] + pub trait DeviceConfigVolatile<'a> { + fn feature_select(self) -> ::volatile::VolatilePtr<'a, u32, ::volatile::access::ReadWrite>; + + /// This is a good field. + fn feature(self) -> ::volatile::VolatilePtr<'a, u32, ReadOnly>; + } + }; + + let expected_impl = quote! { + #[automatically_derived] + impl<'a> DeviceConfigVolatile<'a> for ::volatile::VolatilePtr<'a, DeviceConfig, ::volatile::access::ReadWrite> { + fn feature_select(self) -> ::volatile::VolatilePtr<'a, u32, ::volatile::access::ReadWrite> { + ::volatile::map_field!(self.feature_select).restrict() + } + + fn feature(self) -> ::volatile::VolatilePtr<'a, u32, ReadOnly> { + ::volatile::map_field!(self.feature).restrict() + } + } + }; + + assert_eq!( + expected_trait.to_string(), + result[0].to_token_stream().to_string() + ); + assert_eq!( + expected_impl.to_string(), + result[1].to_token_stream().to_string() + ); + + Ok(()) + } +}