Skip to content

Commit

Permalink
feat(ast_tools): allow attrs on types which do not derive the trait (#…
Browse files Browse the repository at this point in the history
…8874)

Generally, custom attributes relating to a trait e.g. `CloneIn` / `#[clone_in]` are only legal if the trait is derived on the type. Allow a `Derive` to choose to relax this restriction. This is necessary for #8875, #8876, and #8877.
  • Loading branch information
overlookmotel committed Feb 4, 2025
1 parent 7ddd219 commit f40f494
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 26 deletions.
21 changes: 15 additions & 6 deletions tasks/ast_tools/src/parse/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,25 @@ bitflags! {
/// Positions in which an attribute is legal.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AttrPositions: u8 {
/// Attribute on a struct
/// Attribute on a struct which derives the trait
const Struct = 1 << 0;
/// Attribute on an enum
const Enum = 1 << 1;
/// Attribute on a struct which doesn't derive the trait
const StructNotDerived = 1 << 1;
/// Attribute on an enum which derives the trait
const Enum = 1 << 2;
/// Attribute on an enum which doesn't derive the trait
const EnumNotDerived = 1 << 3;
/// Attribute on a struct field
const StructField = 1 << 2;
const StructField = 1 << 4;
/// Attribute on an enum variant
const EnumVariant = 1 << 3;
const EnumVariant = 1 << 5;
/// Part of `#[ast]` attr e.g. `visit` in `#[ast(visit)]`
const AstAttr = 1 << 4;
const AstAttr = 1 << 6;

/// Attribute on a struct which may or may not derive the trait
const StructMaybeDerived = Self::Struct.bits() | Self::StructNotDerived.bits();
/// Attribute on an enum which may or may not derive the trait
const EnumMaybeDerived = Self::Enum.bits() | Self::EnumNotDerived.bits();
}
}

Expand Down
57 changes: 37 additions & 20 deletions tasks/ast_tools/src/parse/parse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -503,42 +503,59 @@ impl<'c> Parser<'c> {
}

if let Some((processor, positions)) = self.codegen.attr_processor(&attr_name) {
// Check attribute is legal in this position
match type_def {
// Check attribute is legal in this position and this type has the relevant trait
// `#[generate_derive]`-ed on it (unless the derive stated legal positions as
// `AttrPositions::StructNotDerived` or `AttrPositions::EnumNotDerived`)
let location = match type_def {
TypeDef::Struct(struct_def) => {
let found_in_positions = match processor {
AttrProcessor::Derive(derive_id) => {
let is_derived = struct_def.generates_derive(derive_id);
if is_derived {
AttrPositions::Struct
} else {
AttrPositions::StructNotDerived
}
}
AttrProcessor::Generator(_) => AttrPositions::StructMaybeDerived,
};

check_attr_position(
positions,
AttrPositions::Struct,
found_in_positions,
struct_def.name(),
&attr_name,
"struct",
);

AttrLocation::Struct(struct_def)
}
TypeDef::Enum(enum_def) => {
let found_in_positions = match processor {
AttrProcessor::Derive(derive_id) => {
let is_derived = enum_def.generates_derive(derive_id);
if is_derived {
AttrPositions::Enum
} else {
AttrPositions::EnumNotDerived
}
}
AttrProcessor::Generator(_) => AttrPositions::EnumMaybeDerived,
};

check_attr_position(
positions,
AttrPositions::Enum,
found_in_positions,
enum_def.name(),
&attr_name,
"enum",
);
}
_ => unreachable!(),
}

// Check this type has the relevant trait `#[generate_derive]`-ed on it
check_attr_is_derived(
processor,
type_def.generated_derives(),
type_def.name(),
&attr_name,
);

let location = match type_def {
TypeDef::Struct(struct_def) => AttrLocation::Struct(struct_def),
TypeDef::Enum(enum_def) => AttrLocation::Enum(enum_def),
AttrLocation::Enum(enum_def)
}
_ => unreachable!(),
};

let result = process_attr(processor, &attr_name, location, &attr.meta);
assert!(
result.is_ok(),
Expand Down Expand Up @@ -770,13 +787,13 @@ fn check_attr_is_derived(
/// Check attribute is in a legal position.
fn check_attr_position(
expected_positions: AttrPositions,
found_in_position: AttrPositions,
found_in_positions: AttrPositions,
type_name: &str,
attr_name: &str,
position_debug_str: &str,
) {
assert!(
expected_positions.contains(found_in_position),
expected_positions.intersects(found_in_positions),
"`{type_name}` type has `#[{attr_name}]` attribute on a {position_debug_str}, \
but `#[{attr_name}]` is not legal in this position."
);
Expand Down

0 comments on commit f40f494

Please sign in to comment.