diff --git a/tasks/ast_tools/src/parse/attr.rs b/tasks/ast_tools/src/parse/attr.rs index 1a08800e2a6be..9eea9856e7d1e 100644 --- a/tasks/ast_tools/src/parse/attr.rs +++ b/tasks/ast_tools/src/parse/attr.rs @@ -30,16 +30,25 @@ bitflags! { /// Positions in which an attribute is legal. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct AttrPositions: u8 { - /// Attribute on a struct + /// Attribute on a struct which derives the trait const Struct = 1 << 0; - /// Attribute on an enum - const Enum = 1 << 1; + /// Attribute on a struct which doesn't derive the trait + const StructNotDerived = 1 << 1; + /// Attribute on an enum which derives the trait + const Enum = 1 << 2; + /// Attribute on an enum which doesn't derive the trait + const EnumNotDerived = 1 << 3; /// Attribute on a struct field - const StructField = 1 << 2; + const StructField = 1 << 4; /// Attribute on an enum variant - const EnumVariant = 1 << 3; + const EnumVariant = 1 << 5; /// Part of `#[ast]` attr e.g. `visit` in `#[ast(visit)]` - const AstAttr = 1 << 4; + const AstAttr = 1 << 6; + + /// Attribute on a struct which may or may not derive the trait + const StructMaybeDerived = Self::Struct.bits() | Self::StructNotDerived.bits(); + /// Attribute on an enum which may or may not derive the trait + const EnumMaybeDerived = Self::Enum.bits() | Self::EnumNotDerived.bits(); } } diff --git a/tasks/ast_tools/src/parse/parse.rs b/tasks/ast_tools/src/parse/parse.rs index 910089def1975..fed1dede8306d 100644 --- a/tasks/ast_tools/src/parse/parse.rs +++ b/tasks/ast_tools/src/parse/parse.rs @@ -503,42 +503,59 @@ impl<'c> Parser<'c> { } if let Some((processor, positions)) = self.codegen.attr_processor(&attr_name) { - // Check attribute is legal in this position - match type_def { + // Check attribute is legal in this position and this type has the relevant trait + // `#[generate_derive]`-ed on it (unless the derive stated legal positions as + // `AttrPositions::StructNotDerived` or `AttrPositions::EnumNotDerived`) + let location = match type_def { TypeDef::Struct(struct_def) => { + let found_in_positions = match processor { + AttrProcessor::Derive(derive_id) => { + let is_derived = struct_def.generates_derive(derive_id); + if is_derived { + AttrPositions::Struct + } else { + AttrPositions::StructNotDerived + } + } + AttrProcessor::Generator(_) => AttrPositions::StructMaybeDerived, + }; + check_attr_position( positions, - AttrPositions::Struct, + found_in_positions, struct_def.name(), &attr_name, "struct", ); + + AttrLocation::Struct(struct_def) } TypeDef::Enum(enum_def) => { + let found_in_positions = match processor { + AttrProcessor::Derive(derive_id) => { + let is_derived = enum_def.generates_derive(derive_id); + if is_derived { + AttrPositions::Enum + } else { + AttrPositions::EnumNotDerived + } + } + AttrProcessor::Generator(_) => AttrPositions::EnumMaybeDerived, + }; + check_attr_position( positions, - AttrPositions::Enum, + found_in_positions, enum_def.name(), &attr_name, "enum", ); - } - _ => unreachable!(), - } - - // Check this type has the relevant trait `#[generate_derive]`-ed on it - check_attr_is_derived( - processor, - type_def.generated_derives(), - type_def.name(), - &attr_name, - ); - let location = match type_def { - TypeDef::Struct(struct_def) => AttrLocation::Struct(struct_def), - TypeDef::Enum(enum_def) => AttrLocation::Enum(enum_def), + AttrLocation::Enum(enum_def) + } _ => unreachable!(), }; + let result = process_attr(processor, &attr_name, location, &attr.meta); assert!( result.is_ok(), @@ -770,13 +787,13 @@ fn check_attr_is_derived( /// Check attribute is in a legal position. fn check_attr_position( expected_positions: AttrPositions, - found_in_position: AttrPositions, + found_in_positions: AttrPositions, type_name: &str, attr_name: &str, position_debug_str: &str, ) { assert!( - expected_positions.contains(found_in_position), + expected_positions.intersects(found_in_positions), "`{type_name}` type has `#[{attr_name}]` attribute on a {position_debug_str}, \ but `#[{attr_name}]` is not legal in this position." );