Skip to content

Commit

Permalink
zeroize_derive: Inject where clauses; skip unused (#882)
Browse files Browse the repository at this point in the history
The where clauses I was previously injecting *were* load bearing, but
they needed to be filtered for skipped fields.

Fixes #878
  • Loading branch information
maurer authored Mar 31, 2023
1 parent 8c650ca commit 4b06984
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 3 deletions.
2 changes: 1 addition & 1 deletion zeroize/derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ proc-macro = true
[dependencies]
proc-macro2 = "1"
quote = "1"
syn = {version = "2", features = ["full", "extra-traits"]}
syn = {version = "2", features = ["full", "extra-traits", "visit"]}

[package.metadata.docs.rs]
rustdoc-args = ["--document-private-items"]
50 changes: 48 additions & 2 deletions zeroize/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ use proc_macro2::{Ident, TokenStream};
use quote::{format_ident, quote};
use syn::{
parse::{Parse, ParseStream},
parse_quote,
punctuated::Punctuated,
token::Comma,
visit::Visit,
Attribute, Data, DeriveInput, Expr, ExprLit, Field, Fields, Lit, Meta, Result, Variant,
WherePredicate,
};
Expand All @@ -36,12 +38,19 @@ pub fn derive_zeroize(input: proc_macro::TokenStream) -> proc_macro::TokenStream
fn derive_zeroize_impl(input: DeriveInput) -> TokenStream {
let attributes = ZeroizeAttrs::parse(&input);

let mut generics = input.generics.clone();

let extra_bounds = match attributes.bound {
Some(bounds) => bounds.0,
None => Default::default(),
None => attributes
.auto_params
.iter()
.map(|type_param| -> WherePredicate {
parse_quote! {#type_param: Zeroize}
})
.collect(),
};

let mut generics = input.generics.clone();
generics.make_where_clause().predicates.extend(extra_bounds);

let ty_name = &input.ident;
Expand Down Expand Up @@ -117,6 +126,8 @@ struct ZeroizeAttrs {
drop: bool,
/// Custom bounds as defined by the user
bound: Option<Bounds>,
/// Type parameters in use by fields
auto_params: Vec<Ident>,
}

/// Parsing helper for custom bounds
Expand All @@ -128,10 +139,37 @@ impl Parse for Bounds {
}
}

struct BoundAccumulator<'a> {
generics: &'a syn::Generics,
params: Vec<Ident>,
}

impl<'ast> Visit<'ast> for BoundAccumulator<'ast> {
fn visit_path(&mut self, path: &'ast syn::Path) {
if path.segments.len() != 1 {
return;
}

if let Some(segment) = path.segments.first() {
for param in &self.generics.params {
if let syn::GenericParam::Type(type_param) = param {
if type_param.ident == segment.ident && !self.params.contains(&segment.ident) {
self.params.push(type_param.ident.clone());
}
}
}
}
}
}

impl ZeroizeAttrs {
/// Parse attributes from the incoming AST
fn parse(input: &DeriveInput) -> Self {
let mut result = Self::default();
let mut bound_accumulator = BoundAccumulator {
generics: &input.generics,
params: Vec::new(),
};

for attr in &input.attrs {
result.parse_attr(attr, None, None);
Expand All @@ -147,6 +185,9 @@ impl ZeroizeAttrs {
for attr in &field.attrs {
result.parse_attr(attr, Some(variant), Some(field));
}
if !attr_skip(&field.attrs) {
bound_accumulator.visit_type(&field.ty);
}
}
}
}
Expand All @@ -155,11 +196,16 @@ impl ZeroizeAttrs {
for attr in &field.attrs {
result.parse_attr(attr, None, Some(field));
}
if !attr_skip(&field.attrs) {
bound_accumulator.visit_type(&field.ty);
}
}
}
syn::Data::Union(union_) => panic!("Unsupported untagged union {:?}", union_),
}

result.auto_params = bound_accumulator.params;

result
}

Expand Down
9 changes: 9 additions & 0 deletions zeroize/tests/zeroize_derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,3 +351,12 @@ fn derive_zeroize_with_marker() {

trait Marker {}
}

#[test]
// Issue #878
fn derive_zeroize_used_param() {
#[derive(Zeroize)]
struct Z<T> {
used: T,
}
}

0 comments on commit 4b06984

Please sign in to comment.